openpi
openpi copied to clipboard
All buffers must have the same element type and count: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)
my execution on two gpu machine:
XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 \
uv run python scripts/train.py pi0_fast_base-genesis_ee_position_2000_20fps \
--exp-name=exp_lora --overwrite --fsdp_devices 2
error message:
E1111 20:20:54.172290 207019 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INVALID_ARGUMENT: All buffers must have the same element type and count
E1111 20:20:54.172289 207016 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INVALID_ARGUMENT: All buffers must have the same element type and count
0%| | 0/100000 [00:06<?, ?it/s]
Traceback (most recent call last):
File "/home/yanan/robotics/openpi/scripts/train.py", line 286, in <module>
main(_config.cli())
File "/home/yanan/robotics/openpi/scripts/train.py", line 264, in main
train_state, info = ptrain_step(train_rng, train_state, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: All buffers must have the same element type and count: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
it get stuck at this point forever, while the gpu usage is
any solution ?