openpi
openpi copied to clipboard
Error occurred while training on two graphics cards
I was training on two graphics cards(RTX3090,24G; RTX3070, 8G) and encountered the following error:
....
Loading dataset shards: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 1692.09it/s]
16:21:36.928 [I] local_batch_size: 2 (21808:data_loader.py:324)
Traceback (most recent call last):
File "/home/wm/code_of_Ken/openpi_all/openpi/scripts/train.py", line 280, in
main(_config.cli())
File "/home/wm/code_of_Ken/openpi_all/openpi/scripts/train.py", line 226, in main
batch = next(data_iter)
^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/src/openpi/training/data_loader.py", line 540, in iter
yield _model.Observation.from_dict(batch), batch["actions"]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/src/openpi/models/model.py", line 118, in from_dict
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 122, in _astype
return lax_numpy.astype(self, dtype, copy=copy, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5636, in astype
result = lax_internal._convert_element_type(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1614, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4701, in _convert_element_type_bind_with_trace
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 1029, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wm/code_of_Ken/openpi_all/openpi/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 88, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: executable is built for device CUDA:0 of type "NVIDIA GeForce RTX 3090"; cannot run it on device CUDA:1 of type "NVIDIA GeForce RTX 3070"
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. wandb: wandb: 🚀 View run my_pi0_aloha_fold_t_shirt at: https://wandb.ai/knighthy/openpi/runs/w7sm0nhj wandb: Find logs at: wandb/run-20251028_162134-w7sm0nhj/logs
Have you encountered this kind of problem before? Looking forward to your reply!