openpi
openpi copied to clipboard
Memory Exhausted Issue when Training with Pi0 Fine Tuned Lora
Hi,
I am trying to get the Pi0 Fine Tuned Lora model to run on my computer so I can train it further on my own dataset. I believe it is able to load the model weights but when it starts to run the first epoch of training I am getting an out of memory issue.
My system specs are as follows: OS: Ubuntu 20.04.6 LTS GPU: Nvidia RTX 4090 RAM: 64GB
Custom Config for Aloha Robot:
TrainConfig(
name="pi0_aloha_battery_place",
# model=pi0.Pi0Config(),
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
data=LeRobotAlohaDataConfig(
repo_id="spokkali/clothesdataset",
assets=AssetsConfig(
assets_dir="s3://openpi-assets/checkpoints/pi0_fast_base/assets",
asset_id="trossen",
),
default_prompt="place the battery in the open slot",
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
}
)
]
),
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
),
),
# weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=20,
freeze_filter=pi0_fast.Pi0FASTConfig(
paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
ema_decay=None,
),
Error Message:
2025-03-14 17:38:23.202537: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] **********************************************______________________________________________________
E0314 17:38:23.202632 50556 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 17512255040 bytes. [tf-allocator-allocation-error='']
0%| | 0/20 [00:12<?, ?it/s]
Traceback (most recent call last):
File "/home/test/work/openpi/openpi/scripts/train.py", line 273, in <module>
main(_config.cli())
File "/home/test/work/openpi/openpi/scripts/train.py", line 254, in main
train_state, info = ptrain_step(train_rng, train_state, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 17512255040 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these
I've tried with and without XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 as the prefix to the following command: XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_aloha_battery_place --exp-name=test_pi0 --overwrite
Any help would be greatly appreciated. Thanks!
struggle with same problem. i wonder training with rtx 3090 is available or is impossible.
@sanjaypokkali so for FAST, the recommendation is to use a shorter sequence length if it is memory constraint:
could you try setting model config to https://github.com/Physical-Intelligence/openpi/blob/main/src/openpi/training/config.py#L557-L561 ?
@mhko1998 rtx 3090 in theory should work with the same low memory finetune config (though we only tested on 4090, but memory capacity is the same and bfloat16 is supported on 3090). could you try using the same config and let me know if it works?
Hi, @haohuanw
I have same issue with rtx 4090. I run XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero_low_mem_finetune --exp-name=test --overwrite
https://github.com/Physical-Intelligence/openpi/blob/main/src/openpi/training/config.py#L557-L561 Do you mean to set a lower value for max_token_len?
Same question as above @haohuanw. Also since I am using the Dual Arm aloha setup are you suggesting I use the same action_history and max_token_len. Because if I want to use both the arms, my action len should be 14, not 7
Same here, even terminate ungracefully with no error code. Have no idea why :(
我也是,程序甚至会不正常地终止,而且没有错误代码。不知道为什么 :(
i meet the same question, do you deal it ?
I tried to lower the max_token_len to 180 and I got thrown this warning:
WARNING:root:Token length (208) exceeds max length (180), truncating. Consider increasing the max_token_len in your model config if this happens frequently.
WARNING:root:Token length (204) exceeds max length (180), truncating. Consider increasing the max_token_len in your model config if this happens frequently
I saw in the comments that this would mean that the input tokens would be cropped out. I tried to mess around with that value and I am noticing that I require a token length of 240 to not encounter this warning. But at 240 tokens I am getting a memory exceeded issue. Is there anything else you recommend I do @haohuanw?
Hi, @haohuanw
I have same issue with rtx 4090. I run
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero_low_mem_finetune --exp-name=test --overwritehttps://github.com/Physical-Intelligence/openpi/blob/main/src/openpi/training/config.py#L557-L561 Do you mean to set a lower value for max_token_len?
yes please try setting the lower token
I tried to lower the max_token_len to 180 and I got thrown this warning: WARNING:root:Token length (208) exceeds max length (180), truncating. Consider increasing the
max_token_lenin your model config if this happens frequently. WARNING:root:Token length (204) exceeds max length (180), truncating. Consider increasing themax_token_lenin your model config if this happens frequentlyI saw in the comments that this would mean that the input tokens would be cropped out. I tried to mess around with that value and I am noticing that I require a token length of 240 to not encounter this warning. But at 240 tokens I am getting a memory exceeded issue. Is there anything else you recommend I do @haohuanw?
yeah so options would be increase max num tokens (which means more memory will be needed) or zero pad the action if it doesn't always exceed the max token length.
The same issue occurred to me when I try to train pi0 model on LIBERO. I have reduced the batch_size to 16 or use multi-gpu and neither solve the issue.
I am using RTX A5000.
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4831838208 bytes.
-----------------------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
@jaelyontway when you are using multiple gpus, have you turned on fsdp? You can turn it on by setting this value to number of gpus you are using.
when you are train pi0 model, which config are you trying to use?
Hi @haohuanw
I've been working on fine-tuning pi0_libero using only the libero_spatial data. I've set my own repo_id, asset_dir, and enabled local_files_only inside of config.py.
I enabled fsdp in training/config.py for multi-gpu.
I tested multiple configurations. For example, with XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 and fsdp_devices: int=4, I tried batch_size values of 32, 16, and 4.
In all cases, each GPU reached ~ 90% memory before hitting an OOM error, consistenly showing: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes.
The only setup seems to work is:
XLA_PYTHON_CLIENT_MEM_FRACTION=false- fsdp_devices: 1, batch_size: 1 However, this config is extremely slow.
I also tried enabling gradient checkpointing as suggested online, but it did not help.
I also tried enabling gradient checkpointing as suggested online, but it did not help.
gradient checkpointing should already enabled by default: https://github.com/Physical-Intelligence/openpi/blob/main/src/openpi/models/gemma.py#L351
if possible could you share a full training log that i could take a look with fsdp?
if possible could you share a full training log that i could take a look with fsdp?
Of course! Thank you!
My command was XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_libero --exp-name=pi0_finetune_libero_spatial --overwrite and the configuration was a batch size of 32 and fsdp devices of 4.
The log was
Resolving data files: 100%|██████████████████████████████████████████████████████████| 432/432 [00:00<00:00, 9021.98it/s]
2025-05-09 21:39:18.586269: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746844758.604279 665494 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746844758.609459 665494 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746844758.622590 665494 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844758.622610 665494 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844758.622612 665494 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844758.622614 665494 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-05-09 21:39:26.685992: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746844766.702888 665678 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746844766.707974 665678 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746844766.720731 665678 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844766.720750 665678 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844766.720752 665678 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746844766.720754 665678 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
21:39:30.772 [I] Initialized data loader:
[0].images['base_0_rgb']: (32, 224, 224, 3)@float32
[0].images['left_wrist_0_rgb']: (32, 224, 224, 3)@float32
[0].images['right_wrist_0_rgb']: (32, 224, 224, 3)@float32
[0].image_masks['base_0_rgb']: (32,)@bool
[0].image_masks['left_wrist_0_rgb']: (32,)@bool
[0].image_masks['right_wrist_0_rgb']: (32,)@bool
[0].state: (32, 32)@float32
[0].tokenized_prompt: (32, 48)@int32
[0].tokenized_prompt_mask: (32, 48)@bool
[1]: (32, 50, 32)@float32 (664670:train.py:227)
21:39:31.344 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_0']['kernel'].value of shape (27, 1152, 4304) (510.68 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.344 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_1']['kernel'].value of shape (27, 4304, 1152) (510.68 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.345 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['key']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.345 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['out']['kernel'].value of shape (27, 16, 72, 1152) (136.69 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.345 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['query']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.345 [I] Sharding .params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['value']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.346 [I] Sharding .params['PaliGemma']['img']['head']['kernel'].value of shape (1152, 2048) (9.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.346 [I] Sharding .params['PaliGemma']['llm']['embedder']['input_embedding'].value of shape (257152, 2048) (2009.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.346 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['w'].value of shape (18, 8, 256, 2048) (288.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.346 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum_1']['w'].value of shape (18, 8, 256, 1024) (144.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.346 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['kv_einsum']['w'].value of shape (18, 2, 1, 2048, 256) (72.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['kv_einsum_1']['w'].value of shape (18, 2, 1, 1024, 256) (36.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['q_einsum']['w'].value of shape (18, 8, 2048, 256) (288.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['attn']['q_einsum_1']['w'].value of shape (18, 8, 1024, 256) (144.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['mlp']['gating_einsum'].value of shape (18, 2, 2048, 16384) (4608.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['mlp']['linear'].value of shape (18, 16384, 2048) (2304.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['mlp_1']['gating_einsum'].value of shape (18, 2, 1024, 4096) (576.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.347 [I] Sharding .params['PaliGemma']['llm']['layers']['mlp_1']['linear'].value of shape (18, 4096, 1024) (288.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.348 [I] Sharding .params['action_time_mlp_in']['kernel'].value of shape (2048, 1024) (8.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.348 [I] Sharding .params['action_time_mlp_out']['kernel'].value of shape (1024, 1024) (4.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.348 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_0']['kernel'].value of shape (27, 1152, 4304) (510.68 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.349 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_1']['kernel'].value of shape (27, 4304, 1152) (510.68 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.349 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['key']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.349 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['out']['kernel'].value of shape (27, 16, 72, 1152) (136.69 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.349 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['query']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.349 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['value']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.350 [I] Sharding .opt_state[1][0].mu['PaliGemma']['img']['head']['kernel'].value of shape (1152, 2048) (9.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.350 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['embedder']['input_embedding'].value of shape (257152, 2048) (2009.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.350 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['w'].value of shape (18, 8, 256, 2048) (288.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.350 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum_1']['w'].value of shape (18, 8, 256, 1024) (144.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.350 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['kv_einsum']['w'].value of shape (18, 2, 1, 2048, 256) (72.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['kv_einsum_1']['w'].value of shape (18, 2, 1, 1024, 256) (36.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['q_einsum']['w'].value of shape (18, 8, 2048, 256) (288.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['attn']['q_einsum_1']['w'].value of shape (18, 8, 1024, 256) (144.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['mlp']['gating_einsum'].value of shape (18, 2, 2048, 16384) (4608.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['mlp']['linear'].value of shape (18, 16384, 2048) (2304.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['mlp_1']['gating_einsum'].value of shape (18, 2, 1024, 4096) (576.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.351 [I] Sharding .opt_state[1][0].mu['PaliGemma']['llm']['layers']['mlp_1']['linear'].value of shape (18, 4096, 1024) (288.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.352 [I] Sharding .opt_state[1][0].mu['action_time_mlp_in']['kernel'].value of shape (2048, 1024) (8.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.352 [I] Sharding .opt_state[1][0].mu['action_time_mlp_out']['kernel'].value of shape (1024, 1024) (4.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.352 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_0']['kernel'].value of shape (27, 1152, 4304) (510.68 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.353 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_1']['kernel'].value of shape (27, 4304, 1152) (510.68 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.353 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['key']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.353 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['out']['kernel'].value of shape (27, 16, 72, 1152) (136.69 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.353 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['query']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.353 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['value']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.354 [I] Sharding .opt_state[1][0].nu['PaliGemma']['img']['head']['kernel'].value of shape (1152, 2048) (9.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.354 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['embedder']['input_embedding'].value of shape (257152, 2048) (2009.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.354 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['w'].value of shape (18, 8, 256, 2048) (288.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.354 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum_1']['w'].value of shape (18, 8, 256, 1024) (144.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.354 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['kv_einsum']['w'].value of shape (18, 2, 1, 2048, 256) (72.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['kv_einsum_1']['w'].value of shape (18, 2, 1, 1024, 256) (36.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['q_einsum']['w'].value of shape (18, 8, 2048, 256) (288.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['attn']['q_einsum_1']['w'].value of shape (18, 8, 1024, 256) (144.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['mlp']['gating_einsum'].value of shape (18, 2, 2048, 16384) (4608.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['mlp']['linear'].value of shape (18, 16384, 2048) (2304.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.355 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['mlp_1']['gating_einsum'].value of shape (18, 2, 1024, 4096) (576.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.356 [I] Sharding .opt_state[1][0].nu['PaliGemma']['llm']['layers']['mlp_1']['linear'].value of shape (18, 4096, 1024) (288.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.356 [I] Sharding .opt_state[1][0].nu['action_time_mlp_in']['kernel'].value of shape (2048, 1024) (8.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.356 [I] Sharding .opt_state[1][0].nu['action_time_mlp_out']['kernel'].value of shape (1024, 1024) (4.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.356 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_0']['kernel'].value of shape (27, 1152, 4304) (510.68 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.357 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MlpBlock_0']['Dense_1']['kernel'].value of shape (27, 4304, 1152) (510.68 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.357 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['key']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.357 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['out']['kernel'].value of shape (27, 16, 72, 1152) (136.69 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.357 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['query']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.358 [I] Sharding .ema_params['PaliGemma']['img']['Transformer']['encoderblock']['MultiHeadDotProductAttention_0']['value']['kernel'].value of shape (27, 1152, 16, 72) (136.69 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.358 [I] Sharding .ema_params['PaliGemma']['img']['head']['kernel'].value of shape (1152, 2048) (9.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.358 [I] Sharding .ema_params['PaliGemma']['llm']['embedder']['input_embedding'].value of shape (257152, 2048) (2009.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.358 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum']['w'].value of shape (18, 8, 256, 2048) (288.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.358 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['attn_vec_einsum_1']['w'].value of shape (18, 8, 256, 1024) (144.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['kv_einsum']['w'].value of shape (18, 2, 1, 2048, 256) (72.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['kv_einsum_1']['w'].value of shape (18, 2, 1, 1024, 256) (36.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['q_einsum']['w'].value of shape (18, 8, 2048, 256) (288.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['attn']['q_einsum_1']['w'].value of shape (18, 8, 1024, 256) (144.00 MiB) along axis 2 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['mlp']['gating_einsum'].value of shape (18, 2, 2048, 16384) (4608.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['mlp']['linear'].value of shape (18, 16384, 2048) (2304.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.359 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['mlp_1']['gating_einsum'].value of shape (18, 2, 1024, 4096) (576.00 MiB) along axis 3 (664670:sharding.py:89)
21:39:31.360 [I] Sharding .ema_params['PaliGemma']['llm']['layers']['mlp_1']['linear'].value of shape (18, 4096, 1024) (288.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.360 [I] Sharding .ema_params['action_time_mlp_in']['kernel'].value of shape (2048, 1024) (8.00 MiB) along axis 0 (664670:sharding.py:89)
21:39:31.360 [I] Sharding .ema_params['action_time_mlp_out']['kernel'].value of shape (1024, 1024) (4.00 MiB) along axis 1 (664670:sharding.py:89)
21:39:31.361 [I] Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=None (664670:base_pytree_checkpoint_handler.py:332)
21:39:31.398 [I] Restoring checkpoint from /home/mingyo/.cache/openpi/openpi-assets/checkpoints/pi0_base/params. (664670:checkpointer.py:256)
21:39:51.084 [I] [process=0] /jax/checkpoint/read/bytes_per_sec: 627.5 MiB/s (total bytes: 12.1 GiB) (time elapsed: 19 seconds) (per-host) (664670:base_pytree_checkpoint_handler.py:113)
21:39:51.086 [I] Finished restoring checkpoint from /home/mingyo/.cache/openpi/openpi-assets/checkpoints/pi0_base/params. (664670:checkpointer.py:259)
21:39:51.087 [I] [process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore (664670:multihost.py:293)
/mnt/data1/jaelyn/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py:1153: UserWarning: Some donated buffers were not usable: ShapedArray(float32[27,1152,4304]), ShapedArray(float32[27,4304,1152]), ShapedArray(float32[27,1152,16,72]), ShapedArray(float32[27,16,72,1152]), ShapedArray(float32[27,1152,16,72]), ShapedArray(float32[27,1152,16,72]), ShapedArray(float32[1152,2048]), ShapedArray(float32[257152,2048]), ShapedArray(float32[18,8,256,2048]), ShapedArray(float32[18,8,256,1024]), ShapedArray(float32[18,2,1,2048,256]), ShapedArray(float32[18,2,1,1024,256]), ShapedArray(float32[18,8,2048,256]), ShapedArray(float32[18,8,1024,256]), ShapedArray(float32[18,2,2048,16384]), ShapedArray(float32[18,16384,2048]), ShapedArray(float32[18,2,1024,4096]), ShapedArray(float32[18,4096,1024]), ShapedArray(float32[2048,1024]), ShapedArray(float32[1024,1024]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
2025-05-09 21:40:18.882376: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 144.00MiB (rounded to 150994944)requested by op
2025-05-09 21:40:18.882776: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************************************************************************************
E0509 21:40:18.882857 665155 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes. [tf-allocator-allocation-error='']
2025-05-09 21:40:19.222695: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_2_bfc) ran out of memory trying to allocate 144.00MiB (rounded to 150994944)requested by op
2025-05-09 21:40:19.223083: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************************************************************************************
E0509 21:40:19.223168 665161 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes. [tf-allocator-allocation-error='']
2025-05-09 21:40:19.262726: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_3_bfc) ran out of memory trying to allocate 144.00MiB (rounded to 150994944)requested by op
2025-05-09 21:40:19.263106: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************************************************************************************
E0509 21:40:19.263177 665164 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes. [tf-allocator-allocation-error='']
2025-05-09 21:40:19.264283: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_1_bfc) ran out of memory trying to allocate 144.00MiB (rounded to 150994944)requested by op
2025-05-09 21:40:19.264696: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************************************************************************************
E0509 21:40:19.264765 665158 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes. [tf-allocator-allocation-error='']
Traceback (most recent call last):
File "/mnt/data1/jaelyn/openpi/scripts/train.py", line 273, in <module>
main(_config.cli())
File "/mnt/data1/jaelyn/openpi/scripts/train.py", line 229, in main
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data1/jaelyn/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 559, in wrapped_fn
return wrapped_fn_impl(args, kwargs, bound, memos)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data1/jaelyn/openpi/.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 483, in wrapped_fn_impl
out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/mnt/data1/jaelyn/openpi/scripts/train.py", line 125, in init_train_state
train_state = jax.jit(
^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I am confronted with the same problem, and though i kill the process, it still take up too much memory. I don't know how to tackle with it. 😿
Hi @haohuanw
I've been working on fine-tuning pi0_libero using only the libero_spatial data. I've set my own repo_id, asset_dir, and enabled local_files_only inside of config.py.
I enabled fsdp in training/config.py for multi-gpu.
I tested multiple configurations. For example, with
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9andfsdp_devices: int=4, I tried batch_sizevalues of 32, 16, and 4. In all cases, each GPU reached ~ 90% memory before hitting an OOM error, consistenly showing:RESOURCE_EXHAUSTED: Out of memory while trying to allocate 150994944 bytes.The only setup seems to work is:
XLA_PYTHON_CLIENT_MEM_FRACTION=false
- fsdp_devices: 1, batch_size: 1 However, this config is extremely slow.
I also tried enabling gradient checkpointing as suggested online, but it did not help.
I having the same problem on 3090 and 4090 with 24GB of RAM. It only works when I run the fine-tuning on the V100. BTW, when you ran the command with XLA_PYTHON_CLIENT_MEM_FRACTION=false the training will happen on CPU, that's why it's super slow.
I was able to train on 4090 w/o any issues. Max batch_size I was able to use is 20 for a single GPU. Also, if you have multiple GPUs, data sharding (not FSDP) should work out of the box. I was able to train on 4x4090 with the batch size of 64.
I was able to train on 4090 w/o any issues. Max batch_size I was able to use is 20 for a single GPU. Also, if you have multiple GPUs, data sharding (not FSDP) should work out of the box. I was able to train on 4x4090 with the batch size of 64.
I was just able to launch the fine-tuning on the 4090. To make it work, I created my own training configuration and had to replace the --overwrite flag with --resume when I ran the train.py. After that, I was able to start the training with a batch size of 16. However, I'm not 100% sure of why it's working now.
@jaelyontway Hi, I’m experiencing the same issue. The code only runs properly when I set XLA_PYTHON_CLIENT_MEM_FRACTION=false. Have you found a solution to this problem?
Ran into the same issue. Setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 works for me. Fine tuned (LoRA) on single RTX 4090 with batch size 32.
Hi @bixie and @hhcaz Try run something like this CUDA_VISIBLE_DEVICES=3 XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_libero --exp-name=my_experiment --overwrite
Must run in one line. You can assign more gpus to CUDA_VISIBLE_DEVICES=____
Had the same problem when trying to fine-tune pi0-FAST on a 4090. Got it working by reducing the batch size to 24, reducing max_token_len to 90 (which works for me without warnings apparently), and run with XLA_PYTHON_CLIENT_MEM_FRACTION=0.95.