Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

how to selectively take structural gradient

Open CarloLucibello opened this issue 3 years ago • 15 comments

In Flux, we typically apply @functor to a type for 2 purposes:

  1. for recursively traversing structs and mapping leaves, as done by gpu
  2. collecting parameters in a Zygote.Params for gradient calculation (this is done by Flux.params(model)). When we what to distinguish the two behaviors, we use Flux.trainable for the parameters collection. This is an
using Flux, Zygote
using Flux: @functor

struct B
   b1::Array
   b2::Array
end
@functor B

struct A
   a1::Array
   eps::Number
   b::B
end
@functor A
Flux.trainable(a::A) = (a.a1,)

a = A(rand(3),0.1,B(rand(2), rand(2)))

Flux.params(a) 
#Params([[0.2755365528802143, 0.7419122552485184, 0.048976872406773175]])

loss(a) = a.eps + sum(a.a1) + sum(a.b.b1)

Now when ones computes the gradient in the implicit form, supposedly only the gradient with respect to a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed, but at least only the one with respect to a.a1 is exposed

julia> g = gradient(() -> loss(a), Flux.params(a))
Grads(...)

julia> g[a.a1]
3-element Fill{Float64}: entries equal to 1.0

julia> g[a.b.b1]
ERROR: KeyError: key [0.7037661100448469, 0.34941543792301455] not found
Stacktrace:
 [1] getindex
   @ ./iddict.jl:93 [inlined]
 [2] getindex(gs::Zygote.Grads, x::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:279
 [3] top-level scope
   @ REPL[42]:1
 [4] top-level scope
   @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52

julia> g = gradient(() -> loss(a), Flux.params(a)).grads
IdDict{Any, Any} with 2 entries:
  [0.275537, 0.741912, 0.0489769] => 3-element Fill{Float64}: entries equal to 1.0
  :(Main.a)                       => (a1 = nothing, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing))

With explicit gradient instead, everything is computed and exposed

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0, eps = 1.0, b = (b1 = 2-element Fill{Float64}: entries equal to 1.0, b2 = nothing)),)

This is bad since we would like to feed this to an update! function, and also inefficient. How do we tell Zygote to drop some model parts from the gradient computation? I would like the following

julia> gradient(a -> loss(a), a)
((a1 = 3-element Fill{Float64}: entries equal to 1.0),)

I see two possibilities:

  • we make gradient @functor/trainable aware
  • we pass to gradient a keyword argument for the gradient masking

CarloLucibello avatar Jul 26 '21 07:07 CarloLucibello

What do you mean by gradient masking here?

Also this is needed since while the gradient for an argument may not be explicitly asked for, but might be required to compute the gradient of a different argument. Forcing that to nothing would still work with accum but give incorrect results.

DhairyaLGandhi avatar Jul 26 '21 07:07 DhairyaLGandhi

Preventing updates during optimization could be accomplished with a helper like https://optax.readthedocs.io/en/latest/api.html?highlight=mask#optax.masked. That said, it doesn't account for the scenario where you want to save memory by not holding onto gradients for certain parameters that won't be updated.

ToucheSir avatar Jul 26 '21 17:07 ToucheSir

Functors.trainable and ChainRulesCore.ProjectTo have quite a bit in common, it's possible they should get to know each other better. I'm not precisely sure why this doesn't work today, but with #1044 it might go something like:

julia> import Zygote.ChainRulesCore: ProjectTo

julia> ProjectTo(a::A) = dA::NamedTuple -> ((; a1 = dA1.a1, eps=nothing, b=nothing),);

julia> gradient(loss, a)
((a1 = Fill(1.0, 3), eps = 1.0, b = (b1 = Fill(1.0, 2), b2 = nothing)),)
((a1 = Fill(1.0, 3), eps = nothing, b = nothing),)  # is what I hoped for

mcabbott avatar Jul 30 '21 17:07 mcabbott

