TensorNetwork icon indicating copy to clipboard operation
TensorNetwork copied to clipboard

`backend.item` in MPS calculation is incompatible with autograd in jax

Open SUSYUSTC opened this issue 3 years ago • 2 comments

In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py, line 319: res.append(self.backend.item(result.tensor)) and line 479 return [self.backend.item(o) for o in c], the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends). I haven't checked with other files so those files might have similar issues. Here's a simple example:

import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))


def func(x):
    mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
    gate = jax.scipy.linalg.expm(Z * x)
    e = mps.measure_local_operator([gate], [0])
    return e[0]


print(func(1.0))                 # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'

SUSYUSTC avatar Jan 21 '22 06:01 SUSYUSTC

hi, and thanks for the message! Can you post the full error message as well? thanks!

mganahl avatar Jan 21 '22 19:01 mganahl

hi, and thanks for the message! Can you post the full error message as well? thanks!

The full output is

(0.7063742876052856-1.4842953532934189e-08j)
Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 2313, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/core.py", line 568, in __getattr__
    attr = getattr(self.aval, name)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ConcreteArray' object has no attribute 'item'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
AttributeError: 'ConcreteArray' object has no attribute 'item'

SUSYUSTC avatar Jan 23 '22 22:01 SUSYUSTC