MS-AMP icon indicating copy to clipboard operation
MS-AMP copied to clipboard

Make new optimizer more extensible, easier to integrate downstream for FSDP

Open muellerzr opened this issue 6 months ago • 6 comments

Description This PR makes it easier for users to use FSDP with MS-AMP from their existing optimizers. This is especially beneficial for library authors, as currently we need to go through quite a bit to get the FSDP version of these optimizers working when a user passes in optim.Adam.

Instead we delegate the FSDPAdamW to an OptimWrapper, which calls an underlying optimizer as a passthrough. This lets us add in any logic that should be done before/after said logic easier, and it takes in a constructed Optimizer rather than being inherited.

Let me know what we think about this, currently I'm going through integrating FSDP and DeepSpeed w/ MS-AMP into Accelerate and found this to be a critical painpoint, as our users pass in normal PyTorch optimizers and don't create special versions themselves.

@tocean @wkcn let me know what you two think :)

New working FSDP:

model, optimizer = ...
model, optimizer = msamp.initialize(model, optimizer, use_fsdp=True, weight_qtype=Dtypes.kfloat8_e4m3)
model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy)
optimizer = FSDPAdamW(optimizer)

muellerzr avatar Aug 15 '24 15:08 muellerzr