openpi
openpi copied to clipboard
CUDA out of memory when using pytorch to finetune in RTX4090
I use the pytorch to finetune the PI0_base.
the config I set shows below:
TrainConfig(
name="test_pi0",
model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
data=LeRobotR1ProDataConfig(
repo_id="test/apple",
base_config=DataConfig(prompt_from_task=True),
extra_delta_transform=True,
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=20_000,
freeze_filter=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora").get_freeze_filter(),
ema_decay=None,
num_workers=1, batch_size=1, fsdp_devices=1, # for single one RTX4090
pytorch_weight_path="/home/test/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch"
),
but the log shows CUDA out of memory:
rist_image': 'observation.images.right_wrist_rgb', 'observation/state': 'state', 'actions': 'actions', 'prompt': 'prompt'})], outputs=()), data_transforms=Group(inputs=(R1ProInputs(model_type=<ModelType.PI0: 'pi0'>), DeltaActions(mask=(True, True, True, True, True, True, True, False))), outputs=(AbsoluteActions(mask=(True, True, True, True, True, True, True, False)), R1ProOutputs())), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7b9f71ccf1d0>, discrete_state_input=False), PadStatesAndActions(model_action_dim=32)], outputs=()), use_quantile_norm=False, action_sequence_keys=('actions',), prompt_from_task=True, rlds_data_dir=None, behavior_dataset_root=None, action_space=None, filter_dict_path=None, episodes_index=None) (55077:data_loader.py:268)
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 940927.51it/s]
09:49:01.746 [I] local_batch_size: 1 (55077:data_loader.py:374)
INFO:2025-09-19 09:49:01,817:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
09:49:01.817 [I] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' (55077:xla_bridge.py:925)
INFO:2025-09-19 09:49:01,818:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
09:49:01.818 [I] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory (55077:xla_bridge.py:925)
09:49:29.395 [I] Enabled gradient checkpointing for PI0Pytorch model (55077:pi0_pytorch.py:133)
09:49:29.395 [I] Enabled gradient checkpointing for memory optimization (55077:train_pytorch.py:414)
09:49:29.396 [I] Step 0 (after_model_creation): GPU memory - allocated: 7.02GB, reserved: 7.09GB, free: 0.07GB, peak_allocated: 7.02GB, peak_reserved: 7.09GB (55077:train_pytorch.py:304)
09:49:29.396 [I] Loading weights from: /home/wzj/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch (55077:train_pytorch.py:443)
09:49:30.258 [I] Loaded PyTorch weights from /home/wzj/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch (55077:train_pytorch.py:449)
09:49:30.261 [I] Running on: wzj | world_size=1 (55077:train_pytorch.py:486)
09:49:30.261 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20000 (55077:train_pytorch.py:489)
09:49:30.261 [I] Memory optimizations: gradient_checkpointing=True (55077:train_pytorch.py:492)
09:49:30.261 [I] LR schedule: warmup=1000, peak_lr=2.50e-05, decay_steps=30000, end_lr=2.50e-06 (55077:train_pytorch.py:493)
09:49:30.261 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (55077:train_pytorch.py:496)
09:49:30.261 [I] EMA is not supported for PyTorch training (55077:train_pytorch.py:499)
09:49:30.261 [I] Training precision: bfloat16 (55077:train_pytorch.py:500)
Training: 0%| | 0/20000 [00:00<?, ?it/s]09:50:05.882 [I] Step 0 (after_backward): GPU memory - allocated: 13.34GB, reserved: 13.55GB, free: 0.21GB, peak_allocated: 13.63GB, peak_reserved: 13.81GB (55077:train_pytorch.py:304)
Traceback (most recent call last):
File "/home/wzj/vla/pi_b1k/openpi_private/scripts/train_pytorch.py", line 632, in
it was OK when I used jax to train. why?