LinearAlgebra.jl
LinearAlgebra.jl copied to clipboard
Feature request: multidimensional `LinearAlgebra.tr`
Right now, LinearAlgebra.tr only traces second-order tensors. However, it can be very useful to trace along a specific set of axes (e.g., for batched calculations), as well as ND traces (e.g., x[i, i, i]). This can be done with Tullio.jl for static dimension specification, but when you need to specify which axes are being traced, it requires some meta-programming. So it would be nicer if tr itself could take in a dims argument, and also handle more than two axes.
Therefore, I propose to extend tr to allow for a dims argument, which can be used for arbitrary dimension arrays. The following code gives a simple working implementation:
"""Perform the calculation sum_{i} x[i, i, ..., i]"""
function _tr_all_dims(x::AbstractArray{T,N}) where {T,N}
sum(i -> x[fill(i, N)...], axes(x, 1))
end
function tr(x::AbstractArray; dims)
mapslices(_tr_all_dims, x; dims=dims)
end
So now we can do things like:
x = zeros(5, 5, 5)
for i=1:5
x[i, i, i] = 1
end
# 3D trace:
tr(x; dims=(1, 2, 3)) # 5
# batched trace:
tr(x; dims=(1, 2)) # vector of 5
Let me know what you think.
Also - this implementation (with splats) is very slow, so we'd probably want to compile-in specialization for the dimensionality of the array.
Thanks, Miles