Skip to content

Commit 90b1e59

Browse files
committed
Fix #18
1 parent 4b19ebc commit 90b1e59

File tree

4 files changed

+65
-12
lines changed

4 files changed

+65
-12
lines changed

src/computesim/core.nim

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# (c) 2024 Antonis Geralis
22
import threading/barrier, vectors
3+
from std/math import isPowerOfTwo
34

45
const
56
SubgroupSize* {.intdefine.} = 8
@@ -23,6 +24,8 @@ const
2324
)
2425
masks
2526

27+
static: assert isPowerOfTwo(SubgroupSize), "SubgroupSize must be a power of two"
28+
2629
type
2730
ValueType* = enum
2831
Bool, Int, Uint, Float, Double # Scalar types

src/computesim/subgroupops.nim

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ defineSubgroupOp(execShuffle):
216216
var shuffledVals {.noinit.}: array[SubgroupSize, RawValue]
217217
# First gather all shuffled values into array
218218
for threadId in threadsInGroup(group):
219-
let srcThreadId = commands[threadId].dirty
219+
let srcThreadId = commands[threadId].dirty and SubgroupSize - 1
220220
# Check if source thread is valid within the group
221221
var found = false
222222
for validId in threadsInGroup(group):
@@ -225,7 +225,7 @@ defineSubgroupOp(execShuffle):
225225
break
226226
# If source thread is valid, take its value
227227
# Otherwise use this thread's own value
228-
shuffledVals[threadId] = commands[if found: srcThreadId else: threadId].val
228+
shuffledVals[threadId] = if found: commands[srcThreadId].val else: default(RawValue)
229229

230230
let valueType = commands[firstThreadId].t
231231
# Then construct results
@@ -245,7 +245,7 @@ defineSubgroupOp(execShuffleXor):
245245
var shuffledVals {.noinit.}: array[SubgroupSize, RawValue]
246246
# First gather all shuffled values into array
247247
for threadId in threadsInGroup(group):
248-
let srcThreadId = threadId xor commands[threadId].dirty
248+
let srcThreadId = threadId xor commands[threadId].dirty and SubgroupSize - 1
249249
# Check if source thread is valid within the group
250250
var found = false
251251
for validId in threadsInGroup(group):
@@ -254,7 +254,7 @@ defineSubgroupOp(execShuffleXor):
254254
break
255255
# If source thread is valid, take its value
256256
# Otherwise use this thread's own value
257-
shuffledVals[threadId] = commands[if found: srcThreadId else: threadId].val
257+
shuffledVals[threadId] = if found: commands[srcThreadId].val else: default(RawValue)
258258

259259
let valueType = commands[firstThreadId].t
260260
# Then construct results
@@ -273,13 +273,13 @@ defineSubgroupOp(execShuffleXor):
273273
defineSubgroupOp(execShuffleDown):
274274
var shuffledVals {.noinit.}: array[SubgroupSize, RawValue]
275275
for threadId in threadsInGroup(group):
276-
let srcThreadId = threadId + commands[threadId].dirty
276+
let srcThreadId = (threadId + commands[threadId].dirty) and SubgroupSize - 1
277277
var found = false
278278
for validId in threadsInGroup(group):
279279
if validId == srcThreadId:
280280
found = true
281281
break
282-
shuffledVals[threadId] = commands[if found: srcThreadId else: threadId].val
282+
shuffledVals[threadId] = if found: commands[srcThreadId].val else: default(RawValue)
283283

284284
let valueType = commands[firstThreadId].t
285285
for threadId in threadsInGroup(group):
@@ -298,16 +298,13 @@ defineSubgroupOp(execShuffleUp):
298298
var shuffledVals {.noinit.}: array[SubgroupSize, RawValue]
299299
for threadId in threadsInGroup(group):
300300
# Convert to signed for safe subtraction
301-
let srcThreadId = if threadId >= commands[threadId].dirty:
302-
threadId - commands[threadId].dirty
303-
else:
304-
threadId
301+
let srcThreadId = (threadId - commands[threadId].dirty) and SubgroupSize - 1
305302
var found = false
306303
for validId in threadsInGroup(group):
307304
if validId == srcThreadId:
308305
found = true
309306
break
310-
shuffledVals[threadId] = commands[if found: srcThreadId else: threadId].val
307+
shuffledVals[threadId] = if found: commands[srcThreadId].val else: default(RawValue)
311308

