composer icon indicating copy to clipboard operation
composer copied to clipboard

Add torch distributed checkpointing monkeypatches to enable TE checkpointing for extra_state attribute

Open j316chuck opened this issue 1 year ago • 0 comments
trafficstars

What does this PR do?

Add torch distributed checkpointing monkeypatches to enable TE checkpointing for extra_state attribute. Patches the internal torch.distributed.state_dict functions:

        state_dict._get_fqns = _get_fqns
        state_dict._verify_options = _verify_options
        state_dict._get_model_state_dict = _get_model_state_dict
        state_dict._load_model_state_dict = _load_model_state_dict

What issue(s) does this change relate to?

Related to: https://github.com/pytorch/pytorch/pull/125336 and https://github.com/pytorch/pytorch/issues/122946

Test run:

(base) ➜  ~ mcli logs -f mpt-125m-te-ckpt-resumption-torch-monkeypatch-zcyRoH | rg loss
	 Train loss/train/total: 11.8668
	 Train loss/train/total: 11.8797
	 Train loss/train/total: 11.8777
	 Train loss/train/total: 11.4828
	 Train loss/train/total: 10.7657
	 Train loss/train/total: 10.4124
	 Train loss/train/total: 10.3823
	 Train loss/train/total: 10.1670
	 Train loss/train/total: 9.9711
	 Train loss/train/total: 9.7310
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case  should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
	 Train loss/train/total: 9.7091
	 Train loss/train/total: 9.5678
	 Train loss/train/total: 9.4688
	 Train loss/train/total: 9.4365
	 Train loss/train/total: 9.2994
	 Train loss/train/total: 9.3235
	 Train loss/train/total: 9.2855
	 Train loss/train/total: 9.1910
	 Train loss/train/total: 9.0667
	 Train loss/train/total: 9.0421
wandb:                     loss/train/total ███▇▅▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁
wandb:                     loss/train/total 9.04214
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case  should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
	 Train loss/train/total: 8.9817
	 Train loss/train/total: 8.9092
	 Train loss/train/total: 8.8225
	 Train loss/train/total: 8.7351
	 Train loss/train/total: 8.6218
	 Train loss/train/total: 8.6051
	 Train loss/train/total: 8.4849
	 Train loss/train/total: 8.3797
	 Train loss/train/total: 8.3112
	 Train loss/train/total: 8.2576
/usr/lib/python3/dist-packages/torch/distributed/fsdp/_common_utils.py:431: UserWarning: An unexpected prefix is detected. This case  should only happen when using DMP with FSDP. prefix = loss_fn., submodule_name = _fsdp_wrapped_module
	 Train loss/train/total: 8.1670
	 Train loss/train/total: 8.0978
	 Train loss/train/total: 8.0123
	 Train loss/train/total: 7.9702
	 Train loss/train/total: 7.8903
	 Train loss/train/total: 7.8319

	 Train loss/train/total: 7.7992
	 Train loss/train/total: 7.7377
	 Train loss/train/total: 7.7233
	 Train loss/train/total: 7.6725
wandb:                     loss/train/total ██▇▇▆▆▅▅▄▄▄▃▃▃▂▂▂▁▁▁
wandb:                     loss/train/total 7.67255

Loss continues to go down monotonically after checkpoint and resumption at timestamp 20 ✅

Checkpoint Tests:

  • mpt-125m-monolithic-resumption-1-node-cEiWPU working monolithic 1-node ✅
  • mpt-125m-monolithic-resumption-2-node-4jvrYY working monolithic 2-node ✅
  • mpt-125m-sharded-resumption-1-node-38Kidh working sharded, 1-node ✅
  • mpt-125m-sharded-resumption-2-node-ahSI61 working sharded 2-node ✅
  • mpt-125m-hybrid-resumption-1-node-7QStEE working hsdp 1-node ✅
  • mpt-125m-hybrid-resumption-2-node-SazoRo working hsdp 2-node ✅

j316chuck avatar May 16 '24 21:05 j316chuck