openpi
openpi copied to clipboard
Out of memory issue using RTXA6000
Hi, I encounter this issue when I test the Pi0 and Pi0_fast training with the dataset I collected in reality. Running code XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py
2025-07-18 11:07:06.132422: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 510.68MiB (rounded to 535486464)requested by op
2025-07-18 11:07:06.132630: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************************************************************************************
E0718 11:07:06.132657 193878 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 535486464 bytes. [tf-allocator-allocation-error='']
Traceback (most recent call last):
File "/home/agx/price/openpi/scripts/train_test.py", line 36, in
test_train("/home/agx/jemodel/test/", "pi0_agileX")
File "/home/agx/price/openpi/scripts/train_test.py", line 29, in test_train
train.main(config)
File "/home/agx/price/openpi/scripts/train.py", line 244, in main
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/agx/price/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 559, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/agx/price/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 483, in wrapped_fn_impl
out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/agx/price/openpi/scripts/train.py", line 131, in init_train_state
train_state = jax.jit(
^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 535486464 bytes.
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I use TrainConfig as below. I tried pi0.Pi0Config(),pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") and pi0_fast.Pi0FASTConfig() but all of them are facing the same issue
TrainConfig( name="pi0_agileX", # model=pi0.Pi0Config(), model = pi0_fast.Pi0FASTConfig( action_dim=7, # 6 joints + 1 gripper actions action_horizon=25, max_token_len=128, ), # model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), data=LeRobotAgileXDataConfig( assets=AssetsConfig(assets_dir="/home/agx/jedata/test_0711a"), default_prompt="pick up the cicular chip and place it on the yellow pot" ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, wandb_enabled=False, num_train_steps=100_000, batch_size=4, log_interval=100, save_interval=5000, keep_period=20_000, num_workers=4, fsdp_devices=1, ),
The GPU usage before the error is | 0 NVIDIA RTX A6000 Off | 00000000:17:00.0 On | Off | | 30% 39C P2 49W / 300W | 38077MiB / 49140MiB | 4% Default | | | | N/A | +-------------
I have no idea how to solve this. Please give me some advice. Appreciate!!