quimb
quimb copied to clipboard
ValueError when measuring MPS with jax backend
What happened?
When measuring an MPS with jax backend, a ValueError: output array is read-only is raised.
The current workaround that I use is either:
- convert the backend to numpy before measure or
- modify file
tensor_1d.py, line 3719 in the definition of "measure" method from MatrixProductState class: replacepi /= pi.sum()bypi = pi / pi.sum().
What did you expect to happen?
No response
Minimal Complete Verifiable Example
import jax
import quimb.tensor as qtn
psi = qtn.MPS_rand_computational_state(4)
print("psi backend:", psi.backend)
psi.measure_(0)
print("numpy measure ok")
psi = qtn.MPS_rand_computational_state(4)
def to_backend(x):
return jax.numpy.asarray(x)
psi.apply_to_arrays(to_backend)
print("psi backend:", psi.backend)
psi.measure_(0)
print("jax measure ok")
Relevant log output
psi backend: numpy
numpy measure ok
psi backend: jax
Traceback (most recent call last):
File "PATH/bug_jax_minimal.py", line 21, in <module>
psi.measure_(0)
File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 379, in wrapped
return fn(self, *args, info=info, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 3719, in measure
pi /= pi.sum()
ValueError: output array is read-only
Anything else we need to know?
No response
Environment
ubuntu 24.04.2 python 3.12.3 quimb 1.11.2 jax 0.8.1
Thanks for the issue @thibxlv, this should be fixed by https://github.com/jcmgray/quimb/commit/fec27eb3d2801bd8268762e31f1f9c95d871cfe0. You should also be able to supply backend_random="jax" now to do the sampling itself in jax and so e.g. jit the whole process.