TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
1.Fused moe_permute_with_probs + Fp8Padding and fused moe_unpermute + Fp8Unpadding, which removes the explicit padding/unpadding in the MOE experts module, improved performance and reduced peak gpu memory usage.
2.Added tests of fused permute/pad and unpermute/unpad operations.
Description
This PR optimizes FP8 MoE permute and pad operations by:
- Fusing
moe_permute_with_probs+Fp8Paddingintomoe_permute_and_pad_with_probs - Fusing
moe_unpermute+Fp8Unpaddingintomoe_unpermutewithpad_offsetsargument - Thereby removing explicit padding/unpadding steps in the MOE experts module
Results:
- 1.1x~1.6x speedup for fused permute-and-pad operations
- 1.7x~3x speedup for fused unpermute-and-unpad operations (measured by
tests/pytorch/test_permutation.py) - Verified in ene-to-end FP8 model training with Megatron framework, +0.4% MFU uplift and ~1GB peak GPU memory reduction in a typical ~600B paramter setup.
Performance data
Tests covering a wide range of model training configurations were performed comparing the fused operations ("Fused:") and the original version ("Orig:"). Running time (in milliseconds) are summarized in the table below and the speedup, measured as the reciprocal of the ratio between running times, are also provided. All tests were carried out with the tests/pytorch/test_permutation.py benchmark script.
The usage in Megatron-LM
Megatron-LM/megatron/core/transformer/moe/moe_utils.py: Added Support for Fused Operations
`
# Added fused function import
from megatron.core.extensions.transformer_engine import (
...,
fused_permute_and_pad_with_probs, # [!code ++]
)
def permute(
...,
tokens_per_expert: Optional[torch.Tensor] = None, # [!code ++]
align_size: int = -1 # [!code ++]
):
...
if fused and probs is not None:
if not HAVE_TE or fused_permute_with_probs is None:
raise ValueError(
"fused_permute_with_probs is not available. Please install TE >= 2.1.0."
)
if tokens_per_expert is not None and align_size > 0: # [!code ++]
# Use fused permute+pad operation [!code ++]
return fused_permute_and_pad_with_probs(tokens, probs, routing_map, tokens_per_expert, align_size) # [!code ++]
else:
# Fallback to original implementation
...
def unpermute(
...,
pad_offsets: Optional[torch.Tensor] = None # [!code ++]
):
return fused_unpermute(
...,
pad_offsets=pad_offsets # [!code ++]
)
`
Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py: Scheduler Integration
`
class _DeepepManager(_DispatchManager):
def __init__(...):
self.pad_offsets = None # [!code ++] Store padding offsets
def get_permuted_hidden_states_by_experts(...):
...
if self.config.moe_permute_padding_for_fp8:# [!code ++]
# Use fused path [!code ++]
( # [!code ++]
hidden_states, # [!code ++]
permuted_probs, # [!code ++]
self.reversed_mapping_for_combine, # [!code ++]
self.pad_offsets, # [!code ++]
self.tokens_per_expert # [!code ++]
) = permute( # [!code ++]
hidden_states, # [!code ++]
self.dispatched_routing_map, # [!code ++]
probs=self.dispatched_probs, # [!code ++]
fused=self.permute_fusion, # [!code ++]
tokens_per_expert=self.tokens_per_expert, # [!code ++]
align_size=get_fp8_align_size(self.config.fp8_recipe), # [!code ++]
) # [!code ++]
else:
# Original path
...
def get_restored_hidden_states_by_experts(...):
hidden_states = unpermute(
...,
pad_offsets=self.pad_offsets if self.config.moe_permute_padding_for_fp8 else None, # [!code ++]
)
...
`
Type of change
Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Changes
- Added
moe_permute_and_pad_with_probsapi for fused permute and pad, modifiedmoe_unpermuteapi with pad_offsets argument for fused unpermute and unpad in transformer_engine/pytorch/permutation.py - Added tests in
tests/pytorch/test_permutation.py
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes