Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Remove redundant sum() rules

Open ToucheSir opened this issue 1 year ago • 10 comments

The pullback is non-differentiable, which messes with nested AD (https://github.com/FluxML/Zygote.jl/issues/1450). It's also not clear to me why this rule still exists when ChainRules has a seemingly GPU-compatible one. Let's see what CI says.

PR Checklist

  • ~~Tests are added~~
  • ~~Documentation, if applicable~~

ToucheSir avatar Sep 01 '23 19:09 ToucheSir

I thought this existed in order to opt-out of the Zygote rule for sum which makes a FillArray.

julia> gradient(sum, [2.0, 3.0])
(Fill(1.0, 2),)

We could delete that too, it saves one copy sometimes but rarely matters in real code, and causes problems.

mcabbott avatar Sep 02 '23 19:09 mcabbott

Deleting that rule fixes all but one testsuite, https://github.com/FluxML/Zygote.jl/blob/612961353e4a81f9861fbca9db714e86f30ad0a3/test/lib/array.jl#L53. Not sure how best to fix it. Perhaps we could generalize https://github.com/FluxML/Zygote.jl/blob/612961353e4a81f9861fbca9db714e86f30ad0a3/src/lib/array.jl#L340-L342 to work on all Integers and convert it to a rrule(::ZygoteRuleConfig, ...) for future-proofing at the same time?

ToucheSir avatar Sep 05 '23 03:09 ToucheSir

We could certainly delete the rule for bool arrays, as there's one here:

https://github.com/JuliaDiff/ChainRules.jl/blob/ba52ec89ddd97a07e79cc35a9fa39019915d203b/src/rulesets/Base/nondiff.jl#L80

IDK what the issue with that Dict test is.

(Considering integers to be differentiable was a mistake, IMO, but a breaking change to fix that, here or in CR.)

mcabbott avatar Sep 05 '23 17:09 mcabbott

IDK what the issue with that Dict test is.

The old rule was arguably wrong, because it was passing through the gradient for the summed value without doing any form of projection. If this were a scalar function, asking to differentiate wrt an integer argument would return a float gradient. So in my mind the test is actually capturing incorrect and inconsistent behaviour of the current rule. If we agree on that, I'll just tweak the test and we'll be back on green CI (minus known AbstractFFT failures).

ToucheSir avatar Sep 05 '23 22:09 ToucheSir

Sorry I didn't look closely, but if the change is just that now you get a Dict of Floats not Ints, then that seems totally fine, we just adjust the test.

mcabbott avatar Sep 05 '23 23:09 mcabbott

The one remaining failure:

sum, prod, cumsum: Test Failed at /var/lib/buildkite-agent/builds/gpuci-1/julialang/zygote-dot-jl/test/gradcheck.jl:117
  Expression: gradient(sum, [true, false, true]) == (nothing,)
 Evaluated: nothing == (nothing,)

Which comes from the isnothing ternary on https://github.com/FluxML/Zygote.jl/blob/e0d3d8b1a785ec291f0a41da3f12cad51d80eb6b/src/compiler/interface.jl#L98

@mcabbott do you recall why we're collapsing to nothing here? I can't recall how we're supposed to handle nothing vs (nothing,) vs (nothing, ..., nothing) when returned from the pullback.

ToucheSir avatar Sep 08 '23 03:09 ToucheSir

My memory is that Zygote is eager to collapse any tuple of nothings to just nothing, but doesn't always manage to do so. I think at least withgradient and perhaps gradient try to restore them & always make a tuple. But I may have forgotten things.

mcabbott avatar Sep 08 '23 03:09 mcabbott

It looks like gradient is not trying to make a tuple when it goes get singular nothing. Should we make it do so? A version of this problem (more aggressive collapsing of zeros after moving to CR rules) is also causing the last two (non-unbreaking) test failures in https://github.com/FluxML/Zygote.jl/pull/1328, ref. https://github.com/FluxML/Zygote.jl/actions/runs/6117262926/job/16603631586?pr=1328#step:6:747.

ToucheSir avatar Sep 08 '23 03:09 ToucheSir

Hi, Is there any hope to merge this PR soon? Is there anything I can do in that direction?

FerreolS avatar Nov 30 '23 08:11 FerreolS

Maybe, if we can get some consensus on the behaviour of gradient around collapsing zeros. See https://github.com/FluxML/Zygote.jl/pull/1466#issuecomment-1780649701. Once that's been established, the failing test here will either automatically pass or just requires a one-line tweak to start passing.

ToucheSir avatar Dec 01 '23 15:12 ToucheSir