ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

rrule causes scalar indexing for `repeat`

Open YichengDWu opened this issue 1 year ago • 3 comments

julia> func(x) = sum(repeat(x, inner = (1, 3)))
func (generic function with 1 method)

julia> func(CUDA.rand(2,3))
7.5995846f0

julia> Zygote.gradient(func,CUDA.rand(2,3))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\Luffy\.julia\packages\GPUArraysCore\rSIl2\src\GPUArraysCore.jl:78
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:9
  [4] getindex
    @ C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:30 [inlined]
  [5] iterate
    @ .\iterators.jl:245 [inlined]
  [6] (::Zygote.var"#519#525"{Tuple{Int64, Int64}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:136
  [7] #2693#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:73 [inlined]
  [8] Pullback
    @ .\REPL[15]:1 [inlined]
  [9] (::Zygote.var"#60#61"{typeof(∂(func))})(Δ::Float32)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [10] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [11] top-level scope
    @ REPL[17]:1
 [12] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\tTK8Y\src\initialization.jl:52

YichengDWu avatar Jul 13 '22 02:07 YichengDWu

Same for outer

julia> func3(x) = sum(repeat(x, outer=(2, 3)))
func3 (generic function with 1 method)

julia> Zygote.gradient(func3,CUDA.rand(2,3))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\Luffy\.julia\packages\GPUArraysCore\rSIl2\src\GPUArraysCore.jl:78
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:9
  [4] getindex
    @ C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:30 [inlined]
  [5] iterate
    @ .\iterators.jl:245 [inlined]
  [6] (::Zygote.var"#519#525"{Tuple{Int64, Int64}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:136
  [7] #2693#back
    @ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:73 [inlined]
  [8] Pullback
    @ .\REPL[31]:1 [inlined]
  [9] (::Zygote.var"#60#61"{typeof(∂(func3))})(Δ::Float32)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [10] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [11] top-level scope
    @ REPL[32]:1
 [12] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\tTK8Y\src\initialization.jl:52 

YichengDWu avatar Jul 13 '22 02:07 YichengDWu

Xref https://github.com/JuliaDiff/ChainRules.jl/issues/383: The rules for repeat with inner etc. were just copied over from Zygote, and are very slow, too. My suggestion there was to re-write them using broadcasting, which seems to work for CuArrays:

julia> gradient(x -> sum(_repeat(x, inner = (1, 3))), CUDA.rand(3,4))  # function & rules from gist
(Float32[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0],)

julia> gradient(x -> sum(repeat(x, inner = (1, 3))), CUDA.rand(3,4))  # Base function & old rules
ERROR: Scalar indexing is disallowed.

Cc @awadell1 and @torfjelde from https://github.com/JuliaGPU/GPUArrays.jl/pull/400 in case they are interested. The code there for repeat is faster than the broadcasting version.

mcabbott avatar Jul 13 '22 03:07 mcabbott

This rule also causes the following error:

https://discourse.julialang.org/t/zygote-inexacterror-using-repeat-with-inner-keyword/92600

mcabbott avatar Jan 06 '23 15:01 mcabbott