diff --git a/docs/src/internal_api.md b/docs/src/internal_api.md index c97b6e034b..2f957203f0 100644 --- a/docs/src/internal_api.md +++ b/docs/src/internal_api.md @@ -7,7 +7,7 @@ without deprecation. ```@autodocs -Modules = [Enzyme.Compiler] +Modules = [Enzyme.Compiler, Enzyme.Compiler.RecursiveMaps] Order = [:module, :type, :constant, :macro, :function] Filter = t -> !(t === Enzyme.Compiler.CheckNan) ``` diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index ef955ebd9b..8c15f21cb7 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,50 +32,11 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Base.zero(prev)::FT -end - -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - new = Base.zero(prev)::FT - seen[prev] = new - return new -end - -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT, seen -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - Enzyme.EnzymeCore.make_zero!(prev, nothing) - return nothing +# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct, +# but in case their dedicated `zero` and `fill!` methods are more efficient than +# `make_zero(!)`s recursion, we opt into treating them as leaves. +@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S, T}}) where {S, T} + return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T) end end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f949664b6a..2dbfdecd2d 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,28 +506,131 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ - make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T - -Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies -what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. + make_zero(prev::T; copy_if_inactive = Val(false), runtime_inactive = Val(false))::T + make_zero(prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}])::T + make_zero( + ::Type{T}, seen::IdDict, prev::T; + copy_if_inactive = Val(false), runtime_inactive = Val(false), + )::T + make_zero( + ::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}] + )::T + +Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed. + +The argument `copy_if_inactive` specifies what to do if the type `T` or any +of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s +instance (if `Val(false)`, the default) or make a copy (if `Val(true)`). + +The argument `runtime_inactive` specifies whether this function should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(false)`, only the +methods of `EnzymeRules.inactive_type` that were defined at the time of precompiling +`Enzyme` will be taken into account when determining a type's activity. If `Val(true)`, new +or changed methods of `EnzymeRules.inactive_type` will be taken into account as per usual +Julia semantics. + +`copy_if_inactive` and `runtime_inactive` may be provided as either positional or keywords +arguments, but not a combination. + +Extending this method for custom types is rarely needed. If you implement a new type, such +as a GPU array type, for which `make_zero` should directly invoke `zero` for scalar eltypes, +it is sufficient to implement `Base.zero` and make sure your type subtypes `DenseArray`. (If +subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref) +instead.) """ function make_zero end """ - make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing + make_zero!(val::T, [seen::IdDict]; runtime_inactive = Val(false))::Nothing + make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive})::Nothing + +Recursively set a variable's differentiable values to zero. Only applicable for types `T` +that are mutable or hold all differentiable values in mutable storage (e.g., +`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over +parts of `val` that are guaranteed to be inactive. + +The argument `runtime_inactive` specifies whether this function should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(false)`, only the +methods of `EnzymeRules.inactive_type` that were defined at the time of precompiling +`Enzyme` will be taken into account when determining a type's activity. If `Val(true)`, new +or changed methods of `EnzymeRules.inactive_type` will be taken into account as per usual +Julia semantics. + +`runtime_inactive` may be given as either a positional or a keyword argument. -Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +Extending this method for custom types is rarely needed. If you implement a new mutable +type, such as a GPU array type, for which `make_zero!` should directly invoke +`fill!(x, false)` for scalar eltypes, it is sufficient to implement `Base.zero`, +`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is +not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.) """ function make_zero! end """ - make_zero(prev::T) + isvectortype(::Type{T})::Bool -Helper function to recursively make zero. -""" -@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} - make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) +Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref) +and [`make_zero!`](@ref) recurse through an object. + +By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when +`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`. + +A new vector type, such as a GPU array type, should normally subtype `DenseArray` and +inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be +extended directly as follows: + +```julia +@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray} + U = eltype(T) + return isbitstype(U) && EnzymeCore.isscalartype(U) end +``` + +In either case, the type should implement `Base.zero` and, if mutable, `Base.fill!`. + +Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at +which vector space operations like addition and scalar multiplication are supported, the +prototypical case being `Array`. Regular Julia structs with vector space-like semantics +should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act +directly on their backing arrays, just like how Enzyme treats them when differentiating. For +example, structured matrix wrappers and sparse array types that are backed by `Array` should +not extend `isvectortype`. + +See also [`isscalartype`](@ref). +""" +function isvectortype end + +""" + isscalartype(::Type{T})::Bool + +Trait defining a subset of [`isvectortype`](@ref) types that should not be considered +composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero +values of the type in-place. For example, `BigFloat` is a mutable type but does not support +in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that +`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] + +By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete +types where `T <: AbstractFloat`. + +A hypothetical new real number type with Enzyme support should usually subtype +`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate, +the function can be extended as follows: + +```julia +@inline EnzymeCore.isscalartype(::Type{NewReal}) = true +@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true +``` + +In either case, the type should implement `Base.zero`. + +See also [`isvectortype`](@ref). + +[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is +mentioned here only to demonstrate that it would be inappropriate to use traits like +`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing, showing the +need for a dedicated `isscalartype` trait. +""" +function isscalartype end function tape_type end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b80075df64..acfefa4b73 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) # compute the correct complex derivative in reverse mode by propagating the conjugate return values # then subtracting twice the imaginary component to get the correct result - for (k, v) in seen - Compiler.recursive_accumulate(k, v, refn_seed) - end - for (k, v) in seen2 - Compiler.recursive_accumulate(k, v, imfn_seed) - end + Compiler.accumulate_seen!(refn_seed, seen) + Compiler.accumulate_seen!(imfn_seed, seen2) fused = fuse_complex_results(results, args...) diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 61d2f35ab7..2f2b8e6ec8 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -427,6 +427,11 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T} + rt = Enzyme.Compiler.active_reg_inner(T, (), world) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + """ Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) diff --git a/src/compiler.jl b/src/compiler.jl index df233cda67..e1dce04c69 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}( include("absint.jl") include("llvm/transforms.jl") include("llvm/passes.jl") -include("typeutils/make_zero.jl") +include("typeutils/recursive_maps.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 04aca1a66a..1fdd874378 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -253,47 +253,6 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:Array} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - seen[into] = (into, from) - for i in eachindex(from) - tup = accumulate_into(into[i], seen, from[i]) - @inbounds into[i] = tup[1] - @inbounds from[i] = tup[2] - end - end - return seen[into] -end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:AbstractFloat} - if !haskey(seen, into) - seen[into] = (into + from, RT(0)) - end - return seen[into] -end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - throw(AssertionError("Unknown type to accumulate into: $RT")) - end - return seen[into] -end - function EnzymeRules.reverse( config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, @@ -302,15 +261,8 @@ function EnzymeRules.reverse( x::Annotation{Ty}, ) where {RT,Ty} if EnzymeRules.needs_shadow(config) - if EnzymeRules.width(config) == 1 - accumulate_into(x.dval, IdDict(), shadow) - else - for i = 1:EnzymeRules.width(config) - accumulate_into(x.dval[i], IdDict(), shadow[i]) - end - end + Compiler.accumulate_into!(x.dval, shadow) end - return (nothing,) end diff --git a/src/typeutils/make_zero.jl b/src/typeutils/make_zero.jl deleted file mode 100644 index 5c7b49a749..0000000000 --- a/src/typeutils/make_zero.jl +++ /dev/null @@ -1,587 +0,0 @@ -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, FT}, -)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, Complex{FT}}, -)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -end - - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, FT}}, - seen::IdDict, - prev::GenericMemory{kind, FT}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, Complex{FT}}}, - seen::IdDict, - prev::GenericMemory{kind, Complex{FT}}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return Complex{RT}(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:GenericMemory} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - prevtup = RT(prev) - TT = Core.Typeof(prevtup) # RT can be abstract - return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - if Base.isconst(RT, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) - else - setfield!(y, i, xi) - end - end - end - return y - end - if nf == 0 - return prev - end - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - return zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - return zero(Complex{T}) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - p = prev[i] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - if guaranteed_const_nongen(NamedTuple{a,b}, nothing) - return prev # unreachable from make_zero! - end - NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i - Base.@_inline_meta - p = prev[a[i]] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - @assert !ismutabletype(T) - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if guaranteed_const_nongen(ST, nothing) - xi - elseif !ismutabletype(ST) - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(T) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(Complex{T}) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, -)::Nothing where {T<:AbstractFloat, kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end -end - - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev[] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev[] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev.contents - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev.contents = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - if nf == 0 - return nothing - end - push!(seen, prev) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - activitystate = active_reg_inner(SBT, (), nothing) - if activitystate == AnyState # guaranteed_const - continue - elseif ismutabletype(T) && !ismutabletype(SBT) - yi = make_zero_immutable!(xi, seen) - if Base.isconst(T, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) - else - setfield!(prev, i, yi) - end - elseif activitystate == DupState - EnzymeCore.make_zero!(xi, seen) - else - msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" - throw(ArgumentError(msg)) - end - end - end - return nothing -end - -@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 039f7d3d0c..c8588df548 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -1,86 +1,117 @@ -# Recursively return x + f(y), where y is active, otherwise x +using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map!, recursive_map_inner -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T,F,F2} - if forcelhs(T) - return x +""" + recursive_add(x::T, y::T, f = identity, forcelhs = guaranteed_const) + +Recursively construct `z::T` such that `zi = xi + f(yi)` where `zi`, `xi`, and `yi` are +corresponding values from `z`, `x`, and `y`. In other words, this is a recursive +generalization of `x .+ f.(y)`. + +The function `f` must return values of the same type as its argument. + +The optional argument `forcelhs` takes a function such that if `forcelhs(S) == true`, values +`zi::S` will be set to `zi = xi`. The default returns true for non-differentiable (inactive) +types, such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies +to non-differentiable values. If a custom callable is passed, it is combined with the +default, as `recursive_add` is not generally capable of traversing inactive objects. +""" +function recursive_add( + x::T, y::T, f::F = identity, forcelhs::L = guaranteed_const + ) where {T, F, L} + function addf(xi::S, yi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (xi + f(yi))::S end - splatnew(T, ntuple(Val(fieldcount(T))) do i - Base.@_inline_meta - prev = getfield(x, i) - next = getfield(y, i) - recursive_add(prev, next, f, forcelhs) - end) + config = RecursiveMaps.InactiveConfig(forcelhs) + return recursive_map(addf, (x, y), config)::T end -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:AbstractFloat,F,F2} - if forcelhs(T) - return x - end - return x + f(y) +""" + accumulate_seen!(f, seen::IdDict; runtime_inactive = Val(false)) + accumulate_seen!(f, seen::IdDict, ::Val{runtime_inactive}) + accumulate_seen!( + f, seen::IdDict, config::RecursiveMaps.InactiveConfig = RecursiveMaps.InactiveConfig() + ) + +Recursively accumulate from values into keys, generalizing `key .+= f.(value)` to arbitrary +types. This accumulation is applied to each key-value pair in `seen::IdDict` where each key +is of a mutable or non-isbits vector type and the corresponding value is of the same type +and structure. Typically `seen` is populated by `make_zero`/`recursive_map`, mapping parts +of its input to the corresponding parts of the returned value. + +The recursion stops at objects of types that are themselves cached by +`make_zero`/`recursive_map`, as these objects should have their own entries in `seen`. The +recursion also stops at inactive objects that would be skipped by +`make_zero`/`recursive_map`. + +If the optional argument `::Val{runtime_inactive}` was passed to `make_zero`, or +`config::RecursiveMaps.InactiveConfig` was passed to `recursive_map`, the same value should +be passed to `accumulate_seen` to ensure consistency. +""" +function accumulate_seen! end + +function accumulate_seen!(f::F, seen::IdDict, args::Vararg{Any, M}; kws...) where {F, M} + accumulate_seen!(f, seen, RecursiveMaps.make_zero_config!(args...; kws...)) + return nothing end -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:Complex,F,F2} - if forcelhs(T) - return x +function accumulate_seen!(f::F, seen::IdDict, config::RecursiveMaps.InactiveConfig) where {F} + cachedconfig = RecursiveMaps.InactiveConfig(config, RecursiveMaps.iscachedtype) + for (k, v) in seen + _accumulate_seen_item!(f, k, v, config, cachedconfig) end - return x + f(y) + return nothing end -@inline mutable_register(::Type{T}) where {T<:Integer} = true -@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{T}) where {T<:Tuple} = false -@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false -@inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where {T<:Array} = true -@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) - -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} - if !mutable_register(T) - for I in eachindex(x) - prev = x[I] - @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) - end +function _accumulate_seen_item!(f::F, k::T, v::T, config, cachedconfig) where {F, T} + function addf!!(ki::S, vi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (ki .+ f.(vi))::S + end + function addf!!(ki::S, _ki::S, vi::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert ki === _ki + ki .+= f.(vi) + return ki::S end + RecursiveMaps.check_nonactive(T, config) + if !RecursiveMaps.isinactivetype(T, config) + newk = recursive_map_inner(nothing, addf!!, Some(k), (k, v), cachedconfig) + @assert newk === k + end + return nothing end +""" + accumulate_into!(into::T, from::T) -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} - recursive_accumulate(x.contents, y.contents, seen, f) -end +Recursively accumulate from `from` into `into` and zero `from`, such that `into_i += from_i` +and `from_i = 0`, where `into_i` and `from_i` are corresponding values within `into` and +`from`. In other words, this is a recursive generalization of -@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) +```julia +into .+= from +from .= 0 +``` - for i = 1:nf - if isdefined(x, i) - xi = getfield(x, i) - ST = Core.Typeof(xi) - if !mutable_register(ST) - @assert ismutable(x) - yi = getfield(y, i) - nexti = recursive_add(xi, yi, f, mutable_register) - setfield!(x, i, nexti) - end - end +The accumulation and zeroing is only applied to differentiable values; non-differentiable +values within both `into` and `from` are left untouched. +""" +function accumulate_into!(into::T, from::T) where {T} + # may not show in coverage but both base cases are covered via deepcopy custom rule tests + function accumulate_into!!(into_i::S, from_i::S) where {S} + @assert EnzymeCore.isvectortype(S) + return (into_i + from_i)::S + end + function accumulate_into!!(into_i::S, _into_i::S, from_i::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert into_i === _into_i + into_i .+= from_i + return into_i::S end + recursive_map!(accumulate_into!!, into, (into, from)) + make_zero!(from) + return nothing end diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl new file mode 100644 index 0000000000..1024cad717 --- /dev/null +++ b/src/typeutils/recursive_maps.jl @@ -0,0 +1,683 @@ +module RecursiveMaps + +using EnzymeCore: EnzymeCore, isvectortype, isscalartype +using ..Compiler: guaranteed_const, guaranteed_const_nongen, guaranteed_nonactive, + guaranteed_nonactive_nongen + +### Config type for setting inactive/nonactive options +""" + config = InactiveConfig( + extra = (T -> false); copy_if_inactive = Val(false), runtime_inactive = Val(false) + ) + config = InactiveConfig{copy_if_inactive::Bool, runtime_inactive::Bool}(extra) + newconfig = InactiveConfig(config::InactiveConfig, extra) + +Config type for specifying which parts of objects should be skipped by `recursive_map{!}`. + +At a minimum, parts that Enzyme always considers inactive are skipped. An inactive type is a +type for which Enzyme can prove that a differentiable value can never be reached from any +instance of the type. + +The optional argument `extra` takes a function defining additional types that should be +skipped regardless of their nominal activity. `extra` should be a plain function +or callable of a singleton type, not a closure or otherwise stateful callable; this is to +ensure that an `InactiveConfig` instance is fully specified by its type. + +The parameter `copy_if_inactive` specifies whether `recursive_map{!}` should share (if +`Val(false)`, the default) or deep-copy (if `Val(true)`) inactive/skipped parts from inputs +to outputs. + +The parameter `runtime_inactive` specifies whether `recursive_map{!}` should respect runtime +semantics when determining if a type is guaranteed inactive. If `Val(false)`, guaranteed +inactivity is determined by `active_reg_nothrow`, which is a generated function and thus +frozen in the precompilation world age; this means that methods added to +`EnzymeRules.inactive_type` after `Enzyme` precompilation are not respected. If `Val(true)`, +the generated function is not used and changes to the `EnzymeRules.inactive_type` method +table are picked up through invalidation as usual. + +Using `runtime_inactive = Val(false)` may be preferred in interactive sessions or if +`EnzymeRules.inactive_type` is extended in downstream packages or package extensions. +However, performance may suffer if the activity of every type cannot be resolved at compile +time, so `runtime_inactive = Val(true)` is preferable when possible and is the default. + +The updating constructor `InactiveConfig(config::InactiveConfig, extra)` returns a new +config that extends `config` with an additional `extra` function. +""" +struct InactiveConfig{copy_if_inactive, runtime_inactive, E} + extra::E + function InactiveConfig{C, R}(extra::E) where {C, R, E} + @assert Base.issingletontype(E) + return new{C::Bool, R::Bool, E}(extra) + end +end + +function InactiveConfig( + extra::E = (_ -> (@nospecialize; false)); + copy_if_inactive::Val{C} = Val(false), runtime_inactive::Val{R} = Val(false), + ) where {E, C, R} + return InactiveConfig{C, R}(extra) +end + +function InactiveConfig(config::InactiveConfig{C, R}, extra::E) where {C, R, E} + @inline combinedextra(::Type{T}) where {T} = (config.extra(T) || extra(T)) + return InactiveConfig{C, R}(combinedextra) +end + +function isinactivetype(::Type{T}, config::InactiveConfig{C, false}) where {T, C} + return guaranteed_const(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime +end +function isinactivetype(::Type{T}, config::InactiveConfig{C, true}) where {T, C} + return config.extra(T) || guaranteed_const_nongen(T, nothing) # call config.extra first, as guaranteed_const_nongen may incur runtime dispatch +end + +function isnonactivetype(::Type{T}, config::InactiveConfig{C, false}) where {T, C} + return guaranteed_nonactive(T) || config.extra(T) # call guaranteed_const first, as this is a constant at runtime +end +function isnonactivetype(::Type{T}, config::InactiveConfig{C, true}) where {T, C} + return config.extra(T) || guaranteed_nonactive_nongen(T, nothing) # call config.extra first, as guaranteed_nonactive_nongen may incur runtime dispatch +end + +### traits defining active leaf types for recursive_map +@inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) +@inline function EnzymeCore.isvectortype(::Type{<:DenseArray{U}}) where {U} + return isbitstype(U) && isscalartype(U) +end + +@inline EnzymeCore.isscalartype(::Type) = false +@inline EnzymeCore.isscalartype(::Type{T}) where {T <: AbstractFloat} = isconcretetype(T) +@inline function EnzymeCore.isscalartype(::Type{Complex{T}}) where {T <: AbstractFloat} + return isconcretetype(T) +end + +### recursive_map: walk arbitrary objects and map a function over scalar and vector leaves +""" + newy = recursive_map( + [seen::Union{Nothing, IdDict},] + f, + [y::T,] + xs::NTuple{N, T}, + config::InactiveConfig = InactiveConfig(), + )::T + +Recurse through `N` objects `xs = (x1::T, x2::T, ..., xN::T)` of the same type, mapping the +function `f` over every differentiable value encountered and constructing a new object +`newy::T` from the resulting values `newy_i = f(x1_i, ..., xN_i)`. + +The trait [`EnzymeCore.isvectortype`](@ref) determines which values are considered leaf +nodes at which to terminate recursion and invoke `f`. See the docstring for +[`EnzymeCore.isvectortype`](@ref) and the related [`EnzymeCore.isscalartype`](@ref) for more +information. + +An existing object `y::T` may be passed, in which case it is updated "partially-in-place": +any parts of `y` that are mutable or non-differentiable are reused in the returned object +`newy`, while immutable differentiable parts are handled out-of-place as if `y` were not +passed. If `T` itself is a mutable type, `y` is modified fully in-place and returned, such +that `newy === y`. + +The recursion and mapping operate on the structure of `T` as defined by struct fields and +plain array elements, not on the values provided through iteration or array interfaces. For +example, given a structured matrix wrapper or sparse array type, this function recurses into +the struct type and operates on the plain arrays held within, rather than operating on the +array that the type notionally represents. + +# Arguments + +* `seen::Union{IdDict, Nothing}` (optional): Dictionary for tracking object identity as + needed to reproduce the object reference graph topology of the `xs` when constructing `y`, + including cycles (i.e., recursive substructures) and convergent paths. If not provided, an + `IdDict` will be allocated internally if required. + + If `nothing` is provided, object identity tracking is turned off. Objects with multiple + references are then duplicated such that the graph of object references within `newy` + becomes a tree. Note that any cycles in the `xs` will result in infinite recursion + and stack overflow. + +* `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf node in `newy`, + that is, `newy_i = f(x1_i::U, ..., xN_i::U)::U`. The function `f` must be applicable to + the type of every leaf node, and must return a value of the same type as its arguments. + + When an existing object `y` is provided and contains leaf nodes of a non-isbits non-scalar + type `U`, `f` should also have a partially-in-place method + `newy_i = f(y_i::U, x1_i::U, ..., xN_i::U)::U` that modifies and reuses any mutable parts + of `y_i`; in particular, if `U` is a mutable type, this method should return + `newy_i === y_i`. + + If a non-isbits leaf type `U` must always be handled using the out-of-place signature, + define the method `EnzymeCore.isscalartype(::Type{U}) = true`. + + See [`EnzymeCore.isvectortype`](@ref) and [`EnzymeCore.isscalartype`](@ref) for more + details about leaf types and scalar types. + +* `y::T` (optional): Instance from which to reuse mutable and non-differentiable parts when + mapping (partially) in-place. + +* `xs::NTuple{N, T}`: Tuple of `N` objects of the same type `T`. + + The first object `x1 = first(xs)` is the reference for graph structure and + non-differentiable values when constructing the returned object. In particular: + * When `y` is not provided, non-differentiable values within `newy` are shared with/copied + from `x1`. + * When `y` is provided, its non-differentiable values are kept unchanged, unless they are + uninitialized, in which case they are shared with/copied from from `x1`. + * The graph topology of object references in `x1` is the one which is reproduced in the + returned object. Hence, for each instance of cycles and converging paths within `x1`, + the same structure must be present in the other objects `x2, ..., xN`, otherwise the + corresponding values in `newy` would not be uniquely defined. However, `x2, ..., xN` may + contain additional cycles and converging paths that are not present in `x1`; these do + not affect the structure of `newy`. + * If any values within `x1` are not initialized (that is, struct fields are undefined or + array elements are unassigned), they remain uninitialized in `newy`. If any such values + are mutable and `y` is provided, the corresponding value in `y` must not already be + initialized, since initialized values cannot be nulled. Conversely, for every value in + `x1` that is initialized, the corresponding values in `x2, ..., xN` must also be + initialized, such that the corresponding value in `newy` can be computed. However, + `x2, ..., xN` may have initialized values where `x1` has uninitialized values; these + will remain uninitialized in `newy`. + +* `config::InactiveConfig` (optional): Config object detailing how to deal with + non-differentiable (inactive) parts. The config specifies whether non-differentiable parts + should be shared or deep-copied from `x1` to `newy`, and whether any additional types + should be skipped in addition to those Enzyme always considers inactive. See + [`InactiveConfig`](@ref) for details. +""" +function recursive_map end + +const Maybe{T} = Union{Nothing, Some{T}} + +## entry points: set default arguments, deal with nothing/Some +function recursive_map(f::F, xs::NTuple, config::InactiveConfig = InactiveConfig()) where {F} + return recursive_map_main(f, nothing, xs, config) +end + +function recursive_map( + f::F, y::T, xs::NTuple{N, T}, config::InactiveConfig = InactiveConfig() + ) where {F, N, T} + return recursive_map_main(f, Some(y), xs, config) +end + +function recursive_map( + seen::Union{Nothing, IdDict}, + f::F, + xs::NTuple, + config::InactiveConfig = InactiveConfig(), + ) where {F} + return recursive_map_main(seen, f, nothing, xs, config) +end + +function recursive_map( + seen::Union{Nothing, IdDict}, + f::F, + y::T, + xs::NTuple{N, T}, + config::InactiveConfig = InactiveConfig(), + ) where {F, N, T} + return recursive_map_main(seen, f, Some(y), xs, config) +end + +## main dispatcher: allocate IdDict if needed, exit early if possible +function recursive_map_main( + f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config::InactiveConfig + ) where {F, N, T} + newy = if isinactivetype(T, config) + recursive_map_inactive(nothing, maybe_y, xs, config) + elseif isvectortype(T) || isbitstype(T) + recursive_map_inner(nothing, f, maybe_y, xs, config) + else + recursive_map_inner(IdDict(), f, maybe_y, xs, config) + end + return newy::T +end + +## recursive methods +function recursive_map_main( + seen::Union{Nothing, IdDict}, + f::F, + maybe_y::Maybe{T}, + xs::NTuple{N, T}, + config::InactiveConfig, + ) where {F, N, T} + # determine whether to continue recursion, copy/share, or retrieve from cache + newy = if isinactivetype(T, config) + recursive_map_inactive(seen, maybe_y, xs, config) + elseif isbitstype(T) # no object identity to to track in this branch + recursive_map_inner(nothing, f, maybe_y, xs, config) + elseif hascache(seen, xs) + getcached(seen, xs) + else + recursive_map_inner(seen, f, maybe_y, xs, config) + end + return newy::T +end + +@inline function recursive_map_inner( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + # forward to appropriate handler for leaf vs. mutable vs. immutable type + @assert !isabstracttype(T) + @assert isconcretetype(T) + newy = if isvectortype(T) + recursive_map_leaf(seen, f, maybe_y, xs) + elseif ismutabletype(T) + recursive_map_mutable(seen, f, maybe_y, xs, config) + else + recursive_map_immutable(seen, f, maybe_y, xs, config) + end + return newy::T +end + +@inline function recursive_map_mutable( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + @assert ismutabletype(T) + if isnothing(maybe_y) && !(T <: DenseArray) && all(isbitstype, fieldtypes(T)) + # fast path for out-of-place handling when all fields are bitstypes, which rules + # out undefined fields and circular references + newy = recursive_map_new(seen, f, nothing, xs, config) + maybecache!(seen, newy, xs) + else + newy = if isnothing(maybe_y) + _similar(first(xs)) + else + something(maybe_y) + end + maybecache!(seen, newy, xs) + recursive_map_mutable_inner!(seen, f, newy, maybe_y, xs, config) + end + return newy::T +end + +@inline function recursive_map_mutable_inner!( + seen, f::F, newy::T, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T <: DenseArray} + if isbitstype(eltype(T)) + if isnothing(maybe_y) + broadcast!(newy, xs...) do xs_i... + recursive_map_main(nothing, f, nothing, xs_i, config) + end + else + broadcast!(newy, something(maybe_y), xs...) do y_i, xs_i... + recursive_map_main(nothing, f, Some(y_i), xs_i, config) + end + end + else + @inbounds for i in eachindex(newy, xs...) + recursive_map_item!(i, seen, f, newy, maybe_y, xs, config) + end + end + return nothing +end + +@generated function recursive_map_mutable_inner!( + seen, f::F, newy::T, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + return quote + @inline + Base.Cartesian.@nexprs $(fieldcount(T)) i -> @inbounds begin + recursive_map_item!(i, seen, f, newy, maybe_y, xs, config) + end + return nothing + end +end + +@inline function recursive_map_immutable( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + @assert !ismutabletype(T) + nf = fieldcount(T) + if nf == 0 # nothing to do (also no known way to hit this branch) + newy = recursive_map_inactive(seen, maybe_y, xs, config) + else + newy = if isinitialized(first(xs), nf) # fast path when all fields are defined + check_allinitialized(Base.tail(xs), nf) + recursive_map_new(seen, f, maybe_y, xs, config) + else + recursive_map_immutable_inner(seen, f, maybe_y, xs, config) + end + # maybecache! _should_ be a no-op here; call it anyway for consistency + maybecache!(seen, newy, xs) + end + return newy::T +end + +@generated function recursive_map_immutable_inner( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + nf = fieldcount(T) + return quote + @inline + x1, xtail = first(xs), Base.tail(xs) + fields = Vector{Any}(undef, $(nf - 1)) + Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields + @inbounds if isinitialized(x1, i) + check_allinitialized(xtail, i) + fields[i] = recursive_map_item(i, seen, f, maybe_y, xs, config) + else + return new_structv(T, fields, i - 1) + end + end + @assert !isinitialized(x1, $nf) + return new_structv(T, fields, $(nf - 1)) + end +end + +@generated function recursive_map_new( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + # direct construction of fully initialized non-cyclic structs + nf = fieldcount(T) + return quote + @inline + fields = Base.@ntuple $nf i -> @inbounds begin + recursive_map_item(i, seen, f, maybe_y, xs, config) + end + newy = $(Expr(:splatnew, :T, :fields)) + return newy::T + end +end + +Base.@propagate_inbounds function recursive_map_item!( + i, seen, f::F, newy::T, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + if isinitialized(first(xs), i) + check_allinitialized(Base.tail(xs), i) + setitem!(newy, i, recursive_map_item(i, seen, f, maybe_y, xs, config)) + elseif !isnothing(maybe_y) + check_initialized(something(maybe_y), i, false) + end + return nothing +end + +Base.@propagate_inbounds function recursive_map_item( + i, seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T}, config + ) where {F, N, T} + # recurse into the xs and apply recursive_map to items with index i + maybe_y_i = if isnothing(maybe_y) || !isinitialized(something(maybe_y), i) + nothing + else + Some(getitem(something(maybe_y), i)) + end + return recursive_map_barrier(seen, f, maybe_y_i, config, getitems(xs, i)...) +end + +# function barrier such that abstractly typed items trigger minimal runtime dispatch +# the idea is that SROA can eliminate the xs_i tuple in the above function, since it's +# splatted directly into a call; thus, instead of a dynamic dispatch to the Tuple +# constructor followed by a dynamic dispatch to recursive_map, we only incur a single +# dynamic dispatch to recursive_map_barrier +function recursive_map_barrier( + seen, f::F, maybe_y_i::Maybe{ST}, config::InactiveConfig, xs_i::Vararg{ST, N} + ) where {F, N, ST} + return recursive_map_main(seen, f, maybe_y_i, xs_i, config)::ST +end + +## recursion base case handlers +@inline function recursive_map_leaf( + seen, f::F, maybe_y::Maybe{T}, xs::NTuple{N, T} + ) where {F, N, T} + # apply the mapped function to leaf values + if isnothing(maybe_y) || isbitstype(T) || isscalartype(T) + newy = f(xs...)::T + else # !isbitstype(T) + y = something(maybe_y) + newy = f(y, xs...)::T + if ismutabletype(T) + @assert newy === y + end + end + maybecache!(seen, newy, xs) + return newy::T +end + +@inline function recursive_map_inactive( + seen, maybe_y::Maybe{T}, (x1,)::NTuple{N, T}, ::InactiveConfig{copy_if_inactive} + ) where {N, T, copy_if_inactive} + newy = if !isnothing(maybe_y) + something(maybe_y) + elseif copy_if_inactive && !isbitstype(T) + if isnothing(seen) + deepcopy(x1) + else + Base.deepcopy_internal(x1, seen) + end + else + x1 + end + return newy::T +end + +### recursive_map!: fully in-place wrapper around recursive_map +""" + recursive_map!( + [seen::Union{Nothing, IdDict},] + f!!, + y::T, + xs::NTuple{N, T}, + isinactivetype::InactiveConfig = InactiveConfig(), + )::Nothing + +Recurse through `N` objects `xs = (x1::T, x2::T, ..., xN::T)` of the same type, mapping the +function `f!!` over every differentiable value encountered and updating `y::T` in-place with +the resulting values. + +This is a simple wrapper that verifies that `T` is a type where all differentiable values +can be updated in-place, calls `recursive_map`, and verifies that the returned value is +indeed identically the same object `y`. See [`recursive_map`](@ref) for details. +""" +function recursive_map! end + +function recursive_map!( + f!!::F, y::T, xs::NTuple{N, T}, config::InactiveConfig = InactiveConfig() + ) where {F, N, T} + check_nonactive(T, config) + newy = recursive_map(f!!, y, xs, config) + @assert newy === y + return nothing +end + +function recursive_map!( + seen::Union{Nothing, IdDict}, + f!!::F, + y::T, + xs::NTuple{N, T}, + config::InactiveConfig = InactiveConfig(), + ) where {F, N, T} + check_nonactive(T, config) + newy = recursive_map(seen, f!!, y, xs, config) + @assert newy === y + return nothing +end + +### recursive_map helpers +@generated function new_structv(::Type{T}, fields::Vector{Any}, nfields_) where {T} + return quote + @inline + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, fields, nfields_)::T + end +end + +@inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T +@inline _similar(x::T) where {T <: DenseArray} = similar(x)::T +Base.@propagate_inbounds isinitialized(x, i) = isdefined(x, i) +Base.@propagate_inbounds isinitialized(x::DenseArray, i) = isassigned(x, i) +Base.@propagate_inbounds getitem(x, i) = getfield(x, i) +Base.@propagate_inbounds getitem(x::DenseArray, i) = x[i] +Base.@propagate_inbounds setitem!(x, i, v) = setfield_force!(x, i, v) +Base.@propagate_inbounds setitem!(x::DenseArray, i, v) = (x[i] = v; nothing) + +Base.@propagate_inbounds function setfield_force!(x::T, i, v) where {T} + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i - 1, v) + else + setfield!(x, i, v) + end + return nothing +end + +Base.@propagate_inbounds function getitems( + (x1, xtail...)::Tuple{T, T, Vararg{T, N}}, i + ) where {T, N} + return (getitem(x1, i), getitems(xtail, i)...) +end + +Base.@propagate_inbounds getitems((x1,)::Tuple{T}, i) where {T} = (getitem(x1, i),) + +## cache (seen) helpers +@inline function iscachedtype(::Type{T}) where {T} + # cache all mutable types and any non-isbits types that are also leaf types + return ismutabletype(T) || ((!isbitstype(T)) && isvectortype(T)) +end + +@inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) +@inline shouldcache(::Nothing, ::Type{T}) where {T} = false + +@inline function maybecache!(seen, newy::T, (x1, xtail...)::NTuple{N, T}) where {N, T} + if shouldcache(seen, T) + seen[x1] = if (N == 1) + newy + else + (newy, xtail...) + end + end + return nothing +end + +@inline function hascache(seen, (x1,)::NTuple{N, T}) where {N, T} + return shouldcache(seen, T) ? haskey(seen, x1) : false +end + +@inline function getcached(seen::IdDict, (x1, xtail...)::NTuple{N, T}) where {N, T} + newy = if (N == 1) + seen[x1]::T + else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + cache = seen[x1]::NTuple{N, T} + cachedtail = cache[2:end] + check_identical(cachedtail, xtail) # check compatible structure + cache[1] + end + return newy::T +end + +## argument validation +Base.@propagate_inbounds function check_initialized(x, i, initialized = true) + if isinitialized(x, i) != initialized + throw_initialized() # TODO: hit this when VectorSpace implemented + end + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( # TODO: hit this when VectorSpace implemented + (x1, xtail...)::Tuple{T, T, Vararg{T, N}}, i, initialized = true + ) where {T, N} + check_initialized(x1, i, initialized) + check_allinitialized(xtail, i, initialized) + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( + (x1,)::Tuple{T}, i, initialized = true + ) where {T} + check_initialized(x1, i, initialized) + return nothing +end + +Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized = true) = nothing + +@inline function check_identical(u, v) # TODO: hit this when VectorSpace implemented + if u !== v + throw_identical() + end + return nothing +end + +@inline function check_nonactive(::Type{T}, config) where {T} + if !isnonactivetype(T, config) + throw_nonactive() + end + return nothing +end + +# TODO: hit all of these via check_* when VectorSpace implemented +@noinline function throw_initialized() + msg = "recursive_map(!) called on objects whose undefined fields/unassigned elements " + msg *= "don't line up" + throw(ArgumentError(msg)) +end + +@noinline function throw_identical() + msg = "recursive_map(!) called on objects whose structure don't match" + throw(ArgumentError(msg)) +end + +@noinline function throw_nonactive() + msg = "recursive_map! called on objects containing immutable differentiable values" + throw(ArgumentError(msg)) +end + +### EnzymeCore.make_zero(!) implementation +@inline function EnzymeCore.make_zero(prev::T, args::Vararg{Any, M}; kws...) where {T, M} + config = make_zero_config(args...; kws...) + new = if iszero(M) && isempty(kws) && !isinactivetype(T, config) && isvectortype(T) # fallback + # isinactivetype has precedence over isvectortype for consistency with recursive_map + convert(T, zero(prev)) # convert because zero(prev)::T may not hold when eltype(T) is abstract + else + recursive_map(_make_zero!!, (prev,), config)::T + end + return new::T +end + +@inline function EnzymeCore.make_zero!(val::T, allargs::Vararg{Any, M}; kws...) where {T, M} + @assert !isscalartype(T) # not appropriate for in-place handler + seen, args = if (M > 0) && (first(allargs) isa IdDict) + first(allargs), Base.tail(allargs) + else + nothing, allargs + end + config = make_zero_config!(args...; kws...) + if iszero(M) && isempty(kws) && !isinactivetype(T, config) && isvectortype(T) # fallback + # isinactivetype has precedence over isvectortype for consistency with recursive_map + fill!(val, false) + elseif isnothing(seen) + recursive_map!(_make_zero!!, val, (val,), config) + else + recursive_map!(seen, _make_zero!!, val, (val,), config) + end + return nothing +end + +# map make_zero(!) args/kws to config +@inline make_zero_config(C) = InactiveConfig(; copy_if_inactive = C) +@inline make_zero_config(C, R) = InactiveConfig(; copy_if_inactive = C, runtime_inactive = R) +@inline make_zero_config(; kws...) = InactiveConfig(; kws...) + +@inline make_zero_config!(R) = InactiveConfig(; runtime_inactive = R) +@inline function make_zero_config!(; runtime_inactive = nothing) + if isnothing(runtime_inactive) + return InactiveConfig() + else + return InactiveConfig(; runtime_inactive) + end +end + +# the mapped function: assert leaf type and call back into single-arg make_zero(!) +function _make_zero!!(prev::T) where {T} + @assert isvectortype(T) # otherwise infinite loop + return EnzymeCore.make_zero(prev)::T +end + +function _make_zero!!(val::T, _val::T) where {T} + @assert !isscalartype(T) # not appropriate for in-place handler + @assert isvectortype(T) # otherwise infinite loop + @assert val === _val + EnzymeCore.make_zero!(val) + return val::T +end + +# alternative entry point for passing custom IdDict +@inline function EnzymeCore.make_zero( + ::Type{T}, seen::IdDict, prev::T, args::Vararg{Any, M}; kws... + ) where {T, M} + new = recursive_map(seen, _make_zero!!, (prev,), make_zero_config(args...; kws...)) + return new::T +end + +end # module RecursiveMaps diff --git a/test/Project.toml b/test/Project.toml index fbc6d754fe..667d94ba1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/make_zero.jl b/test/make_zero.jl deleted file mode 100644 index cbe2f2159f..0000000000 --- a/test/make_zero.jl +++ /dev/null @@ -1,725 +0,0 @@ -module MakeZeroTests - -using Enzyme -using StaticArrays -using Test - -# Universal getters/setters for built-in and custom containers/wrappers -getx(w::Base.RefValue) = w[] -getx(w::Core.Box) = w.contents -getx(w) = first(w) -gety(w) = last(w) - -setx!(w::Base.RefValue, x) = (w[] = x) -setx!(w::Core.Box, x) = (w.contents = x) -setx!(w, x) = (w[begin] = x) -sety!(w, y) = (w[end] = y) - -# non-isbits MArray doesn't support setindex!, so requires a little hack -function setx!(w::MArray{S,T}, x) where {S,T} - if isbitstype(T) - w[begin] = x - else - w.data = (x, Base.tail(w.data)...) - end - return x -end - -function sety!(w::MArray{S,T}, y) where {S,T} - if isbitstype(T) - w[end] = y - else - w.data = (Base.front(w.data)..., y) - end - return y -end - -struct Empty end - -mutable struct MutableEmpty end - -Base.:(==)(::MutableEmpty, ::MutableEmpty) = true - -struct Wrapper{T} - x::T -end - -Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) -getx(a::Wrapper) = a.x - -mutable struct MutableWrapper{T} - x::T -end - -Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) - -getx(a::MutableWrapper) = a.x -setx!(a::MutableWrapper, x) = (a.x = x) - -struct DualWrapper{Tx,Ty} - x::Tx - y::Ty -end - -DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) - -function Base.:(==)(a::DualWrapper, b::DualWrapper) - return (a === b) || ((a.x == b.x) && (a.y == b.y)) -end - -getx(a::DualWrapper) = a.x -gety(a::DualWrapper) = a.y - -mutable struct MutableDualWrapper{Tx,Ty} - x::Tx - y::Ty -end - -MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) - -function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) - return (a === b) || ((a.x == b.x) && (a.y == b.y)) -end - -getx(a::MutableDualWrapper) = a.x -gety(a::MutableDualWrapper) = a.y - -setx!(a::MutableDualWrapper, x) = (a.x = x) -sety!(a::MutableDualWrapper, y) = (a.y = y) - -struct Incomplete{T} - s::String - x::Float64 - w::T - z # not initialized - Incomplete(s, x, w) = new{typeof(w)}(s, x, w) -end - -function Base.:(==)(a::Incomplete, b::Incomplete) - (a === b) && return true - ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false - if isdefined(a, :z) && isdefined(b, :z) - (a.z == b.z) || return false - elseif isdefined(a, :z) || isdefined(b, :z) - return false - end - return true -end - -mutable struct MutableIncomplete{T} - s::String - const x::Float64 - y::Float64 - z # not initialized - w::T - function MutableIncomplete(s, x, y, w) - ret = new{typeof(w)}(s, x, y) - ret.w = w - return ret - end -end - -function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) - (a === b) && return true - if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) - return false - end - if isdefined(a, :z) && isdefined(b, :z) - (a.z == b.z) || return false - elseif isdefined(a, :z) || isdefined(b, :z) - return false - end - return true -end - -mutable struct CustomVector{T} <: AbstractVector{T} - data::Vector{T} -end - -Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) - -function Enzyme.EnzymeCore.make_zero( - ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} -) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} - @info "make_zero(::CustomVector)" - if haskey(seen, prev) - return seen[prev] - end - new = CustomVector(zero(prev.data))::CV - seen[prev] = new - return new -end - -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing - @info "make_zero!(::CustomVector)" - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev.data, false) - return nothing -end - -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) - return Enzyme.EnzymeCore.make_zero!(prev, nothing) -end - -struct WithIO{F} # issue 2091 - v::Vector{Float64} - callback::F - function WithIO(v, io) - callback() = println(io, "hello") - return new{typeof(callback)}(v, callback) - end -end - -macro test_noerr(expr) - return quote - @test_nowarn try - # catch errors to get failed test instead of "exception outside of a @test" - $(esc(expr)) - catch e - showerror(stderr, e) - end - end -end - -const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] - -const inactivetup = ("a", Empty(), MutableEmpty()) -const inactivearr = [inactivetup] - -const wrappers = [ - (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), - (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), - (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), - - (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), - (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), - - (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), - (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), - (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), - - (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), - (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), - (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), - (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), - - (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), - (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), - (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), - - (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), - (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), - - (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), - (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), - - (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), - - (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), - (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), - - (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), - (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), - - # StaticArrays extension - (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), - (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), - (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), - (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), - (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), - (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), - (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), - (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), -] - -@static if VERSION < v"1.11-" -else -_memory(x::Vector) = Memory{eltype(x)}(x) -push!( - wrappers, - (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), - (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), - (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), - (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), -) -end - -function test_make_zero() - @testset "scalars" begin - @testset "$T" for T in scalartypes - x = oneunit(T) - x_makez = make_zero(x) - @test typeof(x_makez) === T # correct type - @test x_makez == zero(T) # correct value - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - end - end - @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) - x = oneunit(T) - w = wrapper.f(x) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(getx(w_makez)) === T # correct type - @test getx(w_makez) == zero(T) # correct value - @test getx(w) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - @testset "doubly included in $(dualwrapper.name)" for - dualwrapper in filter(w -> (w.N == 2), wrappers) - w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - d_outer_makez = make_zero(d_outer) - @test typeof(d_outer_makez) === typeof(d_outer) # correct type - @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type - @test typeof(getx(getx(d_outer_makez))) === T # correct type - @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology - @test getx(getx(d_outer_makez)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # no mutation of original - @test getx(d_outer) === w_inner # no mutation of original - @test getx(w_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - w_outer_makez = make_zero(w_outer) - @test typeof(w_outer_makez) === typeof(w_outer) # correct type - @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type - @test typeof(getx(getx(w_outer_makez))) === T # correct type - @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology - @test getx(getx(w_outer_makez)) == zero(T) # correct value - @test getx(w_outer) === d_inner # no mutation of original - @test getx(d_inner) === gety(d_inner) # no mutation of original - @test getx(d_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - if wrapper.mutable && !dualwrapper.mutable - # some code paths can only be hit with three layers of wrapping: - # mutable(immutable(mutable(scalar))) - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) - w_inner = wrapper.f(x) - d_middle = dualwrapper.f(w_inner, w_inner) - w_outer = outerwrapper.f(d_middle) - w_outer_makez = make_zero(w_outer) - @test typeof(w_outer_makez) === typeof(w_outer) # correct type - @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type - @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type - @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type - @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology - @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value - @test getx(w_outer) === d_middle # no mutation of original - @test getx(d_middle) === gety(d_middle) # no mutation of original - @test getx(d_middle) === w_inner # no mutation of original - @test getx(w_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - end - end - end - end - end - @testset "inactive" begin - @testset "in $(wrapper.name)" for wrapper in wrappers - if wrapper.N == 1 - w = wrapper.f(inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const - end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === inactivearr # no mutation of original - else # wrapper.N == 2 - @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const - end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === gety(w_makez) # preserved topology - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === gety(w) # no mutation of original - @test getx(w) === inactivearr # no mutation of original - end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(getx(w_makez)) === typeof(a) # correct type - @test getx(w_makez) == [0.0] # correct value - @test gety(w_makez) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - @test getx(w) === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - @test gety(w) === inactivearr # no mutation of original - if wrapper.typed == :partial - # above: untyped active / typed inactive - # below: untyped inactive / typed active - w = wrapper.f(inactivearr, a) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - @test typeof(gety(w_makez)) === typeof(a) # correct type - @test gety(w_makez) == [0.0] # correct value - @test getx(w) === inactivearr # no mutation of original - @test gety(w) === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - end - end - @testset "copy_if_inactive $value" for (value, args) in [ - ("unspecified", ()), - ("= false", (Val(false),)), - ("= true", (Val(true),)), - ] - a = [1.0] - w = Any[a, inactivearr, inactivearr] - w_makez = make_zero(w, args...) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(w_makez[1]) === typeof(a) # correct type - @test w_makez[1] == [0.0] # correct value - @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) - @test w[1] === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - @test w[2] === w[3] # no mutation of original - @test w[2] === inactivearr # no mutation of original - @test inactivearr[1] === inactivetup # no mutation of original - if args == (Val(true),) - @test typeof(w_makez[2]) === typeof(inactivearr) # correct type - @test w_makez[2] == inactivearr # correct value - @test w_makez[2][1] !== inactivetup # correct identity - else - @test w_makez[2] === inactivearr # correct value/type/identity - end - end - end - @testset "heterogeneous containers" begin - scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) - wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) - mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) - items = (inactivetup..., scalars..., wraps..., mwraps...) - itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) - labels = Symbol.("i" .* string.(1:length(items))) - @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] - c_makez = make_zero(c) - @test typeof(c_makez) === typeof(c) # correct type - @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type - @test c_makez == cz # correct value - @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities - @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original - @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original - end - end - @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) - ) - a = [1.0] - if wrapper.N == 1 - w = wrapper.f(nothing) - setx!(w, (w, a)) - else - w = wrapper.f(nothing, a) - setx!(w, w) - end - w_makez = @test_noerr make_zero(w) - if wrapper.N == 1 - xz, yz = getx(w_makez) - x, y = getx(w) - else - xz, yz = getx(w_makez), gety(w_makez) - x, y = getx(w), gety(w) - end - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(xz) === typeof(w) # correct type - @test typeof(yz) === typeof(a) # correct type - @test xz === w_makez # correct self-reference - @test yz == [0.0] # correct value - @test x === w # no mutation of original - @test y === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "bring your own IdDict" begin - a = [1.0] - seen = IdDict() - a_makez = make_zero(typeof(a), seen, a) - @test typeof(a_makez) === typeof(a) # correct type - @test a_makez == [0.0] # correct value - @test a[1] === 1.0 # no mutation of original - @test haskey(seen, a) # original added to IdDict - @test seen[a] === a_makez # original points to zeroed value - end - @testset "custom leaf type" begin - a = [1.0] - v = CustomVector(a) - # include optional arg Val(false) to avoid calling the custom method directly; - # it should still be invoked - v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) - @test typeof(v_makez) === typeof(v) # correct type - @test typeof(v_makez.data) === typeof(a) # correct type - @test v_makez == CustomVector([0.0]) # correct value - @test v.data === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - @testset "undefined fields/unassigned elements" begin - @testset "array w inactive/active/mutable/unassigned" begin - a = [1.0] - values = ("a", 1.0, a) - arr = Vector{Any}(undef, 4) - arr[1:3] .= values - arr_makez = make_zero(arr) - @views begin - @test typeof(arr_makez) === typeof(arr) # correct type - @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type - @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value - @test !isassigned(arr_makez, 4) # propagated undefined - @test all(arr[1:3] .=== values) # no mutation of original - @test !isassigned(arr, 4) # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "struct w inactive/active/mutable/undefined" begin - a = [1.0] - incomplete = Incomplete("a", 1.0, a) - incomplete_makez = make_zero(incomplete) - @test typeof(incomplete_makez) === typeof(incomplete) # correct type - @test typeof(incomplete_makez.w) === typeof(a) # correct type - @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined - @test a[1] === 1.0 # no mutation of original - end - @testset "mutable struct w inactive/const active/active/mutable/undefined" begin - a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) - incomplete_makez = make_zero(incomplete) - @test typeof(incomplete_makez) === typeof(incomplete) # correct type - @test typeof(incomplete_makez.w) === typeof(a) # correct type - @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined - @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original - @test incomplete.w === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - end - end - @testset "containing IO" begin # issue #2091 - f = WithIO([1.0, 2.0], stdout) - df = @test_noerr make_zero(f) - @test df.v == [0.0, 0.0] - @test df.callback === f.callback - end - return nothing -end - -function test_make_zero!() - @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) - x = oneunit(T) - if wrapper.mutable - w = wrapper.f(x) - make_zero!(w) - @test typeof(getx(w)) === T # preserved type - @test getx(w) == zero(T) # correct value - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - end - @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( - filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) - ) - w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - make_zero!(d_outer) - @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type - @test typeof(getx(getx(d_outer))) === T # preserved type - @test getx(getx(d_outer)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if wrapper.mutable - @test getx(d_outer) === w_inner # preserved identity - end - d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - make_zero!(w_outer) - @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type - @test typeof(getx(getx(w_outer))) === T # preserved type - @test getx(getx(w_outer)) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if dualwrapper.mutable - @test getx(w_outer) === d_inner # preserved identity - end - if wrapper.mutable && !dualwrapper.mutable - # some code paths can only be hit with three layers of wrapping: - # mutable(immutable(mutable(scalar))) - @assert !dualwrapper.mutable # sanity check - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) - w_inner = wrapper.f(x) - d_middle = dualwrapper.f(w_inner, w_inner) - w_outer = outerwrapper.f(d_middle) - make_zero!(w_outer) - @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type - @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type - @test typeof(getx(getx(getx(w_outer)))) === T # preserved type - @test getx(getx(getx(w_outer))) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology - @test getx(getx(w_outer)) === w_inner # preserved identity - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - end - end - end - end - end - @testset "inactive" begin - @testset "in $(wrapper.name)" for - wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) - if wrapper.N == 1 - w = wrapper.f(inactivearr) - make_zero!(w) - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - else # wrapper.N == 2 - @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - make_zero!(w) - @test getx(w) === gety(w) # preserved topology - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - make_zero!(w) - @test getx(w) === a # preserved identity - @test a[1] === 0.0 # correct value - @test gety(w) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - end - end - end - end - @testset "heterogeneous containers" begin - mwraps = MutableWrapper.(oneunit.(scalartypes)) - mwrapsz = MutableWrapper.(zero.(scalartypes)) - items = (inactivetup..., mwraps...) - itemsz = (inactivetup..., mwrapsz...) - labels = Symbol.("i" .* string.(1:length(items))) - @testset "$name" for (name, c, cz) in [ - ("Tuple", Tuple(items), Tuple(itemsz)), - ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), - ("Array", collect(items), collect(itemsz)), - ] - make_zero!(c) - @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities - @test c == cz # correct value - end - end - @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) - ) - a = [1.0] - if wrapper.N == 1 - w = wrapper.f(nothing) - setx!(w, (w, a)) - else - w = wrapper.f(nothing, a) - setx!(w, w) - end - @test_noerr make_zero!(w) - if wrapper.N == 1 - x, y = getx(w) - else - x, y = getx(w), gety(w) - end - @test x === w # preserved self-referential identity - @test y === a # preserved identity - @test a[1] === 0.0 # correct value - end - end - @testset "bring your own IdSet" begin - a = [1.0] - seen = Base.IdSet() - make_zero!(a, seen) - @test a[1] === 0.0 # correct value - @test (a in seen) # object added to IdSet - end - @testset "custom leaf type" begin - a = [1.0] - v = CustomVector(a) - # bringing own IdSet to avoid calling the custom method directly; - # it should still be invoked - @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) - @test v.data === a # preserved identity - @test a[1] === 0.0 # correct value - end - @testset "undefined fields/unassigned elements" begin - @testset "array w inactive/active/mutable/unassigned" begin - a = [1.0] - values = ("a", 1.0, a) - arr = Vector{Any}(undef, 4) - arr[1:3] .= values - make_zero!(arr) - @views begin - @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types - @test arr[1:3] == ["a", 0.0, [0.0]] # correct value - @test arr[3] === a # preserved identity - @test !isassigned(arr, 4) # preserved unassigned - end - end - @testset "struct w inactive/active/mutable/undefined" begin - a = [1.0] - incompletearr = [Incomplete("a", 1.0, a)] - make_zero!(incompletearr) - @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined - @test incompletearr[1].w === a # preserved identity - end - @testset "mutable struct w inactive/const active/active/mutable/undefined" begin - a = [1.0] - incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) - make_zero!(incomplete) - @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined - @test incomplete.w === a # preserved identity - end - @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin - # old implementation triggered #1935 - # new implementation would work regardless due to limited use of justActive - a = [1.0] - incomplete = Incomplete("a", 1.0, a) - incompletetuparr = [(incomplete,)] - make_zero!(incompletetuparr) - @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type - @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value - @test incompletetuparr[1][1].w === a # preserved identity - end - end - @testset "active/mixed type error" begin - @test_throws ArgumentError make_zero!((1.0,)) - @test_throws ArgumentError make_zero!((1.0, [1.0])) - @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 - end - @testset "containing IO" begin # issue #2091 - f = WithIO([1.0, 2.0], stdout) - fwrapped = [f] - @test_noerr make_zero!(fwrapped) - @test fwrapped[1] === f - @test fwrapped[1].v == [0.0, 0.0] - end - return nothing -end - -@testset "make_zero" test_make_zero() -@testset "make_zero!" test_make_zero!() - -end # module MakeZeroTests diff --git a/test/recursive_maps.jl b/test/recursive_maps.jl new file mode 100644 index 0000000000..a45bd71fc3 --- /dev/null +++ b/test/recursive_maps.jl @@ -0,0 +1,991 @@ +module RecursiveMapTests + +using Enzyme +using JLArrays +using Logging +using StaticArrays +using Test + +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w::JLArray) = JLArrays.@allowscalar first(w) +gety(w::JLArray) = JLArrays.@allowscalar last(w) +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S, T}, x) where {S, T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S, T}, y) where {S, T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx, Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T, typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx, Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T, typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) + +struct Incomplete{T, U} + s::String + x::Float64 + w::T + y::U # possibly not initialized + z # not initialized + Incomplete(s, x, w) = new{typeof(w), Any}(s, x, w) + Incomplete(s, x, w, y) = new{typeof(w), typeof(y)}(s, x, w, y) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false + if isdefined(a, :y) && isdefined(b, :y) + (a.w == b.w) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + elseif isdefined(a, :y) || isdefined(b, :y) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct CustomVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.isvectortype(::Type{CustomVector{T}}) where {T} + return Enzyme.EnzymeCore.isscalartype(T) +end + +function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV <: CustomVector{<:AbstractFloat}} + @info "make_zero(::CustomVector)" + return CustomVector(zero(prev.data))::CV +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + @info "make_zero!(::CustomVector)" + fill!(prev.data, false) + return nothing +end + +struct WithIO{F} # issue 2091 + v::Vector{Float64} + callback::F + function WithIO(v, io) + callback() = println(io, "hello") + return new{typeof(callback)}(v, callback) + end +end + +macro test_noerr(expr) + return quote + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + $(esc(expr)) + catch e + showerror(stderr, e) + end + end +end + +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64, BigFloat, Complex{BigFloat}] + +const inactivebits = (1, Empty()) +const inactivetup = (inactivebits, "a", MutableEmpty()) +const inactivearr = [inactivetup] + +#! format: off +const wrappers = [ + (name = "Tuple{X}", f = tuple, N = 1, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X}", f = (NamedTuple{(:x,)} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "struct{X}", f = Wrapper, N = 1, mutable = false, typed = true, bitsonly = false), + + (name = "@NamedTuple{x}", f = (@NamedTuple{x} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = Wrapper{Any}, N = 1, mutable = false, typed = false, bitsonly = false), + + (name = "Array{X}", f = (x -> [x]), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Base.RefValue{X}", f = Ref, N = 1, mutable = true, typed = true, bitsonly = false), + (name = "mutable struct{X}", f = MutableWrapper, N = 1, mutable = true, typed = true, bitsonly = false), + + (name = "Array{Any}", f = (x -> Any[x]), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Base.RefValue{Any}", f = Ref{Any}, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Core.Box", f = Core.Box, N = 1, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any}", f = MutableWrapper{Any}, N = 1, mutable = true, typed = false, bitsonly = false), + + (name = "Tuple{X, Y}", f = tuple, N = 2, mutable = false, typed = true, bitsonly = false), + (name = "@NamedTuple{x::X, y::Y}", f = (NamedTuple{(:x, :y)} ∘ tuple), N = 2, mutable = false, typed = true, bitsonly = false), + (name = "struct{X, Y}", f = DualWrapper, N = 2, mutable = false, typed = true, bitsonly = false), + + (name = "@NamedTuple{x, y::Y}", f = ((x, y) -> @NamedTuple{x, y::typeof(y)}((x, y))), N = 2, mutable = false, typed = :partial, bitsonly = false), + (name = "struct{Any, Y}", f = DualWrapper{Any}, N = 2, mutable = false, typed = :partial, bitsonly = false), + + (name = "@NamedTuple{x, y}", f = (@NamedTuple{x, y} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "struct{Any}", f = DualWrapper{Any, Any}, N = 2, mutable = false, typed = false, bitsonly = false), + + (name = "mutable struct{X, Y}", f = MutableDualWrapper, N = 2, mutable = true, typed = true, bitsonly = false), + + (name = "Array{promote_type(X, Y)}", f = ((x, y) -> [x, y]), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "mutable struct{Any, Y}", f = MutableDualWrapper{Any}, N = 2, mutable = true, typed = :partial, bitsonly = false), + + (name = "Array{Any}", f = ((x, y) -> Any[x, y]), N = 2, mutable = true, typed = false, bitsonly = false), + (name = "mutable struct{Any, Any}", f = MutableDualWrapper{Any,Any}, N = 2, mutable = true, typed = false, bitsonly = false), + + # StaticArrays extension + (name = "SVector{1, X}", f = (SVector{1} ∘ tuple), N = 1, mutable = false, typed = true, bitsonly = false), + (name = "SVector{1, Any}", f = (SVector{1, Any} ∘ tuple), N = 1, mutable = false, typed = false, bitsonly = false), + (name = "MVector{1, X}", f = (MVector{1} ∘ tuple), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "MVector{1, Any}", f = (MVector{1, Any} ∘ tuple), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "SVector{2, promote_type(X, Y)}", f = (SVector{2} ∘ tuple), N = 2, mutable = false, typed = :promoted, bitsonly = false), + (name = "SVector{2, Any}", f = (SVector{2, Any} ∘ tuple), N = 2, mutable = false, typed = false, bitsonly = false), + (name = "MVector{2, promote_type(X, Y)}", f = (MVector{2} ∘ tuple), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "MVector{2, Any}", f = (MVector{2, Any} ∘ tuple), N = 2, mutable = true, typed = false, bitsonly = false), + + # GPUArrays extension + (name = "JLArray{X}", f = (x -> JLArray([x])), N = 1, mutable = true, typed = true, bitsonly = true), + (name = "JLArray{promote_type(X, Y)}", f = ((x, y) -> JLArray([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = true), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name = "Memory{X}", f = (x -> _memory([x])), N = 1, mutable = true, typed = true, bitsonly = false), + (name = "Memory{Any}", f = (x -> _memory(Any[x])), N = 1, mutable = true, typed = false, bitsonly = false), + (name = "Memory{promote_type(X, Y)}", f = ((x, y) -> _memory([x, y])), N = 2, mutable = true, typed = :promoted, bitsonly = false), + (name = "Memory{Any}", f = ((x, y) -> _memory(Any[x, y])), N = 2, mutable = true, typed = false, bitsonly = false), +) +end +#! format: on + +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in filter( + w -> (w.N == 2), wrappers + ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue + w_inner = wrapper.f(x) + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct layout + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + d_inner = dualwrapper.f(x, x) + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct layout + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct layout + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === inactive # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), true), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(mixed) # correct type + @test getx(w_makez)[1] === 0.0 # correct value + @test getx(w_makez)[2] === inactive # preserved inactive identity + @test getx(w) === mixed # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved inactive value + @test mixed[1] === 1.0 # no mutation of original + @test mixed[2] === inactivearr # no mutation of original + end + end + end + else # wrapper.N == 2 + @testset "multiple references" begin + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved layout + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactive # no mutation of original + if inactive === inactive + @test inactivearr[1] === inactivetup # preserved value + end + end + end + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + end + #! format: off + @testset "copy_if_inactive $value" for (value, args, kwargs) in [ + ("unspecified", (), (;)), + ("= false", (Val(false),), (;)), + ("= false (kwarg)", (), (; copy_if_inactive = Val(false))), + ("= true", (Val(true),), (;)), + ("= true (kwarg)", (), (; copy_if_inactive = Val(true))), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...; kwargs...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct layout (layout should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if (args == (Val(true),)) || (kwargs == (; copy_if_inactive = Val(true))) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + #! format: on + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + #! format: off + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + #! format: on + end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a_makez = make_zero(a) + @test a_makez[1] === zero(a[1]) + @test a_makez[2] === zero(a[2]) + @test typeof(a_makez[3]) === btype + @test a_makez[3] == 0 + @test a_makez[4] === a_makez[3] + @test typeof(a_makez[5]) === btype + @test a_makez[5] == 0 + @test a_makez[5] !== a_makez[3] + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_noerr make_zero(w) + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === a_makez # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "runtime_inactive" begin + # verify that MutableWrapper is seen as active by both variants + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + a_makez = make_zero(a, Val(false), Val(false)) + @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a; runtime_inactive = Val(false)) + @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a, Val(false), Val(true)) + @assert a_makez == MutableWrapper(0.0) + a_makez = make_zero(a; runtime_inactive = Val(true)) + @assert a_makez == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # runtime_inactive == false => redefined inactive_type should have no effect + a_makez = @invokelatest make_zero(a, Val(false), Val(false)) + @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(false)) + @test a_makez == MutableWrapper(0.0) + + # runtime_inactive == true => redefined inactive_type should take effect + # MutableWrapper considered inactive and treated according to copy_if_inactive + a_makez = @invokelatest make_zero(a, Val(false), Val(true)) + @test a_makez === a + a_makez = @invokelatest make_zero( + a; copy_if_inactive = Val(false), runtime_inactive = Val(true) + ) + @test a_makez === a + a_makez = @invokelatest make_zero(a, Val(true), Val(true)) + @test a_makez !== a + @test a_makez == MutableWrapper(1.0) + a_makez = @invokelatest make_zero( + a; copy_if_inactive = Val(true), runtime_inactive = Val(true) + ) + @test a_makez !== a + @test a_makez == MutableWrapper(1.0) + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active by both variants + a_makez = @invokelatest make_zero(a, Val(false), Val(false)) + @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(false)) + @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a, Val(false), Val(true)) + @test a_makez == MutableWrapper(0.0) + a_makez = @invokelatest make_zero(a; runtime_inactive = Val(true)) + @test a_makez == MutableWrapper(0.0) + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + @testset "single undefined" begin + incomplete = Incomplete("a", 1.0, a, nothing) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0], nothing) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "multiple undefined" begin + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=# 1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + df = @test_noerr make_zero(f) + @test df.v == [0.0, 0.0] + @test df.callback === f.callback + end + return nothing +end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue + w_inner = wrapper.f(x) + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + end + d_inner = dualwrapper.f(x, x) + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + end + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in filter( + w -> (w.mutable || (w.typed == true)), wrappers + ) + if wrapper.N == 1 + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + make_zero!(w) + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), wrapper.mutable), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + make_zero!(w) + @test getx(w)[1] === 0.0 + @test getx(w)[2] === inactive + if inactive === inactivearr + @test getx(w) === mixed # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + end + end + else # wrapper.N == 2 + @testset "multiple references" begin + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + make_zero!(w) + @test getx(w) === gety(w) # preserved layout + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + end + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + #! format: off + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + #! format: on + end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a1, a2 = a[1], a[2] + make_zero!(a) + @test a[1] === zero(a1) + @test a[2] === zero(a2) + @test typeof(a[3]) === btype + @test a[3] == 0 + @test a[4] === a[3] + @test typeof(a[5]) === btype + @test a[5] == 0 + @test a[5] !== a[3] + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_noerr make_zero!(w) + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test haskey(seen, a) # object added to IdDict + @test seen[a] === a # object points to zeroed value, i.e., itself + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdDict to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, IdDict()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "runtime_inactive" begin + # verify that MutableWrapper is seen as active by both variants + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + a.x = 1.0 + make_zero!(a, Val(false)) + @assert a == MutableWrapper(0.0) + a.x = 1.0 + make_zero!(a; runtime_inactive = Val(false)) + @assert a == MutableWrapper(0.0) + a.x = 1.0 + make_zero!(a, Val(true)) + @assert a == MutableWrapper(0.0) + a.x = 1.0 + make_zero!(a; runtime_inactive = Val(true)) + @assert a == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # runtime_inactive == false => redefined inactive_type should have no effect + a.x = 1.0 + @invokelatest make_zero!(a, Val(false)) + @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive = Val(false)) + @test a == MutableWrapper(0.0) + + # runtime_inactive == true => redefined inactive_type should take effect + # MutableWrapper considered inactive and won't be zeroed + a.x = 1.0 + @invokelatest make_zero!(a, Val(true)) + @test a == MutableWrapper(1.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive = Val(true)) + @test a == MutableWrapper(1.0) + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active by both variants + a.x = 1.0 + @invokelatest make_zero!(a, Val(true)) + @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive = Val(true)) + @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a, Val(false)) + @test a == MutableWrapper(0.0) + a.x = 1.0 + @invokelatest make_zero!(a; runtime_inactive = Val(false)) + @test a == MutableWrapper(0.0) + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=# 1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation of make_zero! triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + @testset "containing IO" begin # issue #2091 + f = WithIO([1.0, 2.0], stdout) + fwrapped = [f] + @test_noerr make_zero!(fwrapped) + @test fwrapped[1] === f + @test fwrapped[1].v == [0.0, 0.0] + end + return nothing +end + +# because this is wrapped in a module, we should only run a single top-level testset +# otherwise a failed test in the first set will prevent the second from running +@testset "recursive maps" begin + @testset "make_zero" test_make_zero() + @testset "make_zero!" test_make_zero!() +end + +end # module RecursiveMapTests diff --git a/test/runtests.jl b/test/runtests.jl index 5f909abdb1..81a383f586 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,7 @@ include("abi.jl") include("typetree.jl") include("passes.jl") include("optimize.jl") -include("make_zero.jl") +include("recursive_maps.jl") include("rules.jl") include("rrules.jl") @@ -440,6 +440,25 @@ make3() = (1.0, 2.0, 3.0) da = [2.7] @test autodiff(Forward, sumdeepcopy, Duplicated(a, da))[1] ≈ 2.7 + # Nested containers to test nontrivial recursion in deepcopy reverse rule + b = [[3.14]] + db = [[0.0]] + sumdeepcopy_nested(x) = sum(sum, deepcopy(x)) + autodiff(Reverse, sumdeepcopy_nested, Duplicated(b, db)) + @test db[1][1] ≈ 1.0 + + c_inner = [3.14] + dc_inner = [0.0] + c = [c_inner, c_inner] + dc = [dc_inner, dc_inner] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(c, dc)) + @test dc[1] === dc[2] + @test dc[1][1] ≈ 2.0 + + d = [(3.14,)] + dd = [(0.0,)] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(d, dd)) + @test dd[1][1] ≈ 1.0 end @testset "Deferred and deferred thunk" begin @@ -533,94 +552,70 @@ end @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active, Active(z)) @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sum, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 1.0 + function reverse_holomorphic_array_tests( + f, val, dval_expected; val_expected = val, ret = Active, mapf = true + ) + vals = ComplexF64[val] + dvals = ComplexF64[zero(val)] + autodiff(ReverseHolomorphic, f, ret, Duplicated(vals, dvals)) + @test vals[1] ≈ val_expected + @test dvals[1] ≈ dval_expected - sumsq(x) = sum(x .* x) + # Use tuple to test out-of-place accumulate_seen! base case + tvals = [(ComplexF64(val),)] + dtvals = [(ComplexF64(zero(val)),)] + ft = mapf ? v -> first(map(f, v)) : f + autodiff(ReverseHolomorphic, ft, ret, Duplicated(tvals, dtvals)) + @test tvals[1][1] ≈ val_expected + @test dtvals[1][1] ≈ dval_expected + end - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sum" reverse_holomorphic_array_tests(sum, 3.4 + 2.7im, 1.0) + + sumsq(x) = sum(x .* x) + @testset "sumsq" reverse_holomorphic_array_tests(sumsq, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2(x) = sum(abs2.(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2" reverse_holomorphic_array_tests(sumsq2, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2C(x) = Complex{Float64}(sum(abs2.(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2C, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3(x) = sum(x .* conj(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3R(x) = Float64(sum(x .* conj(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3R, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2C" reverse_holomorphic_array_tests(sumsq2C, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3(x) = sum(x .* conj.(x)) + @testset "sumsq3" reverse_holomorphic_array_tests(sumsq3, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3R(x) = Float64(sum(x .* conj.(x))) + @testset "sumsq3R" reverse_holomorphic_array_tests(sumsq3R, 3.4 + 2.7im, 2(3.4 + 2.7im)) function setinact(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] nothing end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact" reverse_holomorphic_array_tests( + setinact, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false + ) function setinact2(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] return 0.0+1.0im end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact2 Const" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false + ) + @testset "setinact2 Active" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Active, mapf = false + ) function setact(z) - z[1] *= 2 - return z[1] + z[1] = 2 .* z[1] # works for both [x] and [(x,)] + return z[1][1] # returns scalar for both [x] and [(x,)] end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 2.0 + @testset "setact Const" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 0.0; val_expected = 2(3.4 + 2.7im), ret = Const, mapf = false + ) + @testset "setact Active" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 2.0; val_expected = 2(3.4 + 2.7im), ret = Active, mapf = false + ) function upgrade(z) z = ComplexF64(z)