The wrench in the works is that Functors doesn't have a trainable method and isn't even involved when taking explicit gradients. Perhaps it could take a dep on ChainRulesCore?

ToucheSir avatar Jul 30 '21 17:07 ToucheSir

I'm not precisely sure why this doesn't work today

Is it because _project is only defined for Numeric?

darsnack avatar Jul 30 '21 17:07 darsnack

Oh right, thanks both. I guess there are many details of this union I don't see yet. But Functors/Optimisers are interested in AD with nested structs, and there might be a nice ChainRules-level way to encode things.

mcabbott avatar Jul 30 '21 17:07 mcabbott

There are two pieces here: (1) not updating non-trainable parameters, and (2) not computing gradients for non-trainable parameters.

For (1), Optimisers.jl uses Functors.jl to walk over the structure and the nested gradient tuple to apply updates. Thanks to https://github.com/FluxML/Functors.jl/pull/14, we can know limit that walk to the parameters defined by trainable. I think that pretty much takes care of (1).

For (2), if f outputs a Foo and df operates on all the fields of dFoo, then I don't think you can selectively drop gradients for any fields. A more concrete example:

function make_model(some_hyperparams...)
    # do stuff with hyper-params to make W and b
    return Dense(W, b) # let's suppose only W is trainable
end

gradient(ps -> loss(make_model(ps...)(x)), ps)

Here it wouldn't make sense for the pullback of (::Dense)(x) to drop the gradients w.r.t. b. We'd basically need something like ProjectTo but dynamic to each gradient call.

Also, if this is somewhere in the middle of the computation, then I would hope the memory gets re-used once that unnecessary gradient is not in the following pullbacks. I think this is really a concern for only the inputs to the full computation.

darsnack avatar Jul 30 '21 18:07 darsnack

This is the blessing and curse of Zygote supporting differentiation of arbitrary structs. AFAIK, there is no way to provide it additional information about what fields should be accumed into the final gradient tuple and which can be omitted (excepting intermediate calculations which require them). I'm not sure what a general solution for this would look like—could we make use of ChainRulesCore.Tangent somehow?

ToucheSir avatar Jul 30 '21 18:07 ToucheSir

blessing and curse of Zygote supporting differentiation of arbitrary structs

Right, PyTorch autograd's requires_grad does (2), but it also prevents PyTorch's layers from being as flexible as ours. I feel like any general solution needs to be non-static. Meaning that the masking info is introduced on the gradient call.

darsnack avatar Jul 30 '21 18:07 darsnack

Yup. Now if we had a function like trainable that returned a set of property/field names instead, I wonder if we could dynamically generate tangents with only those fields when in AD. ref. https://juliadiff.org/ChainRulesCore.jl/stable/converting_zygoterules.html

ToucheSir avatar Jul 30 '21 19:07 ToucheSir

Another (possibly complementary) approach more in line with requires_grad would be some kind of wrapper type that instructs Zygote to always insert nothing when creating the gradient tuple. This of course has all the issues commonly associated with array wrappers.

ToucheSir avatar Jul 31 '21 05:07 ToucheSir

supposedly only the gradient with respect to a.a1 should be computed. This appears to not be exactly currently true, every gradient seems to be computed

Maybe FluxML/Zygote.jl#966 could bring some improvements in that regard.

oschulz avatar Aug 16 '21 20:08 oschulz

For my own edification, do thunks help with deeply nested struct or tangent fields? I can wrap my head around how an entire argument might be excluded from evaluation, but not a piece of one.

ToucheSir avatar Aug 16 '21 21:08 ToucheSir

For my own edification, do thunks help with deeply nested struct or tangent fields?

I have to admit I'm not entirely sure myself, resp. if it will be possible to make the pullback(s) for the struct creation smart enough.

Maybe we will need some kind of hinting procedure at some point, so the user can specify what quantities they want the gradient for, like Enzyme has.

oschulz avatar Aug 16 '21 21:08 oschulz

@willtebbutt

oxinabox avatar Aug 20 '21 17:08 oxinabox