TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Support dtype casting in fused adam
Description
FusedAdam updates the params in-place currently. This PR adds dtype casting in FusedAdam kernel, in addition to updating the master params in-place, but also can update extra model params. The extra params can be of bf16, fp16, fp8 type.
Update: I have validated the convergence using GPT training in Megatron-LM. The losses before and after enabling this feature are identical in bits.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refractor
Changes
Please list the changes introduced in this PR:
- Change A
- Change B
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
@timmoon10 Could you please take a look? The corresponding changes to Megatron-LM are in our internal gitlab MR#1736.
/te-ci pytorch
Hi @timmoon10 , I encountered an issue when trying to update scale_inv inside the Adam kernel using *scale_inv_ptr = 1.0f / scale. This resulted in loss not being bit-wise aligned. The reason is that TE/PyTorch compilation uses --use_fast_math, which compiles the reciprocal calculation into a single MUFU.RCP instruction, producing an approximate result rather than an accurate one.
To achieve bit-wise alignment of the loss, I had to update scale_inv outside the Adam kernel. This also leads to suboptimal performance. Do you have any suggestions to address this?
/te-ci pytorch
/te-ci pytorch
Based on a discussion with @ptrendx, I think we should give more thought to the API. While this is primarily targeting Megatron-LM, it's important that other TE users can use it easily without relying on Mcore infrastructure.
@ptrendx's preferred API is for the optimizer to hold the model weights (including Float8Tensors) and to treat the master weights as optimizer state (similar to exp_avg and exp_avg_sq). This is similar to Option 1 in https://github.com/NVIDIA/TransformerEngine/pull/977#discussion_r1687308954. The workflow should look like:
model = MyModel() # Mix of fp32, bf16, fp8 params
optim = FusedAdam(model.parameters(), dtype=torch.float32) # Create FP32 master weights for each non-fp32 param
optim.step()
# optim.state[bf16_param]["exp_avg"] is fp32 tensor
# optim.state[bf16_param]["exp_avg_sq"] is fp32 tensor
# optim.state[bf16_param]["master_param"] is fp32 tensor
# optim.state[fp32_param]["master_param"] is None
This API is more natural for standard PyTorch workflows and it doesn't require maintaining separate model weights/master weights like in Megatron-LM. That said, I can see value in keeping master_weights as an optional kwarg since Megatron-LM already allocates them:
model = MyModel() # Mix of fp32, bf16, fp8 params
master_weights = [param.float() for param in model.parameters()]
optim = FusedAdam(model.parameters(), dtype=torch.float32, master_weights=master_weights)
# optim.state[param]["master_param"] is from my_master_weights
Hi @timmoon10 , I have made modifications to the FusedAdam API based on your suggestions. I already tested my changes in Megatron-LM, and the training loss matches the previous results exactly. However, there are still some issues that need to be discussed:
-
I have restricted that master_weights must be provided by the user, and the user-provided master_weights must be a list of tensors. If the user does not provide master_weights (i.e., master_weights=None), only the model weights will be updated. Is this approach reasonable?
-
In Megatron-LM,
master_weightsare created in the__init__method ofdist opt, while FusedAdam is created earlier. Therefore, I had to initially setmaster_weightsto None, and then modifyoptimizer.master_weightsin the__init__method ofdist optwith the following code:
# create optimizer
optimizer = FusedAdam(param_groups, ... , master_weights=None)
optimizer = DistributedOptimizer(optimizer, *other_args)
# inside __init__ of dist opt
master_weights = list(itertools.chain(*self.shard_fp32_from_float16_groups))
self.optimizer.master_weights = master_weights # self.optimizer is FusedAdam
This usage is somewhat uncomfortable, but not entirely unusual. Any suggestions?
- Kunlun is currently implementing MX-FP16. After some discussion, we believe that it seems more reasonable to place the creation of master_weights inside FusedAdam. This is because
exp_avg,exp_avg_sqandmaster_weightare optimizer states, and since "exp_avg" and "exp_avg_sq" are created and updated within FusedAdam,master_weightshould be handled in the same way. However, this change would also conflict with the design logic of Megatron.
@timmoon10 Could you please take a look?
/te-ci pytorch
/te-ci pytorch