OneHotArrays.jl icon indicating copy to clipboard operation
OneHotArrays.jl copied to clipboard

WIP: Add axis permutedims

Open nomadbl opened this issue 2 years ago • 8 comments

This pull request introduces a new feature to OneHotArray: the axis in which the vectors are one-hot can be changed at initialization time. This is achieved by a new constructor which includes this new axis variable, and returns a OneHotArray wrapped by a PermutedDimsArray.

This way of implementing it does not require any code changes which may be hard to maintain and debug. The performance degradation I've seen in testing is around 10% when using the wrapped OneHotArray, which I think is reasonable. The alternative to this method would be to introduce a new axis variable to the struct and change many constructors and functions to get the desired behavior - which I've already done in a separate branch. However it is easier to go the easy route and let it go through field testing than go in the complex route first, which may have various downsides.

See also #35 for further discussion

PR Checklist

  • :ballot_box_with_check: Tests are added
  • :ballot_box_with_check: Documentation, if applicable

nomadbl avatar May 25 '23 12:05 nomadbl

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.10 :tada:

Comparison is base (469b192) 96.37% compared to head (1f6599c) 96.47%.

:exclamation: Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #36      +/-   ##
==========================================
+ Coverage   96.37%   96.47%   +0.10%     
==========================================
  Files           3        4       +1     
  Lines         138      142       +4     
==========================================
+ Hits          133      137       +4     
  Misses          5        5              
Impacted Files Coverage Δ
src/array.jl 95.31% <100.00%> (+0.23%) :arrow_up:

... and 1 file with indirect coverage changes

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov-commenter avatar May 25 '23 12:05 codecov-commenter

@mcabbott I've moved the functionality into onehotbatch as you suggested. I'm struggling with getting the inferred output type of onehotbatch right - It is currently either PermutedDimsArray or OneHotArray. If you have suggestions I'll implement them.

nomadbl avatar May 26 '23 09:05 nomadbl

The following is type-stable:

function ohaxis(data::AbstractArray{<:Any, N}, labels, dims::Val{D} = Val(1)) where {N, D}
   out = onehotbatch(data, labels)
   if D == 1
       return out
   else
       perm = (D, ntuple(i -> i, D - 1)..., ntuple(i -> i + D, N - D)...)
       return PermutedDimsArray(out, perm)
   end
end

We can offer the Val based method for people who need type-stability, while still offering a more user-friendly keyword interface.

darsnack avatar Jul 23 '23 13:07 darsnack

This issue comes up in Base too, where cat([1], [2]; dims=3) isn't type-stable, but it also accepts dims=Val(3) if required.

mcabbott avatar Jul 24 '23 00:07 mcabbott

I'll get around to it soon. Thanks for the suggestion! :)

nomadbl avatar Jul 24 '23 08:07 nomadbl

Fixed! :) Thank you for your support, this should do it. I also had to add an additional constructor:

onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels)

Not sure why, but CUDA tests were failing without it.

nomadbl avatar Jul 26 '23 08:07 nomadbl

@darsnack Can I ask for your review as well?

nomadbl avatar Jul 26 '23 08:07 nomadbl

while working on the suggestions of @mcabbott , I seem to have found errors, so I'm marking this as WIP for now. Do not merge this.

nomadbl avatar Jul 26 '23 12:07 nomadbl