ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

`rrule` for `Base.sum` should accept `init` keyword

Open vpuri3 opened this issue 3 years ago • 3 comments

The below issue contains an MWE of Zygote erroring on the init kwarg to Base.sum. Best way to fix it would be to define an rrule for sum

https://github.com/FluxML/Zygote.jl/issues/1279

vpuri3 avatar Aug 03 '22 02:08 vpuri3

cc @mcabbott

vpuri3 avatar Aug 03 '22 02:08 vpuri3

There's a rule for sum, https://github.com/JuliaDiff/ChainRules.jl/blob/9adf759bc63432dc518ccf499d6938fc5a217113/src/rulesets/Base/mapreduce.jl#L28-L41, but what it doesn't do is handle init keyword, added in Julia 1.6:

julia> sum([1 2 3; 4 5 6]; init=10)
31

julia> sum([1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
 15  17  19

also:

julia> sum(abs2, [1 2 3; 4 5 6]; init=10)
101

julia> sum(abs2, [1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
 27  39  55

mcabbott avatar Aug 03 '22 02:08 mcabbott

Xref https://github.com/JuliaDiff/ChainRulesCore.jl/issues/384 --- the lowered form is this

julia> Meta.@lower sum(x; init=10)
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1 ─ %1 = Core.tuple(:init)
│   %2 = Core.apply_type(Core.NamedTuple, %1)
│   %3 = Core.tuple(10)
│   %4 = (%2)(%3)
│   %5 = Core.kwfunc(sum)
│   %6 = (%5)(%4, sum, x)
└──      return %6
))))

but defining a rule as rrule(Core.kwfunc(f), kwargs, args...) doesn't work right now.

mcabbott avatar Aug 27 '22 17:08 mcabbott