MS-AMP
MS-AMP copied to clipboard
Make new optimizer more extensible, easier to integrate downstream for FSDP
Description
This PR makes it easier for users to use FSDP with MS-AMP from their existing optimizers. This is especially beneficial for library authors, as currently we need to go through quite a bit to get the FSDP version of these optimizers working when a user passes in optim.Adam.
Instead we delegate the FSDPAdamW to an OptimWrapper, which calls an underlying optimizer as a passthrough. This lets us add in any logic that should be done before/after said logic easier, and it takes in a constructed Optimizer rather than being inherited.
Let me know what we think about this, currently I'm going through integrating FSDP and DeepSpeed w/ MS-AMP into Accelerate and found this to be a critical painpoint, as our users pass in normal PyTorch optimizers and don't create special versions themselves.
@tocean @wkcn let me know what you two think :)
New working FSDP:
model, optimizer = ...
model, optimizer = msamp.initialize(model, optimizer, use_fsdp=True, weight_qtype=Dtypes.kfloat8_e4m3)
model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy)
optimizer = FSDPAdamW(optimizer)
@microsoft-github-policy-service agree company="Hugging Face"
Sorry for the extraneous pushes while I was figuring something out. Good to go now :)
You can see our new accelerate benchmarking scripts here: https://github.com/huggingface/accelerate/tree/muellerzr-msamp-ds-fsdp/benchmarks/fp8/ms_amp
@tocean @wkcn any particular issues with this? :)
(Ideally it'd be great to include this in the next accelerate release on the 1st :) )
@muellerzr Thanks for your contribution!
The PR looks good to me. Sorry that I am not at Microsoft and do not have the authorization to review and merge the pull request.
Ack okay, I suppose we'll have to wait for @tocean /@abuccts /@guoshzhao to take a look. Thanks for the flag 🤗