tensorcircuit icon indicating copy to clipboard operation
tensorcircuit copied to clipboard

Vmap isn't working when using JAX-based MPSCircuits

Open Muzhou-Ma opened this issue 1 year ago • 3 comments

Issue Description

In the test of #213, I found that the Vmap function isn't working when using JAX-based MPSCircuit. The program is not paralleled and only uses one CPU core. The bug seems to be caused by MPSCircuit, when using JAX-based ordinary circuit, everything is fine.

Muzhou-Ma avatar May 24 '24 15:05 Muzhou-Ma

can be reproduced, may be due to the same issue for QR and SVD. These operations might not support vmap.

Updated: nope, jax can vmap qr and svd, the reason of vmap failure in MPSCircuit requires further investigation

refraction-ray avatar May 25 '24 04:05 refraction-ray

tf backend vmap is ok but with very low CPU utilization, only around 150% for my test example

refraction-ray avatar May 25 '24 05:05 refraction-ray

tf backend vmap is ok but with very low CPU utilization, only around 150% for my test example

tf backend seems to have a warning with QR decomposition, perhaps this will cause low CPU utilization.

Muzhou-Ma avatar May 25 '24 05:05 Muzhou-Ma