Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Utilize ChainRulesCore thunks

Open oschulz opened this issue 3 years ago • 76 comments

Note: Requires FluxML/ZygoteRules.jl#17

Currently, Zygote always unthunks ChainRuleCore thunks, which is wasteful and may also lead to trouble in cases why a thunks just can't be run for the given types/contents.

With

using ChainRulesCore, Zygote
foo(a, b) = a * b
function ChainRulesCore.rrule(::typeof(foo), a, b)
    y = foo(a, b)
    function foo_pullback(Ȳ)
        ∂a = @thunk (@info "Thunk ∂a"; return Ȳ * b')
        ∂b = @thunk (@info "Thunk ∂b"; return a' * Ȳ)
        return (NO_FIELDS, ∂a, ∂b)
    end
    return y, foo_pullback
end
a = rand(4,3); b = rand(3,2); Ȳ = rand(4,2);
Zygote.pullback(foo, a, b)[2](Ȳ)
let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end

we obviously get

julia> Zygote.pullback(foo, a, b)[2](Ȳ)
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...], [...])

but we currently also get

julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂a
[ Info: Thunk ∂b
([...],)

so Zygote also executes both thunks if we only require the pullback in respect to b.

With this PR, we should get

julia> let a = a; Zygote.pullback(b -> foo(a, b), b)[2](Ȳ) end
[ Info: Thunk ∂b
([0.5793294513282393 0.5442572290286626; 0.748469566554706 0.8319235475298901; 0.609222600261931 0.6201995902507516],)

instead.

Note: Requires FluxML/ZygoteRules.jl#17

CC @oxinabox , @mzgubic

oschulz avatar May 07 '21 11:05 oschulz

I think it may actually be the unthunk in tailmemaybe(x::Tuple) that results in that error, not the one in Grads(IdDict(...)).

oschulz avatar May 07 '21 16:05 oschulz

It looks like for this to work, we'll need to overload a lot of functions like x', Adjoint, transpose, Transpose, diag, Diagonal, permutedims, permutedims!, Ref, Vector, Matrix, Array, collect, convert, and so on, to unthunk thunks, since thunks can now appear as in pullback functions.

oschulz avatar May 07 '21 21:05 oschulz

Ok, with FluxML/ZygoteRules.jl#17 this passes the Zygote test suit on my system now (incl. CUDA tests).

In both the simple example above and more complex cases that I've tested only the required thunks are unthunked. Hard to give guarantees on this, obviously.

The @adjoint_keepthunks in FluxML/ZygoteRules.jl#17 provides a way for adjoint authors to declare whether their pullback code is prepared to handle thunks or not.

~~Maybe we could add a similar mechanism to ChainRulesCore (an rrule_keepthunks or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem?~~ (see below)

oschulz avatar May 08 '21 03:05 oschulz

Maybe we could add a similar mechanism to ChainRulesCore (an rrule_keepthunks or so) to provide a soft transition to (hopefully) more and more rules that support thunks in the ecosystem?

I had a look, the number of packages that actually define rrules seems finite (Bijectors, ChainRules, ChainRulesCore, ComponentArrays, CRlibm, DiffEqBase, DistributionsAD, LoopVectorization, NNlib, SpecialFunctions, SymbolicUtils, WebSockets and Zygote in my fairly well-filled package dir). So maybe the better approach would be:

  • Do unthunk in Zygote.wrap_chainrules_input for now (already in this PR). Release a patch version of ZygoteRules with FluxML/ZygoteRules.jl#17 a patch version of Zygote with this PR (FluxML/Zygote.jl#966), these changes should be non-breaking across the ecosystem.

  • Add thunk-support to functions like x', Adjoint, transpose, Transpose, diag, Diagonal, permutedims, permutedims!, Ref, Vector, Matrix, Array, collect, convert (probably a few more) in ChainRules (maybe some of them in ChainRulesCore?).

  • Release ChainRulesCore v0.10, alert users that from now on we're going to take thunks seriously. :-) (Meaning, they may appear in the input to pullback functions). With the additional thunk-supporting methods, we probably wouldn't need to change too much code in rrule-using packages like those listed above.

  • Release a Zygote v0.7 or so, requiring ChainRulesCore v0.10 and removing unthunking in Zygote.wrap_chainrules_input.

This way, we can have a "soft" path to using thunks more and more. We'll have an immediate benefit since this PR can be done as a patch releases of ZygoteRules and Zygot, without ecosystem propagation delay. And after a (hopefully not too long while), rrules will be required to be thunk-aware, so no need to introduce an rrule_keepthunks. And packages with existing ZygoteRules.@adjoints can just replace them with ZygoteRules.@adjoint_keepthunks in their code if the pullback is thunk-compatible (many will be after more basics functions support thunks) - bit by bit, so thunks can propagate deeper with minimal effort.

Update: see below for changed proposal

oschulz avatar May 08 '21 10:05 oschulz

and after a (hopefully not too long while), rrules will be required to be thunk-aware, so no need to introduce an rrule_keepthunks

I think this idea is closer to the right way. We want to overload all linear operators in Base/stdlibs to work on Zero and Thunk (and all other abstract differentials) anyway. We can probably enforce thunk support via automatically testing it in ChainRulesTestUtils.

How much of a problem is it if linear operators on a Thunk unthunk them? I guess we can decide on a per operator basis. I guess maybe we can have Thunk-ness propagate until + is called (which means a gradient is being accumulated) or something like that, though I worry about the overhead that adds carrying that comptuational graph around (but maybe it isn't so much).

What if we made @adjoint always unthunk, and didn't introduce a @adjoint_keepthunks , but instead ensured that all existing rrules supported thunks. and moved the rules you have marked here with @adjoint_keepthunks into ChainRules ? And just told people wanting keep thunks that they should write rrules instead of using @adjoint ? I guess it would be too long to wait for all of the rrules to be updated to support thunks?

cc @mzgubic @willtebbutt @nickrobinson251

oxinabox avatar May 10 '21 19:05 oxinabox

How much of a problem is it if linear operators on a Thunk unthunk them?

I guess most of them will have to?

What if we made @adjoint always unthunk, and didn't introduce a @adjoint_keepthunks , but instead ensured that all existing rrules supported thunks. [...]

How about this:

  • Step 1: We remove @adjoint_keepthunks FluxML/ZygoteRules.jl#17 and add it to Zygote as an internal Zygote.@_adjoint_keepthunks for now. That way, it's easy to change all the low-level adjoints in Zygote (we need keepthunks to a lot of adjoints in Zygote itself to have any benefit from this, as done in this PR) and easy to change back in cases where we might have been to zealous. We can then merge and release FluxML/ZygoteRules.jl#17 and FluxML/Zygote.jl#966 as non-breaking and without introducing new APIs.

  • Step 2: We release a breaking new version of ChainRulesCore that mandates that rrules are prepared to handle thunks, and update the rules in ChainRules accordingly.

  • Step 3: We add all Zygote-internal @_adjoint_keepthunks rules (and maybe some more @_adjoint rules as well) to as rrules to ChainRules. We test performance implications (since rrules undergo some wrapping/unwrapping and a lot of these rules are low level and will be called frequently) while we remove those Zygote-internal adjoints until none are left. Then we can get rid of Zygote.@_adjoint_keepthunks again. We encourage the ecosystem to do the same (use ChainRulesCore instead of ZygoteRules). May need to fix things like FluxML/Zygote.jl#811 first.

I guess it would be too long to wait for all of the rrules to be updated to support thunks?

I suspect switching all necessary Zygote-internal adjoints to rrules and having the ecosystem make all rrules thunk-compatible (step 3 above) will take some time. But if we do it like above, we'd get an immediate benefit. (Purely from an egoistic point of view, I have two use cases right now that would profit a lot, and since there are other people involved, living on dev-branches for Zygote and ZygoteRules would not be an option).

oschulz avatar May 11 '21 10:05 oschulz

My main takeaway from this discussion is that not supporting thunks inside rrules is a serious deficiency, and we should work towards fixing that. It should probably be on the list for 1.0? https://github.com/JuliaDiff/ChainRules.jl/issues/408

A couple of questions:

  • @oschulz Why do we need methods for Zero()? I thought we needed methods for Thunks?
  • @oxinabox Could we move all the @adjoint_keepthunk methods to rrules? Since some of the existing adjoints touch the accum_param? I guess we'd have to keep the adjoints as a wrapper?

I think of Zygote as doing one of the three things to generate a pullback: hit an rrule, hit an adjoint, or do its compiler magic. Right now, we unthunk before getting out of the rrule, so the other two parts never see thunks.

I guess the minimal change we can make is:

  • move unthunk from chain_rules_output to chain_rules_input
  • add unthunk to @adjoint

In this way, we keep compatibility with ChainRules (since the rrules never get thunks - though that should really be supported IMO, and this unthunk can be dropped once ready), and we keep compatibility with existing adjoints. What we gain is that the compiler magic part figures out it can ignore some thunks.

What we lose (compared to this PR) is the efficiency in adjoints that have been moved to use @adjoint_keepthunks. Do we actually gain much by having them? I guess you do in your use case @oschulz ?

mzgubic avatar May 11 '21 10:05 mzgubic

@oschulz Why do we need methods for Zero()? I thought we needed methods for Thunks?

Ah, sorry, yes - in fact, more basic functions need to support (and in some cases pass through) both Zero() and thunks. That would make make a lot of existing rrules that are written in a generic fashion thunk-compatible automatically. Others will need to be changed, of course.

Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.

I guess the minimal change we can make is:

  • move unthunk from chain_rules_output to chain_rules_input
  • add unthunk to @adjoint In this way, we keep compatibility with ChainRules

Yes, that's what this PR and FluxML/ZygoteRules.jl#17 do. Zygote needs quite a few non-unthunking internal adjoints to get an actual benefit, they are already in included in this PR. At least in some use cases I tested, it works quite nicely.

Long term we should then mandate that rrules accept thunks, as suggested in `JuliaDiff/ChainRules.jl#408. That's why I'd like to propose the "three-step" plan above.

oschulz avatar May 11 '21 11:05 oschulz

@oxinabox Could we move all the @adjoint_keepthunk methods to rrules? Since some of the existing adjoints touch the accum_param?

Looking at them, I think they either are things that should just move to ChainRules.jl or they are Zygote internals that can never the less be written using rrules

The only one I am unsure about is: https://github.com/FluxML/Zygote.jl/pull/966/files#diff-cd0210083ce3136f79bee6ebca2bcca77f41a14f11b5a7a65ea1cc54803164c3R103-R110 which I don't understand what it is doing.


I think @DhairyaLGandhi needs to make the call as to if to do the parts of the 3 step plan that involve changes to Zygote/Zygote rules. I am gently in favour of them. On the one hand they seem to work, and that is really useful, and fixing things sooner rather than later is better: Better an egg in the hand than two in the bush. On the other hand it is more complexity to maintain in Zygote; and I dislike adding features to ZygoteRules when we are trying to stop using it.

I am strongly infavor of making changes in ChainRules and ChainRulesCore to facilitate this. Either as part of 3 step plan or as part of going straight to keeping thunks is done via rrule and @adjoint always removes them.


Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk.

I think better to leave that to the AD package. It is a simple enough function, I would rather a little code duplication than expand the API surface in a way that might be wrong/made redundant. (I was reading https://sandimetz.com/blog/2016/1/20/the-wrong-abstraction recently, I think it makes some good points)

oxinabox avatar May 11 '21 12:05 oxinabox

dislike adding features to ZygoteRules when we are trying to stop using it

We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint unthunk in general. And Zygote.@_adjoint_keepthunks would be a purely internal and temporary thing in Zygote.

oschulz avatar May 11 '21 13:05 oschulz

Maybe we could offer a function unthunking in ChainRulesCore so that rrules can return y, unthunking(back) for rrules that need to unthunk. I think better to leave that to the AD package.

Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.

oschulz avatar May 11 '21 13:05 oschulz

Ah, no, I meant to make life easier on rule writers (who would only depend on ChainRulesCore not an AD package). I assume there will be quite a few rules that need to always unthunk, and this could be a tool to keep their code more concise.

I am, for the same reasons, happy for now to just leave it for the rule authors. It is still a very short function. And in that case when unthunking the input, often they will want to unthunk some parts but not others. So mostly just callung unthunk as neeeded seems easiest. But anyway it is easy to revise this opinion later.

oxinabox avatar May 11 '21 14:05 oxinabox

We wouldn't really add a feature to ZygoteRules though, right? We'd just make @adjoint unthunk in general. And Zygote.@_adjoint_keepthunks would be a purely internal and temporary thing in Zygote.

Ah, yeah that makes sense. That makes me even more in favor of this plan

oxinabox avatar May 11 '21 14:05 oxinabox

Will have to go through the discussion to make a note informed call, but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API. Apologies if it's been discussed already, I'm a little late to the thread.

DhairyaLGandhi avatar May 11 '21 15:05 DhairyaLGandhi

I am, for the same reasons, happy for now to just leave it for the rule authors. It is still a very short function.

You're right - unthunked() may be an unnecessary complication. Explicit unthunk()s in the rule code will be clearer.

oschulz avatar May 11 '21 15:05 oschulz

but why is it that we need a separate macro to define thunk aware adjoints? Generally, I see this more as an implementation detail which shouldn't leak to the API

Yes, hence my changed proposal to do this as a Zygote-internal @_adjoint_keepthunks.

We do need two macros (one of them non-public/API-stable) for a soft transition, since we need a lot of non-unthunking Zygote-internal adjoints right now to have any benefit from thunks, but the ecosystem will break if existing adjoints suddenly get thunks as an input. That's why I'd like to do this in steps.

oschulz avatar May 11 '21 15:05 oschulz

Ok I've removed @adjoint_keepthunks from ZygoteRules and added an internal @_adjoint_keepthunks to Zygote itself instead. I've also added some comments to label it as temporary.

Zygote.@_adjoint_keepthunks still uses the extended ZygoteRules.gradm(ex, mut, keepthunks) introduced by FluxML/ZygoteRules.jl#17. gradm is a fairly complicated bit of code and the difference between thunking and unthunking is minimal, so I didn't want to basically duplicate it in Zygote. But it's not an exported function, so we can remove the keepthunks argument from gradm later as a non-breaking change. FluxML/ZygoteRules.jl#17 is necessary in any case, since we do need to make @adjoint unthunk for any of this to work.

oschulz avatar May 11 '21 21:05 oschulz

Here's a little demo on how this will work once we can disable the unthunking of rrule input (step 3 in the proposal above).

With

# Requires
# * https://github.com/FluxML/ZygoteRules.jl/pull/17
# * https://github.com/FluxML/Zygote.jl/pull/966

using Zygote, LinearAlgebra, ChainRulesCore, ChainRulesCore
using ChainRules: CommutativeMulNumber


# Make thunks in log more readable:
Base.show(io::IO, x::Thunk) = print(io, typeof(x).name.name)
Base.show(io::IO, x::InplaceableThunk) = print(io, typeof(x).name.name)


# Disable unthunking of rrule input (to become default behavior in the future):
function Zygote.wrap_chainrules_input(x)
    @info "wrap_chainrules_input($(x))"
    return x
end


# Same rrule as in ChainRules, just add some logging:
function ChainRulesCore.rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    @info "rrule for $(A) * $(B)"
    function times_pullback(Ȳ)
        @info "pullback of $(A) * $(B) for $(Ȳ)"
        return (
            NO_FIELDS,
            InplaceableThunk(
                @thunk((@info("!left thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)"); Ȳ * B')),
                X̄ -> (@info("!left thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(Ȳ) * $(B)");mul!(X̄, Ȳ, B', true, true))
            ),
            InplaceableThunk(
                @thunk((@info("!right thunk.val of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)"); A' * Ȳ)),
                X̄ -> (@info("!right thunk.add! of pullback of $(A) * $(B) for $(Ȳ) = $(A') * $(Ȳ)");mul!(X̄, A', Ȳ, true, true))
            )
        )
    end
    return A * B, times_pullback
end


A, B, C, D = [fill(i,1,1) for i in 2:5]

We get:

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)

julia> Zygote.gradient(X -> sum(X * B * C * D), A)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [6] * [4] for InplaceableThunk
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: pullback of [2] * [3] for InplaceableThunk
[ Info: !left thunk.val of pullback of [2] * [3] for InplaceableThunk = InplaceableThunk * [3]
[ Info: !left thunk.val of pullback of [6] * [4] for InplaceableThunk = InplaceableThunk * [4]
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
([60],)

This should save a lot of computation time in applications that don't diff in respect to every argument along the computational graph. Standard ML applications usually want a gradients for almost everything, of course (since almost everything is a free parameter), but applications like fitting, MCMC, "scientific" ML and the like tend to be more selective in which gradient(s) they need.

If we re-enable unthunking of rrule input (as is part of this PR, to keep compatibility until "step 3")

function Zygote.wrap_chainrules_input(x)
    @info "wrap_chainrules_input($(x))"
    return Zygote.unthunk_tangent(x)
end

things aren't quite that nice if we diff in respect to D

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: wrap_chainrules_input(1×1 Fill{Int64}: entries equal to 1)
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: pullback of [6] * [4] for [5]
[ Info: wrap_chainrules_input(InplaceableThunk)
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: pullback of [2] * [3] for [20]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
([24],)

but we still get some immediate benefit from this PR, since with the current release of Zygote, things look like this:

julia> Zygote.gradient(X -> sum(A * B * C * X), D)
[ Info: rrule for [2] * [3]
[ Info: rrule for [6] * [4]
[ Info: rrule for [24] * [5]
[ Info: pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1
[ Info: !left thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = 1×1 Fill{Int64}: entries equal to 1 * [5]
[ Info: !right thunk.val of pullback of [24] * [5] for 1×1 Fill{Int64}: entries equal to 1 = [24] * 1×1 Fill{Int64}: entries equal to 1
[ Info: pullback of [6] * [4] for [5]
[ Info: !left thunk.val of pullback of [6] * [4] for [5] = [5] * [4]
[ Info: !right thunk.val of pullback of [6] * [4] for [5] = [6] * [5]
[ Info: pullback of [2] * [3] for [20]
[ Info: !left thunk.val of pullback of [2] * [3] for [20] = [20] * [3]
[ Info: !right thunk.val of pullback of [2] * [3] for [20] = [2] * [20]
([24],)

oschulz avatar May 14 '21 13:05 oschulz

Here are some benchmarks for the example above, with bigger arrays:

using Zygote, BenchmarkTools

A, B, C, D = [fill(i,100,100) for i in 2:5]
julia> @btime sum($A * $B * $C * $D)
  1.081 ms (24 allocations: 235.59 KiB)

Using current Zygote v0.6.10:

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  3.066 ms (222 allocations: 792.28 KiB)

With this PR (requires FluxML/ZygoteRules.jl#17):

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  2.192 ms (245 allocations: 562.34 KiB)

And after disabling forced unthunking of rrules input (step 3 in the plan above):

julia> @inline Zygote.wrap_chainrules_input(x) = x

julia> @btime Zygote.gradient(X -> sum($A * $B * $C * X), $D)
  1.413 ms (229 allocations: 416.22 KiB)

oschulz avatar May 14 '21 15:05 oschulz

Hey @DhairyaLGandhi , it seems all that's missing before we can move with this is your blessing. Would you have time to take a look?

oschulz avatar May 30 '21 18:05 oschulz

I'll take a look this week, thanks for this!

DhairyaLGandhi avatar May 30 '21 18:05 DhairyaLGandhi

@DhairyaLGandhi , gentle bump - I don't mean to nag, but I would love to start using the first stage of this for some projects.

oschulz avatar Jun 10 '21 15:06 oschulz

On it :)

Not nagging at all, I was trying to explore what it would entail for writing rules in zygote and what it means for the ergonomics anyway

DhairyaLGandhi avatar Jun 14 '21 05:06 DhairyaLGandhi

Thanks!

There should be no difference for users writing ZygoteRules rules. Starting with step 2 in the plan above, user writing ChainRulesCore rules (and I guess the idea is that this will be the preferred way in the future, @oxinabox ? ) will need to start using unthunk for the tangents their pullbacks receive, if the pullback uses functionality that isn't itself thunk-compatible. That should really be all, regarding ergonomics.

oschulz avatar Jun 14 '21 08:06 oschulz

@mzgubic is working as we speak for making sure all rules in ChainRules.jl support recieving thunks as inputs. Mostly by adding missing overloads for common linear operators to ChainRulesCore, but probably in some cases via calling unthunk at the start. With the intent that when we release ChainRulesCore 1.0 and ChainRulesTestUtils 1.0, it will be a requirement of all rules.

oxinabox avatar Jun 14 '21 14:06 oxinabox

Will the requirement be to ubthunk before anythibg in an rrule?

DhairyaLGandhi avatar Jun 18 '21 15:06 DhairyaLGandhi

Not necessarily, only in cases where the thunk would have been unthunked more than once. See https://github.com/JuliaDiff/ChainRules.jl/pull/449 (which is nearly there)

mzgubic avatar Jun 18 '21 15:06 mzgubic

Will the requirement be to ubthunk before anythibg in an rrule? Not necessarily,

Also, quite a few rrules will not need to unthunk at all, as long as they're only calling on functions that have methods for thunks. From what I understand, there's an ongoing effort right now to significantly increase thunk-support in linear algebra, etc. So in some (maybe quite a few) cases we'll be able to pass thunks through. Correct, @oxinabox ?

oschulz avatar Jun 18 '21 15:06 oschulz

Yes, that's right, some thunks will pass through.

We will need to be careful that we do unthunk in cases where the thunk is used twice (to avoid the duplicate calculation). If the thunk is only used once, it will be passed through.

As you say, there in increasing support for thunks in ChainRulesCore. All of those operations essentially unthunk (or complain that thunk mutation was attempted).

mzgubic avatar Jun 18 '21 15:06 mzgubic

I wonder if Julia shouldn't have something like a thunk as a native concept. We have something along those lines in so many places/packages.

oschulz avatar Jun 18 '21 15:06 oschulz

unthunk in cases where the thunk is used twice (to avoid the duplicate calculation)

Could we use some result caching / lazy value mechanism?

oschulz avatar Jun 18 '21 15:06 oschulz

Could we use some result caching / lazy value mechanism?

Might lead to even more mem allocs than we already have, though ...

oschulz avatar Jun 18 '21 15:06 oschulz

Yes, thunks are an emulation of a compiler feature julia doesn't have. Haskell (and TensorFlow) have lazy evaluation where only things that are used are computed. Thunks are like 1 level of lazy evaluation.

And there is a more general concept of program slicing program slicing would be particularly useful for AD.

oxinabox avatar Jun 18 '21 15:06 oxinabox

Yes, thunks are an emulation of a compiler feature julia doesn't have.

~~Hm, this got me thinking - thunks are really monads, right? [...]~~

Moved to JuliaDiff/ChainRulesCore.jl#373 .

oschulz avatar Jun 19 '21 10:06 oschulz

Could we use some result caching / lazy value mechanism?

Might lead to even more mem allocs than we already have, though ...

A very very ancient version of ChainRules had this. It was removed long before I was involved. Issue for adding it back is still open https://github.com/JuliaDiff/ChainRulesCore.jl/issues/7#issuecomment-491609475 I do worry about the same thing re mem allocs, here

oxinabox avatar Jun 19 '21 22:06 oxinabox

Issue for adding it back is still open JuliaDiff/ChainRulesCore.jl#7 (comment) I do worry about the same thing re mem allocs, here

Thanks, I added a suggestion there.

oschulz avatar Jun 20 '21 12:06 oschulz

@DhairyaLGandhi do we get your blessing on the "plan"?

oschulz avatar Jun 24 '21 10:06 oschulz

Given having full support for thunks in all ChainRules functions is well and truly on its way, at this point the plan might be moot. And we can soon basically just merge something like this PR.

oxinabox avatar Jun 24 '21 11:06 oxinabox

at this point the plan might be moot. And we can soon basically just merge something like this PR.

More than fine with me - we'll still need to merge FluxML/ZygoteRules.jl#17 first, though - or duplicate that functionality within Zygote (would duplicate most of ZygoteRules.gradm).

oschulz avatar Jun 24 '21 11:06 oschulz

FluxML/ZygoteRules.jl#17 has been merged, I'll get this PR into shape as soon as there's new release of ZygoteRules, so CI can run.

oschulz avatar Oct 01 '21 13:10 oschulz

Ok, new release of ZygoteRules is out (had overlooked it), I'll jump back on this.

oschulz avatar Oct 10 '21 17:10 oschulz

Sorry, bit overloaded but haven't forgotten about this.

oschulz avatar Oct 15 '21 20:10 oschulz

Rebasing this now, but several things have changed in Zygote in the mean time ... hope it still works.

oschulz avatar Oct 17 '21 09:10 oschulz

CI results don't look too horrible - at least some of the test failures look similar to the ones in #1104, so may be unrelated to this PR. Could an expert take a look?

oschulz avatar Oct 17 '21 19:10 oschulz

Might be worth wrapping all tests in a @testset, so that it keeps going past the first failure?

mcabbott avatar Oct 17 '21 22:10 mcabbott

I have one test failure that I can't get figured out (how to fix): The "Params nesting" testset results in a map is not defined on dictionaries. It's caused by the unthunking in pullback(f, ps::Params) somehow, I think (it goes away without, but the user-facing pullback function needs to unthunk, obviously), but I don't understand why ∇map hits a Dict with it but not without.

oschulz avatar Oct 18 '21 19:10 oschulz

I have one test failure that I can't get figured out (how to fix): The "Params nesting" testset results in a map is not defined on dictionaries

Some expert help would be very much appreciated. I learned quite a few things about Zygote while preparing this PR, but I don't think I've reached a full understanding of the compiler and some other parts (like nested Params).

oschulz avatar Oct 19 '21 14:10 oschulz

I also do not understand nested Params. because I never used Params. I literally had to look them up in the docs https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1

I do recall @CarloLucibello made some changes to them recently. (Or maybe that was Flux.Grads ?)

oxinabox avatar Oct 19 '21 16:10 oxinabox

I think this is almost there, if we can only figure the nested Params test out.

oschulz avatar Nov 01 '21 18:11 oschulz

@oschulz do you mind rebasing this? I just tried some local testing and there appear to be new test failures.

Will do! Really need to bring this one home ...

oschulz avatar Mar 11 '22 09:03 oschulz

@ToucheSir rebase done.

oschulz avatar Mar 11 '22 10:03 oschulz

Keeping track of test failure causes as I find them:

  • https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/src/lib/lib.jl#L324 is causing 3 test failures. I think the signature of (back::Jnew)(...) needs to be relaxed or another method added which handles thunks.
  • Similar story with the test above failing on https://github.com/FluxML/ZygoteRules.jl/blob/05cd6e1d41a363b2114fcce2a640df145006a5b7/src/adjoint.jl#L24.
  • I'm not exactly sure what's going on with the increased memory usage for https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/test/features.jl#L668. Perhaps thunks are preventing some kind of copy elision?
  • The elements of dy need to be unthunked some time before, during or after https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/src/lib/array.jl#L294 to fix this test. I list 3 options because I'm not sure whether it would be more appropriate for e.g. accum to handle this, the adjoint or something else.
  • Last but not least, https://github.com/FluxML/Zygote.jl/blob/4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e/test/interface.jl#L144. Here the broken behaviour actually makes more sense to me, and it's not clear from the blame why the line in question was added. @lassepe do you recall why?

ToucheSir avatar Mar 12 '22 17:03 ToucheSir

@ToucheSir that is probably not a useful test but I recall what happened there:

When you compute y, back = pullback(f, Params[x]), then evaluating back multiple times would mutate the Grad object handed out earlier. Say I compute g = back(1) then g1 is a Grad object which has g1[w] as the gradient of f w.r.t. w. If I now call g2 = back(nothing), this mutates g1 so that g1[w] == nothing.

I am not sure, what is the intended behavior actually. It does not seem to be documented at least and it was a really bad foot gun for me. I added this test to test the "default case" to contrast it with the "fixed case" where I copy the grad object in between to avoid this unexpected coupling. So I guess I mostly added that as implicit documentation. I am sorry that this has caused confusion here. Maybe I should just open a PR to the docs instead.

lassepe avatar Mar 12 '22 19:03 lassepe

Thanks for the quick response! I think in that case we can call it a footgun and consider that fixed (unintentionally) by this PR, no need for a docs update.

ToucheSir avatar Mar 12 '22 19:03 ToucheSir

I guess this PR does not fix the problem in general, only for the nothing case, right? For example, if I call g1 = back(1) and then g2 = back(2) it seems like that would still mutate g1 since both unthunk, right? I'm not raising this as an issue against this PR (I think the PR is great and I don't think it is its job to fix this behavior), I'm just trying to figure out whether this would need documentation perhaps.

lassepe avatar Mar 13 '22 16:03 lassepe

Thanks for the changes! Let's see that CI say now ... :-)

oschulz avatar Mar 13 '22 18:03 oschulz

For example, if I call g1 = back(1) and then g2 = back(2) it seems like that would still mutate g1 since both unthunk, right?

It appears that may be addressed as well:

using Zygote

x, y = ones(2), rand(2)
ps = Params([x, y])

l, back = pullback(() -> sum(x) + prod(y), ps)

julia> g1 = back(1.)
Grads(...)

julia> g1.grads
IdDict{Any, Any} with 4 entries:
  :(Main.x)            => Fill(1.0, 2)
  :(Main.y)            => [0.981219, 0.920662]
  [0.920662, 0.981219] => [0.981219, 0.920662]
  [1.0, 1.0]           => Fill(1.0, 2)

julia> g2 = back(2.)
Grads(...)

julia> g1.grads
IdDict{Any, Any} with 4 entries:
  :(Main.x)            => Fill(1.0, 2)
  :(Main.y)            => [0.981219, 0.920662]
  [0.920662, 0.981219] => [0.981219, 0.920662]
  [1.0, 1.0]           => Fill(1.0, 2)

julia> g2.grads
IdDict{Any, Any} with 4 entries:
  :(Main.x)            => Fill(3.0, 2)
  :(Main.y)            => [2.94366, 2.76199]
  [0.920662, 0.981219] => [1.96244, 1.84132]
  [1.0, 1.0]           => Fill(2.0, 2)

ToucheSir avatar Mar 13 '22 19:03 ToucheSir

There's still a lot of new test failures (compared to pre-rebase). Of course Zygote has changed a bit - help tracking them down is welcome.

oschulz avatar Mar 14 '22 08:03 oschulz

@oschulz that's what https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188 is for :)

ToucheSir avatar Mar 14 '22 14:03 ToucheSir

@oschulz that's what https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188 is for :)

Aw, sorry, overlooked that one - great, thanks!

oschulz avatar Mar 14 '22 14:03 oschulz

Is there a plan/timeline on this PR? Not pushing it's just it's quite problematic for some of our (PDE based) codes where this is a huge blocker performance wise.

mloubout avatar Aug 01 '22 12:08 mloubout

I would still like this done. AFAIK there are no remaining blockers -- and haven't been for about 12 months.

@DhairyaLGandhi @oschulz should we have a call and workout what is the way forward with this?

oxinabox avatar Aug 01 '22 12:08 oxinabox

I would still like this done.

Me too.

AFAIK there are no remaining blockers

I don't think there are any blockers conceptually, just some test failures we hadn't managed to figure out.

should we have a call and workout what is the way forward with this?

Gladly!

oschulz avatar Aug 01 '22 17:08 oschulz

@DhairyaLGandhi ?

oxinabox avatar Aug 08 '22 18:08 oxinabox

Sure, sounds good

DhairyaLGandhi avatar Aug 08 '22 19:08 DhairyaLGandhi

Sure, sounds good

Let's find a date via Slack?

oschulz avatar Aug 08 '22 21:08 oschulz

Sure

DhairyaLGandhi avatar Aug 08 '22 22:08 DhairyaLGandhi

As the one who raised all these pesky failure cases in https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-1065931188, I'd be more than happy to help support this effort as well. Be that testing or trying to make sense of AD compiler output.

ToucheSir avatar Aug 08 '22 23:08 ToucheSir