composer
composer copied to clipboard
Add torch distributed checkpointing monkeypatches to enable TE checkpointing for extra_state attribute
trafficstars
What does this PR do?
Add torch distributed checkpointing monkeypatches to enable TE checkpointing for extra_state attribute. Patches the internal torch.distributed.state_dict functions:
state_dict._get_fqns = _get_fqns
state_dict._verify_options = _verify_options
state_dict._get_model_state_dict = _get_model_state_dict
state_dict._load_model_state_dict = _load_model_state_dict
What issue(s) does this change relate to?
Related to: https://github.com/pytorch/pytorch/pull/125336 and https://github.com/pytorch/pytorch/issues/122946
Test run:
(base) ➜ ~ mcli logs -f mpt-125m-te-ckpt-resumption-torch-monkeypatch-zcyRoH | rg loss
Train loss/train/total: 11.8668
Train loss/train/total: 11.8797
Train loss/train/total: 11.8777
Train loss/train/total: 11.4828
Train loss/train/total: 10.7657
Train loss/train/total: 10.4124
Train loss/train/total: 10.3823
Train loss/train/total: 10.1670
Train loss/train/total: 9.9711
Train loss/train/total: 9.7310
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
Train loss/train/total: 9.7091
Train loss/train/total: 9.5678
Train loss/train/total: 9.4688
Train loss/train/total: 9.4365
Train loss/train/total: 9.2994
Train loss/train/total: 9.3235
Train loss/train/total: 9.2855
Train loss/train/total: 9.1910
Train loss/train/total: 9.0667
Train loss/train/total: 9.0421
wandb: loss/train/total ███▇▅▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁
wandb: loss/train/total 9.04214
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
Train loss/train/total: 8.9817
Train loss/train/total: 8.9092
Train loss/train/total: 8.8225
Train loss/train/total: 8.7351
Train loss/train/total: 8.6218
Train loss/train/total: 8.6051
Train loss/train/total: 8.4849
Train loss/train/total: 8.3797
Train loss/train/total: 8.3112
Train loss/train/total: 8.2576
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
Train loss/train/total: 8.1670
Train loss/train/total: 8.0978
Train loss/train/total: 8.0123
Train loss/train/total: 7.9702
Train loss/train/total: 7.8903
Train loss/train/total: 7.8319
Train loss/train/total: 7.7992
Train loss/train/total: 7.7377
Train loss/train/total: 7.7233
Train loss/train/total: 7.6725
wandb: loss/train/total ██▇▇▆▆▅▅▄▄▄▃▃▃▂▂▂▁▁▁
wandb: loss/train/total 7.67255
Loss continues to go down monotonically after checkpoint and resumption at timestamp 20 ✅
Checkpoint Tests:
mpt-125m-monolithic-resumption-1-node-cEiWPUworking monolithic 1-node ✅mpt-125m-monolithic-resumption-2-node-4jvrYYworking monolithic 2-node ✅mpt-125m-sharded-resumption-1-node-38Kidhworking sharded, 1-node ✅mpt-125m-sharded-resumption-2-node-ahSI61working sharded 2-node ✅mpt-125m-hybrid-resumption-1-node-7QStEEworking hsdp 1-node ✅mpt-125m-hybrid-resumption-2-node-SazoRoworking hsdp 2-node ✅