Zygote.jl
Zygote.jl copied to clipboard
Remove redundant sum() rules
The pullback is non-differentiable, which messes with nested AD (https://github.com/FluxML/Zygote.jl/issues/1450). It's also not clear to me why this rule still exists when ChainRules has a seemingly GPU-compatible one. Let's see what CI says.
PR Checklist
- ~~Tests are added~~
- ~~Documentation, if applicable~~
I thought this existed in order to opt-out of the Zygote rule for sum which makes a FillArray.
julia> gradient(sum, [2.0, 3.0])
(Fill(1.0, 2),)
We could delete that too, it saves one copy sometimes but rarely matters in real code, and causes problems.
Deleting that rule fixes all but one testsuite, https://github.com/FluxML/Zygote.jl/blob/612961353e4a81f9861fbca9db714e86f30ad0a3/test/lib/array.jl#L53. Not sure how best to fix it. Perhaps we could generalize https://github.com/FluxML/Zygote.jl/blob/612961353e4a81f9861fbca9db714e86f30ad0a3/src/lib/array.jl#L340-L342 to work on all Integer
s and convert it to a rrule(::ZygoteRuleConfig, ...)
for future-proofing at the same time?
We could certainly delete the rule for bool arrays, as there's one here:
https://github.com/JuliaDiff/ChainRules.jl/blob/ba52ec89ddd97a07e79cc35a9fa39019915d203b/src/rulesets/Base/nondiff.jl#L80
IDK what the issue with that Dict test is.
(Considering integers to be differentiable was a mistake, IMO, but a breaking change to fix that, here or in CR.)
IDK what the issue with that Dict test is.
The old rule was arguably wrong, because it was passing through the gradient for the summed value without doing any form of projection. If this were a scalar function, asking to differentiate wrt an integer argument would return a float gradient. So in my mind the test is actually capturing incorrect and inconsistent behaviour of the current rule. If we agree on that, I'll just tweak the test and we'll be back on green CI (minus known AbstractFFT failures).
Sorry I didn't look closely, but if the change is just that now you get a Dict of Floats not Ints, then that seems totally fine, we just adjust the test.
The one remaining failure:
sum, prod, cumsum: Test Failed at /var/lib/buildkite-agent/builds/gpuci-1/julialang/zygote-dot-jl/test/gradcheck.jl:117
Expression: gradient(sum, [true, false, true]) == (nothing,)
Evaluated: nothing == (nothing,)
Which comes from the isnothing
ternary on https://github.com/FluxML/Zygote.jl/blob/e0d3d8b1a785ec291f0a41da3f12cad51d80eb6b/src/compiler/interface.jl#L98
@mcabbott do you recall why we're collapsing to nothing
here? I can't recall how we're supposed to handle nothing
vs (nothing,)
vs (nothing, ..., nothing)
when returned from the pullback.
My memory is that Zygote is eager to collapse any tuple of nothings to just nothing, but doesn't always manage to do so. I think at least withgradient
and perhaps gradient
try to restore them & always make a tuple. But I may have forgotten things.
It looks like gradient
is not trying to make a tuple when it goes get singular nothing
. Should we make it do so? A version of this problem (more aggressive collapsing of zeros after moving to CR rules) is also causing the last two (non-unbreaking) test failures in https://github.com/FluxML/Zygote.jl/pull/1328, ref. https://github.com/FluxML/Zygote.jl/actions/runs/6117262926/job/16603631586?pr=1328#step:6:747.
Hi, Is there any hope to merge this PR soon? Is there anything I can do in that direction?
Maybe, if we can get some consensus on the behaviour of gradient
around collapsing zeros. See https://github.com/FluxML/Zygote.jl/pull/1466#issuecomment-1780649701. Once that's been established, the failing test here will either automatically pass or just requires a one-line tweak to start passing.