TensorNetwork
TensorNetwork copied to clipboard
Support for opt_einsum's intermediate sharing
I would like to make use of opt_einsum.shared_intermediates for caching intermediates across similar but different TN contractions. Here's a trivial example of a ring TN in raw opt_einsum:
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]
def f(factors):
with opt_einsum.shared_intermediates():
x = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
y = opt_einsum.contract("ab,bc,cd,de,ea", *factors)
return x, y
print(jax.make_jaxpr(f)(factors))
Using JAX to inspect the jaxpr indicates that the same TN is reused entirely:
{ lambda ; a b c d e.
let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] b a
g = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
precision=None ] f e
h = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] d c
i = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = dot_general[ dimension_numbers=(((1, 0), (0, 1)), ((), ()))
precision=None ] b a
in (c,) }
device=None
donated_invars=(False, False)
name=_einsum ] g h
in (i, i) }
Now doing a similar thing in tensornetwork:
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
import tensornetwork as tn
factors = [jnp.asarray(np.random.rand(1000, 1000)) for _ in range(5)]
def f(factors):
with opt_einsum.shared_intermediates():
a, b, c, d, e = [tn.Node(m) for m in factors]
a[1] ^ b[0]
b[1] ^ c[0]
c[1] ^ d[0]
d[1] ^ e[0]
e[1] ^ a[0]
x = tn.contractors.auto([a, b, c, d, e]).get_tensor()
f, g, h, i, j = [tn.Node(m) for m in factors]
f[1] ^ g[0]
g[1] ^ h[0]
h[1] ^ i[0]
i[1] ^ j[0]
j[1] ^ f[0]
y = tn.contractors.auto([f, g, h, i, j]).get_tensor()
return x, y
print(jax.make_jaxpr(f)(factors))
The corresponding output indicates that intermediate reuse is not happening:
{ lambda ; a b c d e.
let f = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] c b
g = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] a e
h = dot_general[ dimension_numbers=(((1,), (1,)), ((), ()))
precision=None ] d g
i = dot_general[ dimension_numbers=(((1, 0), (1, 0)), ((), ()))
precision=None ] f h
j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] c b
k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))
precision=None ] e d
l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
precision=None ] a k
m = dot_general[ dimension_numbers=(((0, 1), (1, 0)), ((), ()))
precision=None ] j l
in (i, m) }
Having this kind of intermediate reuse functionality is important for me, and I would appreciate some discussion on how we might be able to implement it.
@Thenerdstation @mganahl Any thoughts on this? It does seem like important functionality, but I'm not sure how our contraction algorithms work well enough to comment directly.
Yeah that seems like it would be difficult to add.
Is there any reason you can't just manually reuse nodes? It's not automated, but that should work?
What is stopping it from working right now? In my mind, constructing a new node from the same JAX array doesn't copy the underlying array and tn.contractors.auto should call into deterministic opt_einsum routines unless I'm mistaken?
Either way I'll try it out by reusing the nodes and get back to you.