ChainRules.jl
ChainRules.jl copied to clipboard
`rrule` for `Base.sum` should accept `init` keyword
The below issue contains an MWE of Zygote erroring on the init kwarg to Base.sum. Best way to fix it would be to define an rrule for sum
https://github.com/FluxML/Zygote.jl/issues/1279
cc @mcabbott
There's a rule for sum, https://github.com/JuliaDiff/ChainRules.jl/blob/9adf759bc63432dc518ccf499d6938fc5a217113/src/rulesets/Base/mapreduce.jl#L28-L41, but what it doesn't do is handle init keyword, added in Julia 1.6:
julia> sum([1 2 3; 4 5 6]; init=10)
31
julia> sum([1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
15 17 19
also:
julia> sum(abs2, [1 2 3; 4 5 6]; init=10)
101
julia> sum(abs2, [1 2 3; 4 5 6]; init=10, dims=1)
1×3 Matrix{Int64}:
27 39 55
Xref https://github.com/JuliaDiff/ChainRulesCore.jl/issues/384 --- the lowered form is this
julia> Meta.@lower sum(x; init=10)
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope`
1 ─ %1 = Core.tuple(:init)
│ %2 = Core.apply_type(Core.NamedTuple, %1)
│ %3 = Core.tuple(10)
│ %4 = (%2)(%3)
│ %5 = Core.kwfunc(sum)
│ %6 = (%5)(%4, sum, x)
└── return %6
))))
but defining a rule as rrule(Core.kwfunc(f), kwargs, args...) doesn't work right now.