LinearAlgebra.jl icon indicating copy to clipboard operation
LinearAlgebra.jl copied to clipboard

Feature request: multidimensional `LinearAlgebra.tr`

Open MilesCranmer opened this issue 3 years ago • 9 comments

Right now, LinearAlgebra.tr only traces second-order tensors. However, it can be very useful to trace along a specific set of axes (e.g., for batched calculations), as well as ND traces (e.g., x[i, i, i]). This can be done with Tullio.jl for static dimension specification, but when you need to specify which axes are being traced, it requires some meta-programming. So it would be nicer if tr itself could take in a dims argument, and also handle more than two axes.

Therefore, I propose to extend tr to allow for a dims argument, which can be used for arbitrary dimension arrays. The following code gives a simple working implementation:

"""Perform the calculation sum_{i} x[i, i, ..., i]"""
function _tr_all_dims(x::AbstractArray{T,N}) where {T,N}
     sum(i -> x[fill(i, N)...], axes(x, 1))
 end

function tr(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end

So now we can do things like:

x = zeros(5, 5, 5)
for i=1:5
    x[i, i, i] = 1
end
# 3D trace:
tr(x; dims=(1, 2, 3))  # 5

# batched trace:
tr(x; dims=(1, 2))  # vector of 5

Let me know what you think.

Also - this implementation (with splats) is very slow, so we'd probably want to compile-in specialization for the dimensionality of the array.

Thanks, Miles

MilesCranmer avatar Oct 13 '22 21:10 MilesCranmer

Here's a significantly faster version:

@generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
    x_part = :(x[$(fill(:i, N)...)])
    quote
        out = zero($T)
        for i in axes(x, 1)
            out += $x_part
        end
        out
    end
end

function tr(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end

although mapslices is quite slow for some reason.

MilesCranmer avatar Oct 13 '22 21:10 MilesCranmer

Bit faster, but still slower than tr

@generated function _tr_dims(x::AbstractArray{T,N}; dims) where {T,N}
    # val_dims is a tuple of Val(i), Val(j), etc.
    dims = collect(val.parameters[1] for val in dims.parameters)
    indices = [
        (j in dims) ? :(i) : :(:) for j in 1:N
    ]
    x_part = :(x[$(indices...)])
    summation = if :(:) in indices
        :(out .+= $x_part)
    else
        :(out += $x_part)
    end
    quote
        i = first(axes(x, 1))
        out = zero($x_part)
        for i in axes(x, 1)
            $summation
        end
        out
    end
end

function LinearAlgebra.tr(x; dims)
    dims = Tuple(collect(Val(dim) for dim in dims))
    return _tr_dims(x; dims=dims)
end

MilesCranmer avatar Oct 13 '22 22:10 MilesCranmer

Actually for big arrays, they are basically comparable. I think the multi-dimensional trace simply has more things to do at startup.

MilesCranmer avatar Oct 13 '22 22:10 MilesCranmer

Perhaps you will want to replace fill(i, N) and alike by ntuple(_ -> i, N). That makes the allocations disappear.

dkarrasch avatar Oct 14 '22 09:10 dkarrasch

Thanks. Here's the updated performances on a second-order tensor. The baseline, with normal tr, is 1.18 us

  1. 1.02 ms
function tr1(x::AbstractArray{T,N}; dims) where {T,N}
    return mapslices(_x -> sum(i -> x[ntuple(_ -> i, N)...], axes(_x, 1)), x; dims=dims)
end
  1. 1.02 ms
@generated function _tr_all_dims(x::AbstractArray{T, N}) where {T,N}
    x_part = :(x[$(fill(:i, N)...)])
    quote
        out = zero($T)
        for i in axes(x, 1)
            out += $x_part
        end
        out
    end
end
function tr2(x::AbstractArray; dims)
    mapslices(_tr_all_dims, x; dims=dims)
end
  1. 2.11 us
@generated function _tr3_dims(x::AbstractArray{T,N}; dims) where {T,N}
    # val_dims is a tuple of Val(i), Val(j), etc.
    dims = collect(val.parameters[1] for val in dims.parameters)
    indices = [
        (j in dims) ? :(i) : :(:) for j in 1:N
    ]
    x_part = :(x[$(indices...)])
    summation = if :(:) in indices
        :(out .+= $x_part)
    else
        :(out += $x_part)
    end
    quote
        i = first(axes(x, 1))
        out = zero($x_part)
        for i in axes(x, 1)
            $summation
        end
        out
    end
end
function tr3(x; dims)
    return _tr3_dims(x; dims=Tuple(collect(Val(dim) for dim in dims)))
end
  1. 1.583 us (Same as 3., but slicing a tuple of Val(i) rather than creating)
function tr4(x::AbstractArray{T,N}; dims) where {T,N}
    possible_dims = ntuple(i -> Val(i), N)
    selected_dims = Tuple(collect(possible_dims[dim] for dim in dims))
    return _tr3_dims(x; dims=selected_dims)
end

MilesCranmer avatar Oct 14 '22 15:10 MilesCranmer

@dkarrasch let me know what you think, and I can make a PR.

MilesCranmer avatar Nov 06 '22 17:11 MilesCranmer

I'm finding the metaprogramming stuff quite hard to read. A simple

using LinearAlgebra
import LinearAlgebra: checksquare

function checksquare(A::AbstractArray{<:Any,N}) where {N}
    sz = size(A)
    all(==(sz[1]), sz) || throw(DimensionMismatch("array is not square: dimensions are $(size(A))"))
    sz[1]
end

function mytr(A::Array{T,N}) where {T,N}
    n = checksquare(A)
    t = zero(T)
    @inbounds @simd for i in 1:n
        t += A[ntuple(_ -> i, N)...]
    end
    t
end

performs even better (due to the @simd annotation, which we should also invest on the current method in dense.jl; done in JuliaLang/julia#47585) than LinearAlgebra.tr on matrices, and its performance depends only on the number of diagonal elements. I tested on 400x400 and 400x400x400 arrays. As for the batched computation, mapslices seems to add massive overhead, but on the other hand I'm not sure that if we have this machinery, then we cook up some complicated super-specialized code here just for tr? @mbauman, do you have some good advice here?

dkarrasch avatar Nov 15 '22 20:11 dkarrasch

It's not so complicated to make CartesianIndices do this.

tr5(A::AbstractArray; dims=:) = _mytr(dims, A)
function _mytr(dims::Tuple{Integer, Vararg{Integer}}, A::AbstractArray)
    dimaxes = map(d -> axes(A,d), dims)
    allequal(dimaxes) || _tr_error(dimaxes)
    mask = ntuple(d -> !(d in dims), ndims(A))
    B = similar(A, ifelse.(mask, axes(A), (Base.OneTo(1),)))
    for I in CartesianIndices(B)
        t = zero(eltype(A))
        @inbounds @simd for j in first(dimaxes)
            K = CartesianIndex(ifelse.(mask, Tuple(I), j))
            t += A[K]
        end
        B[I] = t
    end
    dropdims(B; dims)
end
@noinline _tr_error(ax) = throw(DimensionMismatch(
    "traced dimensions must agree, but got $(ax)"))
function _mytr(::Colon, A::AbstractArray)
    allequal(axes(A)) || _tr_error(axes(A))
    t = zero(eltype(A))
    @inbounds @simd for i in axes(A,1)
        t += A[ntuple(_ -> i, ndims(A))...]
    end
    t
end
_mytr(dims::Integer, A::AbstractArray) = dropdims(sum(A; dims); dims)
_mytr(::Colon, A::AbstractVector) = sum(A)
@noinline _tr_error(ax) = throw(DimensionMismatch("traced dimensions must agree, but got $(ax)"))

x = rand(1:10, 3,3,2)
tr5(x, dims=(1,2))

let n = 100
    dims = (2,3)
    x = randn(n,n,n)
    a = @btime tr2($x; dims=$dims) # mapslices
    b = @btime tr3($x; dims=$dims) # @generated
    c = @btime tr5($x; dims=$dims) # CartesianIndices
    a ≈ b ≈ c
end
# 1.213 ms (229 allocations: 83.02 KiB)
# 16.708 μs (106 allocations: 89.42 KiB)
# 3.792 μs (4 allocations: 992 bytes)

let n = 100
    x = randn(n,n,n)
    a = @btime _tr_all_dims($x) # @generated
    b = @btime tr5($x) # CartesianIndices
    a ≈ b
end
# 101.183 ns (0 allocations: 0 bytes)
# 41.582 ns (0 allocations: 0 bytes)

The big question seems to be whether LinearAlgebra should handle higher-rank objects. dot accepts anything, that might be the only function which goes beyond matrices.

mcabbott avatar Nov 16 '22 00:11 mcabbott

Nice!

The big question seems to be whether LinearAlgebra should handle higher-rank objects. dot accepts anything, that might be the only function which goes beyond matrices.

I could see this being attractive for operations which have distinct meaning when used on higher dimensional arrays, like dot ( $\checkmark$ ), norm ( $\checkmark$ ), or tr ( $\times$ ) - for those operations you couldn't use a 2D version from LinearAlgebra. But most other operations where you would simply want to vectorize it over a batch axis, the user could just loop it themselves.

MilesCranmer avatar Nov 16 '22 00:11 MilesCranmer