TensorNetwork
TensorNetwork copied to clipboard
`backend.item` in MPS calculation is incompatible with autograd in jax
In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py,
line 319: res.append(self.backend.item(result.tensor))
and line 479 return [self.backend.item(o) for o in c],
the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends).
I haven't checked with other files so those files might have similar issues.
Here's a simple example:
import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))
def func(x):
mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
gate = jax.scipy.linalg.expm(Z * x)
e = mps.measure_local_operator([gate], [0])
return e[0]
print(func(1.0)) # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0)) # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
hi, and thanks for the message! Can you post the full error message as well? thanks!
hi, and thanks for the message! Can you post the full error message as well? thanks!
The full output is
(0.7063742876052856-1.4842953532934189e-08j)
Traceback (most recent call last):
File "a.py", line 17, in <module>
print(vg(1.0)) # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 2313, in _vjp
flat_fun, primals_flat, reduce_axes=reduce_axes)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "a.py", line 11, in func
e = mps.measure_local_operator([gate], [0])
File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
res.append(self.backend.item(result.tensor))
File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
return tensor.item()
File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/core.py", line 568, in __getattr__
attr = getattr(self.aval, name)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ConcreteArray' object has no attribute 'item'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "a.py", line 17, in <module>
print(vg(1.0)) # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
File "a.py", line 11, in func
e = mps.measure_local_operator([gate], [0])
File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
res.append(self.backend.item(result.tensor))
File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
return tensor.item()
AttributeError: 'ConcreteArray' object has no attribute 'item'