Skip to content

Commit c2f05d4

Browse files
committed
Add recursive_map and base make_zero(!) on it
1 parent 26ca6fe commit c2f05d4

File tree

8 files changed

+1277
-565
lines changed

8 files changed

+1277
-565
lines changed

ext/EnzymeStaticArraysExt.jl

+9-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ end
3232
end
3333
end
3434

35-
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
36-
return Base.zero(x)
37-
end
38-
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
39-
return Base.zero(x)
35+
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
36+
# but in case their dedicated `zero` and `fill!` methods are more efficient than
37+
# `make_zero(!)`s generic recursion, we opt into treating them as leaves when they have
38+
# isbits eltypes (non-isbits eltypes excluded as the dedicated `zero` and `fill!` methods
39+
# don't support those).
40+
@inline function Enzyme.EnzymeCore.isvectortype(
41+
::Type{<:Union{SArray{S,T},MArray{S,T}}}
42+
) where {S,T}
43+
return isbitstype(T) && Enzyme.Compiler.RecursiveMap.isscalartype(T)
4044
end
4145

4246
end

lib/EnzymeCore/src/EnzymeCore.jl

+75-7
Original file line numberDiff line numberDiff line change
@@ -501,28 +501,96 @@ function autodiff_thunk end
501501
function autodiff_deferred_thunk end
502502

503503
"""
504+
make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T
504505
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T
505506
506507
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
507508
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
509+
510+
Extending this method for custom types is rarely needed. For new plain array types like GPU
511+
arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type implements
512+
`Base.zero`.
508513
"""
509514
function make_zero end
510515

511516
"""
512-
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
517+
make_zero!(val::T, seen::IdDict=IdDict())::Nothing
518+
519+
Recursively set a variables differentiable fields to zero. Only applicable for types `T`
520+
that are mutable or hold all differentiable values in mutable containers (e.g.,
521+
`Tuple{Vector{Float64}}`).
513522
514-
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
523+
Extending this method for custom types is rarely needed. For new plain mutable array types
524+
like GPU arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type
525+
implements `Base.zero` and `Base.fill!`.
515526
"""
516527
function make_zero! end
517528

518529
"""
519-
make_zero(prev::T)
530+
isvectortype(::Type{T})::Bool
531+
532+
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
533+
and [`make_zero!`](@ref) recurse through an object.
534+
535+
By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
536+
`T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`.
537+
538+
A new plain array type, for example a GPU array, may extend this as follows:
539+
540+
```julia
541+
@inline EnzymeCore.isvectortype(::Type{<:GPUArray{U}}) where {U} = isscalartype(U)
542+
```
543+
544+
Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not
545+
feasible, an alternative is to add methods `EnzymeCore.make_zero(arr::T)::T` and, if
546+
mutable, `EnzymeCore.make_zero!(arr::T)::Nothing`; such methods will also be picked up by
547+
recursive calls.)
520548
521-
Helper function to recursively make zero.
549+
Such extensions are mostly relevant for the lowest-level of abstraction of memory at which
550+
vector space operations like addition and scalar multiplication are supported, the
551+
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
552+
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
553+
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
554+
example, structured matrix wrappers and sparse array types that are backed by `Array` should
555+
not extend `isvectortype`.
556+
557+
See also [`isscalartype`](@ref).
522558
"""
523-
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
524-
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
525-
end
559+
function isvectortype end
560+
561+
"""
562+
isscalartype(::Type{T})::Bool
563+
564+
Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
565+
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
566+
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
567+
in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that
568+
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]
569+
570+
By default, `isscalartype(T) == true` for `T <: AbstractFloat` and
571+
`T <: Complex{<:AbstractFloat}`.
572+
573+
A hypothetical new real number type with Enzyme support should in most cases simply subtype
574+
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
575+
the function can be extended as follows:
576+
577+
```julia
578+
@inline EnzymeCore.isscalartype(::Type{<:NewReal}) = true
579+
@inline EnzymeCore.isscalartype(::Type{<:Complex{<:NewReal}}) = true
580+
```
581+
582+
In either case, the type should implement `Base.zero`. (If this is not feasible, an
583+
alternative is to add a method `EnzymeCore.make_zero(x::T)::T`; such a method will also be
584+
picked up by recursive calls.)
585+
586+
See also [`isvectortype`](@ref).
587+
588+
[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
589+
mentioned here only to illustrate that it would be inappropriate to use traits like
590+
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
591+
demonstrating the need for a dedicated `isscalartype` trait.
592+
"""
593+
function isscalartype end
526594

527595
function tape_type end
528596

src/compiler.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -1132,8 +1132,6 @@ struct Tape{TapeTy,ShadowTy,ResT}
11321132
shadow_return::ShadowTy
11331133
end
11341134

1135-
include("make_zero.jl")
1136-
11371135
function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world)
11381136
funcspec = my_methodinstance(typeof(f), tt, world)
11391137
nested_codegen!(mode, mod, funcspec, world)
@@ -7610,6 +7608,7 @@ end
76107608
end
76117609

76127610
# Recursively return x + f(y), where y is active, otherwise x
7611+
include("recursive_map.jl")
76137612

76147613
@inline function recursive_add(
76157614
x::T,

0 commit comments

Comments
 (0)