Zygote.jl
Zygote.jl copied to clipboard
Utilize ChainRulesCore thunks
Note: Requires FluxML/ZygoteRules.jl#17
Currently, Zygote always unthunks ChainRuleCore
thunks, which is wasteful and may also lead to trouble in cases why a thunks just can't be run for the given types/contents.
With
using ChainRulesCore, Zygote
foo(a, b) = a * b
function ChainRulesCore.rrule(::typeof(foo), a, b)
y = foo(a, b)
function foo_pullback(Ȳ)
∂a = @thunk (@info "Thunk ∂a"; return Ȳ * b')
∂b = @thunk (@info "Thunk ∂b"; return a' * Ȳ)
return (NO_FIELDS, ∂a, ∂b)
end
return y, foo_pullback
end
a = rand(4,3); b = rand(3,2); Ȳ = rand(4,2);
Zygote.pullback(foo, a, b)[2](Ȳ)
let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
we obviously get
julia> Zygote.pullback(foo, a, b)[2](Ȳ)
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...], [...])
but we currently also get
julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...],)
so Zygote also executes both thunks if we only require the pullback in respect to b
.
With this PR, we should get
julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂b
([0.5793294513282393 0.5442572290286626; 0.748469566554706 0.8319235475298901; 0.609222600261931 0.6201995902507516],)
instead.
Note: Requires FluxML/ZygoteRules.jl#17
CC @oxinabox , @mzgubic
I think it may actually be the unthunk in tailmemaybe(x::Tuple)
that results in that error, not the one in Grads(IdDict(...))
.
It looks like for this to work, we'll need to overload a lot of functions like x'
, Adjoint
, transpose
, Transpose
, diag
, Diagonal
, permutedims
, permutedims!
, Ref
, Vector
, Matrix
, Array
, collect
, convert
, and so on, to unthunk thunks, since thunks can now appear as Ȳ
in pullback functions.
Ok, with FluxML/ZygoteRules.jl#17 this passes the Zygote test suit on my system now (incl. CUDA tests).
In both the simple example above and more complex cases that I've tested only the required thunks are unthunked. Hard to give guarantees on this, obviously.
The @adjoint_keepthunks
in FluxML/ZygoteRules.jl#17 provides a way for adjoint authors to declare whether their pullback code is prepared to handle thunks or not.
~~Maybe we could add a similar mechanism to ChainRulesCore (an rrule_keepthunks
or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem?~~ (see below)
Maybe we could add a similar mechanism to ChainRulesCore (an
rrule_keepthunks
or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem?
I had a look, the number of packages that actually define rrules seems finite (Bijectors, ChainRules, ChainRulesCore, ComponentArrays, CRlibm, DiffEqBase, DistributionsAD, LoopVectorization, NNlib, SpecialFunctions, SymbolicUtils, WebSockets and Zygote in my fairly well-filled package dir). So maybe the better approach would be:
-
Do unthunk in
Zygote.wrap_chainrules_input
for now (already in this PR). Release a patch version of ZygoteRules with FluxML/ZygoteRules.jl#17 a patch version of Zygote with this PR (FluxML/Zygote.jl#966), these changes should be non-breaking across the ecosystem. -
Add thunk-support to functions like
x'
,Adjoint
,transpose
,Transpose
,diag
,Diagonal
,permutedims
,permutedims!
,Ref
,Vector
,Matrix
,Array
,collect
,convert
(probably a few more) inChainRules
(maybe some of them inChainRulesCore
?). -
Release ChainRulesCore v0.10, alert users that from now on we're going to take thunks seriously. :-) (Meaning, they may appear in the input to pullback functions). With the additional thunk-supporting methods, we probably wouldn't need to change too much code in
rrule
-using packages like those listed above. -
Release a Zygote v0.7 or so, requiring ChainRulesCore v0.10 and removing unthunking in
Zygote.wrap_chainrules_input
.
This way, we can have a "soft" path to using thunks more and more. We'll have an immediate benefit since this PR can be done as a patch releases of ZygoteRules and Zygot, without ecosystem propagation delay. And after a (hopefully not too long while), rrule
s will be required to be thunk-aware, so no need to introduce an rrule_keepthunks
. And packages with existing ZygoteRules.@adjoint
s can just replace them with ZygoteRules.@adjoint_keepthunks
in their code if the pullback is thunk-compatible (many will be after more basics functions support thunks) - bit by bit, so thunks can propagate deeper with minimal effort.
Update: see below for changed proposal
and after a (hopefully not too long while),
rrules
will be required to be thunk-aware, so no need to introduce anrrule_keepthunks
I think this idea is closer to the right way. We want to overload all linear operators in Base/stdlibs to work on Zero and Thunk (and all other abstract differentials) anyway. We can probably enforce thunk support via automatically testing it in ChainRulesTestUtils.
How much of a problem is it if linear operators on a Thunk
unthunk
them?
I guess we can decide on a per operator basis. I guess maybe we can have Thunk
-ness propagate until +
is called (which means a gradient is being accumulated) or something like that, though I worry about the overhead that adds carrying that comptuational graph around (but maybe it isn't so much).
What if we made @adjoint
always unthunk, and didn't introduce a @adjoint_keepthunks
, but instead ensured that all existing rrule
s supported thunks. and moved the rules you have marked here with @adjoint_keepthunk
s into ChainRules ? And just told people wanting keep thunks that they should write rrule
s instead of using @adjoint
?
I guess it would be too long to wait for all of the rrule
s to be updated to support thunks?
cc @mzgubic @willtebbutt @nickrobinson251
How much of a problem is it if linear operators on a Thunk unthunk them?
I guess most of them will have to?
What if we made @adjoint always unthunk, and didn't introduce a @adjoint_keepthunks , but instead ensured that all existing rrules supported thunks. [...]
How about this:
-
Step 1: We remove
@adjoint_keepthunks
FluxML/ZygoteRules.jl#17 and add it to Zygote as an internalZygote.@_adjoint_keepthunks
for now. That way, it's easy to change all the low-level adjoints in Zygote (we need keepthunks to a lot of adjoints in Zygote itself to have any benefit from this, as done in this PR) and easy to change back in cases where we might have been to zealous. We can then merge and release FluxML/ZygoteRules.jl#17 and FluxML/Zygote.jl#966 as non-breaking and without introducing new APIs. -
Step 2: We release a breaking new version of ChainRulesCore that mandates that
rrule
s are prepared to handle thunks, and update the rules in ChainRules accordingly. -
Step 3: We add all Zygote-internal
@_adjoint_keepthunks
rules (and maybe some more@_adjoint
rules as well) to asrrules
to ChainRules. We test performance implications (sincerrule
s undergo some wrapping/unwrapping and a lot of these rules are low level and will be called frequently) while we remove those Zygote-internal adjoints until none are left. Then we can get rid ofZygote.@_adjoint_keepthunks
again. We encourage the ecosystem to do the same (use ChainRulesCore instead of ZygoteRules). May need to fix things like FluxML/Zygote.jl#811 first.
I guess it would be too long to wait for all of the rrules to be updated to support thunks?
I suspect switching all necessary Zygote-internal adjoints to rrule
s and having the ecosystem make all rrule
s thunk-compatible (step 3 above) will take some time. But if we do it like above, we'd get an immediate benefit. (Purely from an egoistic point of view, I have two use cases right now that would profit a lot, and since there are other people involved, living on dev-branches for Zygote and ZygoteRules would not be an option).
My main takeaway from this discussion is that not supporting thunks inside rrules is a serious deficiency, and we should work towards fixing that. It should probably be on the list for 1.0? https://github.com/JuliaDiff/ChainRules.jl/issues/408
A couple of questions:
- @oschulz Why do we need methods for
Zero()
? I thought we needed methods forThunk
s? - @oxinabox Could we move all the
@adjoint_keepthunk
methods to rrules? Since some of the existing adjoints touch theaccum_param
? I guess we'd have to keep the adjoints as a wrapper?
I think of Zygote as doing one of the three things to generate a pullback: hit an rrule, hit an adjoint, or do its compiler magic. Right now, we unthunk before getting out of the rrule, so the other two parts never see thunks.
I guess the minimal change we can make is:
- move
unthunk
fromchain_rules_output
tochain_rules_input
- add
unthunk
to@adjoint
In this way, we keep compatibility with ChainRules (since the rrules never get thunks - though that should really be supported IMO, and this unthunk can be dropped once ready), and we keep compatibility with existing adjoints. What we gain is that the compiler magic part figures out it can ignore some thunks.
What we lose (compared to this PR) is the efficiency in adjoints that have been moved to use @adjoint_keepthunks
. Do we actually gain much by having them? I guess you do in your use case @oschulz ?
@oschulz Why do we need methods for Zero()? I thought we needed methods for Thunks?
Ah, sorry, yes - in fact, more basic functions need to support (and in some cases pass through) both Zero() and thunks. That would make make a lot of existing rrules
that are written in a generic fashion thunk-compatible automatically. Others will need to be changed, of course.
Maybe we could offer a function unthunking
in ChainRulesCore so that rrule
s can return y, unthunking(back)
for rrule
s that need to unthunk.
I guess the minimal change we can make is:
- move
unthunk
fromchain_rules_output
tochain_rules_input
- add
unthunk
to@adjoint
In this way, we keep compatibility with ChainRules
Yes, that's what this PR and FluxML/ZygoteRules.jl#17 do. Zygote needs quite a few non-unthunking internal adjoints to get an actual benefit, they are already in included in this PR. At least in some use cases I tested, it works quite nicely.
Long term we should then mandate that rrule
s accept thunks, as suggested in `JuliaDiff/ChainRules.jl#408. That's why I'd like to propose the "three-step" plan above.
@oxinabox Could we move all the @adjoint_keepthunk methods to rrules? Since some of the existing adjoints touch the accum_param?
Looking at them, I think they either are things that should just move to ChainRules.jl
or they are Zygote internals that can never the less be written using rrule
s
The only one I am unsure about is: https://github.com/FluxML/Zygote.jl/pull/966/files#diff-cd0210083ce3136f79bee6ebca2bcca77f41a14f11b5a7a65ea1cc54803164c3R103-R110 which I don't understand what it is doing.
I think @DhairyaLGandhi needs to make the call as to if to do the parts of the 3 step plan that involve changes to Zygote/Zygote rules. I am gently in favour of them. On the one hand they seem to work, and that is really useful, and fixing things sooner rather than later is better: Better an egg in the hand than two in the bush. On the other hand it is more complexity to maintain in Zygote; and I dislike adding features to ZygoteRules when we are trying to stop using it.
I am strongly infavor of making changes in ChainRules and ChainRulesCore to facilitate this. Either as part of 3 step plan or as part of going straight to keeping thunks is done via rrule
and @adjoint
always removes them.
Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.
I think better to leave that to the AD package. It is a simple enough function, I would rather a little code duplication than expand the API surface in a way that might be wrong/made redundant. (I was reading https://sandimetz.com/blog/2016/1/20/the-wrong-abstraction recently, I think it makes some good points)
dislike adding features to ZygoteRules when we are trying to stop using it
We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint
unthunk in general. And Zygote.@_adjoint_keepthunks
would be a purely internal and temporary thing in Zygote.
Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk. I think better to leave that to the AD package.
Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore
not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.
Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.
I am, for the same reasons, happy for now to just leave it for the rule authors.
It is still a very short function.
And in that case when unthunking the input, often they will want to unthunk some parts but not others.
So mostly just callung unthunk
as neeeded seems easiest.
But anyway it is easy to revise this opinion later.
We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint unthunk in general. And Zygote.@_adjoint_keepthunks would be a purely internal and temporary thing in Zygote.
Ah, yeah that makes sense. That makes me even more in favor of this plan
Will have to go through the discussion to make a note informed call, but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API. Apologies if it's been discussed already, I'm a little late to the thread.
I am, for the same reasons, happy for now to just leave it for the rule authors. It is still a very short function.
You're right - unthunked()
may be an unnecessary complication. Explicit unthunk()
s in the rule code will be clearer.
but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API
Yes, hence my changed proposal to do this as a Zygote-internal @_adjoint_keepthunks
.
We do need two macros (one of them non-public/API-stable) for a soft transition, since we need a lot of non-unthunking Zygote-internal adjoints right now to have any benefit from thunks, but the ecosystem will break if existing adjoints suddenly get thunks as an input. That's why I'd like to do this in steps.
Ok I've removed @adjoint_keepthunks
from ZygoteRules and added an internal @_adjoint_keepthunks
to Zygote itself instead. I've also added some comments to label it as temporary.
Zygote.@_adjoint_keepthunks
still uses the extended ZygoteRules.gradm(ex, mut, keepthunks)
introduced by FluxML/ZygoteRules.jl#17. gradm
is a fairly complicated bit of code and the difference between thunking and unthunking is minimal, so I didn't want to basically duplicate it in Zygote. But it's not an exported function, so we can remove the keepthunks
argument from gradm
later as a non-breaking change. FluxML/ZygoteRules.jl#17 is necessary in any case, since we do need to make @adjoint
unthunk for any of this to work.
Here's a little demo on how this will work once we can disable the unthunking of rrule
input (step 3 in the proposal above).
With
# Requires
# * https://github.com/FluxML/ZygoteRules.jl/pull/17
# * https://github.com/FluxML/Zygote.jl/pull/966
using Zygote, LinearAlgebra, ChainRulesCore, ChainRulesCore
using ChainRules: CommutativeMulNumber
# Make thunks in log more readable:
Base.show(io::IO, x::Thunk) = print(io, typeof(x).name.name)
Base.show(io::IO, x::InplaceableThunk) = print(io, typeof(x).name.name)
# Disable unthunking of rrule input (to become default behavior in the future):
function Zygote.wrap_chainrules_input(x)
@info "wrap_chainrules_input($(x))"
return x
end
# Same rrule as in ChainRules, just add some logging:
function ChainRulesCore.rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
@info "rrule for $(A) * $(B)"
function times_pullback(Ȳ)
@info "pullback of $(A) * $(B) for $(Ȳ)"
return (
NO_FIELDS,
InplaceableThunk(
@thunk((@info("!left thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)"); Ȳ * B')),
X̄ -> (@info("!left thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)");mul!(X̄, Ȳ, B', true, true))
),
InplaceableThunk(
@thunk((@info("!right thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)"); A' * Ȳ)),
X̄ -> (@info("!right thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)");mul!(X̄, A', Ȳ, true, true))
)
)
end
return A * B, times_pullback
end
A, B, C, D = [fill(i,1,1) for i in 2:5]
We get:
julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)
julia> Zygote.gradient(X -> sum(X * B * C * D), A)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !left thunk.val of pullback of [2] * [3] for InplaceableThunk = InplaceableThunk * [3]
[ Info: !left thunk.val of pullback of [6] * [4] for InplaceableThunk = InplaceableThunk * [4]
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
([60],)
This should save a lot of computation time in applications that don't diff in respect to every argument along the computational graph. Standard ML applications usually want a gradients for almost everything, of course (since almost everything is a free parameter), but applications like fitting, MCMC, "scientific" ML and the like tend to be more selective in which gradient(s) they need.
If we re-enable unthunking of rrule input (as is part of this PR, to keep compatibility until "step 3")
function Zygote.wrap_chainrules_input(x)
@info "wrap_chainrules_input($(x))"
return Zygote.unthunk_tangent(x)
end
things aren't quite that nice if we diff in respect to D
julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: pullback of [6] * [4] for [5]
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: pullback of [2] * [3] for [20]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)
but we still get some immediate benefit from this PR, since with the current release of Zygote, things look like this:
julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
[ Info: pullback of [6] * [4] for [5]
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: !right thunk.val of pullback of [6] * [4] for [5] = [6] * [5]
[ Info: pullback of [2] * [3] for [20]
[ Info: !left thunk.val of pullback of [2] * [3] for [20] = [20] * [3]
[ Info: !right thunk.val of pullback of [2] * [3] for [20] = [2] * [20]
([24],)
Here are some benchmarks for the example above, with bigger arrays:
using Zygote, BenchmarkTools
A, B, C, D = [fill(i,100,100) for i in 2:5]
julia> @btime sum($A * $B * $C * $D)
1.081 ms (24 allocations: 235.59 KiB)
Using current Zygote v0.6.10:
julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
3.066 ms (222 allocations: 792.28 KiB)
With this PR (requires FluxML/ZygoteRules.jl#17):
julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
2.192 ms (245 allocations: 562.34 KiB)
And after disabling forced unthunking of rrules input (step 3 in the plan above):
julia> @inline Zygote.wrap_chainrules_input(x) = x
julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
1.413 ms (229 allocations: 416.22 KiB)
Hey @DhairyaLGandhi , it seems all that's missing before we can move with this is your blessing. Would you have time to take a look?
I'll take a look this week, thanks for this!
@DhairyaLGandhi , gentle bump - I don't mean to nag, but I would love to start using the first stage of this for some projects.
On it :)
Not nagging at all, I was trying to explore what it would entail for writing rules in zygote and what it means for the ergonomics anyway
Thanks!
There should be no difference for users writing ZygoteRules
rules. Starting with step 2 in the plan above, user writing ChainRulesCore
rules (and I guess the idea is that this will be the preferred way in the future, @oxinabox ? ) will need to start using unthunk
for the tangents their pullbacks receive, if the pullback uses functionality that isn't itself thunk-compatible. That should really be all, regarding ergonomics.
@mzgubic is working as we speak for making sure all rules in ChainRules.jl support recieving thunks as inputs.
Mostly by adding missing overloads for common linear operators to ChainRulesCore, but probably in some cases via calling unthunk
at the start.
With the intent that when we release ChainRulesCore 1.0 and ChainRulesTestUtils 1.0,
it will be a requirement of all rules.
Will the requirement be to ubthunk before anythibg in an rrule?
Not necessarily, only in cases where the thunk would have been unthunked more than once. See https://github.com/JuliaDiff/ChainRules.jl/pull/449 (which is nearly there)
Will the requirement be to ubthunk before anythibg in an rrule? Not necessarily,
Also, quite a few rrules will not need to unthunk
at all, as long as they're only calling on functions that have methods for thunks. From what I understand, there's an ongoing effort right now to significantly increase thunk-support in linear algebra, etc. So in some (maybe quite a few) cases we'll be able to pass thunks through. Correct, @oxinabox ?
Yes, that's right, some thunks will pass through.
We will need to be careful that we do unthunk in cases where the thunk is used twice (to avoid the duplicate calculation). If the thunk is only used once, it will be passed through.
As you say, there in increasing support for thunks in ChainRulesCore. All of those operations essentially unthunk (or complain that thunk mutation was attempted).
I wonder if Julia shouldn't have something like a thunk as a native concept. We have something along those lines in so many places/packages.
unthunk in cases where the thunk is used twice (to avoid the duplicate calculation)
Could we use some result caching / lazy value mechanism?
Could we use some result caching / lazy value mechanism?
Might lead to even more mem allocs than we already have, though ...
Yes, thunks are an emulation of a compiler feature julia doesn't have. Haskell (and TensorFlow) have lazy evaluation where only things that are used are computed. Thunks are like 1 level of lazy evaluation.
And there is a more general concept of program slicing program slicing would be particularly useful for AD.
Yes, thunks are an emulation of a compiler feature julia doesn't have.
~~Hm, this got me thinking - thunks are really monads, right? [...]~~
Moved to JuliaDiff/ChainRulesCore.jl#373 .
Could we use some result caching / lazy value mechanism?
Might lead to even more mem allocs than we already have, though ...
A very very ancient version of ChainRules had this. It was removed long before I was involved. Issue for adding it back is still open https://github.com/JuliaDiff/ChainRulesCore.jl/issues/7#issuecomment-491609475 I do worry about the same thing re mem allocs, here
Issue for adding it back is still open JuliaDiff/ChainRulesCore.jl#7 (comment) I do worry about the same thing re mem allocs, here
Thanks, I added a suggestion there.
@DhairyaLGandhi do we get your blessing on the "plan"?
Given having full support for thunks in all ChainRules functions is well and truly on its way, at this point the plan might be moot. And we can soon basically just merge something like this PR.
at this point the plan might be moot. And we can soon basically just merge something like this PR.
More than fine with me - we'll still need to merge FluxML/ZygoteRules.jl#17 first, though - or duplicate that functionality within Zygote (would duplicate most of ZygoteRules.gradm
).
FluxML/ZygoteRules.jl#17 has been merged, I'll get this PR into shape as soon as there's new release of ZygoteRules, so CI can run.
Ok, new release of ZygoteRules is out (had overlooked it), I'll jump back on this.
Sorry, bit overloaded but haven't forgotten about this.
Rebasing this now, but several things have changed in Zygote in the mean time ... hope it still works.
CI results don't look too horrible - at least some of the test failures look similar to the ones in #1104, so may be unrelated to this PR. Could an expert take a look?
Might be worth wrapping all tests in a @testset
, so that it keeps going past the first failure?
I have one test failure that I can't get figured out (how to fix): The "Params nesting" testset results in a
map is not defined on dictionaries
. It's caused by the unthunking in pullback(f, ps::Params)
somehow, I think (it goes away without, but the user-facing pullback
function needs to unthunk, obviously), but I don't understand why ∇map
hits a Dict
with it but not without.
I have one test failure that I can't get figured out (how to fix): The "Params nesting" testset results in a map is not defined on dictionaries
Some expert help would be very much appreciated. I learned quite a few things about Zygote while preparing this PR, but I don't think I've reached a full understanding of the compiler and some other parts (like nested Params
).
I also do not understand nested Params
.
because I never used Params
.
I literally had to look them up in the docs
https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1
I do recall @CarloLucibello made some changes to them recently.
(Or maybe that was Flux.Grads
?)
I think this is almost there, if we can only figure the nested Params test out.
@oschulz do you mind rebasing this? I just tried some local testing and there appear to be new test failures.
Will do! Really need to bring this one home ...
@ToucheSir rebase done.
Keeping track of test failure causes as I find them:
- https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/src/lib/lib.jl#L324 is causing 3 test failures. I think the signature of
(back::Jnew)(...)
needs to be relaxed or another method added which handles thunks. - Similar story with the test above failing on https://github.com/FluxML/ZygoteRules.jl/blob/05cd6e1d41a363b2114fcce2a640df145006a5b7/src/adjoint.jl#L24.
- I'm not exactly sure what's going on with the increased memory usage for https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/test/features.jl#L668. Perhaps thunks are preventing some kind of copy elision?
- The elements of
dy
need to be unthunked some time before, during or after https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/src/lib/array.jl#L294 to fix this test. I list 3 options because I'm not sure whether it would be more appropriate for e.g.accum
to handle this, the adjoint or something else. - Last but not least, https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/test/interface.jl#L144. Here the broken behaviour actually makes more sense to me, and it's not clear from the blame why the line in question was added. @lassepe do you recall why?
@ToucheSir that is probably not a useful test but I recall what happened there:
When you compute y, back = pullback(f, Params[x])
, then evaluating back
multiple times would mutate the Grad
object handed out earlier. Say I compute g = back(1)
then g1
is a Grad
object which has g1[w]
as the gradient of f
w.r.t. w
. If I now call g2 = back(nothing)
, this mutates g1
so that g1[w] == nothing
.
I am not sure, what is the intended behavior actually. It does not seem to be documented at least and it was a really bad foot gun for me. I added this test to test the "default case" to contrast it with the "fixed case" where I copy
the grad object in between to avoid this unexpected coupling. So I guess I mostly added that as implicit documentation. I am sorry that this has caused confusion here. Maybe I should just open a PR to the docs instead.
Thanks for the quick response! I think in that case we can call it a footgun and consider that fixed (unintentionally) by this PR, no need for a docs update.
I guess this PR does not fix the problem in general, only for the nothing
case, right? For example, if I call g1 = back(1)
and then g2 = back(2)
it seems like that would still mutate g1
since both unthunk, right? I'm not raising this as an issue against this PR (I think the PR is great and I don't think it is its job to fix this behavior), I'm just trying to figure out whether this would need documentation perhaps.
Thanks for the changes! Let's see that CI say now ... :-)
For example, if I call
g1 = back(1)
and theng2 = back(2)
it seems like that would still mutateg1
since both unthunk, right?
It appears that may be addressed as well:
using Zygote
x, y = ones(2), rand(2)
ps = Params([x, y])
l, back = pullback(() -> sum(x) + prod(y), ps)
julia> g1 = back(1.)
Grads(...)
julia> g1.grads
IdDict{Any, Any} with 4 entries:
:(Main.x) => Fill(1.0, 2)
:(Main.y) => [0.981219, 0.920662]
[0.920662, 0.981219] => [0.981219, 0.920662]
[1.0, 1.0] => Fill(1.0, 2)
julia> g2 = back(2.)
Grads(...)
julia> g1.grads
IdDict{Any, Any} with 4 entries:
:(Main.x) => Fill(1.0, 2)
:(Main.y) => [0.981219, 0.920662]
[0.920662, 0.981219] => [0.981219, 0.920662]
[1.0, 1.0] => Fill(1.0, 2)
julia> g2.grads
IdDict{Any, Any} with 4 entries:
:(Main.x) => Fill(3.0, 2)
:(Main.y) => [2.94366, 2.76199]
[0.920662, 0.981219] => [1.96244, 1.84132]
[1.0, 1.0] => Fill(2.0, 2)
There's still a lot of new test failures (compared to pre-rebase). Of course Zygote has changed a bit - help tracking them down is welcome.
@oschulz that's what https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188 is for :)
@oschulz that's what https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188 is for :)
Aw, sorry, overlooked that one - great, thanks!
Is there a plan/timeline on this PR? Not pushing it's just it's quite problematic for some of our (PDE based) codes where this is a huge blocker performance wise.
I would still like this done. AFAIK there are no remaining blockers -- and haven't been for about 12 months.
@DhairyaLGandhi @oschulz should we have a call and workout what is the way forward with this?
I would still like this done.
Me too.
AFAIK there are no remaining blockers
I don't think there are any blockers conceptually, just some test failures we hadn't managed to figure out.
should we have a call and workout what is the way forward with this?
Gladly!
@DhairyaLGandhi ?
Sure, sounds good
Sure, sounds good
Let's find a date via Slack?
Sure
As the one who raised all these pesky failure cases in https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188, I'd be more than happy to help support this effort as well. Be that testing or trying to make sense of AD compiler output.