ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

`*(::AbstractVector, ::AbstractMatrix)` pullback triggers scalar indexing on the GPU

Open darsnack opened this issue 3 years ago • 1 comments

When multiplying a vector by a matrix, the rrule triggers scalar indexing for GPU arrays.

julia> using Zygote, CUDA

julia> CUDA.allowscalar(false)

julia> x, y = rand(Float32, 6), rand(Float32, 1, 5)
(Float32[0.49304312, 0.30266464, 0.016446471, 0.20248199, 0.07340324, 0.8248376], Float32[0.08382976 0.7970275 … 0.5854495 0.08664018])

julia> gradient((_x, _y) -> sum(_x * _y), x, y)
(Float32[2.457574, 2.457574, 2.457574, 2.457574, 2.457574, 2.457574], Float32[1.9128771 1.9128771 … 1.9128771 1.9128771])

julia> gradient((_x, _y) -> sum(_x * _y), cu(x), cu(y))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
  [3] getindex(xs::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9
  [4] generic_matvecmul!(C::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::Base.ReshapedArray{Float32, 1, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:805
  [5] mul!
    @ /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:81 [inlined]
  [6] mul!
    @ /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
  [7] *(A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, x::Base.ReshapedArray{Float32, 1, LinearAlgebra.Adjoint{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}})
    @ LinearAlgebra /opt/julia/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:56
  [8] #1480
    @ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/arraymath.jl:83 [inlined]
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
 [10] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:237 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:105 [inlined]
 [12] map
    @ ./tuple.jl:223 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:106 [inlined]
 [14] ZBack
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 [inlined]
 [15] Pullback
    @ ./REPL[8]:1 [inlined]
 [16] (::typeof(∂(#3)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof(∂(#3))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
 [18] gradient(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
 [19] top-level scope
    @ REPL[8]:1
 [20] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

This comes from this line using vec(adjoint(::CuArray)) which creates a ReshapedArray which in turn triggers generic matmul for Ȳ * vec(B')).

darsnack avatar Nov 11 '22 20:11 darsnack

I am sure a GPU expert can fix this trivially, but I have no idea how.

oxinabox avatar Dec 09 '22 19:12 oxinabox