Skip to content

Second attempt to improve code clarity #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions src/computesim.nim
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ type
const
MaxConcurrentWorkGroups {.intdefine.} = 2

proc subgroupProc[A, B, C](wg: WorkGroupContext; numActiveThreads: uint32; barrier: BarrierHandle,
proc subgroupProc[A, B, C](wg: WorkGroupContext; subgroupId, numActiveThreads: uint32; barrier: BarrierHandle,
compute: ThreadGenerator[A, B, C]; buffers: A; shared: ptr B; args: C) =
var threads = default(SubgroupThreads)
var threadContexts {.noinit.}: ThreadContexts
let startIdx = wg.gl_SubgroupID * SubgroupSize
let startIdx = subgroupId * SubgroupSize
# Initialize coordinates from startIdx
var x = startIdx mod wg.gl_WorkGroupSize.x
var y = (startIdx div wg.gl_WorkGroupSize.x) mod wg.gl_WorkGroupSize.y
Expand Down Expand Up @@ -154,39 +154,37 @@ proc subgroupProc[A, B, C](wg: WorkGroupContext; numActiveThreads: uint32; barri
for threadId in 0..<numActiveThreads:
threads[threadId] = compute(buffers, shared, args)
# Run threads in lockstep
runThreads(threads, wg, threadContexts, numActiveThreads, barrier)
runThreads(threads, wg, threadContexts, subgroupId, numActiveThreads, barrier)

proc workGroupProc[A, B, C](
workgroupID: UVec3,
wg: WorkGroupContext,
compute: ThreadGenerator[A, B, C],
ssbo: A, smem: ptr B, args: C) =
# Auxiliary proc for work group management
var wg = wg # Shadow for modification
wg.gl_WorkGroupID = workgroupID
let threadsInWorkgroup = wg.gl_WorkGroupSize.x * wg.gl_WorkGroupSize.y * wg.gl_WorkGroupSize.z
let numSubgroups = ceilDiv(threadsInWorkgroup, SubgroupSize)
wg.gl_NumSubgroups = numSubgroups
# Initialize local shared memory
var barrier = createBarrier(numSubgroups)
var barrier = createBarrier(wg.gl_NumSubgroups)
# Create master for managing threads
var master = createMaster(activeProducer = true)
# Calculate total threads in this workgroup
let threadsInWorkgroup = wg.gl_WorkGroupSize.x * wg.gl_WorkGroupSize.y * wg.gl_WorkGroupSize.z
var remainingThreads = threadsInWorkgroup
master.awaitAll:
for subgroupId in 0..<numSubgroups:
wg.gl_SubgroupID = subgroupId
for subgroupId in 0..<wg.gl_NumSubgroups:
# Calculate number of active threads in this subgroup
let threadsInSubgroup = min(remainingThreads, SubgroupSize)
master.spawn subgroupProc(wg, threadsInSubgroup, barrier.getHandle(), compute, ssbo, smem, args)
master.spawn subgroupProc(wg, subgroupId, threadsInSubgroup, barrier.getHandle(), compute, ssbo, smem, args)
dec remainingThreads, threadsInSubgroup

proc runCompute[A, B, C](
numWorkGroups, workGroupSize: UVec3,
compute: ThreadGenerator[A, B, C],
ssbo: A, smem: B, args: C) =
let wg = WorkGroupContext(
let threadsPerWorkgroup = workGroupSize.x * workGroupSize.y * workGroupSize.z
let numSubgroups = ceilDiv(threadsPerWorkgroup, SubgroupSize)
var wg = WorkGroupContext(
gl_NumWorkGroups: numWorkGroups,
gl_WorkGroupSize: workGroupSize
gl_WorkGroupSize: workGroupSize,
gl_NumSubgroups: numSubgroups
)
let totalGroups = wg.gl_NumWorkGroups.x * wg.gl_NumWorkGroups.y * wg.gl_NumWorkGroups.z
let numBatches = ceilDiv(totalGroups, MaxConcurrentWorkGroups)
Expand All @@ -203,7 +201,8 @@ proc runCompute[A, B, C](
master.awaitAll:
var groupIdx: uint32 = 0
while currentGroup < endGroup:
master.spawn workGroupProc(uvec3(wgX, wgY, wgZ), wg, compute, ssbo, addr smemArr[groupIdx], args)
wg.gl_WorkGroupID = uvec3(wgX, wgY, wgZ)
master.spawn workGroupProc(wg, compute, ssbo, addr smemArr[groupIdx], args)
# Increment coordinates, wrapping when needed
inc wgX
if wgX >= wg.gl_NumWorkGroups.x:
Expand Down
2 changes: 1 addition & 1 deletion src/computesim/core.nim
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ proc wait*(m: BarrierHandle) {.inline.} =

type
ThreadClosure* = iterator (iterArg: SubgroupResult, wg: WorkGroupContext,
thread: ThreadContext, threadId: uint32): SubgroupCommand
thread: ThreadContext, subgroupId, threadId: uint32): SubgroupCommand
SubgroupResults* = array[SubgroupSize, SubgroupResult]
SubgroupCommands* = array[SubgroupSize, SubgroupCommand]
SubgroupThreadIDs* = array[SubgroupSize + 1, uint32]
Expand Down
9 changes: 5 additions & 4 deletions src/computesim/lockstep.nim
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template shouldShowDebugOutput(debug: untyped) =
workGroup.gl_WorkgroupID.x == debugWorkgroupX and
workGroup.gl_WorkgroupID.y == debugWorkgroupY and
workGroup.gl_WorkgroupID.z == debugWorkgroupZ and
workGroup.gl_SubgroupID == debugSubgroupID
subgroupID == debugSubgroupID
else:
false

Expand All @@ -33,7 +33,7 @@ type
running, halted, atSubBarrier, atBarrier, finished

proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
threadContexts: ThreadContexts; numActiveThreads: uint32; b: BarrierHandle) =
threadContexts: ThreadContexts; subgroupId, numActiveThreads: uint32; b: BarrierHandle) =
var
anyThreadsActive = true
allThreadsHalted = false
Expand Down Expand Up @@ -64,7 +64,8 @@ proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
threadStates[threadId] == running or canReconverge or canPassBarrier:
madeProgress = true
{.cast(gcsafe).}:
commands[threadId] = threads[threadId](results[threadId], workGroup, threadContexts[threadId], threadId)
commands[threadId] = threads[threadId](results[threadId], workGroup,
threadContexts[threadId], subgroupId, threadId)
if finished(threads[threadId]):
threadStates[threadId] = finished
elif commands[threadId].kind == barrier:
Expand Down Expand Up @@ -137,7 +138,7 @@ proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
let firstThreadId = threadGroups[groupIdx][1]
let opKind = commands[firstThreadId].kind
let opId = commands[firstThreadId].id
case opKind:
case opKind
of subgroupBroadcast:
execSubgroupOp(execBroadcast)
of subgroupBroadcastFirst:
Expand Down
7 changes: 4 additions & 3 deletions src/computesim/transform.nim
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ proc generateWorkGroupTemplates(wgSym: NimNode): NimNode =
"gl_WorkGroupID",
"gl_WorkGroupSize",
"gl_NumWorkGroups",
"gl_NumSubgroups",
"gl_SubgroupID"
"gl_NumSubgroups"
])

proc generateThreadTemplates(threadSym: NimNode): NimNode =
Expand Down Expand Up @@ -272,6 +271,7 @@ macro computeShader*(prc: untyped): untyped =
# Create symbols for both contexts
let wgSym = genSym(nskParam, "wg")
let threadSym = genSym(nskParam, "thread")
let sidSym = genSym(nskParam, "subgroupId")
let tidSym = genSym(nskParam, "threadId")
# Generate template declarations for both contexts
let wgTemplates = generateWorkGroupTemplates(wgSym)
Expand All @@ -280,7 +280,8 @@ macro computeShader*(prc: untyped): untyped =
result = quote do:
proc `procName`(): ThreadClosure =
iterator (`iterArg`: SubgroupResult, `wgSym`: WorkGroupContext,
`threadSym`: ThreadContext, `tidSym`: uint32): SubgroupCommand =
`threadSym`: ThreadContext, `sidSym`, `tidSym`: uint32): SubgroupCommand =
template gl_SubgroupID(): uint32 {.used.} = `sidSym`
template gl_SubgroupInvocationID(): uint32 {.used.} = `tidSym`
`threadTemplates`
`wgTemplates`
Expand Down