Zygote.jl
Zygote.jl copied to clipboard
Implement `frule_via_ad`?
At present, Zygote will use forward mode AD (outsourced to FowardDiff) under 2 circumstances:
- Based on a heuristic for broadcasting
- Upon an explicit call to
Zygote.forwarddiff
As shown by https://github.com/oschulz/ForwardDiffPullbacks.jl, there are a number of cases where being able to make an rrule actually run forward mode AD would be a boon for performance. One particularly salient example from Flux would be RNN pointwise broadcasts, which are currently unfused by Zygote for a massive compute + memory penalty. However, given we are simultaneously moving away from using Zygote-specific APIs downstream, defining rrule(::typeof(pointwise_op), xs...) = Zygote.forwarddiff(...) is a non-starter. Hence, my proposal is to expose the standard frule_via_ad so that downstream code can remain AD agnostic. Under the hood, this would work much the same as Zygote.forwarddiff or ForwardDiffPullbacks.fwddiff do now. It may even be possible to share some implementation details with one of those functions.
Note that this is not a request to make frule_via_ad differentiable in reverse mode. Users would still be responsible for writing their own rrules, but one could imagine swapping out Zygote for Diffractor (which already implements frule_via_ad) without making any code changes. Guarding on RuleConfig{>:HasForwardsMode} would be enough to ensure compatibility with ADs which do not support forward mode.
Zygote could surely define frule_via_ad(::Config, xdot, f, x::Real...) using Dual numbers. Maybe a little wider but arbitrary structs would be tricky. How this would interact with 2nd derivatives is not so clear to me.
being able to make an rrule actually run forward mode AD would be a boon for performance. ... memory penalty
I think the memory savings from ForwardDiffPullbacks.jl come from being willing to run the function twice.
Zygote should never do this, right now (unless you call Zygote.forwarddiff). Instead, its dual number broadcasting saves the partials on the forward pass, and its reverse-mode broadcasting saves the pullbacks, which close over similar data. If you need to save this data, then I don't think a fused forward mode broadcast necessarily saves memory compared to an unfused one, as the intermediate values which the un-fusing materialises would otherwise be captured in the closures / the array of partials.
Whether running twice is safe, and is more efficient than saving, seem hard to infer in general. Calling fwddiff(f).(args...) is one way to explicitly annotate that it should happen for a particular case. It's not particularly easy to apply to things like f.(g.(x .+ y)./2, h.(z)). Surely ForwardDiffPullbacks.jl could be upgraded to allow something like @fwddiff f.(g.(x .+ y)./2, h.(z)). And if Zygote had frule_via_ad, then ForwardDiffPullbacks.jl (or a similar package) could implement this via ChainRulesCore.jl not explicitly via ForwardDiff.jl. I think that's your motivation here?
Zygote's dual number broadcasting is occasionally surprising, and possibly if we had some such approach, then reverting the default to be its own reverse mode would be worth considering.
Re fusion, the approach of https://github.com/JuliaDiff/Diffractor.jl/pull/68 is to fuse the result of a few simple operations like +,-,*,/ with the next function, since running these twice is cheap & safe. In many examples, this avoids a lot of the waste of un-fused broadcast.
Maybe a little wider but arbitrary structs would be tricky.
AFAICT ForwardDiffPullbacks supports this, but I think bailing on them as not implemented could also be an option.
I think the memory savings from ForwardDiffPullbacks.jl ... I don't think a fused forward mode broadcast necessarily saves memory compared to an unfused one, as the intermediate values which the un-fusing materialises would otherwise be captured in the closures / the array of partials.
Do the pullbacks not close over strictly more data for a fused (then unfused) broadcast with more than a couple of functions? Even if not, Zygote is already paying this price because each unfused broadcast I care about is hitting the dual number path. That's to say nothing of other fixed overheads such as GPU kernel launches.
Whether running twice is safe, and is more efficient than saving, seem hard to infer in general.
I don't think running twice needs to be a strict requirement. Running once is fine as well, the point is to have a way of specifying "Use forward mode for this, I know what I'm doing" that doesn't require tight coupling to any particular AD. Inferring intent is thus out of scope, so it would be on the user to define combined(x, y, z) = f(g(x + y)/2, h(z)) and call combined.(x, y, z) instead of writing the fused broadcast.
Now, combined.(...) will hit the optimized + GPU friendly dual number path in Zygote, but something like https://github.com/FluxML/Flux.jl/blob/master/src/layers/recurrent.jl#L141 would not even if pulled out into a separate function, because it takes a non-singleton type (the activation function) as an argument. Unless Zygote's heuristics can be tuned to allow this, I think an explicit opt-in mechanism is the only way to go.
Surely ForwardDiffPullbacks.jl could be upgraded to allow something like
@fwddiff f.(g.(x .+ y)./2, h.(z)). And if Zygote hadfrule_via_ad, then ForwardDiffPullbacks.jl (or a similar package) could implement this via ChainRulesCore.jl not explicitly via ForwardDiff.jl.
Since my motivation for this is to use it in Flux, we'd have to be okay with tying ourselves to a library like ForwardDiffPullbacks.jl. That may be fine, but with the way things are going wrt AD agnosticism I just assumed it was a non-starter.
Diffractor.jl#68 is a good general-purpose solution. However, two advantages we have on the library dev side are a) deeper knowledge of which operations will or will not be more efficient (un)fused, and b) more tolerance for adding not-so-pretty code to manually fuse bits that make sense. For RNNs and loss functions in particular, the memory savings (both quantity and allocation overhead) and avoidance of multiple kernel launches far outweigh any benefits of being able to reuse intermediates from the forward pass on GPU.
Do the pullbacks not close over strictly more data for a fused (then unfused) broadcast with more than a couple of functions?
I think they will just close over what they actually use. Which can be more (if they close over data which you have separately in the source and results array) or less (if your un-fused broadcast saves intermediates which aren't in fact required).
Maybe I'm a bit unclear what's being compared though. The only saving of pullbacks right now (in Zygote, or Diffractor68) is for just one (un-fused) broadcast. Sometimes this does take less memory than Zygote's Dual number story for the same broadcast --- the Partials stored contain one sensitivity per input array. I did write at some point a version which saved the pullbacks for the whole fused function, but it was super-slow.
Ah right, I think the whole saved memory in pullbacks part was a distraction from the main goal. A motivating example:
@. f(g(h(x)))
# becomes
%1 = h.(x) # allocates O(length(x)), 1 kernel launch on GPU
%2 = g.(%1) # allocates + kernel launch
%3 = f.(%2) # allocates + kernel launch
# in the generated primal, whereas:
fgh(x) = f(g(h(x)))
fgh.(x) # allocates O(length(x)), 1 kernel launch
I've omitted the generated pullback since it's more or less the same story there.
The problem is that, currently:
higher_order(f, x, y) = f(x, y)
higher_order.(+, x, y) # falls back to slow, non-GPU compatible path
Because functions are not singleton types. I'm singling out functions here since they're top of mind for implementing RNN cells, but it's possible to think of cases for other non-singleton types as well. Zygote has to be conservative because it doesn't have enough information on whether this broadcasting is AD-safe, but we do and thus having a "I know what I'm doing" switch would be very desirable.
Edit:
The caller may well know this.
This is exactly what I'm driving at. Even if the caller does know this, they can't express it since the AD doesn't provide a mechanism for doing so. Well, it kind of does, but only if you couple yourself tightly to its API. To my knowledge, frule_via_ad is the only remotely cross-library interface for this.
I don't think running twice needs to be a strict requirement. Running once is fine as well, the point is to have a way of specifying "Use forward mode for this
This still seems slightly confusing to me. Trying to separate the pieces:
-
There's a correctness question: "May I run this unknown function twice?" Zygote assumes not,
fwddiffassumes yes. Absent a sufficiently smart compiler, some annotation seems necessary to allow this. -
There's a function cost: "Is the forward pass much more expensive than the reverse?" On the CPU
tanhis like that, gradient is trivial provided you have the result, but getting the result is quite slow. The caller may well know this. -
There's a memory cost. The less you save, the better, obviously. Maybe sometimes you can avoid saving things by being smarter (e.g. proving that you don't need the output) but this is tricky. If you avoid saving things by running functions again later, then you must have 1, and you may need to trade this off against 2.
-
There's a how-many-launches cost. Here fusion obviously helps, whether it's a fused save-everything (always allowed), or a fused run-it-all-again calculation (depends on 1, and 2). At present, generic fused reverse mode is going to be really slow.
If I understand right, ForwardDiffPullbacks.jl isn't quite as fused as it could be for 4, as it works out the gradient for each function argument as a separate pass. fwddiff(f∘g).(x) is one gradient call, but fwddiff(f).(x,y,z) is three, I think?
I see that ForwardDiffPullbacks.jl does claim to support all static structs, sorry somehow I thought it was just numbers, tuples, etc. (Haven't looked but it sounds like that must make reconstruction assumptions.)
[Edit: one more:]
- Some broadcasts have arrays of different sizes. Any fused treatment of
f.(mat .+ g.(vec .+ 1)./2)will evaluategand its gradient N^2 times (and allocate accordingly) while an un-fused treatment will evaluate it only N times, which is great ifgis expensive. Possibly you can read this as saying that broadcasting makes no guarantee about the number of executions, 1.
If it helps, we can act as if ForwardDiffPullbacks was never mentioned in this discussion. I brought it up solely as an example of why it might be nice to have a way to do rrule(::typeof(f), args...) = # run forward AD that a) doesn't require direct reference to Zygote or ForwardDiff APIs, b) isn't beholden to Zygote's current set of necessarily conservative heuristics.
Because functions are not singleton types.
I think that right now, Zygote's GPU broadcast unconditionally uses the stored-Dual path. It will silently ignore things which aren't real numbers, which is #1215:
https://github.com/FluxML/Zygote.jl/blob/a133200422e4f12e1d7266f5825154054faf0d9a/src/lib/broadcast.jl#L270-L271
It's the generic case which has half a dozen checks:
https://github.com/FluxML/Zygote.jl/blob/a133200422e4f12e1d7266f5825154054faf0d9a/src/lib/broadcast.jl#L187-L188
That siad, for the RNN example σ.(Wi*x .+ Wh*h .+ b) we need not broadcast a higher-order function. We could instead close over the activation function:
bcplus(Wx, Wh, b) = tanh(Wx + Wh + b)
Base.issingletontype(typeof(bcplus)) # true
bcplus.(Wi*x, Wh*h, b)
Even on the CPU, this should hit the broadcast_forward path, which will save an array of Duals with 3 partials each. Whether that's more efficient than the present state (where there's a custom rule for + and for tanh) I don't know, I think it ought to save some launches.
Here, Diffractor#68 should be almost optimal, as it will fuse the +. But it won't fuse f(g(h(x))). Are there great examples of that from the wild which might guide us?
Even if the caller does know this, they can't express it since the AD doesn't provide a mechanism for doing so.
Yes. I'm with you that there ought to be more ways to control this. It's just a bit tricky to work out precisely which bits of information it's most useful to provide, and how they ought to be provided. Writing your own function like fgh or bcplus is one way which requires no API (but does require arcane knowledge!). We could provide rules for (f∘g∘h).(x) if that would often be a useful (arcane?) API. fwddiff is another but (as you say) ties you to ForwardDiff.
Does bcplus actually close over the activation function? I'm pretty sure a closure/custom type for a partially applied function would not be a singleton type, and since the activation function is configurable we can't just inline it (we could with a generated function, but let's not go there!)
The short-term solution for RNNs and loss functions would be to define a fused function only for GPU arrays. That's better than nothing, but ideally we could avoid more GPU-specific methods in Flux itself. I suppose this is as good of an argument as any for moving more bits into NNlib ASAP...
Maybe I'm not sure what you mean by close over here? The typeof(tanh) specifies the function, and this is known to the compiler from typeof(bcf). I guess (tanh∘+) is the more literal one, and this too has a singleton type. (Fix1 is more complicated, IIRC.)
Interesting, for whatever reason I was not able to cook up an example that was a singleton type before, but testing again both ComposedFunction and Fix1 are transitively singleton types if their arguments are.
The other thing I forgot, more closely relevant to the tile: frule((dx, dy), f, x, y) and frule_via_ad don't handle chunking. On each call of f, you can push forward only one dx, dy. So for f.(x, y, z) you would need 3 evaluations to get all the sensitivities. Whereas the present Dual number broadcasting uses just one.
Learned about https://github.com/JuliaNonconvex/NonconvexUtils.jl/pull/6 today. Does it fit into this picture at all?