TensorNetwork
TensorNetwork copied to clipboard
SVD on jax backend and thus ``split_node`` cannot be jitted when ``max_truncation_err`` is set
SVD and split_node
are ok on tensorflow backend with tensorflow jit:
import tensorflow as tf
tn.set_default_backend("tensorflow")
@tf.function
def f(b):
a = tn.Node(b)
n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
return n1.tensor
f(tf.ones([2,2,2,2]))
But it fails on jax backend as:
import jax
from jax import numpy as jnp
tn.set_default_backend("jax")
@jax.jit
def f(b):
a = tn.Node(b)
n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
return n1.tensor
f(jnp.ones([2,2,2,2]))
The error is raised from svd operation in backends/numpy/decompositions.py
: num_sing_vals_keep = min(max_singular_values, num_sing_vals_err)
as ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
.
This error is actually as expected even before I tried this, since jax jitted function only accepts and returns tensors with fixed shape, which supports only a subset of functionalities of tf.function
. Since split_node
with max_truncation_err
returns nodes of varying shape (final shape depends on the singular value), it seems to be incompatible with jax jit mechanism.
Any thoughts or workaround on this? As I believe it is very common to apply split_node
with max_singular_values
in tensornetwork related algorithms and it would be great such algorithms can be jitted.