[BUG] Token routing probability all-gather precision in token_dispatcher causes differing results between EP ranks
Describe the bug I have a setup on a small MoE model on 2 H100s with 2-way EP (DP), 1-way TP/PP. I am feeding the same token sequence into the model on both DP ranks, and expect the activations after the MoE to be the same, since the tokens from both sequences should be routed to the same experts.
However, I observe the outputs after token unpermutation on each rank to differ by 0.1% to 2%. This is concerning, since the error continues to compound on top of this after many layers.
I've tracked the source of the error to the communication of the routing probabilities. Specifically, converting the routing probabilities to torch.float32 before performing the all-gather eliminates this error.
To Reproduce
- Instantiate an MoE model with 2-way EP/DP sharding using
MoEAllGatherTokenDispatcher - Feed the same tokens into both DP shards
- Record the outputs of the MoE on both GPUs, and observe that they are different.
Expected behavior
The MoE output should be the same.
Stack trace/logs N/A
Environment (please complete the following information):
- Megatron-LM commit ID:
aa719a0b0145481fb9212c577ee9a3f000fd16da+ internal patches - PyTorch version: 2.5.1
- CUDA version: 12.2
- NCCL version: 2.21.5
Proposed fix
Add probs = probs.to(torch.float32) before the all-gather on this line to perform the all-gather in float32.
Additional context N/A
Thanks for reporting the issue!
I suspect the discrepancy is due to the different accumulation orders of reduction during token combination. We've received feedback from other customers suggesting that reduction should use fp32.
Update: the MR has been merged. There's an internal MR to promote routing and weighted averaging data types to prevent precision loss, which we'll merge ASAP.
Hi @yanring , after manually setting the probs to fp64, I still have precision issues with EP. Do you have any suggestion on what else need to be promoted to higher precision? Thanks!
Hi @yanring , after manually setting the probs to fp64, I still have precision issues with EP. Do you have any suggestion on what else need to be promoted to higher precision? Thanks!
I assume you're using the arg --moe-router-dtype, right? Did you also enable TP? By the way, could you try the allgather dispatcher to see if the issue still exists?
Hi @yanring Yeah I set moe-router-dtype to fp64. I ran on master branch the following test but it failed.
Basically the test checks if the output is the same between a 1024 length random sequence and its truncated version (setting input[90:] = 0.
input = torch.randn(1024, 1, 4096).cuda().bfloat16()
input_trunc90 = input.clone()
input_trunc90[90:] = 0
...
assert torch.allclose(notrunc_output[:90], trunc_output[:90], atol=1e-3, rtol=1e-3)
Can you please help take a look? Thanks!
test.py
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import time
import pytest
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.training.initialize import _set_random_seed
from tests.unit_tests.test_utilities import Utils
class TestMoELayerDispatcherDiscrepancy:
def setup_method(self, method):
pass
# @pytest.mark.parametrize("num_moe_experts", [8, 16, 32, 64])
@pytest.mark.parametrize("num_moe_experts", [64])
@pytest.mark.parametrize("grouped_gemm", [False])
# @pytest.mark.parametrize("tp_size,ep_size", [(4, 1)])
# @pytest.mark.parametrize("tp_size,ep_size", [(1, 1)])
@pytest.mark.parametrize("tp_size,ep_size", [(1, 4)])
# @pytest.mark.parametrize("tp_size,ep_size", [(1, 1), (1, 2), (1, 4)])
@pytest.mark.parametrize("bf16", [True])
# @pytest.mark.parametrize("bf16", [True, False])
# @pytest.mark.parametrize("moe_length", [1])
@pytest.mark.parametrize("moe_length", [4])
# @pytest.mark.parametrize("moe_length", [1, 2, 3, 4])
@pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"])
@pytest.mark.internal
def test_truncation(
self, num_moe_experts, grouped_gemm, tp_size, ep_size, bf16, moe_length, moe_token_dispatcher_type
):
Utils.initialize_model_parallel(
tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size
)
# Init input and layer
_set_random_seed(seed_=123, data_parallel_random_init=False)
input = torch.randn(1024, 1, 4096).cuda().bfloat16()
input_trunc90 = input.clone()
input_trunc90[90:] = 0
self.transformer_config = TransformerConfig(
num_layers=1,
hidden_size=4096,
num_attention_heads=16,
num_moe_experts=num_moe_experts,
use_cpu_initialization=False,
moe_token_dispatcher_type=moe_token_dispatcher_type,
# moe_shared_expert_intermediate_size =
moe_router_topk=6,
moe_aux_loss_coeff=0.01,
moe_grouped_gemm=grouped_gemm,
moe_router_dtype="fp64",
add_bias_linear=False,
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
sequence_parallel=False,
bf16=bf16,
)
transformer_layer_spec = get_gpt_layer_local_spec(
num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm
)
layer = (
TransformerLayer(self.transformer_config, transformer_layer_spec.submodules)
.cuda()
.bfloat16()
)
# moe_layer = [layer.mlp, layer.mlp2, layer.mlp3, layer.mlp4]
moe_layer = [layer.mlp] * moe_length
# moe_layer2 = layer.mlp2
# moe_layer3 = layer.mlp3
# moe_layer4 = layer.mlp4
layer.eval()
# Test allgather dispatcher
import sys
for idx in range(moe_length):
moe_layer[idx].config.moe_token_dispatcher_type = "allgather"
with torch.no_grad():
# notrunc_output = moe_layer(input)[0]
xx = input
for idx in range(moe_length):
xx = moe_layer[idx](xx)
# sys.stderr.write(f"notrunc: xx shape: {xx.shape}\n")
xx = xx[0]
notrunc_output = xx
# Allgather the output to check if it's the same in all ranks
notrunc_output_ag_shape = (torch.distributed.get_world_size(), *(notrunc_output.shape))
notrunc_output_ag = torch.zeros(
notrunc_output_ag_shape, device=notrunc_output.device, dtype=notrunc_output.dtype
)
torch.distributed.all_gather_into_tensor(
notrunc_output_ag, notrunc_output, group=torch.distributed.group.WORLD
)
# Check if output is the same across all ranks
if parallel_state.get_data_parallel_rank() == 0:
for i in range(1, parallel_state.get_tensor_model_parallel_world_size()):
if not torch.equal(notrunc_output_ag[0], notrunc_output_ag[i]):
print(f"Allgather output differs at rank {torch.distributed.get_rank()}")
raise ValueError("Allgather output differs at rank {i}")
print(f"Allgather output is the same across all ranks", flush=True)
torch.cuda.synchronize()
# Test alltoall dispatcher
# moe_layer.config.moe_token_dispatcher_type = "alltoall"
with torch.no_grad():
xx = input_trunc90
# for _ in range(4):
for idx in range(moe_length):
xx = moe_layer[idx](xx)
# sys.stderr.write(f"notrunc: xx shape: {xx.shape}\n")
xx = xx[0]
# sys.stderr.write(f"trunc: xx shape: {xx.shape}\n")
trunc_output = xx
# Allgather the output to check if it's the same in all ranks
trunc_output_ag_shape = (torch.distributed.get_world_size(), *(trunc_output.shape))
trunc_output_ag = torch.zeros(
trunc_output_ag_shape, device=trunc_output.device, dtype=trunc_output.dtype
)
torch.distributed.all_gather_into_tensor(
trunc_output_ag, trunc_output, group=torch.distributed.group.WORLD
)
# Check if output is the same across all ranks
if parallel_state.get_data_parallel_rank() == 0:
for i in range(1, parallel_state.get_tensor_model_parallel_world_size()):
if not torch.equal(trunc_output_ag[0], trunc_output_ag[i]):
print(f"A2A output differs at rank {torch.distributed.get_rank()}")
raise ValueError("A2A output differs at rank {i}")
print(f"A2A output is the same across all ranks", flush=True)
torch.cuda.synchronize()
# if torch.distributed.get_rank() == 0:
# from IPython import embed; embed()
# else:
# import time; time.sleep(1000000)
assert torch.allclose(notrunc_output[:90], trunc_output[:90], atol=1e-3, rtol=1e-3)
print(f"Allgather and A2A output is the same", flush=True)
Utils.destroy_model_parallel()
Marking as stale. No activity in 60 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.
@Ir1d if you are still facing precision problems with EP, please feel to create a new issue with a repro