openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Error occurred while training on two graphics cards

Open sunmoon2018 opened this issue 1 month ago • 0 comments

I was training on two graphics cards(RTX3090,24G; RTX3070, 8G) and encountered the following error:

.... Loading dataset shards: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 1692.09it/s] 16:21:36.928 [I] local_batch_size: 2 (21808:data_loader.py:324) Traceback (most recent call last): File "/home/wm/code_of_Ken/openpi_all/openpi/scripts/train.py", line 280, in main(_config.cli()) File "/home/wm/code_of_Ken/openpi_all/openpi/scripts/train.py", line 226, in main batch = next(data_iter) ^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/src/openpi/training/data_loader.py", line 540, in iter yield _model.Observation.from_dict(batch), batch["actions"] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/src/openpi/models/model.py", line 118, in from_dict data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 122, in _astype return lax_numpy.astype(self, dtype, copy=copy, device=device) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5636, in astype result = lax_internal._convert_element_type( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1614, in _convert_element_type return convert_element_type_p.bind( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind return self._true_bind(*args, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind return self.bind_with_trace(prev_trace, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4701, in _convert_element_type_bind_with_trace operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace return trace.process_primitive(self, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 1029, in process_primitive return primitive.impl(*args, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 88, in apply_primitive outs = fun(*args) ^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: executable is built for device CUDA:0 of type "NVIDIA GeForce RTX 3090"; cannot run it on device CUDA:1 of type "NVIDIA GeForce RTX 3070"

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. wandb: wandb: 🚀 View run my_pi0_aloha_fold_t_shirt at: https://wandb.ai/knighthy/openpi/runs/w7sm0nhj wandb: Find logs at: wandb/run-20251028_162134-w7sm0nhj/logs

Have you encountered this kind of problem before? Looking forward to your reply!

sunmoon2018 avatar Oct 28 '25 08:10 sunmoon2018