ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Scalar indexing error using prod(...)

Open mentics opened this issue 2 years ago • 1 comments

This might not be a bug, but at the least it seems to be a mysterious error.

I'm using Flux and in the loss function (which is processed by Zygote and so uses ChainRules.jl), I'm using prod. The surprising thing was that it ran for several iterations before failing with a scalar index error.

I'm guessing the error happens when it goes into this branch because it finds a zero somewhere. That would explain why some iterations occur before hitting this error. I don't know if the code can be changed to avoid scalar indexing, or maybe a more informative error, or maybe it's just something I need to better understand.

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore C:\Users\joel\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:103
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\joel\.julia\packages\GPUArrays\5XhED\src\host\indexing.jl:9
  [4] maybeview
    @ .\views.jl:149 [inlined]
  [5] ∇prod_dims!(dx::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, #unused#::Val{1}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:287
  [6] ∇prod_dims(vald::Val{1}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:278
  [7] (::ChainRules.var"#1683#1686"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Int64, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})()
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\9sNmB\src\rulesets\Base\mapreduce.jl:265
  [8] unthunk
    @ C:\Users\joel\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:204 [inlined]
  [9] unthunk
    @ C:\Users\joel\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:237 [inlined]
 [10] wrap_chainrules_output
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:110 [inlined]
 [11] map
    @ .\tuple.jl:274 [inlined]
 [12] wrap_chainrules_output
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:111 [inlined]
 [13] ZBack
    @ C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:211 [inlined]
 [14] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#prod_pullback#1684"{Int64, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(dy::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\4rucm\src\compiler\chainrules.jl:237
...

mentics avatar Oct 24 '23 16:10 mentics

I forgot how this works, but as you observe, the code which is careful about zeros is behind a test any(iszero, x), precisely to let all-nonzero x work without error on a GPU.

It wouldn't be crazy to add an explicit message there, like x isa AbstractGPUArray && @error "rule for prod found zeros..." maxlog=1. The reason for the error is otherwise a bit mysterious, as e.g. it may work for the first iteration but later fail. Usually scalar indexing errors depend only on the types not the values.

The alternatives are just to use the broadcasting branch on GPU arrays (and accept that getting NaN is your signal that something is wrong), or to write a GPU kernel to do this correctly (perhaps using KernelAbstractions to be device-agnostic).

That's for ChainRules. Can you share more about what your actual use is? Inserting something like clamp.(x, 0.001, 0.99) is one way you might avoid problems.

mcabbott avatar Oct 24 '23 16:10 mcabbott