torchtune
torchtune copied to clipboard
OOM handling and recovery
We just hit OOM, revealing that by default torchtune does not use torch.compile and that it does not use fused linear cross entropy yet...
I found the following report from 2024:
- https://www.reddit.com/r/LocalLLaMA/comments/1di0fhv/torchtune_vs_axolotl_vs_unsloth_trainer/
- https://wandb.ai/augmxnt/train-bench/reports/Trainer-performance-comparison-torchtune-vs-axolotl-vs-Unsloth---Vmlldzo4MzU3NTAx
Are there any plans to make torchtune excellent for peak GPU memory usage and practical OOM handling?
For fine-tuning on long chain-of-thought's, OOM coming from some long not-filtered-out examples is an issue.
Could torchtune somehow make it OOM-crash-safe out-of-the-box? E.g. robustly skip batches in runtime if OOM occurs once in a while? Or aggressively cpu-offload some huge activation tensors, when OOM hits and re-try after off-loading...
File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 919, in train
[rank6]: current_loss = self._loss_step(batch) * current_num_tokens
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 822, in _loss_step
[rank6]: loss = self._loss_fn(outputs, labels)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank6]: return self._call_impl(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank6]: return forward_call(*args, **kwargs)
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 136, in forward
[rank6]: total_loss += self.compute_cross_entropy(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 105, in compute_cross
_entropy
[rank6]: return F.cross_entropy(
[rank6]: ^^^^^^^^^^^^^^^^
[rank6]: File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/functional.py", line 3494, in cross_entropy
[rank6]: return torch._C._nn.cross_entropy_loss(
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.15 GiB. GPU 6 has a total capacity of 79.10 GiB of which 3.79 GiB is
free. Including non-PyTorch memory, this process has 75.30 GiB memory in use. Of the allocated memory 58.28 GiB is allocated by PyTorch, and 1
5.06 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_se
gments:True to avoid fragmentation. See documentation for Memory Management