openpi icon indicating copy to clipboard operation
openpi copied to clipboard

CUDA out of memory when using pytorch to finetune in RTX4090

Open ZijianWu1121 opened this issue 2 months ago • 10 comments

I use the pytorch to finetune the PI0_base.

the config I set shows below:

TrainConfig(
    name="test_pi0",
    model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
    data=LeRobotR1ProDataConfig(
        repo_id="test/apple",
        base_config=DataConfig(prompt_from_task=True),
        extra_delta_transform=True,
    ),
    weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
    num_train_steps=20_000,
    freeze_filter=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora").get_freeze_filter(),
    ema_decay=None,
    num_workers=1, batch_size=1, fsdp_devices=1, # for single one RTX4090
    pytorch_weight_path="/home/test/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch"
),

but the log shows CUDA out of memory:

rist_image': 'observation.images.right_wrist_rgb', 'observation/state': 'state', 'actions': 'actions', 'prompt': 'prompt'})], outputs=()), data_transforms=Group(inputs=(R1ProInputs(model_type=<ModelType.PI0: 'pi0'>), DeltaActions(mask=(True, True, True, True, True, True, True, False))), outputs=(AbsoluteActions(mask=(True, True, True, True, True, True, True, False)), R1ProOutputs())), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7b9f71ccf1d0>, discrete_state_input=False), PadStatesAndActions(model_action_dim=32)], outputs=()), use_quantile_norm=False, action_sequence_keys=('actions',), prompt_from_task=True, rlds_data_dir=None, behavior_dataset_root=None, action_space=None, filter_dict_path=None, episodes_index=None) (55077:data_loader.py:268) Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 940927.51it/s] 09:49:01.746 [I] local_batch_size: 1 (55077:data_loader.py:374) INFO:2025-09-19 09:49:01,817:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' 09:49:01.817 [I] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' (55077:xla_bridge.py:925) INFO:2025-09-19 09:49:01,818:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 09:49:01.818 [I] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory (55077:xla_bridge.py:925) 09:49:29.395 [I] Enabled gradient checkpointing for PI0Pytorch model (55077:pi0_pytorch.py:133) 09:49:29.395 [I] Enabled gradient checkpointing for memory optimization (55077:train_pytorch.py:414) 09:49:29.396 [I] Step 0 (after_model_creation): GPU memory - allocated: 7.02GB, reserved: 7.09GB, free: 0.07GB, peak_allocated: 7.02GB, peak_reserved: 7.09GB (55077:train_pytorch.py:304) 09:49:29.396 [I] Loading weights from: /home/wzj/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch (55077:train_pytorch.py:443) 09:49:30.258 [I] Loaded PyTorch weights from /home/wzj/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch (55077:train_pytorch.py:449) 09:49:30.261 [I] Running on: wzj | world_size=1 (55077:train_pytorch.py:486) 09:49:30.261 [I] Training config: batch_size=1, effective_batch_size=1, num_train_steps=20000 (55077:train_pytorch.py:489) 09:49:30.261 [I] Memory optimizations: gradient_checkpointing=True (55077:train_pytorch.py:492) 09:49:30.261 [I] LR schedule: warmup=1000, peak_lr=2.50e-05, decay_steps=30000, end_lr=2.50e-06 (55077:train_pytorch.py:493) 09:49:30.261 [I] Optimizer: AdamW, weight_decay=1e-10, clip_norm=1.0 (55077:train_pytorch.py:496) 09:49:30.261 [I] EMA is not supported for PyTorch training (55077:train_pytorch.py:499) 09:49:30.261 [I] Training precision: bfloat16 (55077:train_pytorch.py:500) Training: 0%| | 0/20000 [00:00<?, ?it/s]09:50:05.882 [I] Step 0 (after_backward): GPU memory - allocated: 13.34GB, reserved: 13.55GB, free: 0.21GB, peak_allocated: 13.63GB, peak_reserved: 13.81GB (55077:train_pytorch.py:304) Traceback (most recent call last): File "/home/wzj/vla/pi_b1k/openpi_private/scripts/train_pytorch.py", line 632, in main() File "/home/wzj/vla/pi_b1k/openpi_private/scripts/train_pytorch.py", line 628, in main train_loop(config) File "/home/wzj/vla/pi_b1k/openpi_private/scripts/train_pytorch.py", line 549, in train_loop optim.step() File "/home/wzj/vla/pi_b1k/openpi_private/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 485, in wrapper out = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/wzj/vla/pi_b1k/openpi_private/.venv/lib/python3.11/site-packages/torch/optim/optimizer.py", line 79, in _use_grad ret = func(self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/wzj/vla/pi_b1k/openpi_private/.venv/lib/python3.11/site-packages/torch/optim/adam.py", line 236, in step has_complex = self._init_group( ^^^^^^^^^^^^^^^^^ File "/home/wzj/vla/pi_b1k/openpi_private/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/wzj/vla/pi_b1k/openpi_private/.venv/lib/python3.11/site-packages/torch/optim/adam.py", line 176, in _init_group state["exp_avg"] = torch.zeros_like( ^^^^^^^^^^^^^^^^^ torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacity of 23.63 GiB of which 226.56 MiB is free. Including non-PyTorch memory, this process has 22.77 GiB memory in use. Of the allocated memory 21.98 GiB is allocated by PyTorch, and 320.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

it was OK when I used jax to train. why?

ZijianWu1121 avatar Sep 19 '25 01:09 ZijianWu1121