TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization

Open xiaoxi-wangfj opened this issue 4 months ago • 4 comments

1.Fused moe_permute_with_probs + Fp8Padding and fused moe_unpermute + Fp8Unpadding, which removes the explicit padding/unpadding in the MOE experts module, improved performance and reduced peak gpu memory usage. 2.Added tests of fused permute/pad and unpermute/unpad operations.

Description

This PR optimizes FP8 MoE permute and pad operations by:

  1. Fusing moe_permute_with_probs + Fp8Padding into moe_permute_and_pad_with_probs
  2. Fusing moe_unpermute + Fp8Unpadding into moe_unpermute with pad_offsets argument
  3. Thereby removing explicit padding/unpadding steps in the MOE experts module

Results:

  • 1.1x~1.6x speedup for fused permute-and-pad operations
  • 1.7x~3x speedup for fused unpermute-and-unpad operations (measured by tests/pytorch/test_permutation.py)
  • Verified in ene-to-end FP8 model training with Megatron framework, +0.4% MFU uplift and ~1GB peak GPU memory reduction in a typical ~600B paramter setup.

Performance data

Tests covering a wide range of model training configurations were performed comparing the fused operations ("Fused:") and the original version ("Orig:"). Running time (in milliseconds) are summarized in the table below and the speedup, measured as the reciprocal of the ratio between running times, are also provided. All tests were carried out with the tests/pytorch/test_permutation.py benchmark script.

Fused-perm-pad

The usage in Megatron-LM

  1. Megatron-LM/megatron/core/transformer/moe/moe_utils.py : Added Support for Fused Operations

`

# Added fused function import
from megatron.core.extensions.transformer_engine import (
    ...,
    fused_permute_and_pad_with_probs,  # [!code ++]
)

def permute(
    ...,
    tokens_per_expert: Optional[torch.Tensor] = None,  # [!code ++]
    align_size: int = -1  # [!code ++]
):
  ...
  if fused and probs is not None:
      if not HAVE_TE or fused_permute_with_probs is None:
          raise ValueError(
              "fused_permute_with_probs is not available. Please install TE >= 2.1.0."
          )
      if tokens_per_expert is not None and align_size > 0:  # [!code ++]
          # Use fused permute+pad operation [!code ++]
          return fused_permute_and_pad_with_probs(tokens, probs, routing_map, tokens_per_expert, align_size)   # [!code ++]
      else:
          # Fallback to original implementation
          ...


def unpermute(
    ...,
    pad_offsets: Optional[torch.Tensor] = None  # [!code ++]
):
    return fused_unpermute(
        ...,
        pad_offsets=pad_offsets  # [!code ++]
    )

`

  1. Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py: Scheduler Integration

`

class _DeepepManager(_DispatchManager):
    def __init__(...):
        self.pad_offsets = None  # [!code ++] Store padding offsets
    
    def get_permuted_hidden_states_by_experts(...):
        ...
        if self.config.moe_permute_padding_for_fp8:# [!code ++]
            # Use fused path [!code ++]
            (                                                          # [!code ++]
                hidden_states,                                         # [!code ++]
                permuted_probs,                                        # [!code ++]
                self.reversed_mapping_for_combine,                     # [!code ++]
                self.pad_offsets,                                      # [!code ++]
                self.tokens_per_expert                                 # [!code ++]
            ) = permute(                                               # [!code ++]
                hidden_states,                                         # [!code ++]
                self.dispatched_routing_map,                           # [!code ++]
                probs=self.dispatched_probs,                           # [!code ++]
                fused=self.permute_fusion,                             # [!code ++]
                tokens_per_expert=self.tokens_per_expert,              # [!code ++]
                align_size=get_fp8_align_size(self.config.fp8_recipe), # [!code ++]
            )                                                          # [!code ++]
        else:
            # Original path
            ...
    
    def get_restored_hidden_states_by_experts(...):
        hidden_states = unpermute(
            ...,
            pad_offsets=self.pad_offsets if self.config.moe_permute_padding_for_fp8 else None, # [!code ++]
        )
        ...

`

Type of change

Documentation change (change only to the documentation, either a fix or a new content)

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Changes

  • Added moe_permute_and_pad_with_probs api for fused permute and pad, modified moe_unpermute api with pad_offsets argument for fused unpermute and unpad in transformer_engine/pytorch/permutation.py
  • Added tests in tests/pytorch/test_permutation.py

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

xiaoxi-wangfj avatar Jul 03 '25 10:07 xiaoxi-wangfj