torchtitan
torchtitan copied to clipboard
OOM recovery under multi-node FSDP/HSDP
Bug description
Does torchtitan provide any recipes of how to implement batch skipping / OOM recovery in multi-node FSDP setup?
In RL/GRPO training this is very pertinent (where we don't know response seqlens a-priori to do packing / clipping):
- https://github.com/volcengine/verl/issues/2159
One thing I could think of:
- some sort of micro-batching for backward pass
- some generic batch skipping
Some sort of memory operation tracing would also be very useful to better know what is the reason of OOM (fragmentation):
- https://github.com/pytorch/pytorch/issues/91692#issuecomment-2996838221
Versions
N/A