ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Attach rule to `mapfoldl_impl` not `foldl`

Open mcabbott opened this issue 3 years ago • 10 comments

Closes #567, perhaps in the minimal way, by attaching these rules to internal function which take positional arguments. Gradient for init is just @not_implemented for now.

One nice effect is that I think foldr may work too.

One weird effect is that accumulate!(f, y, x) will work, silently overwriting y. It does return a NotImplemented, maybe that helps. Xref https://github.com/JuliaDiff/ChainRules.jl/pull/521

Non-vector shapes like accumulate(f, ::Matrix) take a different path, via Iterators.accumulate, and will miss the rule. So will accumulate(f, ::Tuple). Maybe for that case Base's code is OK.


Closes https://github.com/JuliaDiff/ChainRules.jl/issues/672 . Probably closes https://github.com/FluxML/Zygote.jl/issues/1297

mcabbott avatar Jan 15 '22 14:01 mcabbott

what are the TODOs that would make this nolonger [draft]?

oxinabox avatar Feb 15 '22 18:02 oxinabox

I have forgotten. But one was to decide how unhappy we are about this:

accumulate!(f, y, x) will work, silently overwriting y.

And in general about hooking on deep inside Base. I didn't see a nicer way to hook onto accumulate.

mcabbott avatar Feb 15 '22 18:02 mcabbott

Following up on this, could the accumulate! behaviour be worked around by adding an explicit @opt_out for that function?

ToucheSir avatar Jun 21 '22 02:06 ToucheSir

fff84b5 separates out the Tuple method, for which this way of writing the rule makes more sense. And handles init by a separate rule.

It is, strangely, much slower than tagged version, inside AD.

And it breaks Yota.

The functions between foldl and mapfoldl_impl are these, shouldn't they be easy for AD to work through?

foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
mapfoldl(f, op, itr; init=_InitialValue()) = mapfoldl_impl(f, op, init, itr)
julia> xn = Tuple(randn(10));

julia> xm = Tuple(rand(10,10) for _ in 1:10);

# Zygote

julia> @btime Zygote.gradient(x -> foldl(/, x), $xn);
  min 47.318 ns, mean 77.223 ns (1 allocation, 16 bytes)  # before
  min 4.381 μs, mean 4.692 μs (37 allocations, 2.16 KiB)  # after -- 100x slower

julia> @btime Zygote.gradient(x -> sum(abs2, foldl(*, x)), $xm);
  min 17.667 μs, mean 77.964 μs (53 allocations, 29.52 KiB)  # before
  min 19.708 μs, mean 23.791 μs (69 allocations, 27.48 KiB)  # after

julia> @btime Zygote.gradient(x -> Base.afoldl(/, x...), $xn);  # no rule -- much slower
  min 130.500 μs, mean 135.423 μs (413 allocations, 16.33 KiB)

julia> @btime Zygote.gradient(x -> sum(abs2, Base.afoldl(*, x...)), $xm);
  min 143.500 μs, mean 151.413 μs (384 allocations, 40.30 KiB)

# Diffractor

julia> @btime Diffractor.gradient(x -> foldl(/, x), $xn);
  min 29.271 ns, mean 30.017 ns (0 allocations)              # before
  min 350.632 ns, mean 400.959 ns (6 allocations, 672 bytes) # after -- 10x slower

julia> @btime Diffractor.gradient(x -> sum(abs2, foldl(*, x)), $xm);
  min 13.666 μs, mean 16.422 μs (29 allocations, 25.38 KiB)      # before
  min 162.584 μs, mean 218.275 μs (357 allocations, 168.42 KiB); # after

julia> @btime Diffractor.gradient(x -> Base.afoldl(/, x...), $xn);  # no rule -- better than Zygote
  min 352.882 ns, mean 419.163 ns (6 allocations, 672 bytes)

julia> @btime Diffractor.gradient(x -> sum(abs2, Base.afoldl(/, x...)), $xm)
  min 163.125 μs, mean 204.721 μs (357 allocations, 168.42 KiB)

# Yota

julia> @btime Yota.grad(x -> foldl(/, x), $xn);
  min 182.790 ns, mean 657.142 ns (3 allocations, 208 bytes)  # before
ERROR: No deriative rule found for op %3 = foldl(/, %2)::Float64, try defining it... # after -- fails

julia> @btime Yota.grad(x -> sum(abs2, foldl(*, x)), $xm);
  min 8.583 μs, mean 50.186 μs (21 allocations, 16.19 KiB)

julia> Yota.grad(x -> Base.afoldl(/, x...), xn);
ERROR: syntax: Slot objects should not occur in an AST

# Checking pieces?

julia> yyy = Yota.YotaRuleConfig()

julia> @code_warntype rrule(yyy, foldl, /, xn)  # before
julia> @code_warntype rrule(yyy, foldl, /, xn)[2](1.0)

julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)  # after
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)[2](1.0)

julia> @btime rrule($yyy, foldl, /, $xn)[2](1.0);
  min 29.271 ns, mean 30.036 ns (0 allocations)

julia> @btime rrule($yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), $xn)[2](1.0);
  min 29.271 ns, mean 29.753 ns (0 allocations)

could the accumulate! behaviour be worked around by adding an explicit @opt_out for that function?

I don't see how. I think you can only opt out of functions which have rules, and those ones need to be called to work.

~~This needs https://github.com/JuliaDiff/ChainRulesCore.jl/pull/567~~ now merged.

