TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

Support for opt_einsum's intermediate sharing

Open AidanGG opened this issue 5 years ago • 3 comments

I would like to make use of opt_einsum.shared_intermediates for caching intermediates across similar but different TN contractions. Here's a trivial example of a ring TN in raw opt_einsum:

import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum

factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]

def f(factors):
    with opt_einsum.shared_intermediates():
        x = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
        y = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
    return x, y

print(jax.make_jaxpr(f)(factors))

Using JAX to inspect the jaxpr indicates that the same TN is reused entirely:

{ lambda  ; a b c d e.
  let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] b a
      g = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
                       precision=None ] f e
      h = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] d c
      i = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = dot_general[ dimension_numbers=(((1, 0), (0, 1)), ((), ()))
                                                      precision=None ] b a
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=_einsum ] g h
  in (i, i) }

Now doing a similar thing in tensornetwork:

import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
import tensornetwork as tn

factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]

def f(factors):
    with opt_einsum.shared_intermediates():
        a, b, c, d, e = [tn.Node(m) for m in factors]
        a[1] ^ b[0]
        b[1] ^ c[0]
        c[1] ^ d[0]
        d[1] ^ e[0]
        e[1] ^ a[0]
        x = tn.contractors.auto([a, b, c, d, e]).get_tensor()

        f, g, h, i, j = [tn.Node(m) for m in factors]
        f[1] ^ g[0]
        g[1] ^ h[0]
        h[1] ^ i[0]
        i[1] ^ j[0]
        j[1] ^ f[0]
        y = tn.contractors.auto([f, g, h, i, j]).get_tensor()

    return x, y

print(jax.make_jaxpr(f)(factors))

The corresponding output indicates that intermediate reuse is not happening:

{ lambda  ; a b c d e.
  let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] c b
      g = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] a e
      h = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
                       precision=None ] d g
      i = dot_general[ dimension_numbers=(((1, 0), (1, 0)), ((), ()))
                       precision=None ] f h
      j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] c b
      k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
                       precision=None ] e d
      l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None ] a k
      m = dot_general[ dimension_numbers=(((0, 1), (1, 0)), ((), ()))
                       precision=None ] j l
  in (i, m) }

Having this kind of intermediate reuse functionality is important for me, and I would appreciate some discussion on how we might be able to implement it.

AidanGG avatar Sep 20 '20 14:09 AidanGG

@Thenerdstation @mganahl Any thoughts on this? It does seem like important functionality, but I'm not sure how our contraction algorithms work well enough to comment directly.

alewis avatar Sep 28 '20 12:09 alewis

Yeah that seems like it would be difficult to add.

Is there any reason you can't just manually reuse nodes? It's not automated, but that should work?

chaserileyroberts avatar Sep 28 '20 14:09 chaserileyroberts

What is stopping it from working right now? In my mind, constructing a new node from the same JAX array doesn't copy the underlying array and tn.contractors.auto should call into deterministic opt_einsum routines unless I'm mistaken?

Either way I'll try it out by reusing the nodes and get back to you.

AidanGG avatar Sep 28 '20 19:09 AidanGG