composer
composer copied to clipboard
Supporting automicrobatching on FSDP2
trafficstars
Draft PR for supporting automicrobatching on FSDP2
This isn't added yet because we ran into some hiccups with how FSDP2 handles state transitions. As FSDP2 is stateful, it expects the program to stop when training runs into an OOM issue. Since we just restart training with a reduced microbatch size, torch.distributed.fsdp._fully_shard._fsdp_common.TrainingState can be in an illegal state which can cause hangs/unexpected errors. Until we have a clear API for something like this, it would be finnicky to mess around with this state.