ChainRules.jl
ChainRules.jl copied to clipboard
Attach rule to `mapfoldl_impl` not `foldl`
Closes #567, perhaps in the minimal way, by attaching these rules to internal function which take positional arguments. Gradient for init is just @not_implemented for now.
One nice effect is that I think foldr may work too.
One weird effect is that accumulate!(f, y, x) will work, silently overwriting y. It does return a NotImplemented, maybe that helps. Xref https://github.com/JuliaDiff/ChainRules.jl/pull/521
Non-vector shapes like accumulate(f, ::Matrix) take a different path, via Iterators.accumulate, and will miss the rule. So will accumulate(f, ::Tuple). Maybe for that case Base's code is OK.
Closes https://github.com/JuliaDiff/ChainRules.jl/issues/672 . Probably closes https://github.com/FluxML/Zygote.jl/issues/1297
what are the TODOs that would make this nolonger [draft]?
I have forgotten. But one was to decide how unhappy we are about this:
accumulate!(f, y, x)will work, silently overwriting y.
And in general about hooking on deep inside Base. I didn't see a nicer way to hook onto accumulate.
Following up on this, could the accumulate! behaviour be worked around by adding an explicit @opt_out for that function?
fff84b5 separates out the Tuple method, for which this way of writing the rule makes more sense. And handles init by a separate rule.
It is, strangely, much slower than tagged version, inside AD.
And it breaks Yota.
The functions between foldl and mapfoldl_impl are these, shouldn't they be easy for AD to work through?
foldl(op, itr; kw...) = mapfoldl(identity, op, itr; kw...)
mapfoldl(f, op, itr; init=_InitialValue()) = mapfoldl_impl(f, op, init, itr)
julia> xn = Tuple(randn(10));
julia> xm = Tuple(rand(10,10) for _ in 1:10);
# Zygote
julia> @btime Zygote.gradient(x -> foldl(/, x), $xn);
min 47.318 ns, mean 77.223 ns (1 allocation, 16 bytes) # before
min 4.381 μs, mean 4.692 μs (37 allocations, 2.16 KiB) # after -- 100x slower
julia> @btime Zygote.gradient(x -> sum(abs2, foldl(*, x)), $xm);
min 17.667 μs, mean 77.964 μs (53 allocations, 29.52 KiB) # before
min 19.708 μs, mean 23.791 μs (69 allocations, 27.48 KiB) # after
julia> @btime Zygote.gradient(x -> Base.afoldl(/, x...), $xn); # no rule -- much slower
min 130.500 μs, mean 135.423 μs (413 allocations, 16.33 KiB)
julia> @btime Zygote.gradient(x -> sum(abs2, Base.afoldl(*, x...)), $xm);
min 143.500 μs, mean 151.413 μs (384 allocations, 40.30 KiB)
# Diffractor
julia> @btime Diffractor.gradient(x -> foldl(/, x), $xn);
min 29.271 ns, mean 30.017 ns (0 allocations) # before
min 350.632 ns, mean 400.959 ns (6 allocations, 672 bytes) # after -- 10x slower
julia> @btime Diffractor.gradient(x -> sum(abs2, foldl(*, x)), $xm);
min 13.666 μs, mean 16.422 μs (29 allocations, 25.38 KiB) # before
min 162.584 μs, mean 218.275 μs (357 allocations, 168.42 KiB); # after
julia> @btime Diffractor.gradient(x -> Base.afoldl(/, x...), $xn); # no rule -- better than Zygote
min 352.882 ns, mean 419.163 ns (6 allocations, 672 bytes)
julia> @btime Diffractor.gradient(x -> sum(abs2, Base.afoldl(/, x...)), $xm)
min 163.125 μs, mean 204.721 μs (357 allocations, 168.42 KiB)
# Yota
julia> @btime Yota.grad(x -> foldl(/, x), $xn);
min 182.790 ns, mean 657.142 ns (3 allocations, 208 bytes) # before
ERROR: No deriative rule found for op %3 = foldl(/, %2)::Float64, try defining it... # after -- fails
julia> @btime Yota.grad(x -> sum(abs2, foldl(*, x)), $xm);
min 8.583 μs, mean 50.186 μs (21 allocations, 16.19 KiB)
julia> Yota.grad(x -> Base.afoldl(/, x...), xn);
ERROR: syntax: Slot objects should not occur in an AST
# Checking pieces?
julia> yyy = Yota.YotaRuleConfig()
julia> @code_warntype rrule(yyy, foldl, /, xn) # before
julia> @code_warntype rrule(yyy, foldl, /, xn)[2](1.0)
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn) # after
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)[2](1.0)
julia> @btime rrule($yyy, foldl, /, $xn)[2](1.0);
min 29.271 ns, mean 30.036 ns (0 allocations)
julia> @btime rrule($yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), $xn)[2](1.0);
min 29.271 ns, mean 29.753 ns (0 allocations)
could the
accumulate!behaviour be worked around by adding an explicit@opt_outfor that function?
I don't see how. I think you can only opt out of functions which have rules, and those ones need to be called to work.
~~This needs https://github.com/JuliaDiff/ChainRulesCore.jl/pull/567~~ now merged.
9af7a64 also adds mapfoldl, as map then foldl, for tuples.
Maybe also worth noting, moving the rule to _accumulate! means no such rule for tuples. But @less accumulate(/, (1,2,3)) shows this is pretty simple, and calls Base.afoldl. Perhaps the tuple foldl rule should be applied to Base.afoldl too (or v-v).
Trying a bit to track this down, today, I think the slowdown is just some quirk of Zygote's handling of keywords. So it's not the rule's fault. And anything which fixes the init problem will probably hit it. Diffractor no longer sees the slowdown seen above:
using Diffractor, ChainRulesCore
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing
# Solves same error as https://github.com/JuliaDiff/ChainRulesCore.jl/pull/503
xn = Tuple(randn(10));
@btime Diffractor.gradient(x -> foldl(/, x), $xn);
# min 29.313 ns, mean 29.545 ns (0 allocations) before (old rule on foldl)
# min 29.313 ns, mean 29.522 ns (0 allocations) after (new rule on Base.mapfoldl_impl)
@btime Diffractor.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 47.625 μs, mean 53.596 μs (569 allocations, 33.16 KiB) before -- i.e. with no rule, just Base, NB μs
_foldl(op::G, itr; kw...) where {G} = _mapfoldl(identity, op, itr; kw...)
_mapfoldl(f::F, op::G, itr; init=Base._InitialValue()) where {F,G} = Base.mapfoldl_impl(f, op, init, itr)
@btime Diffractor.gradient(x -> _foldl(/, x), $xn);
# min 56.542 μs, mean 62.279 μs (672 allocations, 38.78 KiB) before -- i.e. with no rule, just Base, NB μs
import Zygote
@btime Zygote.gradient(x -> foldl(/, x), $xn);
# min 47.402 ns, mean 48.592 ns (1 allocation, 16 bytes) before
# min 4.482 μs, mean 9.120 μs (37 allocations, 2.16 KiB) after -- this I didn't like, above
# Same with Zygote#master, thus including https://github.com/FluxML/Zygote.jl/pull/1286
@btime Zygote.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 152.667 μs, mean 157.707 μs (494 allocations, 26.44 KiB) before -- i.e. using no rule, jus Base, NB μs
# min 47.402 ns, mean 82.826 ns (1 allocation, 16 bytes) after -- so the issue is Zygote & keywords
using Yota
@btime Yota.grad(x -> foldl(/, x), $xn);
# min 235.140 ns, mean 251.834 ns (3 allocations, 208 bytes) before
# error afterwards, doesn't track further?
ChainRulesCore.@non_differentiable Base._InitialValue()
@btime Yota.grad(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 231.805 ns, mean 250.267 ns (3 allocations, 208 bytes) after
So I think we should merge this, if tests pass etc.
Zygote tries to diff through the kwsorter definition (i.e. https://docs.julialang.org/en/v1/devdocs/functions/#Keyword-arguments), which includes control flow. It's very difficult to make this type stable because it requires saving a different set of pullbacks for each branch (does anybody know how does Diffractor does this?), but https://github.com/FluxML/Zygote.jl/pull/1195 might help with runtime overhead.
After looking into Diffractor, I think whatever it does happens outside the actual AD transform (perhaps leaving control flow intact is enough), but the ability to have unused branches/blocks in the keyword sorter pruned in the final IR does wonders for type stability. Inspired by this, https://github.com/FluxML/Zygote.jl/issues/446#issuecomment-1221236153 has some thoughts on how we might do something similar there.
The remaining test failure is 1.8 on x86:
Testing rulesets/LinearAlgebra/structured.jl:
[181](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:184)
170.036836 seconds (119.56 M allocations: 3.781 GiB, 9.09% gc time, 97.04% compilation time)
[182](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:185)
Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:186)
terminate called after throwing an instance of 'std::bad_alloc'
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:187)
what(): std::bad_alloc
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:188)
[186](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:189)
signal (6): Aborted
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:190)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[188](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:191)
__kernel_vsyscall at linux-gate.so.1 (unknown line)
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:192)
gsignal at /lib32/libc.so.6 (unknown line)
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/7934430028?check_suite_focus=true#step:6:193)
Also happened https://github.com/JuliaDiff/ChainRules.jl/runs/7933271950?check_suite_focus=true with https://github.com/JuliaDiff/ChainRules.jl/pull/667 (no longer needed). Or:
Testing rulesets/LinearAlgebra/symmetric.jl:
[183](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:186)
Internal error: encountered unexpected error in runtime:
[184](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:187)
OutOfMemoryError()
[185](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:188)
terminate called after throwing an instance of 'std::bad_alloc'
[186](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:189)
what(): std::bad_alloc
[187](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:190)
[188](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:191)
signal (6): Aborted
[189](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:192)
in expression starting at /home/runner/work/ChainRules.jl/ChainRules.jl/test/rulesets/LinearAlgebra/symmetric.jl:1
[190](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:193)
ERROR: LoadError: Package ChainRules errored during testing (received signal: 6)
[191](https://github.com/JuliaDiff/ChainRules.jl/runs/8000812940?check_suite_focus=true#step:6:194)
If you think this is good to go then we can merge it. If we see it is breaking things we can revert it