OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

Onecold errors with reactant arrays

Open mkschleg opened this issue 8 months ago • 1 comments

onecold doesn't work with Reactant arrays, and errors out. Julia version 1.11.4.

In a fresh environment:

julia> using Pkg
julia> Pkg.add(["MLDataDevices", "Reactant", "OneHotArrays"])
julia> using Reactant, MLDataDevices, OneHotArrays

julia> onecold([true false false; false true true])
3-element Vector{Int64}:
 1
 2
 2

julia> const dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)

julia> onecold([true false false; false true true]|>dev)
ERROR: MethodError: no method matching vec(::Tuple{Int64})
The function `vec` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  vec(::StaticArraysCore.SizedArray{S, T, N, M} where {T, N, M}) where S
   @ StaticArrays ~/.julia/packages/StaticArrays/LSPcF/src/SizedArray.jl:171
  vec(::SparseArrays.AbstractSparseVector)
   @ SparseArrays ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/SparseArrays/src/sparsevector.jl:1128
  vec(::LinearAlgebra.Adjoint{<:Real, <:AbstractVector})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.11.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:374
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/AebXg/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::typeof(vec), ::Tuple{Int64})
    @ Reactant ~/.julia/packages/Reactant/AebXg/src/utils.jl:790
  [3] getindex
    @ ~/.julia/packages/Reactant/AebXg/src/TracedRArray.jl:176 [inlined]
  [4] getindex(none::Reactant.TracedRArray{Bool, 2}, none::Tuple{Int64, CartesianIndex{1}})
    @ Reactant ./<missing>:0
  [5] getindex
    @ ~/.julia/packages/Reactant/AebXg/src/TracedRArray.jl:167 [inlined]
  [6] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Bool, 2}, ::Int64, ::CartesianIndex{1})
    @ Reactant ~/.julia/packages/Reactant/AebXg/src/utils.jl:0
  [7] make_mlir_fn(f::typeof(getindex), args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/AebXg/src/TracedUtils.jl:303
  [8] make_mlir_fn
    @ ~/.julia/packages/Reactant/AebXg/src/TracedUtils.jl:178 [inlined]
  [9] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(getindex), args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::IdDict{Reactant.Sharding.Mesh, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation}}; optimize::Bool, shardy_passes::Symbol, no_nan::Bool, backend::String, fn_kwargs::Tuple{}, raise::Bool, input_shardings::Nothing, output_shardings::Nothing, do_transpose::Bool, runtime::Val{:PJRT})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:840
 [10] compile_mlir!
    @ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:784 [inlined]
 [11] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}; client::Nothing, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:1992
 [12] compile_xla
    @ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:1974 [inlined]
 [13] compile(f::Function, args::Tuple{ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Int64, CartesianIndex{1}}; sync::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:2040
 [14] compile
    @ ~/.julia/packages/Reactant/AebXg/src/Compiler.jl:2039 [inlined]
 [15] getindex(::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ::Int64, ::CartesianIndex{1})
    @ Reactant ~/.julia/packages/Reactant/AebXg/src/ConcreteRArray.jl:271
 [16] findminmax!(f::typeof(identity), op::typeof(isless), Rval::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Rind::Matrix{CartesianIndex{2}}, A::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}})
    @ Base ./reducedim.jl:1039
 [17] _findmax(f::typeof(identity), A::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, region::Int64)
    @ Base ./reducedim.jl:1209
 [18] _findmax
    @ ./reducedim.jl:1176 [inlined]
 [19] findmax
    @ ./reducedim.jl:1175 [inlined]
 [20] argmax
    @ ./reducedim.jl:1274 [inlined]
 [21] _fast_argmax
    @ ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:167 [inlined]
 [22] onecold(y::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, labels::UnitRange{Int64})
    @ OneHotArrays ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:161
 [23] onecold(y::ConcretePJRTArray{Bool, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}})
    @ OneHotArrays ~/.julia/packages/OneHotArrays/rXTnu/src/onehot.jl:158
 [24] top-level scope
    @ REPL[9]:1

This might be an issue with reactant, but I'm new to playing w/ this part of the ecosystem so let me know if I need to file something there.

mkschleg avatar Mar 28 '25 16:03 mkschleg

I get the same error just calling argmax directly, so I think this is a Reactant.jl problem:

julia> using Reactant, MLDataDevices, OneHotArrays

julia> argmax([true false false; false true true]; dims=1)
1×3 Matrix{CartesianIndex{2}}:
 CartesianIndex(1, 1)  CartesianIndex(2, 2)  CartesianIndex(2, 3)

julia> argmax([true false false; false true true] |> reactant_device(); dims=1)
ERROR: MethodError: no method matching vec(::Tuple{Int64})

A possible workaround for now is this (although it may not work on GPU):

julia> myonecold(x) = map(argmax, eachcol(x))
myonecold (generic function with 1 method)

julia> myonecold([true false false; false true true])
3-element Vector{Int64}:
 1
 2
 2

julia> myonecold([true false false; false true true] |> reactant_device())
3-element Vector{Int64}:
 1
 2
 2

mcabbott avatar Mar 28 '25 16:03 mcabbott