Flux.jl
Flux.jl copied to clipboard
Gradient Interface Design
We expose a pullback/vjp-based API for gradients (y, back = forward(f, x); x̄ = back(ȳ)), with x̄ = gradient(f, x) as simple syntax sugar of top of this. This interface is pretty awesome – gradient aligns nicely with the mathematical and intuitive notions of a derivative operator, it naturally expresses nested derivatives, and you can build pretty much any other AD-related functionality (checkpointing, forward mode, gradient hooks, etc) on top of pullbacks, without having to go into AD internals. So far I haven't come across anything that pullbacks can't do straightforwardly; in one case the PyTorch-style back! may be slightly more convenient, but it's overall more cumbersome and requires more knowledge of internals.
However, a challenge of the "mathematical" gradient operator is that it's cumbersome to pass in all our parameter arrays explicitly (gradient(resnet, W1, b1, W2, b2, ...)). So we need to be able to handle taking gradients of large models without it being cumbersome. ~~There are currently two ideas about how to do this: the structural approach and the implicit approach.~~
Edit: Since writing this I have convinced myself that we can get the convenience of implicit params by slightly generalising the structural approach. I think this gives us a clear path forward, though unfortunately it does mean additional API churn.
Structural Gradients
The first approach (which Zygote will support whatever happens, and could be added to Flux) is to take the gradients w.r.t. some structure containing all of the parameters. The structure could be a dictionary or list, but it's usually convenient to combine the weight structure with the definition of the forward pass. This is effectively a closure, which we refer to as a "layer". Layers can contain other layers (we often call a compound layer a "model", but there's no fundamental difference). Taking a gradient looks like this:
m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...
m̄ = gradient(m -> loss(m(x), y), m)
This looks pretty weird at first but makes a lot of sense once it clicks. One then carries out the update step m .+= m̄.
Implicit Gradients
The implicit approach is what Flux supports natively, though it works in Zygote as well. In this case we ask for gradients of a shapeless set of parameters, which are implicitly used at some point during the forward pass. In this case we have something more like:
m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
x, y = ...
θ = params(m)
θ̄ = gradient(() -> loss(m(x), y), θ)
θ̄ is a dictionary from param to gradient. One then loops over the parameters, doing p .+= θ̄[p].
Do we need implicit parameters?
Implicit parameters have some downsides. They feel somewhat less clean and functional than structural ones. It does not support scalars or immutable arrays well, which are needed for more restrictive backends like TPUs; supporting both means having more than one way to do things.
However, implicit parameters have a huge advantage: they make it easy to write "script-like" models. I see them as being a lot like global variables: sure they're a bit unclean, but sometimes it's just convenient to build up a model gradually in a notebook, without a lot of structure (and if I have one non-negotiable rule of API design, it's that you should never have to define a struct/class or satisfy an interface to use a library). Our VAE model is a nice example of this style which I think would be made significantly more cumbersome otherwise.
A potential solution is to make it easier to define "anonymous layers". In my ideal world these would also just be closures, but ~~unfortunately this isn't workable~~ (see discussion below) – closures don't explicitly store parameters when they are closed over from global scope, making them invisible to structural AD. Functions that return closures would be completely fine, but the distinction is too subtle / tied to implementation details.
Other concerns
A couple of other small subtleties.
In the implicit style parameter identity matters, which means we can reuse parameters when creating a model. For example:
d = Dense(10, 10, relu)
m = Chain(d, d, softmax)
dm = gradient(...)
In the implicit parameter style the weights of d are shared, but in the structural version we get two separate gradients for d at each point in the chain, and we'd have to construct the chain inside the forward pass to get the gradient we want. Similar issues come up in nested AD. I don't think either behaviour is more correct or better – both are well-defined and predictable – but they are different.
[This is further complicated by the fact that in-place updates mean the weight is effectively shared even in the structural case, just with weird semantics in optimiser state.]
This has some relevance to RNNs, which currently weight-ties the initial and current state fields. I don't think this needs to result in user-facing changes though. It's also possible to make a nice pure-functional RNN interface (see e.g. JAX), you just can't abstract over RNNs quite like we currently do; state needs to be a bit more explicitly managed (which isn't necessarily a deal breaker, but worth considering).
[To be clear though, while the structural approach is more functional, it does not force differentiated programs to be functional themselves, so something very like our current RNN design is still possible.]
With JAX/stax, there's no straightforward way to tie weights right now short of writing your own init_fn and apply_fn (and, separately, no one that I know of has tried building RNNs so we don't really know how nice the interface will feel for that). Having an intuitive way to do this is pretty important, and to me (and given Julia's pervasive reference semantics) your example in "Other concerns" should definitely have the effect of tying the weights of d.
Ok, I think I finally have this figured out. The trick is that you need the gradient of the module as if it were any other struct. Then global variables – fields of the module – can be treated just like any other parameter / struct field. This means that we can go full on with closures being layers, and vice vera.
This doesn't directly address James' concern, except in that the "custom layer" is now a one-line lambda and everything just works. Might be worth having some challenge cases, but I think it will look pretty good.
Would this allow for model/ layer inheritance?
In what sense? This does make it ridiculously easy to extend layer behaviour, e.g.
l1 = Dense(10, 5)
l2 = x -> relu.(l1(x))
and l2 can now be used anywhere l1 could have been. But I'm not sure if you mean something else by "inheritance".
That does look really nice.
I'm thinking something more along the lines of building a library like this: https://github.com/rusty1s/pytorch_geometric
The code example shows a messaging passing layer inheriting from and Edge Convolution. This could be done by composition, but sometimes inheritance (or traits) works better. Another example would be having a user specialize an "abstract transformer" type.
In the general sense it would be differentiating (differentiable) programs written with full use of the type system.
Yeah, we are of course limited by what Julia can express here, so still can't do Python-style inheritance exactly. But it'd be easy to do this via composition (which can mean a library layer wrapping custom one, as well as the other way around). e.g.
mlp = Chain(Dense(...), ...)
model = MessagePassing() do x
mlp(x) |> softmax
end
(I'm just making something up here, because I don't actually know what MessagePassing needs – you'd probably need to pass two functions in – but hopefully that gives a flavour).
That make sense. What about making MessagePassing an abstract type? I'm not sure if that would make sense in this instance, but let's say generally you have a functions transform1 that calls transform2 which both take abstractlayer1. Then you have a user subtype abstractlayer1 to layer1, include some learnable parameters and maybe overload one or more of the functions in the lattice.
I think the important point from an AD perspective is that if you can write it down, it will work :) Feel free to open a new issue to discuss specifics of layer design. I'm not personally a big fan of using abstract types to simulate inheritance, but I'm happy to discuss that / how it interacts with AD anyway.
Question about this proposal in terms of functional AD. How does x̄ = gradient(f, x) know what x is unless it's defined previously. Comparing to jax now which has a very nice df = jax.grad(f) which returns a function that computes the gradient of f (by default wrt the first argument). This can be changed to other arguments like dfdy = jax.grad(f,argnum=(1)).
This is nice because I can define a function, and then give it to jax.grad to get the gradient function, and do this recursively for higher order gradients.
In the proposal how would this functional gradient look?
df = x->gradient(f,x)
It's not obvious to me that this is doing the same thing. For instance, would this compile to a function that computes the gradient of f wrt x in the same sense that the jax version does?
Yeah, both ways of writing this are equivalent (and identical in terms of performance etc). You can implement grad as
grad(f) = x -> gradient(f, x)[1]
And implement gradient back on top of this as
gradient(f, x) = grad(f)(x)
(modulo tweaks for argnum, multiple arguments etc.)
I don't like argnum because you can do this easily with a lambda, e.g. rather than grad(*, argnum=2)(x, y) you can do gradient(y -> x*y, y).
I also prefer not to provide curried versions of function where possible, but that's really more of a style thing. If we can come up with a good name for it I wouldn't oppose having a grad equivalent in Zygote.
I'm playing with this interface recently with my own staff by using
gradient(()->m(input), m)
where m is my model defined with Chain, Dense etc. and gradient will give me a NamedTuple as proposed above. However, in my case, I will have some post processing for the gradients (sort of policy gradient), it seems a bit in-convenient when the gradients are stored as structure. But I do believe the explicit structural gradient is pretty natural given gradient is the adjoint of model parameters in some sense.
So I'm wondering if we could make Zygote return a Grad{T} type where T is the original type instead of returning a NamedTuple, and Grad{T} can be used to dispatch methods that T has while it can be added to original. So we can easily use multiple dispatch to traverse the nodes in the model, e.g
foo(::Grad{<:Dense}) = # blabla
foo(x::Grad{<:Chain}) = foreach(foo, grad.layers)
Does anyone know if this hit a dead end or was ruled to be to complex to implement? I think the "closures are layers" aspect would help immensely for https://github.com/FluxML/ML-Coordination-Tracker/issues/10, while the "structural approach" would make loadparams! and co. a lot nicer to use.
Not a dead end, it's my preference to use this interface more and make it stable. We have https://github.com/FluxML/XLA.jl/pull/5/files#diff-2da3a01fb49af8d3ca12681d630be0e89f22536d6cb8322f6f6d239699bfd28f and FluxML/Optimisers.jl#3 which just need the rules to be ported over now. We can move ahead with this fairly swiftly.
@MikeInnes Thanks for all your work!
In JAX one can compute gradients with respect to nested dictionaries, a simple example is in the README here:
https://github.com/anhinga/jax-pytree-example
I wonder how difficult would it be to do something similar in Zygote.jl
Have you tried doing this? All the pieces appear to be in place, and if you run into issues please file them at Zygote. That said, since your JAX example only uses string keys, it would be far more efficient to use namedtuples for the same purpose in Julia.
Yeah it works well already. eg
julia> gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
(Dict("foo" => Dict("bar" => 10)),)
@ToucheSir @MikeInnes Thanks!
(My mistake was trying to coerce a dictionary into Params.)