ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Pullback for `tr` produces a CPU `Diagonal` causing downstream scalar indexing on GPUs

Open darsnack opened this issue 3 years ago • 3 comments

Functions like tr(A * B) will throw scalar indexing issues in the pullback for * when A and B are CuArrays. This is because the pullback for tr creates a Diagonal which will cause downstream matrix multiplies to hit the LinearAlgebra definition.

darsnack avatar Oct 17 '22 15:10 darsnack

Worse it's a Diagonal{T,Array}. Would a Diagonal{T, CuArray} work?

mcabbott avatar Oct 17 '22 15:10 mcabbott

Yeah, that seems to avoid scalar indexing

darsnack avatar Oct 17 '22 15:10 darsnack

Then probably it can re-use what sum does, which should also allow 2nd derivatives:

https://github.com/JuliaDiff/ChainRules.jl/blob/a9a84ba6cb8aa9ce079af9401600e7c96a8aff3a/src/rulesets/Base/mapreduce.jl#L47

It would also be nice if the test noticed this. We have @gpu test_rrule(tr, randn(4, 4)) but it apparently isn't smart enough to object to the Array.

mcabbott avatar Oct 17 '22 15:10 mcabbott