LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul
Matrix multiplication for wrapper types such as Hermitian currently uses the unwrapping mechanism that assigns a character based on the type of the wrapper. However, this isn't always unique, as for Hermitian/Symmetric types, this also looks as the uplo field, which isn't usually known at compile time.
https://github.com/JuliaLang/julia/blob/6023ad6718514c15b3297197757ae3d93b85270b/stdlib/LinearAlgebra/src/LinearAlgebra.jl#L523-L525
An example of a badly inferred function call because of this:
julia> @descend_code_warntype (A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(Symmetric(rand(2,2)))
(::var"#1#2")(A) @ Main REPL[2]:1
┌ Warning: couldn't retrieve source of (::var"#1#2")(A) @ Main REPL[2]:1
└ @ TypedSyntax ~/.julia/packages/TypedSyntax/cH1Nu/src/node.jl:36
Variables
#self#::Core.Const(var"#1#2"())
A::Symmetric{Float64, Matrix{Float64}}
Body::Union{Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Symmetric{Float64, Matrix{Float64}}, Transpose{Float64, Matrix{Float64}}, Matrix{Float64}}
@ REPL[2]:1 within `#1`
1 ─ %1 = LinearAlgebra.wrap::Core.Const(LinearAlgebra.wrap)
│ %2 = Main.parent::Core.Const(parent)
│ %3 = (%2)(A)::Matrix{Float64}
│ %4 = LinearAlgebra.wrapper_char::Core.Const(LinearAlgebra.wrapper_char)
│ %5 = (%4)(A)::Char
│ %6 = (%1)(%3, %5)::Union{Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Symmetric{Float64, Matrix{Float64}}, Transpose{Float64, Matrix{Float64}}, Matrix{Float64}}
└── return %6
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [w]arn, [h]ide type-stable statements, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native, [j]ump to source always.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
• %3 = parent(::Symmetric{Float64, Matrix{Float64}})::Matrix{Float64}
%5 = wrapper_char(::Symmetric{Float64, Matrix{Float64}})::Char
%6 = wrap(::Matrix{Float64},::Char)::…
The output type is inferred as a large union, which complicates further type-inference downstream. Often, the impact of the runtime dispatch is minimal due to function barriers. However, we may avoid the runtime dispatch altogether.
This PR separates the uplo character from that for the type, storing them both in a newly defined struct. Using this approach, the type information may be constant-propagated even if the uplo isn't, and the return type may be concretely inferred. After this,
julia> @descend_code_warntype (A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(Symmetric(rand(2,2)))
(::var"#3#4")(A) @ Main REPL[3]:1
┌ Warning: couldn't retrieve source of (::var"#3#4")(A) @ Main REPL[3]:1
└ @ TypedSyntax ~/.julia/packages/TypedSyntax/cH1Nu/src/node.jl:36
Variables
#self#::Core.Const(var"#3#4"())
A::Symmetric{Float64, Matrix{Float64}}
Body::Symmetric{Float64, Matrix{Float64}}
@ REPL[3]:1 within `unknown scope`
1 ─ %1 = LinearAlgebra.wrap::Core.Const(LinearAlgebra.wrap)
│ %2 = Main.parent::Core.Const(parent)
│ %3 = (%2)(A)::Matrix{Float64}
│ %4 = LinearAlgebra.wrapper_char::Core.Const(LinearAlgebra.wrapper_char)
│ %5 = (%4)(A)::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool])
│ %6 = (%1)(%3, %5)::Symmetric{Float64, Matrix{Float64}}
└── return %6
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [w]arn, [h]ide type-stable statements, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native, [j]ump to source always.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
• %3 = parent(::Symmetric{Float64, Matrix{Float64}})::Matrix{Float64}
%5 = wrapper_char(::Symmetric{Float64, Matrix{Float64}})::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool])
%6 = < constprop > wrap(::Matrix{Float64},::Core.PartialStruct(LinearAlgebra.WrapperChar, Any[Core.Const('S'), Bool]))::…
↩
This change should be compatible with existing codes, as the new struct subtypes an AbstractChar, and it may be converted and compared to a Char like before.
Fixes https://github.com/JuliaLang/julia/issues/53951.
Should we even backport this to v1.10?
So, the main idea is not to confuse the compiler unnecessarily with uppercase/lowercase H or S when, for the result type, that distinction is irrelevant anyway, right? Because that distinction is just the value of a field, and not encoded into the type?
Yes, that's the idea.
Backporting to v1.10 might require some manual intervention, but should be a good idea.
The last few commits improve type-stability and ensure constant propagation in various checks in the matmul functions. ~~Introduces a new function _in that parallels in, but uses 2-value logic and is defined recursively. This allows checks like tA in ('T', 'N', 'C') to be evaluated at compile time, which should remove branches in the code. Ideally, in should already be doing this, but I don't know enough about the compile-time implications in the general case. For our specific case, this shouldn't matter much. Also, defines a function all_in that acts like all(in(..)), but ensures constant propagation by unrolling the loop over the arguments.~~ This isn't used anymore, as all(map(in(..), ...)) achieves constant propagation without the need for special helper functions.
Fixes #53951 after the recent set of commits. After this,
julia> using LinearAlgebra
julia> using BenchmarkTools
julia> A = Hermitian([1.0 2.0; 2.0 3.0])
2×2 Hermitian{Float64, Matrix{Float64}}:
1.0 2.0
2.0 3.0
julia> B = [4.0 5.0; 6.0 7.0]
2×2 Matrix{Float64}:
4.0 5.0
6.0 7.0
julia> Y = similar(B)
2×2 Matrix{Float64}:
6.0e-323 6.4e-323
NaN 0.0
julia> @btime mul!($Y, $A, $B)
127.843 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
16.0 19.0
26.0 31.0
julia> Badj = B'
2×2 adjoint(::Matrix{Float64}) with eltype Float64:
4.0 6.0
5.0 7.0
julia> @btime mul!($Y, $A, $Badj)
44.311 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
14.0 20.0
23.0 33.0
I love it. I always dreamt of the day when that character stuff be inferred, or constant-propagated far enough. Is this ready to go now? I think we should first merge this, and then "stabilize MulAddMul strategically" PR, to give this one a chance for backport to v1.10, though I'm not sure if this is a bit too ambitious.
Yes, this is ready from my side.