quimb icon indicating copy to clipboard operation
quimb copied to clipboard

ValueError when measuring MPS with jax backend

Open thibxlv opened this issue 1 month ago • 1 comments

What happened?

When measuring an MPS with jax backend, a ValueError: output array is read-only is raised.

The current workaround that I use is either:

  • convert the backend to numpy before measure or
  • modify file tensor_1d.py, line 3719 in the definition of "measure" method from MatrixProductState class: replace pi /= pi.sum() by pi = pi / pi.sum().

What did you expect to happen?

No response

Minimal Complete Verifiable Example

import jax
import quimb.tensor as qtn

psi = qtn.MPS_rand_computational_state(4)
print("psi backend:", psi.backend)

psi.measure_(0)
print("numpy measure ok")

psi = qtn.MPS_rand_computational_state(4)

def to_backend(x):
    return jax.numpy.asarray(x)

psi.apply_to_arrays(to_backend)
print("psi backend:", psi.backend)

psi.measure_(0)
print("jax measure ok")

Relevant log output

psi backend: numpy
numpy measure ok
psi backend: jax
Traceback (most recent call last):
  File "PATH/bug_jax_minimal.py", line 21, in <module>
    psi.measure_(0)
  File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 379, in wrapped
    return fn(self, *args, info=info, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 3719, in measure
    pi /= pi.sum()
ValueError: output array is read-only

Anything else we need to know?

No response

Environment

ubuntu 24.04.2 python 3.12.3 quimb 1.11.2 jax 0.8.1

thibxlv avatar Dec 01 '25 11:12 thibxlv

Thanks for the issue @thibxlv, this should be fixed by https://github.com/jcmgray/quimb/commit/fec27eb3d2801bd8268762e31f1f9c95d871cfe0. You should also be able to supply backend_random="jax" now to do the sampling itself in jax and so e.g. jit the whole process.

jcmgray avatar Dec 02 '25 01:12 jcmgray