Skip to content

Commit ca3dd49

Browse files
authored
Batch duplicated fn abi (#2375)
* Batch duplicated fn abi * fix
1 parent 658ffd5 commit ca3dd49

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

src/compiler.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -5074,19 +5074,23 @@ end
50745074
end
50755075
elseif !(FA <: Const)
50765076
argexpr = :(fn.dval)
5077-
if isboxed
5078-
push!(types, Any)
5079-
elseif width == 1
5077+
F_ABI = F
5078+
if width == 1
50805079
if (FA <: MixedDuplicated)
5081-
push!(types, Base.RefValue{F})
5080+
push!(types, Any)
50825081
else
5083-
push!(types, F)
5082+
push!(types, F_ABI)
50845083
end
50855084
else
5086-
if (FA <: BatchMixedDuplicated)
5087-
push!(types, NTuple{width,Base.RefValue{F}})
5085+
if F_ABI <: BatchMixedDuplicated
5086+
F_ABI = Base.RefValue{F_ABI}
5087+
end
5088+
F_ABI = NTuple{width, F_ABI}
5089+
isboxedvec = GPUCompiler.deserves_argbox(F_ABI)
5090+
if isboxedvec
5091+
push!(types, Any)
50885092
else
5089-
push!(types, NTuple{width,F})
5093+
push!(types, F_ABI)
50905094
end
50915095
end
50925096
push!(ccexprs, argexpr)

test/abi.jl

+26
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,29 @@ end
597597
end
598598

599599
include("usermixed.jl")
600+
601+
mutable struct EmptyStruct end
602+
603+
function (uf::EmptyStruct)(du, u, v)
604+
@inbounds du[1] = @inbounds u[1]
605+
return nothing
606+
end
607+
608+
@testset "Batch Duplicated Fn" begin
609+
610+
a = EmptyStruct()
611+
u0 = [1.0 0.5; 0.5 1.0]
612+
du = similar(u0)
613+
614+
batched_result = ([0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0])
615+
batched_seed = ([1.0 0.0; 0.0 0.0], [2.0 0.0; 1.0 0.0])
616+
617+
f!_and_df! = BatchDuplicated(a, ntuple(_ -> Enzyme.make_zero(a), Val(length(batched_result))))
618+
x_and_tx = BatchDuplicated(u0, batched_seed)
619+
y_and_ty = BatchDuplicated(du, batched_result)
620+
621+
autodiff(Forward, f!_and_df!, Const, y_and_ty, x_and_tx, f!_and_df!)
622+
623+
@test batched_result[1][1] 1.0
624+
@test batched_result[2][1] 2.0
625+
end

0 commit comments

Comments
 (0)