LinearAlgebra.jl icon indicating copy to clipboard operation
LinearAlgebra.jl copied to clipboard

Feature request: multidimensional `LinearAlgebra.tr`

Open MilesCranmer opened this issue 3 years ago • 9 comments

Right now, LinearAlgebra.tr only traces second-order tensors. However, it can be very useful to trace along a specific set of axes (e.g., for batched calculations), as well as ND traces (e.g., x[i, i, i]). This can be done with Tullio.jl for static dimension specification, but when you need to specify which axes are being traced, it requires some meta-programming. So it would be nicer if tr itself could take in a dims argument, and also handle more than two axes.

Therefore, I propose to extend tr to allow for a dims argument, which can be used for arbitrary dimension arrays. The following code gives a simple working implementation:

"""Perform the calculation sum_{i} x[i, i, ..., i]"""
function _tr_all_dims(x::AbstractArray{T,N}) where {T,N}
     sum(i -> x[fill(i, N)...], axes(x, 1))
 end

function tr(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end

So now we can do things like:

x = zeros(5, 5, 5)
for i=1:5
    x[i, i, i] = 1
end
# 3D trace:
tr(x; dims=(1, 2, 3))  # 5

# batched trace:
tr(x; dims=(1, 2))  # vector of 5

Let me know what you think.

Also - this implementation (with splats) is very slow, so we'd probably want to compile-in specialization for the dimensionality of the array.

Thanks, Miles

MilesCranmer avatar Oct 13 '22 21:10 MilesCranmer