Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

RFC: strip most types from `gradient` output

Open mcabbott opened this issue 1 year ago • 0 comments

This is a draft of a way to start addressing #1334, for comment.

It implements what I called level 2 here: https://github.com/EnzymeAD/Enzyme.jl/issues/1334#issuecomment-2016573247

On arrays like these, no change. Natural and structural representations agree:

julia> Enzyme.gradient(Reverse, first, Diagonal([1,2.]))
2×2 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅ 
  ⋅   0.0

julia> using SparseArrays, StaticArrays

julia> Enzyme.gradient(Reverse, sum, sparse([5 0 6.]))
1×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
 1.0   ⋅   1.0

julia> Enzyme.gradient(Reverse, sum, PermutedDimsArray(sparse([1 2; 3 0.]), (2,1)))
2×2 PermutedDimsArray(::SparseMatrixCSC{Float64, Int64}, (2, 1)) with eltype Float64:
 1.0  1.0
 1.0  0.0

julia> Enzyme.gradient(Reverse, first, reshape(SA[1,2,3,4.]',2,2))
2×2 reshape(adjoint(::SVector{4, Float64}), 2, 2) with eltype Float64:
 1.0  0.0
 0.0  0.0

On arrays like these, it does not know how to construct the natural representation, so doesn't try: (I know how, but the fields of the result will not line up with the existing ones.)

julia> Enzyme.gradient(Reverse, sum, Symmetric(rand(3,3)))
(data = [1.0 2.0 2.0; 0.0 1.0 2.0; 0.0 0.0 1.0], uplo = nothing)

julia> Enzyme.gradient(Reverse, first, reshape(LinRange(1,2,4)',2,2))
(parent = (parent = (start = 1.0, stop = 0.0, len = nothing, lendiv = nothing),), dims = (nothing, nothing), mi = ())

Arrays of non-diff objects cannot be wrapped up in array structs:

julia> Enzyme.gradient(Reverse, float∘first, Diagonal([1,2,3]))
(diag = nothing,)

julia> Enzyme.gradient(Reverse, float∘first, SA[1,2,3]')
(parent = (data = (nothing, nothing, nothing),),)

make_zeros uses an IdDict cache to preserve identity between different branches of the struct. At present this does not...

julia> mutable struct TwoThings{A,B}; a::A; b::B; end

julia> nt = (x=TwoThings(3.0, 4.0), y=TwoThings(3.0, 4.0));

julia> nt.x === nt.y
false

julia> grad = Enzyme.gradient(Reverse, nt -> nt.x.a + nt.y.a + 20nt.x.b + 20nt.y.b, nt)
(x = (a = [11.0, 11.0], b = 22.0), y = (a = [11.0, 11.0], b = 22.0))

julia> grad.x === grad.y  # new identity created
true

# example 2

julia> arrs = [[1,2.], [3,4.]];

julia> grad = Enzyme.gradient(Reverse, nt -> sum(sum(sum, x)::Float64 for x in nt), (a = arrs, b = arrs))
(a = [[2.0, 2.0], [2.0, 2.0]], b = [[2.0, 2.0], [2.0, 2.0]])

julia> grad.a === grad.b  # container array identity is not preserved
false

julia> grad.a[1] === grad.b[1]  # array of numbers
true

A simple Flux model, no functional change, just looks different to the model:

julia> using Flux

julia> model = Chain(Embedding(reshape(1:6, 2,3) .+ 0.0), softmax);

julia> Enzyme.gradient(Reverse, m -> sum(abs2, m(1)), model)
(layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),)

Comments:

  • I'm not sure what I think about the failure to preserve === relations between some mutable objects in the original gradient. Some of this could be solved by adding an IdDict cache like make_zeros does.

  • The function called strip_types for now probably needs to be public, so that you can call it yourself after constructing dx = make_zero(x), and so that you can overload it for your array wrappers.

  • Projecting things like Symmetric to their covariant representation probably needs to be opt-in, by somehow telling gradient that you want this. (That's level 4 here: https://github.com/EnzymeAD/Enzyme.jl/issues/1334#issuecomment-2016573247 .) Could be implemented as additional methods of this function, something like strip_types(x, dx, Val(true), cache)?

  • Surely all the code is in the wrong place, and needs tests.

mcabbott avatar Mar 26 '24 05:03 mcabbott