Checkpointer memory leak
TLDR; version 0.11.8 works, latest is leaking (haven't tested with versions in-between).
I've originally opened an issue about this at https://github.com/google-deepmind/gemma/issues/354 but nailed it down to the current latest orbax-checkpoint (v0.11.19):
Running lora.py:
python -m kauldron.main --cfg=lora.py --cfg.workdir=/tmp/ckpt_lora
Prints out warnings like
replica_slices.py:419] Transferring arrays to host memory with options: use_replica_parallel=True, enable_pinned_host_transfer=True
W external/xla/xla/stream_executor/integrations/stream_executor_allocator.cc:66] could not allocate pinned host of size: 4294967296
...
RAM (not VRAM) grows rapidly and indefinitely.
Disabling checkpointer in the config eliminates the leak.
Versions: gemma: 3.0.2 jax: 0.6.2 orbax-checkpoint: 0.11.19 kauldron: 1.2.2 NVIDIA Driver Version: 570.124.06 CUDA Version: 12.8
After doing a quick search through issues over here, I saw #1713 refers to use_replica_parallel which I then found in the CHANGELOG re-enabled in version 0.11.9. I decided to install 0.11.8 instead and re-ran the experiment, voila! checkpointing works normally now.
I also noticed that when training GEMMA3_1B_IT, the printed message say use_replica_parallel=False, enable_pinned_host_transfer=False (both are False) so seems the issue is related to one/both of these settings.
@gspschmid Could you PTAL? enable_pinned_host_transfer is only set to True for GPUs at the moment.
Both of these settings are useful performance-wise, but not critical. You can disable them if you'd prefer to be at a newer version.
Hello @sirfz , could you provide a couple more details about the exact configuration and hardware you were running? Specifically:
- How much host-side RAM did your machine have?
- Which GPU were you using, and with how much VRAM?
- Which model were you finetuning (
lora.pyspecifies the model to be Gemma3_4B, but your GitHub issue also mentions GEMMA3_1B_IT)?
I tried to reproduce this issue on a DGX 100 machine with 1 TiB of RAM and 8 80G A100 GPUs.
- A run using all 8 GPUs and Gemma3_4B finishes without any errors.
- When using just 1 GPU, I get a device OOM error (not host).
- Reducing the batch size to 2 allows the job to complete without any errors. This seems to be because when using > 1 GPU, training is run in a data-parallel manner, and the batch size parameter in
lora.pyspecifies the global batch size across all data-parallel ranks.
- Reducing the batch size to 2 allows the job to complete without any errors. This seems to be because when using > 1 GPU, training is run in a data-parallel manner, and the batch size parameter in
- I saw this same behavior regardless of if I ran with orbax 0.11.19 or 0.11.8, and also regardless of what
use_replica_parallelandenable_pinned_host_transferwere set to.
Regarding host-side memory usage, frequent checkpointing seems to cause more volatility in memory usage, but evidence of an unbounded leak / increase is somewhat unclear.
-
I monitored RAM usage as I ran the program for ~1.5 hours. When checkpointing every 100 steps, there seems to be higher memory pressure and usage trends upwards, but the usage seems to remain bounded below ~37 GiB. See this plot (the vertical red lines correspond to when checkpoints are saved):
-
With checkpointing disabled, memory usage is much lower (~10 GiB) and much more stable.
-
This behavior again seems consistent regardless of Orbax version or
enable_pinned_host_transfer/use_replica_parallel.
Lastly, it seems unlikely that pinned host memory allocations are specifically the issue. When tracking the behavior of the pinned host memory allocator, it seems like the only allocations occur during the very first checkpoint save and training step. The same pinned host memory buffers are reused for subsequent checkpoints, so it does not appear to be the case that the amount of pinned host memory allocated is growing arbitrarily.
In summary, we do see some increased volatility in host memory usage with frequent checkpointing. Some more details about the model you were finetuning and your hardware could help us better understand the issue / try to come up with a more accurate reproduction. Thanks for your help!
Thank you for the thorough response. I've moved on from this project since posting this issue but here's the info I have:
- GCP instance a2-highgpu-1g
- 12 vCPUs with 85GB RAM
- 1 x NVIDIA A100 40GB
GEMMA_1B has no issues and runs normally with both orbax versions mentioned in my OP. however when I switched to 4B, the issue happened consistently every time even on different instances (of the same type). Only downgrading resolved the problem.
I observed this on my custom experiment (pretty much vanilla config just a different dataset) but the vanilla config I linked to from the gemma repo reproduced the problem for me.
Edit: this happens on the initial checkpoint before training even starts so it's not an issue with frequency of checkpoints. For me it would just kill the machine and never take off.
Thanks for the info; unfortunately I haven't been able to reproduce the issue. Rerunning with Orbax 0.11.19 on a machine with a 40GB A100 and 125 GiB of host RAM gave me similar results to those I described above. I see a device out-of-memory error with Gemma 4B and batch size 16 on the first training step, and training / checkpointing progresses normally (past the first training step) if I reduce the batch size to 2. There are some spikes in host memory usage, but no clear signs of a memory leak. Pinned host memory allocations seem to be behaving as expected as well (allocations only occur on the first training step and checkpoint save).
If anyone else encounters this issue, perhaps trying a lower batch size could be a good start. Thanks!
Good to hear, I'll report back if I get to test this again at some point. Thank you for taking the time to debug this.