AxisKeys.jl
AxisKeys.jl copied to clipboard
Keys can be lost in `mapreduce(..., cat, ...)`
I want to slice a 3D array into matrices, multiply each matrix by another matrix, and then cat the result back into a 3D array.
The reason to do it this way is: the dimension shared by the matrices has non-overlapping keys, and I want to find the overlapping keys (that have non-missing values) for each slice.
The problem is that I can't preserve the axiskeys that I sliced over in the final concatenated array - it defaults to OneTo
.
MWE:
julia> KA1 = KeyedArray(ones(2, 3), w=['a', 'b'], x=[:a, :b, :c]);
julia> KA2 = KeyedArray(ones(3, 4, 2), x=[:a, :b, :c], y=0.:3., z=["foo", "bar"]);
julia> mapreduce((x, y) -> cat(x, y; dims=:z), axiskeys(KA2, :z)) do z
KA2_slice = KA2(z=z)
return KA1 * KA2_slice
end
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ w ∈ 2-element Vector{Char}
→ y ∈ 4-element StepRangeLen{Float64,...}
□ z ∈ 2-element OneTo{Int}
And data, 2×4×2 Array{Float64, 3}:
[:, :, 1] ~ (:, :, 1):
(0.0) (1.0) (2.0) (3.0)
('a') 3.0 3.0 3.0 3.0
('b') 3.0 3.0 3.0 3.0
[:, :, 2] ~ (:, :, 2):
(0.0) (1.0) (2.0) (3.0)
('a') 3.0 3.0 3.0 3.0
('b') 3.0 3.0 3.0 3.0
I tried using an Interval
at the KA2_slice = KA2(z=z)
step, but Julia doesn't seem to support multiplying tensors with different dimensions, even if it's a trailing singleton dimension:
julia> ones(2, 3) * ones(3, 2, 1)
ERROR: MethodError: no method matching *(::Matrix{Float64}, ::Array{Float64, 3})
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
*(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:44
*(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Float32, Float64}, ::StridedMatrix{var"#s831"} where var"#s831"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:158
...
Stacktrace:
[1] top-level scope
@ REPL[21]:1
My current workaround is to use wrapdims
afterward to re-key the array.
This is tricky. My first thought here is that, without thinking about this package, you probably want to just reshape and call *
instead of making slices. One package which wraps this neatly is:
julia> using TensorCore
julia> KA1 ⊡ KA2
2×4×2 Array{Float64, 3}:
[:, :, 1] =
3.0 3.0 3.0 3.0
3.0 3.0 3.0 3.0
[:, :, 2] =
3.0 3.0 3.0 3.0
3.0 3.0 3.0 3.0
These packages are unaware of each other, but reshaping done there uses axes
and thus with #6 it almost succeeds:
julia> KA1 ⊡ KA2
3-dimensional KeyedArray(...) with keys:
↓ 2-element Vector{Char}
→ 4-element StepRangeLen{Float64,...}
â—ª 2-element Vector{String}
And data, 2×4×2 reshape(::NamedDimsArray{(:w, :_), Float64, 2, Matrix{Float64}}, 2, 4, 2) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):
→ _
↓ w (0.0) (1.0) (2.0) (3.0)
('a') 3.0 3.0 3.0 3.0
('b') 3.0 3.0 3.0 3.0
[:, :, 2] ~ (:, :, "bar"):
→ _
↓ w (0.0) (1.0) (2.0) (3.0)
('a') 3.0 3.0 3.0 3.0
('b') 3.0 3.0 3.0 3.0
The second thought is that, regardless of where slices come from, it would be nice to better propagate properties. Instead of cat(xs...; dims=:z)
you almost want cat(xs...; z = KA2.z)
? Not sure that can be done. The closest which works now is to wrap a comprehension and call stack
, which has methods for keys etc:
julia> [KA1 * KA2(z=z) for z in KA2.z] # KA2.z === axiskeys(KA2, :z)
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
julia> comp = [KA1 * KA2[z=i] for i in axes(KA2, :z)]
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
[3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
julia> using LazyStack # a package AxisKeys knows about
julia> stack(comp) # 3rd axis still _ ∈ 2-element OneTo
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ w ∈ 2-element Vector{Char}
→ y ∈ 4-element StepRangeLen{Float64,...}
◪ _ ∈ 2-element OneTo{Int}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, 1):
...
julia> stack(wrapdims(comp, z=KA2.z))
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓ w ∈ 2-element Vector{Char}
→ y ∈ 4-element StepRangeLen{Float64,...}
◪ z ∈ 2-element Vector{String}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):
Again thinking about #6, I think comp
could plausibly be made to automatically wrap like this, since it makes axes(KA2, :z)
a special type.
Thanks for the suggestions!
To clarify, is there currently no support (on master
in any package) for tensor multiplication that preserves axiskeys? And is #6 (or similar) your preferred solution?