TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

Support for opt_einsum's intermediate sharing

Open AidanGG opened this issue 4 years ago • 3 comments

I would like to make use of opt_einsum.shared_intermediates for caching intermediates across similar but different TN contractions. Here's a trivial example of a ring TN in raw opt_einsum:

import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum

factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]

def f(factors):
    with opt_einsum.shared_intermediates():
        x = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
        y = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
    return x, y

print(jax.make_jaxpr(f)(factors))

Using JAX to inspect the jaxpr indicates that the same TN is reused entirely:

{ lambda  ; a b c d e.
  let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] b a
      g = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
                       precision=None ] f e
      h = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] d c
      i = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = dot_general[ dimension_numbers=(((1, 0), (0, 1)), ((), ()))
                                                      precision=None ] b a
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=_einsum ] g h
  in (i, i) }

Now doing a similar thing in tensornetwork:

import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
import tensornetwork as tn

factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]

def f(factors):
    with opt_einsum.shared_intermediates():
        a, b, c, d, e = [tn.Node(m) for m in factors]
        a[1] ^ b[0]
        b[1] ^ c[0]
        c[1] ^ d[0]
        d[1] ^ e[0]
        e[1] ^ a[0]
        x = tn.contractors.auto([a, b, c, d, e]).get_tensor()

        f, g, h, i, j = [tn.Node(m) for m in factors]
        f[1] ^ g[0]
        g[1] ^ h[0]
        h[1] ^ i[0]
        i[1] ^ j[0]
        j[1] ^ f[0]
        y = tn.contractors.auto([f, g, h, i, j]).get_tensor()

    return x, y

print(jax.make_jaxpr(f)(factors))

The corresponding output indicates that intermediate reuse is not happening:

{ lambda  ; a b c d e.
  let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] c b
      g = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] a e
      h = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
                       precision=None ] d g
      i = dot_general[ dimension_numbers=(((1, 0), (1, 0)), ((), ()))
                       precision=None ] f h
      j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] c b
      k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] e d
      l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None ] a k
      m = dot_general[ dimension_numbers=(((0, 1), (1, 0)), ((), ()))
                       precision=None ] j l
  in (i, m) }

Having this kind of intermediate reuse functionality is important for me, and I would appreciate some discussion on how we might be able to implement it.

AidanGG avatar Sep 20 '20 14:09 AidanGG