OneHotArrays.jl
OneHotArrays.jl copied to clipboard
WIP: Add axis permutedims
This pull request introduces a new feature to OneHotArray: the axis in which the vectors are one-hot can be changed at initialization time. This is achieved by a new constructor which includes this new axis variable, and returns a OneHotArray wrapped by a PermutedDimsArray.
This way of implementing it does not require any code changes which may be hard to maintain and debug.
The performance degradation I've seen in testing is around 10% when using the wrapped OneHotArray, which I think is reasonable.
The alternative to this method would be to introduce a new axis variable to the struct and change many constructors and functions to get the desired behavior - which I've already done in a separate branch. However it is easier to go the easy route and let it go through field testing than go in the complex route first, which may have various downsides.
See also #35 for further discussion
PR Checklist
- :ballot_box_with_check: Tests are added
- :ballot_box_with_check: Documentation, if applicable
Codecov Report
Patch coverage: 100.00% and project coverage change: +0.10 :tada:
Comparison is base (
469b192) 96.37% compared to head (1f6599c) 96.47%.
:exclamation: Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.
Additional details and impacted files
@@ Coverage Diff @@
## main #36 +/- ##
==========================================
+ Coverage 96.37% 96.47% +0.10%
==========================================
Files 3 4 +1
Lines 138 142 +4
==========================================
+ Hits 133 137 +4
Misses 5 5
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/array.jl | 95.31% <100.00%> (+0.23%) |
:arrow_up: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
@mcabbott
I've moved the functionality into onehotbatch as you suggested.
I'm struggling with getting the inferred output type of onehotbatch right - It is currently either PermutedDimsArray or OneHotArray.
If you have suggestions I'll implement them.
The following is type-stable:
function ohaxis(data::AbstractArray{<:Any, N}, labels, dims::Val{D} = Val(1)) where {N, D}
out = onehotbatch(data, labels)
if D == 1
return out
else
perm = (D, ntuple(i -> i, D - 1)..., ntuple(i -> i + D, N - D)...)
return PermutedDimsArray(out, perm)
end
end
We can offer the Val based method for people who need type-stability, while still offering a more user-friendly keyword interface.
This issue comes up in Base too, where cat([1], [2]; dims=3) isn't type-stable, but it also accepts dims=Val(3) if required.
I'll get around to it soon. Thanks for the suggestion! :)
Fixed! :) Thank you for your support, this should do it. I also had to add an additional constructor:
onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels)
Not sure why, but CUDA tests were failing without it.
@darsnack Can I ask for your review as well?
while working on the suggestions of @mcabbott , I seem to have found errors, so I'm marking this as WIP for now. Do not merge this.