PyCallChainRules.jl icon indicating copy to clipboard operation
PyCallChainRules.jl copied to clipboard

How to update parameters using gradients

Open lorenzoh opened this issue 3 years ago • 12 comments

I am trying to put together an example of using FastAI.jl to finetune a pretrained ResNet from torchvision and am unsure how to use the output of Zygote.grad on a TorchModuleWrapper to optimize its parameters.

So the question: what is the best way to update the parameters of a nested module? Have you tried this successfully and could share a minimal example?

I've tried applying Flux.Optimise.update! to the README example by iterating over params and gradients manually, but I am not sure if this will work for further nested structures.

model = TorchModuleWrapper(torch_module)
optim = ADAM()

for i in 1:10
    grad, = Zygote.gradient(m->loss(m, input, target), model)
    for (p, g) in zip(model.params, grad.params)
        Flux.Optimise.update!(optim, p, g)
    end
end

The above works, while Flux.Optimise.update!(optim, model, grad) does not. Maybe an overload for Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _) is needed?

lorenzoh avatar Mar 18 '22 18:03 lorenzoh

This won't help FastAI, but I think Optimisers.jl might be able to handle this directly. You might be able to workaround the lack of a Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _) by manually creating a Params and Grads with model.params and grad.params.

ToucheSir avatar Mar 18 '22 19:03 ToucheSir

If you see https://github.com/rejuvyesh/PyCallChainRules.jl/blob/92cf084f6c7379d3fa9d5d399f1aa086bbcdc409/src/pytorch.jl#L29 together with https://github.com/rejuvyesh/PyCallChainRules.jl/blob/92cf084f6c7379d3fa9d5d399f1aa086bbcdc409/src/pytorch.jl#L41-L43, I think it shouldn't be too hard to hook in with the Optimiser interface. Let me check what's required for Flux.Optimise to work. Probably Flux.trainable needs to be defined?

Edit:

jlmodel = TorchModuleWrapper(torch_module)
Flux.params(jlmodel)

Already gives the correct thing. So probably should just use the gradient calls as used with implicit parameters currently in Flux examples.

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

functorch (which is being used behind the scenes) collapses the nested parameters in a "flat" tuple of parameters.

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

Edit: ignore everything below, params works as mentioned above!

Flux.trainable defaults to functor, so unless you want it to return something other than TorchModuleWrapper#params then no additional action should be required. The problem is that update! only has overloads for single arrays and Params on the Flux side: https://github.com/FluxML/Flux.jl/blob/v0.12.9/src/optimise/train.jl#L23-L34. If everything in the wrapper params can be mutated in-place, then I think this would be the overload Lorenz was asking about:

function Flux.Optimise.update!(opt::AbstractOptimiser, model::TorchModuleWrapper, grads::Tangent{<:TorchModuleWrapper})
    for (p, g) in zip(model.params, grads.params)
        Flux.Optimise.update!(opt, p, g)
    end
end

ToucheSir avatar Mar 18 '22 20:03 ToucheSir

The issue is the interaction between ChainRules and Zygote implicit parameters here?

julia> gs = Flux.gradient(ps) do 
           loss(input, target)
       end
Grads(...)

julia> gs.grads
IdDict{Any, Any} with 9 entries:
  :(Main.target)                                                                                => Float32[-0.00564587 0.238695 … 0.0521358 -0.0587006; 0.128218 0.054992 … -0.047878 -0.0400576]
  :(Main.input)                                                                                 => Float32[-0.00423729 7.43475f-6 … -0.00136046 0.00378413; -0.00100987 0.0220082 … 0.00141332 -0.00220777; -0.00780671 -0.0137588 … 0.000454976 0.00193283; -0.00435127 -0.00824403 … -0.00…
  Float32[-0.415697 0.274342 … -0.299125 0.418624; 0.462995 -0.144811 … -0.450912 -0.00113821;… => nothing
  Float32[-0.219483, -0.424986, 0.194869, -0.148786, -0.185551, -0.296822, 0.302707, -0.086165… => nothing
  Float32[0.183148 0.0902577; 0.0673819 0.241829; … ; -0.0106342 0.0899293; 0.180051 0.214607]  => nothing
  Float32[-0.166111, 0.091679]                                                                  => nothing
  :(Main.jlmodel)                                                                               => (torch_stateless_module = nothing, dtype = nothing, params = (Float32[-0.000155696 0.00759759 … -0.00265382 0.0210725; -0.00102487 -0.00692717 … 0.00439671 0.00441634; 0.00025565 0.0101…
  Float32[-0.137103 0.155661 … -0.0798025 -0.137254; 0.240076 0.218976 … -0.200846 0.0126162; … => nothing
  Float32[0.060363, 0.143615, 0.107176, 0.224407, -0.144465, 0.21469, -0.0161946, -0.163965, -… => nothing

Note the gradients are for jlmodel.params not the implicit parameters it seems.

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

using Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch

input_dim = 4
output_dim = 2
hiddendim = 16

batchsize = 8

torch_module = torch.nn.Sequential(
                            torch.nn.Linear(input_dim, hiddendim), torch.nn.ReLU(),
                            torch.nn.Linear(hiddendim, hiddendim), torch.nn.ReLU(),
                            torch.nn.Linear(hiddendim, output_dim)
                        )
jlmodel = TorchModuleWrapper(torch_module)

opt = Flux.ADAM(0.1)

input = randn(Float32, input_dim, batchsize)
target = randn(Float32, output_dim, batchsize)

loss(x, y) = Flux.Losses.mse(jlmodel(x), y)
ps = Flux.params(jlmodel)
@info "before" map(sum, ps)
for i in 1:1
    gs = Flux.gradient(ps) do 
        loss(input, target)
    end
    @show gs.grads
    Flux.Optimise.update!(opt, ps, gs)
end
@info "after" map(sum, ps)

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

As @ToucheSir mentioned, explicit interface in Optimisers.jl should likely work correctly. Will probably need help from people with better understanding of Zygote/ChainRules interaction for this implicit parameters issue.

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

See examples in #20. Explicit interface from Optimisers.jl seems to work correctly.

rejuvyesh avatar Mar 18 '22 20:03 rejuvyesh

I now see why @terasakisatoshi's https://github.com/terasakisatoshi/PCRP.jl/blob/8f88e89ae4e7e5b8f3b41c23a05950487cbe1143/playground/notebook/example_pcr.jl#L110-L120 was necessary.

rejuvyesh avatar Mar 19 '22 00:03 rejuvyesh

I now see why @terasakisatoshi's https://github.com/terasakisatoshi/PCRP.jl/blob/8f88e89ae4e7e5b8f3b41c23a05950487cbe1143/playground/notebook/example_pcr.jl#L110-L120 was necessary.

I hope #20 will be a better solution than I suggested 👍

terasakisatoshi avatar Mar 19 '22 04:03 terasakisatoshi

Just so we don't lose the comment on Julia slack, @ToucheSir mentioned:

I see that accum_param is the only general-purpose function call that pushes implicit gradients into the cache. However, the only references I could find to it (and dependents like unwrap ) were direct calls in certain rules. Does this mean that any custom rule that wants to play well with parameter tracking needs to call accum_param? Or is there some code instrumentation going on that I'm missing?

rejuvyesh avatar Mar 21 '22 20:03 rejuvyesh

I fear the answer is no :frowning_face:. I'm not sure if there's anything we can do in Zygote to track implicit gradients for nested parameter arrays when the entire module is fed into a rrule off the bat, but I will defer to the AD experts on this.

ToucheSir avatar Mar 22 '22 04:03 ToucheSir