ChainRules.jl
ChainRules.jl copied to clipboard
Pullback for `tr` produces a CPU `Diagonal` causing downstream scalar indexing on GPUs
Functions like tr(A * B) will throw scalar indexing issues in the pullback for * when A and B are CuArrays. This is because the pullback for tr creates a Diagonal which will cause downstream matrix multiplies to hit the LinearAlgebra definition.
Worse it's a Diagonal{T,Array}. Would a Diagonal{T, CuArray} work?
Yeah, that seems to avoid scalar indexing
Then probably it can re-use what sum does, which should also allow 2nd derivatives:
https://github.com/JuliaDiff/ChainRules.jl/blob/a9a84ba6cb8aa9ce079af9401600e7c96a8aff3a/src/rulesets/Base/mapreduce.jl#L47
It would also be nice if the test noticed this. We have @gpu test_rrule(tr, randn(4, 4)) but it apparently isn't smart enough to object to the Array.