Question about the performace of GroupedLinear
Issue
When testing the performance of DeepSeek-v2-lite using Megatron and TransformerEngine, I encountered an issue where GroupedLinear exhibits unusually high duration. The TEGroupedLinear forward operation typically takes about 1ms as observed in the nsys timeline, but there are anomalous events that exceed 200ms. What could be causing this issue?
environment
- megatron-core r0.10.0
- transformerEngine 1.13.0+e5edd6c
- image: nvcr.io/nvidia/pytorch:24.07-py3 + cudnn-9.5.1
- 8xH800
Duration of GroupedLinear event
I cannot provide the timeline for some reason. The table following provides the duration of abnormal events and normal events which were extracted from nsys timeline. Why is there such a large difference in duration between nvte_multi_stream_cublas_gemm and TERowParallelGroupedLinear? and why does the start time of the abnormal event nvte_multi_stream_cublas_gemm lag behind the start time of TERowParallelGroupedLinear by about 200ms?
And if I directly use TEGroupedLinear and input tensors of the same shape for microBenchmark, the time consumption returns to normal, is the training workflow affecting the execution efficiency of the kernel?
| Name | Start | Duration | TID |
|---|---|---|---|
| #TEGroupLinear forward | 3.67097s | 206.707 ms | 179103 |
| ##TERowParallelGroupedLinear forward | 3.67099s | 206.660 ms | 179103 |
| ###nvte_multi_stream_cublas_gemm | 3.87704s | 387.909 μs | 179103 |
| #TEGroupLinear forward | 1.4103s | 3.373 ms | 179103 |
| ##TERowParallelGroupedLinear forward | 1.41032s | 3.327 ms | 179103 |
| ###nvte_multi_stream_cublas_gemm | 1.41077s | 1.008 ms | 179103 |
| #TEGroupLinear forward | 2.58523s | 3.103 ms | 179103 |
| ##TERowParallelGroupedLinear forward | 2.58525s | 3.055 ms | 179103 |
| ###nvte_multi_stream_cublas_gemm | 2.58579s | 1.128 ms | 179103 |
Optimization of hopper?
And is there a plan to optimize GroupedLinear for Hopper architecture? Based on the parameters of DeepSeek-v2, the tflops of H800 compared to A800 did not improve significantly, and overall performance is quite poor. The test results and code are as follows:
# H800
Average execution time: 0.0011188620805740357 s, tflops: 253.35369430928066
Average execution time: 0.001063387517929077 s, tflops: 133.2852966376957
# A800
Average execution time: 0.0018983731222152712 s, tflops: 149.32145752527958
Average execution time: 0.0013353574371337891 s, tflops: 106.13931283613297
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.parallel_state import set_tensor_model_parallel_world_size
import torch
from typing import Callable
def benchmark(benchmark_func: Callable, warmup_times = 10, benchmark_times = 50):
warmup_elapsed_time_list = []
benchmark_elapsed_time_list = []
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times):
start_event.record()
benchmark_func()
end_event.record()
torch.cuda.synchronize()
warmup_elapsed_time_list.append(start_event.elapsed_time(end_event))
for _ in range(benchmark_times):
start_event.record()
benchmark_func()
end_event.record()
torch.cuda.synchronize()
benchmark_elapsed_time_list.append(start_event.elapsed_time(end_event))
# print(f"warmup_elapsed_time: {warmup_elapsed_time_list}, benchmark_elapsed_time_list: {benchmark_elapsed_time_list}")
avg_secs = sum(benchmark_elapsed_time_list) / benchmark_times / 1e3
return avg_secs
def test_dsv2_lite_ep1():
num_local_experts = 64
hidden_size = 2048
moe_ffn_hidden_size = 1408
topk = 6
seqlen = 4096
# m_splits = [8, 83, 945, 0, 162, 4, 3, 0, 251, 510, 24, 140, 37, 0, 33, 10, 1, 0, 5, 0, 115, 0, 1, 1, 1, 188, 43, 1, 7, 0, 12, 0, 324, 5, 88, 0, 0, 58, 558, 219, 1296, 1155, 2, 1102, 6, 0, 115, 0, 106, 0, 10, 0, 698, 2, 594, 221, 351, 0, 1, 2, 2040, 0, 7, 743]
num_even_tokens = seqlen * topk // num_local_experts
m_splits = [num_even_tokens for _ in range(num_local_experts)]
config = TransformerConfig(num_attention_heads=1, num_layers=1)
config.params_dtype = torch.bfloat16
config.use_cpu_initialization = False
config.add_bias_linear = False
config.gradient_accumulation_fusion = True
# hack parallel states
set_tensor_model_parallel_world_size(1)
# column up linear
linear_fc1 = TEColumnParallelGroupedLinear(
num_gemms=num_local_experts,
input_size=hidden_size,
output_size=moe_ffn_hidden_size*2,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name='fc1',
)
linear_fc2 = TEColumnParallelGroupedLinear(
num_gemms=num_local_experts,
input_size=moe_ffn_hidden_size,
output_size=hidden_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=True,
tp_comm_buffer_name='fc1',
)
up_inputs = torch.randn((topk*seqlen, hidden_size), dtype=torch.bfloat16, device='cuda')
down_inputs = torch.randn((topk*seqlen, moe_ffn_hidden_size), dtype=torch.bfloat16, device='cuda')
def up_linear():
linear_fc1(up_inputs, m_splits)
def down_linear():
linear_fc2(down_inputs, m_splits)
avg_secs = benchmark(up_linear)
tflops = 2 * seqlen * topk * hidden_size * 2 * moe_ffn_hidden_size / avg_secs / 1e12
print(f"Average execution time: {avg_secs} s, tflops: {tflops}")
avg_secs = benchmark(down_linear)
tflops = 2 * seqlen * topk * moe_ffn_hidden_size * hidden_size / avg_secs / 1e12
print(f"Average execution time: {avg_secs} s, tflops: {tflops}")
test_dsv2_lite_ep1()
Can you summary your questions into the following two?
- Q: Why is there such a large difference in duration between
nvte_multi_stream_cublas_gemmandTERowParallelGroupedLinear? A: It's the CPU overheads of PyTorch ops, such astorch.split(), andtorch.empty()(2xnum_gemmscalls) underfused_multi_cast_transpose. You can capture them in Nsys by adding the contextwith torch.autograd.profiler.emit_nvtx(enabled=True)to your code during profiling. It's not trivial to eliminate these overheads. - Q: These are some abnormal iterations that consume much more time than usual, while in micro benchmark, there is no problem. A: I have no idea of this issue. Maybe you can enable nvtx for torch ops using the context mentioned above and see what it is actually doing there.
I've had a similar problem:
environment:
- megatron-core r0.10.0
- transformerEngine 1.12.0+7f2afaaa
- image: nvcr.io/nvidia/pytorch:24.07-py3
- 8xH800
phenomenon:
- During deepseek v2-lite training, te.pytorch.GroupedLinear.forward often has a performance decay of 10% or more, resulting in frequent sharp jitter in the throughput of the training, with much less jitter in the backward.
- Test deepseek v2-lite moe module with input[b,s]=[6, 4096] for 10k forward calculations and performance decay more than 30% apear about 40 times, while for the mixtral 8*7b moe layer, only 3 times of performance decay of more than 30% per 10k forward calculations.
- Such phenomenon doesn't appear in A800 with same environment.
Could it be the balance issue in MoE training? Can you try drop and pad by setting
moe_token_drop_policy="probs"
moe_expert_capacity_factor=1.0
moe_pad_expert_input_to_capacity=True
and see if the issue still happens?
Could it be the balance issue in MoE training? Can you try drop and pad by setting
moe_token_drop_policy="probs" moe_expert_capacity_factor=1.0 moe_pad_expert_input_to_capacity=True and see if the issue still happens?
- I tested transformer_engine.pytorch.GroupedLinear individually and split the token evenly, and there is still the problem as above.
- I found that inefficiency calculations always seem to happen continuously
- Below is my test script
import torch
from transformer_engine.pytorch import GroupedLinear
import time
from tqdm import tqdm
seqlen = 4096
bsz = 6
topk = 6
tokens = seqlen * bsz * topk
num_gemms = 64
assert tokens % num_gemms == 0
in_features = 2048
out_features = 1408 * 2
params_dtype = torch.bfloat16
torch.cuda.set_device(0)
device = torch.cuda.current_device()
bias = False
m_splits = [tokens // num_gemms for _ in range(num_gemms)]
inp = torch.randn([tokens, 1, in_features], dtype=params_dtype, device=device)
flops = tokens * in_features * out_features * 1.5 * 2
linear1 = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode="row",
device="cuda",
)
linear2 = GroupedLinear(
num_gemms,
out_features,
in_features,
bias=bias,
params_dtype=params_dtype,
parallel_mode="column",
device="cuda",
)
test_turns = 10000
tflops_list = []
for i in tqdm(range(test_turns)):
t0 = time.time()
out1 = linear1(inp, m_splits)
out2 = linear2(out1, m_splits)
torch.cuda.synchronize()
cost = time.time() - t0
tflops_list.append(flops / cost / 1e12)
sorted_tflops, sorted_idx = torch.tensor(tflops_list).sort()
avg_peak = sorted_tflops[-100:].mean().item()
low_eff_ids = sorted_idx.masked_select(sorted_tflops < avg_peak * 0.7)
print(f"###DEBUG### avg peak: {avg_peak}, low eff num: {len(low_eff_ids)}/{test_turns}, low eff idx: {sorted(low_eff_ids.tolist())}")
I can reproduce the issue. I ran two GroupedLinear layers for 1000 iterations using the above snippet and observed 3 slowdowns, that's iter 357, 533, and 710. Then I checked the nsys timeline and found there're large CPU blank spaces in these slow iterations. From the sampling points, it seems all the root causes are _libc_malloc.
- Iter 357
- Iter 533
- Iter 710
Any ideas here? @timmoon10
@yaox12
Unsure if related, but also notice similar slowdowns in GroupedLinear. a trace on the execution shows intermittently there's a slowdown on the torch.split call
Thanks for reporting the CPU overhead of group linear, we have observed CPU overhead when GEMM sizes are small and the number of local experts of each rank is large, such that CPU overhead introduced in torch.split and torch.named_parameter and even isinstance calls become notable in timeline. For example, having N local experts means calling isinstance N times more frequent, and calling such python built-in methods is becoming a problem for moe training.
Working on fixes to make it less possible to expose CPU overhead, while at the same time we encourage exploring configs like EP, topK, sequence length, micro-batch size to increase GEMM sizes to avoid exposing CPU overhead.