ChainRules.jl
ChainRules.jl copied to clipboard
Rules for `eachslice` with multiple `dims`
With Julia 1.9 eachslice enables to properly drop dimension when taking a reduction function over some dimension (at least based on JuliaLang/julia#16606). However, having implemented this, it seems multiple dimensions is not handled yet. I stumbled upon @mcabbott advice, "the gradient rule for eachslice is unable to handle case right now, please make an issue" to fill an issue, so here it is. If I'm correct, this has not been opened yet.
I just encountered this issue. Is there a way to work around this?
Ideally someone extends the rule. But until then, can you reshape before & after to fit the single dims case? Or, less efficiently, make a comprehension which just calls view.