OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

variable axis on which the "one hot" property holds

Open nomadbl opened this issue 2 years ago • 3 comments

Motivation and description

I am working on a layer that produces one hot outputs, so I am looking into using OneHotArrays.jl. My gripe is that currently the datatype only supports the one hot vectors to extend on the first axis.

I thought I'd write my thoughts and possible implementations of the variable axis, to get some feedback and context from other maintainers and users here (I am very new to Julia and Flux, coming from working with python).

Possible Implementation

Implementation path 1 (WIP), change the constructors, size and getindex:

struct OneHotArray{T<:Integer,N,var"N+1",I<:Union{T,AbstractArray{T,N}}} <: AbstractArray{Bool,var"N+1"}
  indices::I
  nlabels::Int
  axis::Int
end
OneHotArray{T,N,I}(indices, L::Int, axis::Int=1) where {T,N,I} = OneHotArray{T,N,N + 1,I}(indices, L, axis)
OneHotArray(indices::T, L::Int, axis::Int=1) where {T<:Integer} = OneHotArray{T,0,1,T}(indices, L, axis)
OneHotArray(indices::I, L::Int, axis::Int=1) where {T,N,I<:AbstractArray{T,N}} = OneHotArray{T,N,N + 1,I}(indices, L, axis)

Base.size(x::OneHotArray) = Tuple(insert!(collect(size(x.indices)), x.axis, x.nlabels))

function Base.getindex(x::OneHotArray, I::Vararg{Int,N}) where {N}
  length(I) == length(size(x)) || throw(DimensionMismatch("dimensions of OneHotArray $(length(size(x))) and dimensions of indices $(length(I)) do not match."))
  @boundscheck all(1 .<= I .<= size(x)) || throw(BoundsError(x, I))
  Ip = Tuple(popat!(collect(I), x.axis))
  return some_appropriate_checks_here
end

The idea with this is to maintain the sparse nature of the representation for later optimized multiplications, backptop etc.

While working on this I also hit upon path 2, to reuse all the original code, but use the new axis parameter to do appropriate permutations of the underlying (1,...) dimensional object before computations.

I expect to do a PR of this soon, but I'd love to hear your thoughts: do you think the first approach is better (more memory and compute efficient?)? But also it is probably harder to maintain and test.

nomadbl avatar May 23 '23 12:05 nomadbl

Maybe the first question should be: What comes after this layer?

In this package, I think the efficient methods are that *(::Matrix, ::OneHotMatrix) is just indexing, used by Flux.Embedding, and argmax(:: OneHotMatrix) just reads the indices. Common uses like Flux.crossentropy(::Matrix, ::OneHotMatrix) don't do anything special, they use broadcasting which uses getindex. They could specialise but they are never the bottleneck.

A variant of path 2 is also just to wrap PermutedDimsArray(OneHotArray(.... If the next operation is broadcasting, this should be about as good. Downstream operations could also specialise on e.g. PermutedDimsArray{..., (2,1,3), ..., <:OneHotArray} to target particular dims.

mcabbott avatar May 23 '23 16:05 mcabbott

I expect this layer to be followed by *(...) primarily. I'm not sure exactly what you mean by broadcasting, could you link to a function definition or explain?

I'll try both ways since I'm close to done on path 1, and path 2 seems simple at first glance.

In terms of tests for correctness I'm assuming to use the Flux and OneHotArrays tests as a first step, and add tests as necessary. Do you have some kind of benchmark that would be useful to compare path 1 vs 2 in terms of speed?

Thanks for the help :)

nomadbl avatar May 23 '23 21:05 nomadbl

expect this layer to be followed by *(...) primarily.

Since * is only for matrices & vectors, perhaps you just want transpose(onehotbatch([1,1,2], 1:4))? That doesn't specialise but it could:

julia> @which rand(3,4) * onehotbatch([1,1,2], 1:4)
*(A::AbstractMatrix, B::Union{OneHotArray{var"#s13", 1, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{var"#s13", <:Any, <:Any, I}}} where {var"#s13", var"N+1", I})
     @ OneHotArrays ~/.julia/packages/OneHotArrays/T3yiq/src/linalg.jl:7

julia> @which transpose(onehotbatch([1,1,2], 1:4)) * rand(4,3)
*(A::AbstractMatrix, B::AbstractMatrix)
     @ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:108

I'm not sure exactly what you mean by broadcasting

Operations like .* are broadcasting, see e.g. this or the manual.

mcabbott avatar May 24 '23 01:05 mcabbott