9af7a64 also adds mapfoldl, as map then foldl, for tuples.

Maybe also worth noting, moving the rule to _accumulate! means no such rule for tuples. But @less accumulate(/, (1,2,3)) shows this is pretty simple, and calls Base.afoldl. Perhaps the tuple foldl rule should be applied to Base.afoldl too (or v-v).

mcabbott avatar Jul 19 '22 02:07 mcabbott

Trying a bit to track this down, today, I think the slowdown is just some quirk of Zygote's handling of keywords. So it's not the rule's fault. And anything which fixes the init problem will probably hit it. Diffractor no longer sees the slowdown seen above:

using Diffractor, ChainRulesCore
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing
# Solves same error as   https://github.com/JuliaDiff/ChainRulesCore.jl/pull/503
xn = Tuple(randn(10));

@btime Diffractor.gradient(x -> foldl(/, x), $xn);
#   min 29.313 ns, mean 29.545 ns (0 allocations)  before (old rule on foldl)
#   min 29.313 ns, mean 29.522 ns (0 allocations)  after (new rule on Base.mapfoldl_impl)

@btime Diffractor.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#   min 47.625 μs, mean 53.596 μs (569 allocations, 33.16 KiB)  before -- i.e. with no rule, just Base, NB μs
_foldl(op::G, itr; kw...) where {G} = _mapfoldl(identity, op, itr; kw...)
_mapfoldl(f::F, op::G, itr; init=Base._InitialValue()) where {F,G} = Base.mapfoldl_impl(f, op, init, itr)
@btime Diffractor.gradient(x -> _foldl(/, x), $xn);
#   min 56.542 μs, mean 62.279 μs (672 allocations, 38.78 KiB)  before -- i.e. with no rule, just Base, NB μs

import Zygote
@btime Zygote.gradient(x -> foldl(/, x), $xn);
#   min 47.402 ns, mean 48.592 ns (1 allocation, 16 bytes)  before
#   min 4.482 μs, mean 9.120 μs (37 allocations, 2.16 KiB)  after -- this I didn't like, above
# Same with Zygote#master, thus including https://github.com/FluxML/Zygote.jl/pull/1286

@btime Zygote.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#   min 152.667 μs, mean 157.707 μs (494 allocations, 26.44 KiB)  before -- i.e. using no rule, jus Base, NB μs
#   min 47.402 ns, mean 82.826 ns (1 allocation, 16 bytes)  after -- so the issue is Zygote & keywords

using Yota
@btime Yota.grad(x -> foldl(/, x), $xn);
#   min 235.140 ns, mean 251.834 ns (3 allocations, 208 bytes)  before
# error afterwards, doesn't track further?
ChainRulesCore.@non_differentiable Base._InitialValue()
@btime Yota.grad(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
#  min 231.805 ns, mean 250.267 ns (3 allocations, 208 bytes)  after

So I think we should merge this, if tests pass etc.

mcabbott avatar Aug 19 '22 05:08 mcabbott

Zygote tries to diff through the kwsorter definition (i.e. https://docs.julialang.org/en/v1/devdocs/functions/#Keyword-arguments), which includes control flow. It's very difficult to make this type stable because it requires saving a different set of pullbacks for each branch (does anybody know how does Diffractor does this?), but https://github.com/FluxML/Zygote.jl/pull/1195 might help with runtime overhead.

ToucheSir avatar Aug 19 '22 14:08 ToucheSir

After looking into Diffractor, I think whatever it does happens outside the actual AD transform (perhaps leaving control flow intact is enough), but the ability to have unused branches/blocks in the keyword sorter pruned in the final IR does wonders for type stability. Inspired by this, https://github.com/FluxML/Zygote.jl/issues/446#issuecomment-1221236153 has some thoughts on how we might do something similar there.

ToucheSir avatar Aug 20 '22 05:08 ToucheSir

The remaining test failure is 1.8 on x86:

Testing rulesets/LinearAlgebra/structured.jl:
[181](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:184)
170.036836 seconds (119.56 M allocations: 3.781 GiB, 9.09% gc time, 97.04% compilation time)
[182](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:185)
Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:186)
terminate called after throwing an instance of 'std::bad_alloc'
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:187)
  what():  std::bad_alloc
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:188)

[186](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:189)
signal (6): Aborted
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:190)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[188](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:191)
__kernel_vsyscall at linux-gate.so.1 (unknown line)
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:192)
gsignal at /lib32/libc.so.6 (unknown line)
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:193)

Also happened https://github.com/JuliaDiff/ChainRules.jl/runs/7933271950?check_suite_focus=true with https://github.com/JuliaDiff/ChainRules.jl/pull/667 (no longer needed). Or:

Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:186)
Internal error: encountered unexpected error in runtime:
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:187)
OutOfMemoryError()
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:188)
terminate called after throwing an instance of 'std::bad_alloc'
[186](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:189)
  what():  std::bad_alloc
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:190)

[188](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:191)
signal (6): Aborted
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:192)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:193)
ERROR: LoadError: Package ChainRules errored during testing (received signal: 6)
[191](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:194)

mcabbott avatar Aug 20 '22 23:08 mcabbott

Can this be merged?

It's not the last word, as noted above, but it is a step forwards.

mcabbott avatar Aug 25 '22 14:08 mcabbott

If you think this is good to go then we can merge it. If we see it is breaking things we can revert it

oxinabox avatar Mar 07 '23 06:03 oxinabox