ChainRules.jl
ChainRules.jl copied to clipboard
Prevent type-inferability escaping for rrule of sortslices
In 1.11 something changed i guess with inlining, constant-propagation and/or unrolling.
And now inds = ntuple(d -> d == dims ? p : (:), N) doesn't infer.
It used to be able to work it out based on constant folding over dims and N
but now it gives back a Tuple{Union{Colon, Vector{Int64}}, Union{Colon, Vector{Int64}}}}
I couldn't workout how to get it to do that again.
But it is so cheap to recompute approprate inds since N is like under 5 most of the time, recomputing it is cheap.
(unlike recomputing p which is not)
This at least stops there being a non-bitstype field on the pullback closure. and contains the inference failure from poluting outside the function.
Together with #816 we should then have 1.11 passing again.