PyCallChainRules.jl
PyCallChainRules.jl copied to clipboard
How to update parameters using gradients
I am trying to put together an example of using FastAI.jl to finetune a pretrained ResNet from torchvision and am unsure how to use the output of Zygote.grad on a TorchModuleWrapper to optimize its parameters.
So the question: what is the best way to update the parameters of a nested module? Have you tried this successfully and could share a minimal example?
I've tried applying Flux.Optimise.update! to the README example by iterating over params and gradients manually, but I am not sure if this will work for further nested structures.
model = TorchModuleWrapper(torch_module)
optim = ADAM()
for i in 1:10
grad, = Zygote.gradient(m->loss(m, input, target), model)
for (p, g) in zip(model.params, grad.params)
Flux.Optimise.update!(optim, p, g)
end
end
The above works, while Flux.Optimise.update!(optim, model, grad) does not. Maybe an overload for Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _) is needed?
This won't help FastAI, but I think Optimisers.jl might be able to handle this directly. You might be able to workaround the lack of a Flux.Optimise.update!(::AbstractOptimiser, ::TorchModuleWrapper, _) by manually creating a Params and Grads with model.params and grad.params.
If you see https://github.com/rejuvyesh/PyCallChainRules.jl/blob/92cf084f6c7379d3fa9d5d399f1aa086bbcdc409/src/pytorch.jl#L29 together with
https://github.com/rejuvyesh/PyCallChainRules.jl/blob/92cf084f6c7379d3fa9d5d399f1aa086bbcdc409/src/pytorch.jl#L41-L43, I think it shouldn't be too hard to hook in with the Optimiser interface.
Let me check what's required for Flux.Optimise to work. Probably Flux.trainable needs to be defined?
Edit:
jlmodel = TorchModuleWrapper(torch_module)
Flux.params(jlmodel)
Already gives the correct thing. So probably should just use the gradient calls as used with implicit parameters currently in Flux examples.
functorch (which is being used behind the scenes) collapses the nested parameters in a "flat" tuple of parameters.
Edit: ignore everything below, params works as mentioned above!
Flux.trainable defaults to functor, so unless you want it to return something other than TorchModuleWrapper#params then no additional action should be required. The problem is that update! only has overloads for single arrays and Params on the Flux side: https://github.com/FluxML/Flux.jl/blob/v0.12.9/src/optimise/train.jl#L23-L34. If everything in the wrapper params can be mutated in-place, then I think this would be the overload Lorenz was asking about:
function Flux.Optimise.update!(opt::AbstractOptimiser, model::TorchModuleWrapper, grads::Tangent{<:TorchModuleWrapper})
for (p, g) in zip(model.params, grads.params)
Flux.Optimise.update!(opt, p, g)
end
end
The issue is the interaction between ChainRules and Zygote implicit parameters here?
julia> gs = Flux.gradient(ps) do
loss(input, target)
end
Grads(...)
julia> gs.grads
IdDict{Any, Any} with 9 entries:
:(Main.target) => Float32[-0.00564587 0.238695 … 0.0521358 -0.0587006; 0.128218 0.054992 … -0.047878 -0.0400576]
:(Main.input) => Float32[-0.00423729 7.43475f-6 … -0.00136046 0.00378413; -0.00100987 0.0220082 … 0.00141332 -0.00220777; -0.00780671 -0.0137588 … 0.000454976 0.00193283; -0.00435127 -0.00824403 … -0.00…
Float32[-0.415697 0.274342 … -0.299125 0.418624; 0.462995 -0.144811 … -0.450912 -0.00113821;… => nothing
Float32[-0.219483, -0.424986, 0.194869, -0.148786, -0.185551, -0.296822, 0.302707, -0.086165… => nothing
Float32[0.183148 0.0902577; 0.0673819 0.241829; … ; -0.0106342 0.0899293; 0.180051 0.214607] => nothing
Float32[-0.166111, 0.091679] => nothing
:(Main.jlmodel) => (torch_stateless_module = nothing, dtype = nothing, params = (Float32[-0.000155696 0.00759759 … -0.00265382 0.0210725; -0.00102487 -0.00692717 … 0.00439671 0.00441634; 0.00025565 0.0101…
Float32[-0.137103 0.155661 … -0.0798025 -0.137254; 0.240076 0.218976 … -0.200846 0.0126162; … => nothing
Float32[0.060363, 0.143615, 0.107176, 0.224407, -0.144465, 0.21469, -0.0161946, -0.163965, -… => nothing
Note the gradients are for jlmodel.params not the implicit parameters it seems.
using Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch
input_dim = 4
output_dim = 2
hiddendim = 16
batchsize = 8
torch_module = torch.nn.Sequential(
torch.nn.Linear(input_dim, hiddendim), torch.nn.ReLU(),
torch.nn.Linear(hiddendim, hiddendim), torch.nn.ReLU(),
torch.nn.Linear(hiddendim, output_dim)
)
jlmodel = TorchModuleWrapper(torch_module)
opt = Flux.ADAM(0.1)
input = randn(Float32, input_dim, batchsize)
target = randn(Float32, output_dim, batchsize)
loss(x, y) = Flux.Losses.mse(jlmodel(x), y)
ps = Flux.params(jlmodel)
@info "before" map(sum, ps)
for i in 1:1
gs = Flux.gradient(ps) do
loss(input, target)
end
@show gs.grads
Flux.Optimise.update!(opt, ps, gs)
end
@info "after" map(sum, ps)
As @ToucheSir mentioned, explicit interface in Optimisers.jl should likely work correctly. Will probably need help from people with better understanding of Zygote/ChainRules interaction for this implicit parameters issue.
See examples in #20. Explicit interface from Optimisers.jl seems to work correctly.
I now see why @terasakisatoshi's https://github.com/terasakisatoshi/PCRP.jl/blob/8f88e89ae4e7e5b8f3b41c23a05950487cbe1143/playground/notebook/example_pcr.jl#L110-L120 was necessary.
I now see why @terasakisatoshi's https://github.com/terasakisatoshi/PCRP.jl/blob/8f88e89ae4e7e5b8f3b41c23a05950487cbe1143/playground/notebook/example_pcr.jl#L110-L120 was necessary.
I hope #20 will be a better solution than I suggested 👍
Just so we don't lose the comment on Julia slack, @ToucheSir mentioned:
I see that accum_param is the only general-purpose function call that pushes implicit gradients into the cache. However, the only references I could find to it (and dependents like unwrap ) were direct calls in certain rules. Does this mean that any custom rule that wants to play well with parameter tracking needs to call accum_param? Or is there some code instrumentation going on that I'm missing?
I fear the answer is no :frowning_face:. I'm not sure if there's anything we can do in Zygote to track implicit gradients for nested parameter arrays when the entire module is fed into a rrule off the bat, but I will defer to the AD experts on this.