ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
summing many differentials (varadic +)
Zygote regularly emits things like
accum(a, b, c, d)
which is equivalent to
+(a, b, c, d)
in ChainRules language. This means that we a) need to support this and b) could optimise it 🥳 .
A simple default implementation already exists provided that +(a, b) is defined, and for simple things like Arrays this is optimised already. For example:
using BenchmarkTools
a = randn(10);
b = randn(10);
c = randn(10);
@benchmark $a + $b + $c
yields
BenchmarkTools.Trial:
memory estimate: 160 bytes
allocs estimate: 1
--------------
minimum time: 51.717 ns (0.00% GC)
median time: 52.820 ns (0.00% GC)
mean time: 55.692 ns (1.67% GC)
maximum time: 443.181 ns (82.46% GC)
--------------
samples: 10000
evals/sample: 980
Only a single temporary was allocated, which is great. However, we don't have this optimisation for e.g. Composites. Consider
julia> using ChainRulesCore
julia> a = Composite{Any}(a);
julia> b = Composite{Any}(b);
julia> c = Composite{Any}(c);
julia> @benchmark $a + $b + $c
BenchmarkTools.Trial:
memory estimate: 320 bytes
allocs estimate: 2
--------------
minimum time: 84.682 ns (0.00% GC)
median time: 88.059 ns (0.00% GC)
mean time: 93.438 ns (2.04% GC)
maximum time: 644.918 ns (85.62% GC)
--------------
samples: 10000
evals/sample: 947
It would be nice if we could obtain the same benefits here as in the Array case.
Loosely related to #226 and #113
we should make + and add!! varadic and handle this case.