OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

variable axis on which the "one hot" property holds

Open nomadbl opened this issue 1 year ago • 3 comments

Motivation and description

I am working on a layer that produces one hot outputs, so I am looking into using OneHotArrays.jl. My gripe is that currently the datatype only supports the one hot vectors to extend on the first axis.

I thought I'd write my thoughts and possible implementations of the variable axis, to get some feedback and context from other maintainers and users here (I am very new to Julia and Flux, coming from working with python).

Possible Implementation

Implementation path 1 (WIP), change the constructors, size and getindex:

struct OneHotArray{T<:Integer,N,var"N+1",I<:Union{T,AbstractArray{T,N}}} <: AbstractArray{Bool,var"N+1"}
  indices::I
  nlabels::Int
  axis::Int
end
OneHotArray{T,N,I}(indices, L::Int, axis::Int=1) where {T,N,I} = OneHotArray{T,N,N + 1,I}(indices, L, axis)
OneHotArray(indices::T, L::Int, axis::Int=1) where {T<:Integer} = OneHotArray{T,0,1,T}(indices, L, axis)
OneHotArray(indices::I, L::Int, axis::Int=1) where {T,N,I<:AbstractArray{T,N}} = OneHotArray{T,N,N + 1,I}(indices, L, axis)

Base.size(x::OneHotArray) = Tuple(insert!(collect(size(x.indices)), x.axis, x.nlabels))

function Base.getindex(x::OneHotArray, I::Vararg{Int,N}) where {N}
  length(I) == length(size(x)) || throw(DimensionMismatch("dimensions of OneHotArray $(length(size(x))) and dimensions of indices $(length(I)) do not match."))
  @boundscheck all(1 .<= I .<= size(x)) || throw(BoundsError(x, I))
  Ip = Tuple(popat!(collect(I), x.axis))
  return some_appropriate_checks_here
end

The idea with this is to maintain the sparse nature of the representation for later optimized multiplications, backptop etc.

While working on this I also hit upon path 2, to reuse all the original code, but use the new axis parameter to do appropriate permutations of the underlying (1,...) dimensional object before computations.

I expect to do a PR of this soon, but I'd love to hear your thoughts: do you think the first approach is better (more memory and compute efficient?)? But also it is probably harder to maintain and test.

nomadbl avatar May 23 '23 12:05 nomadbl