TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

`backend.item` in MPS calculation is incompatible with autograd in jax

Open SUSYUSTC opened this issue 3 years ago • 2 comments

In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py, line 319: res.append(self.backend.item(result.tensor)) and line 479 return [self.backend.item(o) for o in c], the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends). I haven't checked with other files so those files might have similar issues. Here's a simple example:

import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))


def func(x):
    mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
    gate = jax.scipy.linalg.expm(Z * x)
    e = mps.measure_local_operator([gate], [0])
    return e[0]


print(func(1.0))                 # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'

SUSYUSTC avatar Jan 21 '22 06:01 SUSYUSTC