OneHotArrays.jl
OneHotArrays.jl copied to clipboard
`onehotbatch(::CuArray, ...)` moves data to host
The lack of https://github.com/FluxML/Flux.jl/pull/1959 causes the following error, currently blocking https://github.com/FluxML/Flux.jl/pull/2025 :
julia> using CUDA, OneHotArrays, NNlibCUDA
julia> CUDA.allowscalar(false)
julia> x = [1, 3, 2];
julia> y = onehotbatch(x, (0,1,2,3))
4×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
⋅ ⋅ ⋅
1 ⋅ ⋅
⋅ ⋅ 1
⋅ 1 ⋅
julia> y2 = onehotbatch(x |> cu, (0,1,2,3))
ERROR: Scalar indexing is disallowed.
Edit: after #27, onehotbatch(x |> cu, 0:3) works, but other ways to specify the labels do not.
I don't think that should be allowed. Taking cu should always happened AFTER taking onehotbatch. Consider the some real case where labels are non-bits types like array of strings, it doesn't make sense to take onehotbatch on gpu array. It just need to be clarified in the docs.
No strong opinions, was just trying to make Flux's tests pass. https://github.com/FluxML/OneHotArrays.jl/pull/17 is more code duplication than ideal.
Maybe we could remove that test in Flux?
Also no strong feelings either way, but we if we don't want to support we should add the functionality as deprecated so that it's not a breaking change on the Flux side.
As seen in #24 the current behavior is surprising when allowscalar(true). We should allow onehotbatch(::CuArray, ...) whenever possible and error out otherwise.
#27 fixes the case onehotbatch(::CuVector{<:Integer}, ::UnitRange)