Zygote.jl
Zygote.jl copied to clipboard
Free CuArrays in the reverse pass
This adds:
- A flag to
Context
to indicate that the pullback will never be called twice -- set to true forgradient
, false forjacobian
- Modifications to many rules, esp. for broadcasting, so that
y=f(x)
in the forward pass hasfinalize(y)
in the reverse. This increases the largest size of Flux model which can run on a given GPU.
Applying such modifications everywhere led to many errors, some from rules like y = x .+ false
which return y === x
under Zygote. So they now require a separate macro @adjoint_final
.
At present this modification is applied to all CR rrule
s. This is probably unsafe and we should revert 2524163c8a5bd4aab9101c964d6d8d0676b501e8 . Unclear how best to opt-in within ChainRules. Xref https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 about the idea of a flag, but not entirely sure that's the right approach.
Explicit finalising won't work well with thunks. Which doesn't matter at all yet, but might after #966.
It also does not work with second derivatives, hence is disabled. Other uses of the context flag (like testing only_once(cfg)
& then over-writing some array) probably also need to be disabled.
Needs https://github.com/FluxML/ZygoteRules.jl/pull/23 so CI will fail. Locally, one failure, one failure to fail:
Global Params: Error During Test at /Users/me/.julia/dev/Zygote/test/features.jl:399
Got exception outside of a @test
KeyError: key :(Main.global_param) not found
Stacktrace:
[1] getindex(d::IdDict{Any, Any}, key::Any)
@ Base ./iddict.jl:108
[2] macro expansion
@ ~/.julia/dev/Zygote/test/features.jl:404 [inlined]
Compiler: Error During Test at /Users/me/.julia/dev/Zygote/test/compiler.jl:35
Unexpected Pass
Expression: trace_contains(bt, :badly, "compiler.jl", 24)
Got correct result, please change to @test if no longer broken.
Could we combine this with https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 ?
It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.
At present, BTW, most of these maybe_final
s seem not to be called & I'm not sure why.
It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.
Or introduce another type parameter like ZygoteRuleConfig{once} where once
?
At present, BTW, most of these maybe_finals seem not to be called & I'm not sure why.
Do you mean the finalize
is not called, or it is called but the memory is not freed?
Do you mean the finalize is not called
With something like this
Zygote.maybe_final(x::CuArray) = begin CNT[]+=1; CUDA.unsafe_free!(x); nothing end
a big ResNet gradient [used to] ~~gives me CNT[] == 3 afterwards. (Thought I had this working when I opened it...)~~ [fixed in 9f01eff]
type parameter like
ZygoteRuleConfig{once} where once
But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
and the new ones would need different supertypes.
We could also think about changing it to <: RuleConfig{Union{HasReverseMode,NoForwardsMode}, true}
, in which case matching Context{..., true}
would be easy.
But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} and the new ones would need different supertypes.
Couldn't it be done like struct ZygoteRuleConfig{P<:PullbackCapability} <: RuleConfig{Union{HasReverseMode,NoForwardsMode,P}}
?
Oh right, that ought to work.
Current status is that some arrays are freed too early (e.g. with Metalhead's ResNet, at addact(relu)) but it's hard to isolate. Still happens if I disable all thunks. In Zygote's tests, some failures due to too-early fill!(x, NaN)
(included here as a test), perhaps related.