ComponentArrays.jl icon indicating copy to clipboard operation
ComponentArrays.jl copied to clipboard

Reexport `inv` and `diag` etc from `LinearAlgebra`?

Open Yuan-Ru-Lin opened this issue 2 years ago • 2 comments

Is ther any reason not to reexport inv and diag etc? They can implemented quite easily as follows.

import LinearAlgebra: inv, diag
inv(x::ComponentMatrix)::ComponentMatrix = ComponentMatrix(inv(getdata(x)), getaxes(x)...)
diag(x::ComponentMatrix)::ComponentVector = ComponentVector(diag(getdata(x)), getaxes(x)[1])

Yuan-Ru-Lin avatar Feb 05 '23 14:02 Yuan-Ru-Lin

Well we probably wouldn't want to export them in order to not clog people's namespace (if they want diag they can do using LinearAlgebra). But I think the bigger part of the question is whether we should specifically overload these functions to return ComponentArrays instead of falling back to plain Arrays. It tends to be a huge headache to chase down every method of every function like this and make a special case, so I just don't do that. Usually I just rely on the functions to be properly calling similar, convert, or copy so ComponentArrays would be automatically created when they should.

With that said, it's actually a little trickier than the above to implement things like inv or diag. For example, if you had

ab = ComponentVector(a=1, b=2)
xy = ComponentVector(x=3, y=4)
abxy = ab * xy'

what would you expect the axes of diag(abxy) to look like? There's not really a good reason they should be a, b over x, y.

jonniedie avatar Feb 12 '23 04:02 jonniedie

I agree there is ambiguity in the case of diag. Maybe we can have something like ComponentSquareMatrix that enforces identical Axis in both dimensions?

As for the case of inv, it seems that the internal of LinearAlgebra does call convert, but it still falls back to ordinary matrix.

julia> using LinearAlgebra, ComponentArrays
julia> M = [1 2; 3 4];
julia> inv(ComponentMatrix(M, Axis(a=1, b=2), Axis(a=1, b=2)))
2×2 Matrix{Float64}:
  1.5  -0.5
 -2.0   1.0

while @edit inv(M) leads to

...
function inv(A::StridedMatrix{T}) where T
    checksquare(A)
    S = typeof((one(T)*zero(T) + one(T)*zero(T))/one(T))
    AA = convert(AbstractArray{S}, A)
    if istriu(AA)
        Ai = triu!(parent(inv(UpperTriangular(AA))))
    elseif istril(AA)
        Ai = tril!(parent(inv(LowerTriangular(AA))))
    else
        Ai = inv!(lu(AA))
        Ai = convert(typeof(parent(Ai)), Ai)
    end
    return Ai
end
...

Maybe I should open two separate issues with these?

Yuan-Ru-Lin avatar Feb 17 '23 03:02 Yuan-Ru-Lin