feat: catch scalar indexing failures early
fixes #523 (maybe also https://github.com/FluxML/Flux.jl/issues/2440)
using CUDA, NNlib
x = cu(rand(Float32, 4, 4, 3, 2))
w = cu(rand(Float32, 3, 3, 3, 4))
cdims = DenseConvDims(x, w)
conv(x, w, cdims)
New Error Message
ERROR: AssertionError: `conv!` requires all arguments to support fast scalar indexing. You might be missing an `using cuDNN` or `import cuDNN` statement.
Stacktrace:
[1] special_scalar_indexing_error(::Val{:conv!}, ::CuArray{Float32, 5, CUDA.DeviceMemory})
@ NNlibCUDAExt /mnt/research/lux/NNlib.jl/ext/NNlibCUDAExt/utils.jl:41
[2] (::Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}})(y::CuArray{Float32, 5, CUDA.DeviceMemory})
@ Base ./operators.jl:1127
[3] (::Base.var"#70#71"{Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}})(::Nothing, x::CuArray{Float32, 5, CUDA.DeviceMemory})
@ Base ./tuple.jl:692
[4] BottomRF
@ ./reduce.jl:86 [inlined]
[5] afoldl(::Base.BottomRF{Base.var"#70#71"{Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}}}, ::Nothing, ::CuArray{Float32, 5, CUDA.DeviceMemory}, ::CuArray{Float32, 5, CUDA.DeviceMemory}, ::CuArray{Float32, 5, CUDA.DeviceMemory})
@ Base ./operators.jl:553
[6] _foldl_impl(op::Base.BottomRF{Base.var"#70#71"{Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}}}, init::Nothing, itr::Tuple{CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}})
@ Base ./reduce.jl:68
[7] foldl_impl(op::Base.BottomRF{Base.var"#70#71"{Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}}}, nt::Nothing, itr::Tuple{CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}})
@ Base ./reduce.jl:48
[8] mapfoldl_impl(f::typeof(identity), op::Base.var"#70#71"{Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}}, nt::Nothing, itr::Tuple{CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}})
@ Base ./reduce.jl:44
[9] mapfoldl(f::Function, op::Function, itr::Tuple{CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}}; init::Nothing)
@ Base ./reduce.jl:175
[10] mapfoldl
@ ./reduce.jl:175 [inlined]
[11] #foldl#336
@ ./reduce.jl:198 [inlined]
[12] foreach(f::Base.Fix1{typeof(NNlib.special_scalar_indexing_error), Val{:conv!}}, itr::Tuple{CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}, CuArray{Float32, 5, CUDA.DeviceMemory}})
@ Base ./tuple.jl:692
[13] assert_all_fast_scalar_indexing(::Val{:conv!}, ::CuArray{Float32, 5, CUDA.DeviceMemory}, ::Vararg{CuArray{Float32, 5, CUDA.DeviceMemory}})
@ NNlib /mnt/research/lux/NNlib.jl/src/utils.jl:167
[14] conv!(out::CuArray{Float32, 5, CUDA.DeviceMemory}, in1::CuArray{Float32, 5, CUDA.DeviceMemory}, in2::CuArray{Float32, 5, CUDA.DeviceMemory}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::@Kwargs{})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:195
[15] conv!(out::CuArray{Float32, 5, CUDA.DeviceMemory}, in1::CuArray{Float32, 5, CUDA.DeviceMemory}, in2::CuArray{Float32, 5, CUDA.DeviceMemory}, cdims::DenseConvDims{3, 3, 3, 6, 3})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:185
[16] conv!(y::CuArray{Float32, 4, CUDA.DeviceMemory}, x::CuArray{Float32, 4, CUDA.DeviceMemory}, w::CuArray{Float32, 4, CUDA.DeviceMemory}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::@Kwargs{})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:145
[17] conv!(y::CuArray{Float32, 4, CUDA.DeviceMemory}, x::CuArray{Float32, 4, CUDA.DeviceMemory}, w::CuArray{Float32, 4, CUDA.DeviceMemory}, cdims::DenseConvDims{2, 2, 2, 4, 2})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:140
[18] conv(x::CuArray{Float32, 4, CUDA.DeviceMemory}, w::CuArray{Float32, 4, CUDA.DeviceMemory}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::@Kwargs{})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:88
[19] conv(x::CuArray{Float32, 4, CUDA.DeviceMemory}, w::CuArray{Float32, 4, CUDA.DeviceMemory}, cdims::DenseConvDims{2, 2, 2, 4, 2})
@ NNlib /mnt/research/lux/NNlib.jl/src/conv.jl:83
[20] top-level scope
@ REPL[38]:1
[21] top-level scope
@ none:1
ArrayInterface.jl doesn't support <1.10. Now that the LTS is 1.10 can we drop 1.9 support? #600 does this
yes it seems convenient to drop julia <1.10
I like the idea, but couldn't this be achieved with package extensions and dispatch? That way we don't need to add a dependency.
but couldn't this be achieved with package extensions and dispatch
Do you mean moving ArrayInterface to an extension? It only depends on Adapt + 3 std libs (the sparse arrays one could probably be removed upstream), so it is quite lightweight.
bump on this