312309
let valueType = commands[firstThreadId].t
313310
for threadId in threadsInGroup(group):

tests/config.nims

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
--define:"ThreadPoolSize=10"
88
# --define:"SubgroupSize=32"
99
--define:debugSubgroup
10-
switch("define", "debugSubgroupID:1")
10+
# switch("define", "debugSubgroupID:1")
1111

1212
when not defined(windows):
1313
--debugger:"native"

tests/tshuffle.nim

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import std/math, computesim
2+
3+
proc calculate(output: ptr seq[int32]; numElements: uint32) {.computeShader.} =
4+
let tid = gl_GlobalInvocationID.x
5+
var value: int32 = 0
6+
# Loop from 1 to 2
7+
for i in 1'i32..2:
8+
if (tid.int32 + i) mod 3 == 0: # skip iteration
9+
continue
10+
value = subgroupShuffle(tid.int32 + 1, tid.uint32 + i.uint32)
11+
# Store the result in the SSBO
12+
output[tid] = value
13+
14+
const
15+
NumElements = 64'u32
16+
WorkGroupSize = 32'u32 # Force underutilization of hardware subgroups
17+
18+
proc main() =
19+
# Set up compute dimensions
20+
let numWorkGroups = uvec3(ceilDiv(NumElements, WorkGroupSize), 1, 1)
21+
let workGroupSize = uvec3(WorkGroupSize, 1, 1)
22+
23+
# Initialize buffer
24+
let output = newSeq[int32](NumElements)
25+
26+
# Run reduction on CPU
27+
runComputeOnCpu(
28+
numWorkGroups = numWorkGroups,
29+
workGroupSize = workGroupSize,
30+
compute = calculate,
31+
ssbo = addr output,
32+
args = NumElements
33+
)
34+
35+
assert output == @[3'i32, 0, 0, 6, 0, 0, 1, 1, 0, 12, 0, 0, 15, 0, 9, 10,
36+
0, 0, 21, 0, 0, 24, 0, 18, 27, 0, 0, 30, 0, 0, 25, 25,
37+
0, 36, 0, 0, 39, 0, 33, 34, 0, 0, 45, 0, 0, 48, 0, 42,
38+
51, 0, 0, 54, 0, 0, 49, 49, 0, 60, 0, 0, 63, 0, 57, 58]
39+
40+
# Debug Output:
41+
# - SubgroupID 0
42+
# [Shuffle #1] inputs {t0: 1, t1: 2, t3: 4, t4: 5, t6: 7, t7: 8} | shuffled: [2, 0, 5, 0, 8, 1]
43+
# [Shuffle #1] inputs {t0: 1, t2: 3, t3: 4, t5: 6, t6: 7} | shuffled: [3, 0, 6, 0, 1]
44+
# - SubgroupID 1
45+
# [Shuffle #1] inputs {t1: 10, t2: 11, t4: 13, t5: 14, t7: 16} | shuffled: [11, 0, 14, 0, 0]
46+
# [Shuffle #1] inputs {t0: 9, t1: 10, t3: 12, t4: 13, t6: 15, t7: 16} | shuffled: [0, 12, 0, 15, 9, 10]
47+
# - SubgroupID 2
48+
# [Shuffle #1] inputs {t0: 17, t2: 19, t3: 20, t5: 22, t6: 23} | shuffled: [0, 20, 0, 23, 0]
49+
# [Shuffle #1] inputs {t1: 18, t2: 19, t4: 21, t5: 22, t7: 24} | shuffled: [0, 21, 0, 24, 18]
50+
# Output Buffer:
51+
# [3, 0, 0, 6, 0, 0, 1, 1, 0, 12, 0, 0, 15, 0, 9, 10, 0, 0, 21, 0, 0, 24, 0, 18, 27, 0, 0, ...
52+
# - Matches GLSL shader output.
53+
main()

0 commit comments

Comments
 (0)