AxisIndices.jl
AxisIndices.jl copied to clipboard
Accessing "coordinates" or enumerating elements
I find myself frequently extracting the "coordinates" of points via
julia> A = AxisArray(reshape(1:12, (3,4)), 0.1:0.1:0.3, 0.01:0.01:0.04)
3×4 AxisArray{Int64,2}
• dim_1 - 0.1:0.1:0.3
• dim_2 - 0.01:0.01:0.04
0.01 0.02 0.03 0.04
0.1 1 4 7 10
0.2 2 5 8 11
0.3 3 6 9 12
julia> product(keys.(axes(A))...) |> collect
3×4 Matrix{Tuple{Float64,Float64}}:
(0.1, 0.01) (0.1, 0.02) (0.1, 0.03) (0.1, 0.04)
(0.2, 0.01) (0.2, 0.02) (0.2, 0.03) (0.2, 0.04)
(0.3, 0.01) (0.3, 0.02) (0.3, 0.03) (0.3, 0.04)
Is there a simpler way to do this? It'd also be nice to have a way to dispatch pairs to output these coordinates rather than the linear indices, perhaps via returning something different from IndexStyle(A).
(I'm happy to implement this if you think it'd be useful and can tell me the interface you'd like)
The closest thing I have right now is this.
julia> axes_keys(AxisArray(reshape(1:12, (3,4)), 0.1:0.1:0.3, 0.01:0.01:0.04))
(0.1:0.1:0.3, 0.01:0.01:0.04)
But you'd still need to use |> Iterators.product |> collect. Right now there is some semi-formal convention in base for calling keys on an array being defined as CartesianIndices(axes(A)). I think the trickiest part would be settling on syntax that's not going to seem counterintuitive to how things are in base.
It almost seems like this is going for an index => key instead of a key => index sort of thing. It might be worth fleshing this out more with something like KeyStyle or KeyToIndexStyle traits. It also might be sufficient to just have a simple method do this one thing. If this is something you have an immediate need for you might have a better feel for what makes sense. I'd be happy to see what ideas you have.
I came across this again with a further case: For a NamedAxisArray I'd ideally like something like enumerate(::NamedAxisArray) -> (::NamedTuple, ::T). I'll post the code for whatever solution I have in a second, but this is the third time I've done this and I always forget to reply to this issue with my solution :/
Here's my solution. Even ignoring the lack of generality, I'm 100% sure there's a cleaner way to do this. I'm only posting this here as a hack for the next time I run into this.
EDIT: Is ordering of axes_keys and NAMES of NamedAxisArray{NAMES} guaranteed to be the same?
nt_coord(naa::NamedAxisArray{X}, tup::Tuple) where X = NamedTuple{X}(tup)
macro ifsomething(ex)
quote
result = $(esc(ex))
result === nothing && return nothing
result
end
end
struct EnumerateNAA{NAA,II}
naa::NAA
inner_iter::II
EnumerateNAA(naa::NAA, ii::II) where {NAA,II} = new{NAA,II}(naa, ii)
end
EnumerateNAA(naa) = EnumerateNAA(naa, zip(Iterators.product(axes_keys(naa)...), naa))
Base.length(e::EnumerateNAA) = length(e.naa)
function enumerate_nt(naa::NamedAxisArray{X}) where X
EnumerateNAA(naa)
end
function Base.iterate(enaa::EnumerateNAA, state)
((tup, val), state) = @ifsomething iterate(enaa.inner_iter, state)
nt = nt_coord(enaa.naa, tup)
return ((nt, val), state)
end
function Base.iterate(enaa::EnumerateNAA)
((tup, val), state) = @ifsomething iterate(enaa.inner_iter)
nt = nt_coord(enaa.naa, tup)
return ((nt, val), state)
end
Is ordering of axes_keys and NAMES of NamedAxisArray{NAMES} guaranteed to be the same?
dimension names will always be in the same order as the axes.