PDMats.jl icon indicating copy to clipboard operation
PDMats.jl copied to clipboard

gemm! alias

Open harrisonritz opened this issue 1 year ago • 5 comments

love the package!

BLAS.gemm! fails for any PDMat arguments unless you pass a.mat. Maybe something like could be more general:

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

Benchmarks seem to run just as fast. minimal example:

using LinearAlgebra, PDMats, BenchmarkTools


ix = randn(20,20);
xx = PDMat(Hermitian(ix'*ix));
aa = randn(20,20);

pd_gemm!(tA, tB, alpha, A, B, beta, C) =    BLAS.gemm!( tA, tB, alpha, 
                                                        A isa AbstractPDMat ? A.mat : A, 
                                                        B isa AbstractPDMat ? B.mat : B, 
                                                        beta, 
                                                        C isa AbstractPDMat ? C.mat : C);

yy = zeros(20,20);
@benchmark mul!($yy, $xx, $aa', 1.0, 1.0)

yy = zeros(20,20);
@benchmark BLAS.gemm!('N', 'T', 1.0, $xx.mat, $aa, 1.0, $aa)

yy = zeros(20,20);
@benchmark pd_gemm!('N', 'T', 1.0, $xx, $aa, 1.0, $yy)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.384 μs … 331.486 μs  ┊ GC (min … max):  0.00% … 98.48%
 Time  (median):     3.176 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   4.285 μs ±  15.201 μs  ┊ GC (mean ± σ):  22.44% ±  6.26%

            █▂                                                 
  █▂▁▁▁▁▁▁▂▇██▇▆▅▄▃▃▂▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2.38 μs         Histogram: frequency by time         6.3 μs <

 Memory estimate: 20.58 KiB, allocs estimate: 4.

BenchmarkTools.Trial: 10000 samples with 193 evaluations.
 Range (min … max):  505.394 ns … 993.523 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     506.477 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   518.730 ns ±  28.264 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▂▁▁▁▁▄▃▁▁▁▂▂▁▂▁▁                                             ▁
  ██████████████████▇▇▇▆▆▆▇▇▆▆▆▆▆▆▇▆▇▆▆▆▆▆▆▆▆▅▆▆▅▆▆▅▅▄▄▅▃▄▅▄▅▄▅ █
  505 ns        Histogram: log(frequency) by time        642 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

BenchmarkTools.Trial: 10000 samples with 194 evaluations.
 Range (min … max):  505.371 ns … 909.577 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     506.443 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   510.625 ns ±  13.425 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▆▂▁      ▁  ▂▃▁                                              ▁
  ████████▇███▇███▇▆▆▆▇▇▇▇▇▇▇▆▆▆▆▅▆▆▆▆▆▅▄▅▅▄▅▅▄▄▄▃▅▄▃▄▅▅▃▄▄▄▂▃▃ █
  505 ns        Histogram: log(frequency) by time        573 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

harrisonritz avatar Jan 21 '24 19:01 harrisonritz