OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

GPU array scalar indexing edge cases

Open ExpandingMan opened this issue 10 months ago • 6 comments

This issue came up during #47.

There are some rather awkward cases where it is not clear how to elide scalar indexing into GPU arrays. For example

  y = onehotbatch(ones(3), 1:2) |> cu
  y = reshape(y, 3, 2)
  gA = rand(2, 3) |> cu
  
  @test LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) ≈ gA*y

Both sides of this test are currently broken. The failure currently in main is a method ambiguity, however it's not entirely clear how to fix this as there don't seem to be easy answers about what to do in this case. I doubt it is possible to cover every conceivable such edge case, at some point users should have to materialize the array.

I think probably what needs to be done here is to add documentation and possibly convenience methods describing circumstances in which arrays should be materialized. Even though there's no clear answer to this I'm opening this issue because as of now the handling of this is extremely wonky, and any users not intimately familiar with Julia array packages will be justifiably confused.

ExpandingMan avatar Feb 04 '25 22:02 ExpandingMan

Please include the error messages for both failures.

Do they fail on CPU too?

mcabbott avatar Feb 04 '25 23:02 mcabbott

This is only a problem on GPU because of scalar indexing.

On current main it should give

ERROR: MethodError: getindex(::OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.DeviceMemory}}, ::Int64, ::Int64) is ambiguous.

Candidates:
  getindex(x::OneHotArray{var"#s3", N, var"N+1", I} where {var"#s3", var"N+1", I<:Union{AbstractArray{var"#s3", N}, var"#s3"}}, i::Int64, I::Vararg{Int64, N}) where N
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/array.jl:65
  getindex(x::OneHotArray{<:Any, N, <:Any, <:GPUArraysCore.AbstractGPUArray}, i::Int64, I::Vararg{Any, N}) where N
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/array.jl:71

Possible fix, define
  getindex(::OneHotArray{var"#s3", N, <:Any, <:Union{…}} where var"#s3", ::Int64, ::Vararg{Int64, N}) where N

Stacktrace:
  [1] _unsafe_getindex_rs
    @ ./reshapedarray.jl:276 [inlined]
  [2] _unsafe_getindex
    @ ./reshapedarray.jl:273 [inlined]
  [3] getindex
    @ ./reshapedarray.jl:261 [inlined]
  [4] _generic_matmatmul!(C::Matrix{…}, A::CuArray{…}, B::Base.ReshapedArray{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894
  [5] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
  [6] _mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [7] mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
  [8] mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
  [9] mul!(Y::Matrix{Float32}, A::CuArray{Float32, 2, CUDA.DeviceMemory}, B::Base.ReshapedArray{Bool, 2, OneHotMatrix{…}, Tuple{…}})
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:38
 [10] *
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [11] *(A::CuArray{Float32, 2, CUDA.DeviceMemory}, B::Base.ReshapedArray{Bool, 2, OneHotMatrix{…}, Tuple{…}})
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:8

Resolving this method ambiguity is easy of course, but it's not clear what the alternative would be. In my opinion it wasn't really a good idea for this package to start down the road of supporting ReshapeArray.

ExpandingMan avatar Feb 05 '25 22:02 ExpandingMan

FWIW, this is the historical context that led to using the wrapper: https://github.com/FluxML/Flux.jl/pull/1459#issuecomment-757343865.

One continuing tension with this package is that the code paths needed for a memory-efficient type can be in conflict for the ones needed for GPU support. Generally the idea was to have people use this outside of AD/GPU (i.e. during preprocessing and data loading) and materialize before using AD/GPU. Which is not to say the status quo is ideal, but it sheds some light on why GPU support hasn't been the No. 1 priority historically.

ToucheSir avatar Feb 05 '25 22:02 ToucheSir

I think the getindex error is a dup of https://github.com/FluxML/OneHotArrays.jl/issues/28 . Close if you agree?

mcabbott avatar Feb 05 '25 23:02 mcabbott

I think there is a bigger issue here of what exactly do you do in the edge cases that this package allows where you would seem to have to scalar index a GPU array. However there is redundancy between this issue and #28, feel free to close if you don't think it's doing anything.

ExpandingMan avatar Feb 06 '25 20:02 ExpandingMan

Isn't this just how arrays work?

E.g. one role of Transpose is to dispatch to the BLAS with a 'T', easy. The second role is to plug into generic code which doesn't know about it at all, so that reverse(transpose(M); dims=1) just works. This second role is the only reason it has supertype AbstractMatrix.

When it wraps a GPU array, the first is no harder, but almost none of the AbstractMatrix fallbacks will work. You get scalar indexing and then you either fix it (supply a routine) or avoid it. Having no supertype would perhaps give more obvious error message, like "MethodError: no method matching reverse(A::Transpose{Float64, JLArray{Float64, 2}})".

mcabbott avatar Feb 06 '25 21:02 mcabbott