torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

OOM recovery under multi-node FSDP/HSDP

Open vadimkantorov opened this issue 5 months ago • 8 comments

Bug description

Does torchtitan provide any recipes of how to implement batch skipping / OOM recovery in multi-node FSDP setup?

In RL/GRPO training this is very pertinent (where we don't know response seqlens a-priori to do packing / clipping):

  • https://github.com/volcengine/verl/issues/2159

One thing I could think of:

  • some sort of micro-batching for backward pass
  • some generic batch skipping

Some sort of memory operation tracing would also be very useful to better know what is the reason of OOM (fragmentation):

  • https://github.com/pytorch/pytorch/issues/91692#issuecomment-2996838221

Versions

N/A

vadimkantorov avatar Jun 23 '25 16:06 vadimkantorov