openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Out of memory issue using RTXA6000

Open dragonbobo-no3 opened this issue 4 months ago • 3 comments

Hi, I encounter this issue when I test the Pi0 and Pi0_fast training with the dataset I collected in reality. Running code XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py

2025-07-18 11:07:06.132422: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 510.68MiB (rounded to 535486464)requested by op 2025-07-18 11:07:06.132630: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] **************************************************************************************************** E0718 11:07:06.132657 193878 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 535486464 bytes. [tf-allocator-allocation-error=''] Traceback (most recent call last): File "/home/agx/price/openpi/scripts/train_test.py", line 36, in test_train("/home/agx/jemodel/test/", "pi0_agileX") File "/home/agx/price/openpi/scripts/train_test.py", line 29, in test_train train.main(config) File "/home/agx/price/openpi/scripts/train.py", line 244, in main train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/agx/price/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 559, in wrapped_fn return wrapped_fn_impl(args, kwargs, bound, memos) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/agx/price/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 483, in wrapped_fn_impl out = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/agx/price/openpi/scripts/train.py", line 131, in init_train_state train_state = jax.jit( ^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 535486464 bytes.

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I use TrainConfig as below. I tried pi0.Pi0Config(),pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") and pi0_fast.Pi0FASTConfig() but all of them are facing the same issue

TrainConfig( name="pi0_agileX", # model=pi0.Pi0Config(), model = pi0_fast.Pi0FASTConfig( action_dim=7, # 6 joints + 1 gripper actions action_horizon=25, max_token_len=128, ), # model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), data=LeRobotAgileXDataConfig( assets=AssetsConfig(assets_dir="/home/agx/jedata/test_0711a"), default_prompt="pick up the cicular chip and place it on the yellow pot" ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, wandb_enabled=False, num_train_steps=100_000, batch_size=4, log_interval=100, save_interval=5000, keep_period=20_000, num_workers=4, fsdp_devices=1, ),

The GPU usage before the error is | 0 NVIDIA RTX A6000 Off | 00000000:17:00.0 On | Off | | 30% 39C P2 49W / 300W | 38077MiB / 49140MiB | 4% Default | | | | N/A | +-------------

I have no idea how to solve this. Please give me some advice. Appreciate!!

dragonbobo-no3 avatar Jul 18 '25 03:07 dragonbobo-no3