ChainRules.jl
ChainRules.jl copied to clipboard
`*(::AbstractVector, ::AbstractMatrix)` pullback triggers scalar indexing on the GPU
When multiplying a vector by a matrix, the rrule triggers scalar indexing for GPU arrays.
julia> using Zygote, CUDA
julia> CUDA.allowscalar(false)
julia> x, y = rand(Float32, 6), rand(Float32, 1, 5)
(Float32[0.49304312, 0.30266464, 0.016446471, 0.20248199, 0.07340324, 0.8248376], Float32[0.08382976 0.7970275 … 0.5854495 0.08664018])
julia> gradient((_x, _y) -> sum(_x * _y), x, y)
(Float32[2.457574, 2.457574, 2.457574, 2.457574, 2.457574, 2.457574], Float32[1.9128771 1.9128771 … 1.9128771 1.9128771])
julia> gradient((_x, _y) -> sum(_x * _y), cu(x), cu(y))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
[3] getindex(xs::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
[4] generic_matvecmul!(C::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::Base.ReshapedArray{Float32, 1, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:805
[5] mul!
@ /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:81 [inlined]
[6] mul!
@ /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[7] *(A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, x::Base.ReshapedArray{Float32, 1, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}})
@ LinearAlgebra /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:56
[8] #1480
@ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/arraymath.jl:83 [inlined]
[9] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[10] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237 [inlined]
[11] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:105 [inlined]
[12] map
@ ./tuple.jl:223 [inlined]
[13] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:106 [inlined]
[14] ZBack
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 [inlined]
[15] Pullback
@ ./REPL[8]:1 [inlined]
[16] (::typeof(∂(#3)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[17] (::Zygote.var"#60#61"{typeof(∂(#3))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
[18] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
[19] top-level scope
@ REPL[8]:1
[20] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
This comes from this line using vec(adjoint(::CuArray)) which creates a ReshapedArray which in turn triggers generic matmul for Ȳ * vec(B')).
I am sure a GPU expert can fix this trivially, but I have no idea how.