composer
composer copied to clipboard
[WIP] Auto-microbatch fix
What does this PR do?
Fix the auto-microbatch. Before this change, composer added sync_hook to module.register_forward_hook and module.register_full_backward_hook. those hooks are triggered AFTER forward and backward of the original module (not the fsdp wrapper)
issue with previous solution:
let's say the model forward like this:
fsdp_module_0 -> fsdp_module_1 -> fsdp_module_2
if the oom happens on rank 0, right in the middle of fsdp_module_0 and fsdp_module_1. Rank 0 starts this allReduce. Rank 1 will continue run fdsp_module_1, which starts the all_gather. This caused mismatch (rank 0 allReduce vs rank 1 allGather)
fix
So the fix is easy, we just add the hook to pre-foward and pre-backward. So it will do the oom detection before any fsdp allGather, instead of after fsdp allGather.
test
unit test
python -m composer.cli.launcher -n 2 -m pytest -m gpu tests/trainer/test_fsdp.py -k test_fsdp_auto_microbatch