TensorNetwork
TensorNetwork copied to clipboard
`backend.item` in MPS calculation is incompatible with autograd in jax
In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py,
line 319: res.append(self.backend.item(result.tensor))
and line 479 return [self.backend.item(o) for o in c]
,
the using of self.backend.item
is incompatible with autograd in jax (and maybe also other backends).
I haven't checked with other files so those files might have similar issues.
Here's a simple example:
import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))
def func(x):
mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
gate = jax.scipy.linalg.expm(Z * x)
e = mps.measure_local_operator([gate], [0])
return e[0]
print(func(1.0)) # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0)) # error: AttributeError: 'ConcreteArray' object has no attribute 'item'