ChainRules.jl
ChainRules.jl copied to clipboard
rrule causes scalar indexing for `repeat`
julia> func(x) = sum(repeat(x, inner = (1, 3)))
func (generic function with 1 method)
julia> func(CUDA.rand(2,3))
7.5995846f0
julia> Zygote.gradient(func,CUDA.rand(2,3))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] assertscalar(op::String)
@ GPUArraysCore C:\Users\Luffy\.julia\packages\GPUArraysCore\rSIl2\src\GPUArraysCore.jl:78
[3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:9
[4] getindex
@ C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:30 [inlined]
[5] iterate
@ .\iterators.jl:245 [inlined]
[6] (::Zygote.var"#519#525"{Tuple{Int64, Int64}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:136
[7] #2693#back
@ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:73 [inlined]
[8] Pullback
@ .\REPL[15]:1 [inlined]
[9] (::Zygote.var"#60#61"{typeof(∂(func))})(Δ::Float32)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
[10] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
[11] top-level scope
@ REPL[17]:1
[12] top-level scope
@ C:\Users\Luffy\.julia\packages\CUDA\tTK8Y\src\initialization.jl:52
Same for outer
julia> func3(x) = sum(repeat(x, outer=(2, 3)))
func3 (generic function with 1 method)
julia> Zygote.gradient(func3,CUDA.rand(2,3))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] assertscalar(op::String)
@ GPUArraysCore C:\Users\Luffy\.julia\packages\GPUArraysCore\rSIl2\src\GPUArraysCore.jl:78
[3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
@ GPUArrays C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:9
[4] getindex
@ C:\Users\Luffy\.julia\packages\GPUArrays\gok9K\src\host\indexing.jl:30 [inlined]
[5] iterate
@ .\iterators.jl:245 [inlined]
[6] (::Zygote.var"#519#525"{Tuple{Int64, Int64}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:136
[7] #2693#back
@ C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:73 [inlined]
[8] Pullback
@ .\REPL[31]:1 [inlined]
[9] (::Zygote.var"#60#61"{typeof(∂(func3))})(Δ::Float32)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
[10] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
[11] top-level scope
@ REPL[32]:1
[12] top-level scope
@ C:\Users\Luffy\.julia\packages\CUDA\tTK8Y\src\initialization.jl:52
Xref https://github.com/JuliaDiff/ChainRules.jl/issues/383: The rules for repeat with inner
etc. were just copied over from Zygote, and are very slow, too. My suggestion there was to re-write them using broadcasting, which seems to work for CuArrays:
julia> gradient(x -> sum(_repeat(x, inner = (1, 3))), CUDA.rand(3,4)) # function & rules from gist
(Float32[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0],)
julia> gradient(x -> sum(repeat(x, inner = (1, 3))), CUDA.rand(3,4)) # Base function & old rules
ERROR: Scalar indexing is disallowed.
Cc @awadell1 and @torfjelde from https://github.com/JuliaGPU/GPUArrays.jl/pull/400 in case they are interested. The code there for repeat
is faster than the broadcasting version.
This rule also causes the following error:
https://discourse.julialang.org/t/zygote-inexacterror-using-repeat-with-inner-keyword/92600