Skip to content

Commit 3719d51

Browse files
authored
Add shader test for nan/inf detection (#473)
1 parent eced145 commit 3719d51

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

tests/test_not_finite.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
"""
2+
A collection of tests related to non-finite values in shaders, like nan and inf.
3+
4+
See:
5+
* https://en.wikipedia.org/wiki/NaN
6+
* https://github.com/gpuweb/gpuweb/pull/2311#issuecomment-1973533433
7+
8+
"""
9+
10+
import ctypes
11+
12+
import numpy as np
13+
14+
from wgpu.utils.compute import compute_with_buffers
15+
from pytest import skip
16+
from testutils import can_use_wgpu_lib
17+
18+
19+
if not can_use_wgpu_lib:
20+
skip("Skipping tests that need the wgpu lib", allow_module_level=True)
21+
22+
23+
def test_finite_using_nequal():
24+
# Just to demonstrate that this does not work.
25+
# The compiler filters optimizes away the check.
26+
27+
shader = """
28+
@group(0)
29+
@binding(0)
30+
var<storage,read> values: array<f32>;
31+
32+
fn is_nan(v:f32) -> bool {
33+
return v != v;
34+
}
35+
36+
fn is_inf(v:f32) -> bool {
37+
return v != 0.0 && v * 2.0 == v;
38+
}
39+
40+
fn is_finite(v:f32) -> bool {
41+
return v == v && v * 2.0 != v;
42+
}
43+
44+
fn to_real(v:f32) -> f32 {
45+
return select(0.0, v, is_finite(v));
46+
}
47+
48+
"""
49+
50+
detect_finites("nequal", shader, False, False)
51+
52+
53+
def test_finite_using_min_max():
54+
# This obfuscates the check for equality enough for the compiler
55+
# not to optimize it away.
56+
#
57+
# However, if fastmath is enabled, depending on the hardare/compiler,
58+
# the loaded value may not actually be a nan/inf anymore.
59+
60+
shader = """
61+
@group(0)
62+
@binding(0)
63+
var<storage,read> values: array<f32>;
64+
65+
fn is_nan(v:f32) -> bool {
66+
return min(v, 1.0) == 1.0 && max(v, -1.0) == -1.0;
67+
}
68+
69+
fn is_inf(v:f32) -> bool {
70+
return v != 0.0 && v * 2.0 == v;
71+
}
72+
73+
fn is_finite(v:f32) -> bool {
74+
return !is_nan(v) && !is_inf(v);
75+
}
76+
77+
fn to_real(v:f32) -> f32 {
78+
return select(0.0, v, is_finite(v));
79+
}
80+
81+
"""
82+
83+
detect_finites("min-max", shader, True, True)
84+
85+
86+
def test_finite_using_uint():
87+
# This is the most reliable approach.
88+
89+
shader = """
90+
@group(0)
91+
@binding(0)
92+
var<storage,read> values: array<u32>;
93+
94+
fn is_nan(v:u32) -> bool {
95+
let mask = 0x7F800000u;
96+
let v_is_pos_inf = v == 0x7F800000u;
97+
let v_is_neg_inf = v == 0xFF800000u;
98+
let v_is_finite = (v & mask) != mask;
99+
return !v_is_finite & !(v_is_pos_inf | v_is_neg_inf);
100+
}
101+
102+
fn is_inf(v:u32) -> bool {
103+
let v_is_pos_inf = v == 0x7F800000u;
104+
let v_is_neg_inf = v == 0xFF800000u;
105+
return v_is_pos_inf | v_is_neg_inf;
106+
}
107+
108+
fn is_finite(v:u32) -> bool {
109+
return (v & 0x7F800000u) != 0x7F800000u;
110+
}
111+
112+
fn to_real(v:u32) -> f32 {
113+
return select(0.0, bitcast<f32>(v), is_finite(v));
114+
}
115+
"""
116+
117+
detect_finites("uint", shader, True, True)
118+
119+
120+
def detect_finites(title, shader, expect_detection_nan, expect_detection_inf):
121+
122+
base_shader = """
123+
124+
@group(0)
125+
@binding(1)
126+
var<storage,read_write> result_nan: array<i32>;
127+
128+
@group(0)
129+
@binding(2)
130+
var<storage,read_write> result_inf: array<i32>;
131+
132+
@group(0)
133+
@binding(3)
134+
var<storage,read_write> result_finite: array<i32>;
135+
136+
@group(0)
137+
@binding(4)
138+
var<storage,read_write> result_real: array<f32>;
139+
140+
@compute
141+
@workgroup_size(1)
142+
fn main(@builtin(global_invocation_id) index: vec3<u32>) {
143+
let i = i32(index.x);
144+
let value = values[i];
145+
146+
result_nan[i] = i32(is_nan(value));
147+
result_inf[i] = i32(is_inf(value));
148+
result_finite[i] = i32(is_finite(value));
149+
result_real[i] = to_real(value);
150+
151+
}
152+
153+
"""
154+
155+
# Create data in blocks of 10: zeros, nans, infs, random reals
156+
parts = [
157+
[0.0] * 10,
158+
[
159+
float("nan"),
160+
np.nan,
161+
np.nan,
162+
np.nan,
163+
np.nan,
164+
np.nan,
165+
np.nan,
166+
np.nan,
167+
np.nan,
168+
np.nan,
169+
],
170+
[
171+
float("-inf"),
172+
float("inf"),
173+
-np.inf,
174+
np.inf,
175+
np.inf,
176+
np.inf,
177+
np.inf,
178+
np.inf,
179+
np.inf,
180+
np.inf,
181+
],
182+
np.random.uniform(-1e9, 1e9, 10),
183+
]
184+
values = np.concatenate(parts, dtype=np.float32)
185+
186+
# Check length
187+
assert values.shape == (40,)
188+
n = len(values)
189+
190+
# Create reference bool arrays
191+
is_nan_ref = np.zeros((n,), bool)
192+
is_nan_ref[10:20] = True
193+
is_inf_ref = np.zeros((n,), bool)
194+
is_inf_ref[20:30] = True
195+
is_finite_ref = np.ones((n,), bool)
196+
is_finite_ref[10:30] = False
197+
198+
# Get reference real array
199+
real_ref = values.copy()
200+
real_ref[~is_finite_ref] = 0
201+
202+
# Compute!
203+
out = compute_with_buffers(
204+
{0: (ctypes.c_float * n)(*values)},
205+
{
206+
1: n * ctypes.c_int32,
207+
2: n * ctypes.c_int32,
208+
3: n * ctypes.c_int32,
209+
4: n * ctypes.c_float,
210+
},
211+
shader + base_shader,
212+
)
213+
is_nan = out[1]
214+
is_inf = out[2]
215+
is_finite = out[3]
216+
real = out[4]
217+
218+
# Check that numpy detects ok
219+
assert np.all(np.isnan(values) == is_nan_ref)
220+
assert np.all(np.isinf(values) == is_inf_ref)
221+
assert np.all(np.isfinite(values) == is_finite_ref)
222+
223+
# Check that our shader does too
224+
detected_nan = bool(np.all(is_nan == is_nan_ref))
225+
detected_inf = bool(np.all(is_inf == is_inf_ref))
226+
detected_finite = bool(np.all(is_finite == is_finite_ref))
227+
good_reals = bool(np.all(real == real_ref))
228+
229+
# Print, for when run as a script
230+
checkmark = lambda x: "x✓"[x] # noqa
231+
print(
232+
f"{title:>10}: {checkmark(detected_nan)} is_nan {checkmark(detected_inf)} is_inf {checkmark(detected_finite)} is_finite {checkmark(good_reals)} good_reals"
233+
)
234+
235+
# Test
236+
if expect_detection_nan:
237+
assert detected_nan
238+
if expect_detection_inf:
239+
assert detected_inf
240+
if expect_detection_nan and expect_detection_inf:
241+
assert detected_finite
242+
assert good_reals
243+
244+
245+
if __name__ == "__main__":
246+
247+
test_finite_using_nequal()
248+
test_finite_using_min_max()
249+
test_finite_using_uint()

0 commit comments

Comments
 (0)