TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

SVD on jax backend and thus ``split_node`` cannot be jitted when ``max_truncation_err`` is set

Open refraction-ray opened this issue 3 years ago • 0 comments

SVD and split_node are ok on tensorflow backend with tensorflow jit:

import tensorflow as tf
tn.set_default_backend("tensorflow")
@tf.function
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(tf.ones([2,2,2,2]))

But it fails on jax backend as:

import jax
from jax import numpy as jnp
tn.set_default_backend("jax")
@jax.jit
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(jnp.ones([2,2,2,2]))

The error is raised from svd operation in backends/numpy/decompositions.py: num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) as ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:.

This error is actually as expected even before I tried this, since jax jitted function only accepts and returns tensors with fixed shape, which supports only a subset of functionalities of tf.function. Since split_node with max_truncation_err returns nodes of varying shape (final shape depends on the singular value), it seems to be incompatible with jax jit mechanism.

Any thoughts or workaround on this? As I believe it is very common to apply split_node with max_singular_values in tensornetwork related algorithms and it would be great such algorithms can be jitted.

refraction-ray avatar Nov 18 '21 01:11 refraction-ray