opt_einsum
opt_einsum copied to clipboard
Could `opt_einsum` understand repeated inputs?
Hello and thank you for the great library!
I'm curious if opt_einsum
can be generalized to let the user specify which inputs are the same, and use this info to produce a more optimal contraction?
Example:
import numpy as np
A = np.random.normal(size=(3, 2))
B = np.random.normal(size=(2, 2))
def f(A, B):
AB = A @ B
return AB @ AB.T
def f_einsum(A, B):
return np.einsum('ij,jk,lk,zl->iz', A, B, B, A, optimize='optimal')
import opt_einsum
opt_einsum.contract_path('ij,jk,lk,zl->iz', A, B, B, A, optimize='optimal')
gives
([(1, 2), (0, 2), (0, 1)], Complete contraction: ij,jk,lk,zl->iz
Naive scaling: 5
Optimized scaling: 3
Naive FLOP count: 2.880e+2
Optimized FLOP count: 7.600e+1
Theoretical speedup: 3.789e+0
Largest intermediate: 9.000e+0 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 GEMM lk,jk->lj ij,zl,lj->iz
3 GEMM lj,ij->li zl,li->iz
3 GEMM li,zl->iz iz->iz)
i.e. doing 3 contractions instead of two (where in this case evaluating f
in two contractions is indeed faster than evaluating f_einsum
). I wonder if it's feasible to accept a list of input identifiers (in this case [0, 1, 1, 0]
) and leverage it to compute the contraction faster? Thank you for consideration!