openpi
openpi copied to clipboard
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
When I use multi-GPU training, the following error occurs:“Traceback (most recent call last):
File "/liury/src/openpi/scripts/train.py", line 304, in
main(_config.cli())
File "/liury/src/openpi/scripts/train.py", line 285, in main
train_state, info = ptrain_step(train_rng, train_state, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6523): 'status': while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.”