ChainRules.jl
ChainRules.jl copied to clipboard
Add rule for `mean(f, x; dims)`
We have rules for mean, except for mean(f, x; dims) which is new as of Julia v1.3
The implementation could look a bit like sum(f, xs):
https://github.com/JuliaDiff/ChainRules.jl/blob/ce78d3d3e8aaf6303e1aa7085fdbdfc2d36d1b64/src/rulesets/Base/mapreduce.jl#L69-L94
Ideally it would probably share code, have a function which for mean gets scale=1/size(...) or something.
Xref #529 which is trying to re-work that rule.
reopened as i had to revert the fix