OneHotArrays.jl
OneHotArrays.jl copied to clipboard
Methods for multiplication OneHotMatrix * AbstractMatrix
Continuation of #56 , allows multiplication with OneHotMatrix on the left. All of these cases work:
julia> using OneHotArrays, JLArrays, LinearAlgebra, Test
julia> for op1 in (identity, adjoint), op2 in (identity, adjoint), order in (identity, reverse)
m1, m2 = order([rand(3,3), onehotbatch([1,2,2], 1:3)])
A, B = op1(m1), op2(m2)
C = A * B
try
jlC = jl(A) * jl(B)
@test Array(jlC) ≈ C
catch e
@error "A*B failed" typeof(A) typeof(B)
end
end
julia> for op1 in (identity, adjoint), op2 in (identity, adjoint), order in (identity, reverse)
m1, m2 = order([rand(3,3), onehotbatch([1,2,2], 1:3)])
A, B = op1(m1), op2(m2)
C = A * B
try
jlC = mul!(jl(similar(C)), jl(A), jl(B))
@test Array(jlC) ≈ C
catch e
@error "mul!(C,A,B) failed" typeof(A) typeof(B)
end
end
... but there are some ambiguities:
julia> ohm = onehotbatch([1,2,2], 1:3);
julia> ohm * ohm
ERROR: MethodError: *(::OneHotMatrix{UInt32, Vector{UInt32}}, ::OneHotMatrix{UInt32, Vector{UInt32}}) is ambiguous.
Candidates:
*(A::OneHotMatrix, B::AbstractMatrix{<:Number})
@ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:13
*(A::AbstractMatrix, B::OneHotLike)
@ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:1
*(A::AbstractMatrix, B::Union{OneHotArray{var"#s29", 1, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{var"#s29", <:Any, <:Any, I}}} where {var"#s29", var"N+1", I})
@ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:7
Possible fix, define
*(::OneHotMatrix, ::Union{Base.ReshapedArray{Bool, 2, <:OneHotArray{T, <:Any, <:Any, I}} where {T, I}, OneHotMatrix})
Needs tests. And perhaps a use case.
PR Checklist
- [ ] Tests are added
- [ ] Documentation, if applicable