AxisKeys.jl icon indicating copy to clipboard operation
AxisKeys.jl copied to clipboard

Keys can be lost in `mapreduce(..., cat, ...)`

Open bencottier opened this issue 3 years ago • 2 comments

I want to slice a 3D array into matrices, multiply each matrix by another matrix, and then cat the result back into a 3D array.

The reason to do it this way is: the dimension shared by the matrices has non-overlapping keys, and I want to find the overlapping keys (that have non-missing values) for each slice.

The problem is that I can't preserve the axiskeys that I sliced over in the final concatenated array - it defaults to OneTo.

MWE:

julia> KA1 = KeyedArray(ones(2, 3), w=['a', 'b'], x=[:a, :b, :c]);

julia> KA2 = KeyedArray(ones(3, 4, 2), x=[:a, :b, :c], y=0.:3., z=["foo", "bar"]);

julia> mapreduce((x, y) -> cat(x, y; dims=:z), axiskeys(KA2, :z)) do z
           KA2_slice = KA2(z=z)
           return KA1 * KA2_slice
       end
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓   w ∈ 2-element Vector{Char}
→   y ∈ 4-element StepRangeLen{Float64,...}
□   z ∈ 2-element OneTo{Int}
And data, 2×4×2 Array{Float64, 3}:
[:, :, 1] ~ (:, :, 1):
         (0.0)  (1.0)  (2.0)  (3.0)
  ('a')    3.0    3.0    3.0    3.0
  ('b')    3.0    3.0    3.0    3.0

[:, :, 2] ~ (:, :, 2):
         (0.0)  (1.0)  (2.0)  (3.0)
  ('a')    3.0    3.0    3.0    3.0
  ('b')    3.0    3.0    3.0    3.0

I tried using an Interval at the KA2_slice = KA2(z=z) step, but Julia doesn't seem to support multiplying tensors with different dimensions, even if it's a trailing singleton dimension:

julia> ones(2, 3) * ones(3, 2, 1)
ERROR: MethodError: no method matching *(::Matrix{Float64}, ::Array{Float64, 3})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  *(::StridedMatrix{T}, ::StridedVector{S}) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}, S<:Real} at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:44
  *(::StridedMatrix{var"#s832"} where var"#s832"<:Union{Float32, Float64}, ::StridedMatrix{var"#s831"} where var"#s831"<:Union{Float32, Float64, ComplexF32, ComplexF64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:158
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[21]:1

My current workaround is to use wrapdims afterward to re-key the array.

bencottier avatar Oct 08 '21 10:10 bencottier

This is tricky. My first thought here is that, without thinking about this package, you probably want to just reshape and call * instead of making slices. One package which wraps this neatly is:

julia> using TensorCore

julia> KA1 ⊡ KA2
2×4×2 Array{Float64, 3}:
[:, :, 1] =
 3.0  3.0  3.0  3.0
 3.0  3.0  3.0  3.0

[:, :, 2] =
 3.0  3.0  3.0  3.0
 3.0  3.0  3.0  3.0

These packages are unaware of each other, but reshaping done there uses axes and thus with #6 it almost succeeds:

julia> KA1 ⊡ KA2
3-dimensional KeyedArray(...) with keys:
↓   2-element Vector{Char}
→   4-element StepRangeLen{Float64,...}
â—ª   2-element Vector{String}
And data, 2×4×2 reshape(::NamedDimsArray{(:w, :_), Float64, 2, Matrix{Float64}}, 2, 4, 2) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):
     → _
↓ w          (0.0)  (1.0)  (2.0)  (3.0)
      ('a')    3.0    3.0    3.0    3.0
      ('b')    3.0    3.0    3.0    3.0

[:, :, 2] ~ (:, :, "bar"):
     → _
↓ w          (0.0)  (1.0)  (2.0)  (3.0)
      ('a')    3.0    3.0    3.0    3.0
      ('b')    3.0    3.0    3.0    3.0

The second thought is that, regardless of where slices come from, it would be nice to better propagate properties. Instead of cat(xs...; dims=:z) you almost want cat(xs...; z = KA2.z)? Not sure that can be done. The closest which works now is to wrap a comprehension and call stack, which has methods for keys etc:

julia> [KA1 * KA2(z=z) for z in KA2.z]  # KA2.z === axiskeys(KA2, :z)
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]

julia> comp = [KA1 * KA2[z=i] for i in axes(KA2, :z)]
2-element Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}:
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]
 [3.0 3.0 3.0 3.0; 3.0 3.0 3.0 3.0]

julia> using LazyStack  # a package AxisKeys knows about

julia> stack(comp)  # 3rd axis still _ ∈ 2-element OneTo
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓   w ∈ 2-element Vector{Char}
→   y ∈ 4-element StepRangeLen{Float64,...}
◪   _ ∈ 2-element OneTo{Int}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, 1):
...

julia> stack(wrapdims(comp, z=KA2.z))
3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
↓   w ∈ 2-element Vector{Char}
→   y ∈ 4-element StepRangeLen{Float64,...}
◪   z ∈ 2-element Vector{String}
And data, 2×4×2 stack(::Vector{KeyedArray{Float64, 2, NamedDimsArray{(:w, :y), Float64, 2, Matrix{Float64}}, Tuple{Vector{Char}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}}) with eltype Float64:
[:, :, 1] ~ (:, :, "foo"):

Again thinking about #6, I think comp could plausibly be made to automatically wrap like this, since it makes axes(KA2, :z) a special type.

mcabbott avatar Oct 08 '21 11:10 mcabbott

Thanks for the suggestions!

To clarify, is there currently no support (on master in any package) for tensor multiplication that preserves axiskeys? And is #6 (or similar) your preferred solution?

bencottier avatar Oct 11 '21 11:10 bencottier