TensorNetwork
TensorNetwork copied to clipboard
Support for opt_einsum's intermediate sharing
I would like to make use of opt_einsum.shared_intermediates
for caching intermediates across similar but different TN contractions. Here's a trivial example of a ring TN in raw opt_einsum
:
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]
def f(factors):
with opt_einsum.shared_intermediates():
x = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
y = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
return x, y
print(jax.make_jaxpr(f)(factors))
Using JAX to inspect the jaxpr
indicates that the same TN is reused entirely:
{ lambda ; a b c d e.
let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] b a
g = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
precision=None ] f e
h = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] d c
i = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = dot_general[ dimension_numbers=(((1, 0), (0, 1)), ((), ()))
precision=None ] b a
in (c,) }
device=None
donated_invars=(False, False)
name=_einsum ] g h
in (i, i) }
Now doing a similar thing in tensornetwork
:
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
import tensornetwork as tn
factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]
def f(factors):
with opt_einsum.shared_intermediates():
a, b, c, d, e = [tn.Node(m) for m in factors]
a[1] ^ b[0]
b[1] ^ c[0]
c[1] ^ d[0]
d[1] ^ e[0]
e[1] ^ a[0]
x = tn.contractors.auto([a, b, c, d, e]).get_tensor()
f, g, h, i, j = [tn.Node(m) for m in factors]
f[1] ^ g[0]
g[1] ^ h[0]
h[1] ^ i[0]
i[1] ^ j[0]
j[1] ^ f[0]
y = tn.contractors.auto([f, g, h, i, j]).get_tensor()
return x, y
print(jax.make_jaxpr(f)(factors))
The corresponding output indicates that intermediate reuse is not happening:
{ lambda ; a b c d e.
let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] c b
g = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] a e
h = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
precision=None ] d g
i = dot_general[ dimension_numbers=(((1, 0), (1, 0)), ((), ()))
precision=None ] f h
j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] c b
k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] e d
l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
precision=None ] a k
m = dot_general[ dimension_numbers=(((0, 1), (1, 0)), ((), ()))
precision=None ] j l
in (i, m) }
Having this kind of intermediate reuse functionality is important for me, and I would appreciate some discussion on how we might be able to implement it.