@@ -501,28 +501,96 @@ function autodiff_thunk end
501
501
function autodiff_deferred_thunk end
502
502
503
503
"""
504
+ make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T
504
505
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T
505
506
506
507
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
507
508
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
509
+
510
+ Extending this method for custom types is rarely needed. For new plain array types like GPU
511
+ arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type implements
512
+ `Base.zero`.
508
513
"""
509
514
function make_zero end
510
515
511
516
"""
512
- make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
517
+ make_zero!(val::T, seen::IdDict=IdDict())::Nothing
518
+
519
+ Recursively set a variables differentiable fields to zero. Only applicable for types `T`
520
+ that are mutable or hold all differentiable values in mutable containers (e.g.,
521
+ `Tuple{Vector{Float64}}`).
513
522
514
- Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
523
+ Extending this method for custom types is rarely needed. For new plain mutable array types
524
+ like GPU arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type
525
+ implements `Base.zero` and `Base.fill!`.
515
526
"""
516
527
function make_zero! end
517
528
518
529
"""
519
- make_zero(prev::T)
530
+ isvectortype(::Type{T})::Bool
531
+
532
+ Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
533
+ and [`make_zero!`](@ref) recurse through an object.
534
+
535
+ By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
536
+ `T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`.
537
+
538
+ A new plain array type, for example a GPU array, may extend this as follows:
539
+
540
+ ```julia
541
+ @inline EnzymeCore.isvectortype(::Type{<:GPUArray{U}}) where {U} = isscalartype(U)
542
+ ```
543
+
544
+ Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not
545
+ feasible, an alternative is to add methods `EnzymeCore.make_zero(arr::T)::T` and, if
546
+ mutable, `EnzymeCore.make_zero!(arr::T)::Nothing`; such methods will also be picked up by
547
+ recursive calls.)
520
548
521
- Helper function to recursively make zero.
549
+ Such extensions are mostly relevant for the lowest-level of abstraction of memory at which
550
+ vector space operations like addition and scalar multiplication are supported, the
551
+ prototypical case being `Array`. Regular Julia structs with vector space-like semantics
552
+ should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
553
+ directly on their backing arrays, just like how Enzyme treats them when differentiating. For
554
+ example, structured matrix wrappers and sparse array types that are backed by `Array` should
555
+ not extend `isvectortype`.
556
+
557
+ See also [`isscalartype`](@ref).
522
558
"""
523
- @inline function make_zero (prev:: T , :: Val{copy_if_inactive} = Val (false )) where {T, copy_if_inactive}
524
- make_zero (Core. Typeof (prev), IdDict (), prev, Val (copy_if_inactive))
525
- end
559
+ function isvectortype end
560
+
561
+ """
562
+ isscalartype(::Type{T})::Bool
563
+
564
+ Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
565
+ composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
566
+ values of the type in-place. For example, `BigFloat` is a mutable type but does not support
567
+ in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that
568
+ `make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]
569
+
570
+ By default, `isscalartype(T) == true` for `T <: AbstractFloat` and
571
+ `T <: Complex{<:AbstractFloat}`.
572
+
573
+ A hypothetical new real number type with Enzyme support should in most cases simply subtype
574
+ `AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
575
+ the function can be extended as follows:
576
+
577
+ ```julia
578
+ @inline EnzymeCore.isscalartype(::Type{<:NewReal}) = true
579
+ @inline EnzymeCore.isscalartype(::Type{<:Complex{<:NewReal}}) = true
580
+ ```
581
+
582
+ In either case, the type should implement `Base.zero`. (If this is not feasible, an
583
+ alternative is to add a method `EnzymeCore.make_zero(x::T)::T`; such a method will also be
584
+ picked up by recursive calls.)
585
+
586
+ See also [`isvectortype`](@ref).
587
+
588
+ [^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
589
+ mentioned here only to illustrate that it would be inappropriate to use traits like
590
+ `ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
591
+ demonstrating the need for a dedicated `isscalartype` trait.
592
+ """
593
+ function isscalartype end
526
594
527
595
function tape_type end
528
596
0 commit comments