ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

summing many differentials (varadic +)

Open willtebbutt opened this issue 4 years ago • 1 comments

Zygote regularly emits things like

accum(a, b, c, d)

which is equivalent to

+(a, b, c, d)

in ChainRules language. This means that we a) need to support this and b) could optimise it 🥳 .

A simple default implementation already exists provided that +(a, b) is defined, and for simple things like Arrays this is optimised already. For example:

using BenchmarkTools
a = randn(10);
b = randn(10);
c = randn(10);
@benchmark $a + $b + $c

yields

BenchmarkTools.Trial:
  memory estimate:  160 bytes
  allocs estimate:  1
  --------------
  minimum time:     51.717 ns (0.00% GC)
  median time:      52.820 ns (0.00% GC)
  mean time:        55.692 ns (1.67% GC)
  maximum time:     443.181 ns (82.46% GC)
  --------------
  samples:          10000
  evals/sample:     980

Only a single temporary was allocated, which is great. However, we don't have this optimisation for e.g. Composites. Consider

julia> using ChainRulesCore

julia> a = Composite{Any}(a);

julia> b = Composite{Any}(b);

julia> c = Composite{Any}(c);

julia> @benchmark $a + $b + $c
BenchmarkTools.Trial:
  memory estimate:  320 bytes
  allocs estimate:  2
  --------------
  minimum time:     84.682 ns (0.00% GC)
  median time:      88.059 ns (0.00% GC)
  mean time:        93.438 ns (2.04% GC)
  maximum time:     644.918 ns (85.62% GC)
  --------------
  samples:          10000
  evals/sample:     947

It would be nice if we could obtain the same benefits here as in the Array case.

Loosely related to #226 and #113

willtebbutt avatar Jan 12 '21 13:01 willtebbutt

we should make + and add!! varadic and handle this case.

oxinabox avatar Jan 12 '21 13:01 oxinabox