torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

OOM handling and recovery

Open vadimkantorov opened this issue 5 months ago • 9 comments

We just hit OOM, revealing that by default torchtune does not use torch.compile and that it does not use fused linear cross entropy yet...

I found the following report from 2024:

  • https://www.reddit.com/r/LocalLLaMA/comments/1di0fhv/torchtune_vs_axolotl_vs_unsloth_trainer/
  • https://wandb.ai/augmxnt/train-bench/reports/Trainer-performance-comparison-torchtune-vs-axolotl-vs-Unsloth---Vmlldzo4MzU3NTAx

Are there any plans to make torchtune excellent for peak GPU memory usage and practical OOM handling?

For fine-tuning on long chain-of-thought's, OOM coming from some long not-filtered-out examples is an issue.

Could torchtune somehow make it OOM-crash-safe out-of-the-box? E.g. robustly skip batches in runtime if OOM occurs once in a while? Or aggressively cpu-offload some huge activation tensors, when OOM hits and re-try after off-loading...

File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 919, in train                
[rank6]:     current_loss = self._loss_step(batch) * current_num_tokens                                                                       
[rank6]:                    ^^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/recipes/full_finetune_distributed.py", line 822, in _loss_step           
[rank6]:     loss = self._loss_fn(outputs, labels)                                                                                            
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                            
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl            
[rank6]:     return self._call_impl(*args, **kwargs)                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                          
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl                    
[rank6]:     return forward_call(*args, **kwargs)                                                                                             
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                             
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 136, in forward      
[rank6]:     total_loss += self.compute_cross_entropy(                                                                                        
[rank6]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                        
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torchtune/modules/loss/cross_entropy_loss.py", line 105, in compute_cross
_entropy                                                                                                                                      
[rank6]:     return F.cross_entropy(                                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^                                                                                                          
[rank6]:   File "/mnt/fs/venv_torchtune/lib/python3.12/site-packages/torch/nn/functional.py", line 3494, in cross_entropy                     
[rank6]:     return torch._C._nn.cross_entropy_loss(                                                                                          
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                          
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.15 GiB. GPU 6 has a total capacity of 79.10 GiB of which 3.79 GiB is 
free. Including non-PyTorch memory, this process has 75.30 GiB memory in use. Of the allocated memory 58.28 GiB is allocated by PyTorch, and 1
5.06 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_se
gments:True to avoid fragmentation.  See documentation for Memory Management

vadimkantorov avatar Jun 16 '25 13:06 vadimkantorov