openpi icon indicating copy to clipboard operation
openpi copied to clipboard

When I ran train.py in OpenPI, I got the following error:

Open satokeeen opened this issue 2 months ago • 2 comments

Environment: Ubuntu 24.04, RTX5090, CUDA 12.6.3

When I run

XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=my_experiment --overwrite

I get the following error: XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=my_experiment --overwrite warning: The tool.uv.dev-dependencies field (used in packages/openpi-client/pyproject.toml) is deprecated and will be removed in a future release; use dependency-groups.dev instead /home/openpi/.venv/lib/python3.11/site-packages/tyro/_parsers.py:332: UserWarning: The field model.action-expert-variant is annotated with type typing.Literal['dummy', 'gemma_300m', 'gemma_2b', 'gemma_2b_lora'], but the default value gemma_300m_lora has type <class 'str'>. We'll try to handle this gracefully, but it may cause unexpected behavior. warnings.warn(message) 11:53:51.620 [I] Running on: su-station-05 (20721:train.py:195) INFO:2025-09-20 11:53:51,899:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' 11:53:51.899 [I] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' (20721:xla_bridge.py:945) INFO:2025-09-20 11:53:51,900:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 11:53:51.900 [I] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory (20721:xla_bridge.py:945) 2025-09-20 11:53:52.207761: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:237] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 12.0 2025-09-20 11:53:52.207773: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:240] Used ptxas at /usr/local/cuda/bin/ptxas Traceback (most recent call last): File "/home/openpi/scripts/train.py", line 273, in main(_config.cli()) File "/home/openpi/scripts/train.py", line 204, in main rng = jax.random.key(config.seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 218, in key return _key('key', seed, impl) ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key return prng.random_seed(seed, impl=impl) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 534, in random_seed return random_seed_p.bind(seeds_arr, impl=impl) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind return self.bind_with_trace(prev_trace, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace return trace.process_primitive(self, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 954, in process_primitive return primitive.impl(*args, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 546, in random_seed_impl base_arr = random_seed_impl_base(seeds, impl=impl) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 551, in random_seed_impl_base return seed(seeds) ^^^^^^^^^^^ File "/home/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 767, in threefry_seed return _threefry_seed(seed) ^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: /usr/local/cuda/bin/ptxas ptxas too old. Falling back to the driver to compile. -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I updated ptxas to the latest version, but the same error still occurs. I am using Docker, and the PATH is not a problem. I would appreciate it if anyone knows how to solve this.

satokeeen avatar Sep 20 '25 12:09 satokeeen