Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Free CuArrays in the reverse pass

Open mcabbott opened this issue 2 years ago • 6 comments

This adds:

  • A flag to Context to indicate that the pullback will never be called twice -- set to true for gradient, false for jacobian
  • Modifications to many rules, esp. for broadcasting, so that y=f(x) in the forward pass has finalize(y) in the reverse. This increases the largest size of Flux model which can run on a given GPU.

Applying such modifications everywhere led to many errors, some from rules like y = x .+ false which return y === x under Zygote. So they now require a separate macro @adjoint_final.

At present this modification is applied to all CR rrules. This is probably unsafe and we should revert 2524163c8a5bd4aab9101c964d6d8d0676b501e8 . Unclear how best to opt-in within ChainRules. Xref https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 about the idea of a flag, but not entirely sure that's the right approach.

Explicit finalising won't work well with thunks. Which doesn't matter at all yet, but might after #966.

It also does not work with second derivatives, hence is disabled. Other uses of the context flag (like testing only_once(cfg) & then over-writing some array) probably also need to be disabled.

Needs https://github.com/FluxML/ZygoteRules.jl/pull/23 so CI will fail. Locally, one failure, one failure to fail:

Global Params: Error During Test at /Users/me/.julia/dev/Zygote/test/features.jl:399
  Got exception outside of a @test
  KeyError: key :(Main.global_param) not found
  Stacktrace:
    [1] getindex(d::IdDict{Any, Any}, key::Any)
      @ Base ./iddict.jl:108
    [2] macro expansion
      @ ~/.julia/dev/Zygote/test/features.jl:404 [inlined]

Compiler: Error During Test at /Users/me/.julia/dev/Zygote/test/compiler.jl:35
 Unexpected Pass
 Expression: trace_contains(bt, :badly, "compiler.jl", 24)
 Got correct result, please change to @test if no longer broken.

mcabbott avatar Dec 18 '22 21:12 mcabbott

Could we combine this with https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 ?

chengchingwen avatar Dec 19 '22 09:12 chengchingwen

It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.

At present, BTW, most of these maybe_finals seem not to be called & I'm not sure why.

mcabbott avatar Dec 19 '22 14:12 mcabbott

It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.

Or introduce another type parameter like ZygoteRuleConfig{once} where once?

At present, BTW, most of these maybe_finals seem not to be called & I'm not sure why.

Do you mean the finalize is not called, or it is called but the memory is not freed?

chengchingwen avatar Dec 19 '22 14:12 chengchingwen

Do you mean the finalize is not called

With something like this

Zygote.maybe_final(x::CuArray) = begin CNT[]+=1; CUDA.unsafe_free!(x); nothing end

a big ResNet gradient [used to] ~~gives me CNT[] == 3 afterwards. (Thought I had this working when I opened it...)~~ [fixed in 9f01eff]

type parameter like ZygoteRuleConfig{once} where once

But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} and the new ones would need different supertypes.

We could also think about changing it to <: RuleConfig{Union{HasReverseMode,NoForwardsMode}, true}, in which case matching Context{..., true} would be easy.

mcabbott avatar Dec 19 '22 14:12 mcabbott

But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} and the new ones would need different supertypes.

Couldn't it be done like struct ZygoteRuleConfig{P<:PullbackCapability} <: RuleConfig{Union{HasReverseMode,NoForwardsMode,P}}?

chengchingwen avatar Dec 20 '22 12:12 chengchingwen

Oh right, that ought to work.

Current status is that some arrays are freed too early (e.g. with Metalhead's ResNet, at addact(relu)) but it's hard to isolate. Still happens if I disable all thunks. In Zygote's tests, some failures due to too-early fill!(x, NaN) (included here as a test), perhaps related.

mcabbott avatar Dec 20 '22 14:12 mcabbott