opt_einsum icon indicating copy to clipboard operation
opt_einsum copied to clipboard

Sub-optimal contraction path when using broadcasting

Open pimdh opened this issue 2 years ago • 5 comments

Hi, I'm not sure if there's an immediate solution possible, but it seems like opt_einsum first considers broadcasting, then optimizes the contraction path. This leads to sub-optimal results:

import opt_einsum
print(opt_einsum.__version__)
print(opt_einsum.contract_path("ijk,bj,bk->bi", (32, 32, 32), (10000, 32), (1, 32), optimize="optimal", shapes=True))
print(opt_einsum.contract_path("ijk,bj,k->bi", (32, 32, 32), (10000, 32), (32,), optimize="optimal", shapes=True))

Gives

v3.3.0+24.g1a984b7
([(1, 2), (0, 1)],   Complete contraction:  ijk,bj,bk->bi
         Naive scaling:  4
     Optimized scaling:  4
      Naive FLOP count:  9.830e+8
  Optimized FLOP count:  6.656e+8
   Theoretical speedup:  1.477e+0
  Largest intermediate:  1.024e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3              0             bk,bj->bkj                           ijk,bkj->bi
   4           TDOT            bkj,ijk->bi                                bi->bi)
([(0, 2), (0, 1)],   Complete contraction:  ijk,bj,k->bi
         Naive scaling:  4
     Optimized scaling:  3
      Naive FLOP count:  9.830e+8
  Optimized FLOP count:  2.055e+7
   Theoretical speedup:  4.785e+1
  Largest intermediate:  3.200e+5 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3           GEMM              k,ijk->ij                             bj,ij->bi
   3           GEMM              ij,bj->bi                                bi->bi)

We see that in the first case, the third tensor is broadcasted to (b, 32) and then the optimizer decides it's best to contract the latter two tensors. Ideally, we'd strip off the to-be-broadcasted dim from the third tensor, which allows for a much faster computation, as shown in the second case.

Any ideas on how this could be addressed? I understand that this doesn't involve just choosing a contraction path, so might not be solvable by this library. Thanks!

pimdh avatar Sep 23 '23 10:09 pimdh

opt_einsum isn't natively aware of broadcasting rules and likely gets a bit confused on the size of the dimension (i_dim=32). Likely the easiest way to solve this is with preprocessing as seen in https://github.com/dgasmith/opt_einsum/issues/114.

dgasmith avatar Sep 26 '24 15:09 dgasmith

+1. For my use case, I would like to fix input array size to simplify dimension handling. This leads to a number of singleton dimensions, and often gives sub-optimal paths from opt_einsum (manifesting as memory issues). While it is conceptually simple to strip them prior to the einsum, that step is nuanced and at the very least introduces code smell.

nedlrichards avatar Feb 06 '25 20:02 nedlrichards

This is helpful that there are multiple needs for a preprocessor. Could you provide a list of paths, especially edge cases, that we should consider for testing? The first preprocessor attempt was fairly ambitions, I think if we limit it to consolidating Hadamard and below contractions (such as the example above) it should move the needle without being overly complex.

dgasmith avatar Feb 07 '25 17:02 dgasmith

I'll try to work up something, but I suspect my use case if fairly simple from the perspective of this library.

The hardest challenge I have so far is to preserve the output shape when every instance of a dimension in the inputs is 1. If you strip this dimension from all of the occurrences in the inputs, it will also require it to be removed from the output, and then re-inflated post processing. That nuance turns the problem from an input by input treatment to one that requires consideration off all of the inputs together.

nedlrichards avatar Feb 07 '25 19:02 nedlrichards

As a edge case, removing a singleton dimension here would lead to a valid expression and would otherwise raise a ValueError.

import numpy as np
a = np.random.randn(2, 2)
b = np.random.randn(2)
c = np.random.randn(1, 2)
np.einsum('ii,i->i', a, b)
np.einsum('ii,i->i', c, b)

nedlrichards avatar Feb 08 '25 05:02 nedlrichards