Enzyme.jl
Enzyme.jl copied to clipboard
RFC: strip most types from `gradient` output
This is a draft of a way to start addressing #1334, for comment.
It implements what I called level 2 here: https://github.com/EnzymeAD/Enzyme.jl/issues/1334#issuecomment-2016573247
On arrays like these, no change. Natural and structural representations agree:
julia> Enzyme.gradient(Reverse, first, Diagonal([1,2.]))
2×2 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅
⋅ 0.0
julia> using SparseArrays, StaticArrays
julia> Enzyme.gradient(Reverse, sum, sparse([5 0 6.]))
1×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1.0 ⋅ 1.0
julia> Enzyme.gradient(Reverse, sum, PermutedDimsArray(sparse([1 2; 3 0.]), (2,1)))
2×2 PermutedDimsArray(::SparseMatrixCSC{Float64, Int64}, (2, 1)) with eltype Float64:
1.0 1.0
1.0 0.0
julia> Enzyme.gradient(Reverse, first, reshape(SA[1,2,3,4.]',2,2))
2×2 reshape(adjoint(::SVector{4, Float64}), 2, 2) with eltype Float64:
1.0 0.0
0.0 0.0
On arrays like these, it does not know how to construct the natural representation, so doesn't try: (I know how, but the fields of the result will not line up with the existing ones.)
julia> Enzyme.gradient(Reverse, sum, Symmetric(rand(3,3)))
(data = [1.0 2.0 2.0; 0.0 1.0 2.0; 0.0 0.0 1.0], uplo = nothing)
julia> Enzyme.gradient(Reverse, first, reshape(LinRange(1,2,4)',2,2))
(parent = (parent = (start = 1.0, stop = 0.0, len = nothing, lendiv = nothing),), dims = (nothing, nothing), mi = ())
Arrays of non-diff objects cannot be wrapped up in array structs:
julia> Enzyme.gradient(Reverse, float∘first, Diagonal([1,2,3]))
(diag = nothing,)
julia> Enzyme.gradient(Reverse, float∘first, SA[1,2,3]')
(parent = (data = (nothing, nothing, nothing),),)
make_zeros uses an IdDict cache to preserve identity between different branches of the struct.
At present this does not...
julia> mutable struct TwoThings{A,B}; a::A; b::B; end
julia> nt = (x=TwoThings(3.0, 4.0), y=TwoThings(3.0, 4.0));
julia> nt.x === nt.y
false
julia> grad = Enzyme.gradient(Reverse, nt -> nt.x.a + nt.y.a + 20nt.x.b + 20nt.y.b, nt)
(x = (a = [11.0, 11.0], b = 22.0), y = (a = [11.0, 11.0], b = 22.0))
julia> grad.x === grad.y # new identity created
true
# example 2
julia> arrs = [[1,2.], [3,4.]];
julia> grad = Enzyme.gradient(Reverse, nt -> sum(sum(sum, x)::Float64 for x in nt), (a = arrs, b = arrs))
(a = [[2.0, 2.0], [2.0, 2.0]], b = [[2.0, 2.0], [2.0, 2.0]])
julia> grad.a === grad.b # container array identity is not preserved
false
julia> grad.a[1] === grad.b[1] # array of numbers
true
A simple Flux model, no functional change, just looks different to the model:
julia> using Flux
julia> model = Chain(Embedding(reshape(1:6, 2,3) .+ 0.0), softmax);
julia> Enzyme.gradient(Reverse, m -> sum(abs2, m(1)), model)
(layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),)
Comments:
-
I'm not sure what I think about the failure to preserve
===relations between some mutable objects in the original gradient. Some of this could be solved by adding an IdDict cache likemake_zerosdoes. -
The function called
strip_typesfor now probably needs to be public, so that you can call it yourself after constructingdx = make_zero(x), and so that you can overload it for your array wrappers. -
Projecting things like Symmetric to their covariant representation probably needs to be opt-in, by somehow telling
gradientthat you want this. (That's level 4 here: https://github.com/EnzymeAD/Enzyme.jl/issues/1334#issuecomment-2016573247 .) Could be implemented as additional methods of this function, something likestrip_types(x, dx, Val(true), cache)? -
Surely all the code is in the wrong place, and needs tests.