OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

`onehotbatch(::CuArray, ...)` moves data to host

Open mcabbott opened this issue 3 years ago • 6 comments

The lack of https://github.com/FluxML/Flux.jl/pull/1959 causes the following error, currently blocking https://github.com/FluxML/Flux.jl/pull/2025 :

julia> using CUDA, OneHotArrays, NNlibCUDA

julia> CUDA.allowscalar(false)

julia> x = [1, 3, 2];

julia> y = onehotbatch(x, (0,1,2,3))
4×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  ⋅  ⋅
 1  ⋅  ⋅
 ⋅  ⋅  1
 ⋅  1  ⋅

julia> y2 = onehotbatch(x |> cu, (0,1,2,3))
ERROR: Scalar indexing is disallowed.

Edit: after #27, onehotbatch(x |> cu, 0:3) works, but other ways to specify the labels do not.

mcabbott avatar Jul 24 '22 17:07 mcabbott

I don't think that should be allowed. Taking cu should always happened AFTER taking onehotbatch. Consider the some real case where labels are non-bits types like array of strings, it doesn't make sense to take onehotbatch on gpu array. It just need to be clarified in the docs.

chengchingwen avatar Aug 03 '22 01:08 chengchingwen

No strong opinions, was just trying to make Flux's tests pass. https://github.com/FluxML/OneHotArrays.jl/pull/17 is more code duplication than ideal.

mcabbott avatar Aug 03 '22 13:08 mcabbott

Maybe we could remove that test in Flux?

chengchingwen avatar Aug 03 '22 13:08 chengchingwen

Also no strong feelings either way, but we if we don't want to support we should add the functionality as deprecated so that it's not a breaking change on the Flux side.

ToucheSir avatar Aug 03 '22 22:08 ToucheSir

As seen in #24 the current behavior is surprising when allowscalar(true). We should allow onehotbatch(::CuArray, ...) whenever possible and error out otherwise.

CarloLucibello avatar Nov 11 '22 07:11 CarloLucibello

#27 fixes the case onehotbatch(::CuVector{<:Integer}, ::UnitRange)

CarloLucibello avatar Dec 28 '22 08:12 CarloLucibello