Zygote.jl
Zygote.jl copied to clipboard
how to selectively take structural gradient
In Flux, we typically apply @functor
to a type for 2 purposes:
- for recursively traversing structs and mapping leaves, as done by
gpu
- collecting parameters in a
Zygote.Params
for gradient calculation (this is done byFlux.params(model)
). When we what to distinguish the two behaviors, we useFlux.trainable
for the parameters collection. This is an
using Flux, Zygote
using Flux: @functor
struct B
b1::Array
b2::Array
end
@functor B
struct A
a1::Array
eps::Number
b::B
end
@functor A
Flux.trainable(a::A) = (a.a1,)
a = A(rand(3),0.1,B(rand(2), rand(2)))
Flux.params(a)
#Params([[0.2755365528802143, 0.7419122552485184, 0.048976872406773175]])
loss(a) = a.eps + sum(a.a1) + sum(a.b.b1)
Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to
a.a1
should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect to a.a1
is exposed
julia> g = gradient(() -> loss(a), Flux.params(a))
Grads(...)
julia> g[a.a1]
3-element Fill{Float64}: entries equal to 1.0
julia> g[a.b.b1]
ERROR: KeyError: key [0.7037661100448469, 0.34941543792301455] not found
Stacktrace:
[1] getindex
@ ./iddict.jl:93 [inlined]
[2] getindex(gs::Zygote.Grads, x::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:279
[3] top-level scope
@ REPL[42]:1
[4] top-level scope
@ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52
julia> g = gradient(() -> loss(a), Flux.params(a)).grads
IdDict{Any, Any} with 2 entries:
[0.275537, 0.741912, 0.0489769] => 3-element Fill{Float64}: entries equal to 1.0
:(Main.a) => (a1 = nothing, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing))
With explicit gradient instead, everything is computed and exposed
julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing)),)
This is bad since we would like to feed this to an update!
function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the following
julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0),)
I see two possibilities:
- we make gradient
@functor
/trainable
aware - we pass to gradient a keyword argument for the gradient masking
What do you mean by gradient masking here?
Also this is needed since while the gradient for an argument may not be explicitly asked for, but might be required to compute the gradient of a different argument. Forcing that to nothing
would still work with accum
but give incorrect results.
Preventing updates during optimization could be accomplished with a helper like https://optax.readthedocs.io/en/latest/api.html?highlight=mask#optax.masked. That said, it doesn't account for the scenario where you want to save memory by not holding onto gradients for certain parameters that won't be updated.
Functors.trainable and ChainRulesCore.ProjectTo have quite a bit in common, it's possible they should get to know each other better. I'm not precisely sure why this doesn't work today, but with #1044 it might go something like:
julia> import Zygote.ChainRulesCore: ProjectTo
julia> ProjectTo(a::A) = dA::NamedTuple -> ((; a1 = dA1.a1, eps=nothing, b=nothing),);
julia> gradient(loss, a)
((a1 = Fill(1.0, 3), eps = 1.0, b = (b1 = Fill(1.0, 2), b2 = nothing)),)
((a1 = Fill(1.0, 3), eps = nothing, b = nothing),) # is what I hoped for
The wrench in the works is that Functors doesn't have a trainable
method and isn't even involved when taking explicit gradients. Perhaps it could take a dep on ChainRulesCore?
I'm not precisely sure why this doesn't work today
Is it because _project
is only defined for Numeric
?
Oh right, thanks both. I guess there are many details of this union I don't see yet. But Functors/Optimisers are interested in AD with nested structs, and there might be a nice ChainRules-level way to encode things.
There are two pieces here: (1) not updating non-trainable parameters, and (2) not computing gradients for non-trainable parameters.
For (1), Optimisers.jl uses Functors.jl to walk over the structure and the nested gradient tuple to apply updates. Thanks to https://github.com/FluxML/Functors.jl/pull/14, we can know limit that walk to the parameters defined by trainable
. I think that pretty much takes care of (1).
For (2), if f
outputs a Foo
and df
operates on all the fields of dFoo
, then I don't think you can selectively drop gradients for any fields. A more concrete example:
function make_model(some_hyperparams...)
# do stuff with hyper-params to make W and b
return Dense(W, b) # let's suppose only W is trainable
end
gradient(ps -> loss(make_model(ps...)(x)), ps)
Here it wouldn't make sense for the pullback of (::Dense)(x)
to drop the gradients w.r.t. b
. We'd basically need something like ProjectTo
but dynamic to each gradient
call.
Also, if this is somewhere in the middle of the computation, then I would hope the memory gets re-used once that unnecessary gradient is not in the following pullbacks. I think this is really a concern for only the inputs to the full computation.
This is the blessing and curse of Zygote supporting differentiation of arbitrary structs. AFAIK, there is no way to provide it additional information about what fields should be accum
ed into the final gradient tuple and which can be omitted (excepting intermediate calculations which require them). I'm not sure what a general solution for this would look like—could we make use of ChainRulesCore.Tangent
somehow?
blessing and curse of Zygote supporting differentiation of arbitrary structs
Right, PyTorch autograd's requires_grad
does (2), but it also prevents PyTorch's layers from being as flexible as ours. I feel like any general solution needs to be non-static. Meaning that the masking info is introduced on the gradient call.
Yup. Now if we had a function like trainable
that returned a set of property/field names instead, I wonder if we could dynamically generate tangents with only those fields when in AD. ref. https://juliadiff.org/ChainRulesCore.jl/stable/converting_zygoterules.html
Another (possibly complementary) approach more in line with requires_grad
would be some kind of wrapper type that instructs Zygote to always insert nothing
when creating the gradient tuple. This of course has all the issues commonly associated with array wrappers.
supposedly only the gradient with respect to a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed
Maybe FluxML/Zygote.jl#966 could bring some improvements in that regard.
For my own edification, do thunks help with deeply nested struct or tangent fields? I can wrap my head around how an entire argument might be excluded from evaluation, but not a piece of one.
For my own edification, do thunks help with deeply nested struct or tangent fields?
I have to admit I'm not entirely sure myself, resp. if it will be possible to make the pullback(s) for the struct creation smart enough.
Maybe we will need some kind of hinting procedure at some point, so the user can specify what quantities they want the gradient for, like Enzyme has.
@willtebbutt