TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Performance Issue with NVIDIA Transformer Engine FP8 Linear Functions on L20

Open cslvjt opened this issue 7 months ago • 8 comments

issue

When testing the linear API provided by NVIDIA's transformer engine (with FP8 precision) on an L20 device, I found that its speed is significantly slower than PyTorch's built-in linear API. Is there something I might have missed in the configuration or additional optimization settings required? Any suggestions or guidance would be appreciated.

code

`import torch import transformer_engine.pytorch as te from transformer_engine.pytorch import fp8_model_init from transformer_engine.common import recipe import time

in_features = 768 out_features = 3072 hidden_size = 2048 iters = 5

with fp8_model_init(enabled=True): model = te.Linear(in_features, out_features, bias=True, device="cuda") inp = torch.randn(hidden_size, in_features, device="cuda") for name, param in model.named_parameters(): print(f"{name}: {param.dtype}") torch_model = torch.nn.Linear(in_features, out_features, bias=True).cuda() for name, param in torch_model.named_parameters(): print(f"{name}: {param.dtype}")

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

for _ in range(5): out = torch_model(inp)

times = [] with torch.no_grad(): for i in range(iters): start = time.time() with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): out = model(inp) # input改成inp torch.cuda.synchronize() end = time.time() times.append(end - start) avg_time_te = sum(times) / len(times) print(f"te linear average costime: {avg_time_te:.6f} seconds")

torch_times = [] for i in range(iters): start = time.time() with torch.no_grad(): out = torch_model(inp) torch.cuda.synchronize() end = time.time() torch_times.append(end - start) avg_time_torch = sum(torch_times) / len(torch_times) print(f"torch linear average costime: {avg_time_torch:.6f} seconds")`

result

Image

environment

Image

cslvjt avatar Aug 27 '25 08:08 cslvjt

cudnn version is 9.9.0

cslvjt avatar Aug 27 '25 08:08 cslvjt

transformer_engine version is 2.2.0+c55e425

cslvjt avatar Aug 27 '25 08:08 cslvjt

Hello everyone, I've discovered a new speed issue. When my input shape changes, the speed improvement of te's linear compared to torch's linear is different. Here is code

import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common import recipe
import time
import csv

def run_once(in_features, out_features, batch_size, iters=10):
    
    model = te.Linear(in_features, out_features, bias=True, device="cuda").half()
    inp = torch.randn(batch_size, in_features, device="cuda").half()
    torch_model = torch.nn.Linear(in_features, out_features, bias=True).cuda().half()
    
    fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

   
    for _ in range(10):
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            out = model(inp, is_first_microbatch=True)

    
    times = []
    with torch.no_grad():
        for i in range(iters):
            start = time.time()
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                out = model(inp, is_first_microbatch=False)
            torch.cuda.synchronize()
            end = time.time()
            times.append(end - start)
    avg_time_te = sum(times) / len(times)

    
    torch_times = []
    for i in range(iters):
        start = time.time()
        with torch.no_grad():
            out = torch_model(inp)
        torch.cuda.synchronize()
        end = time.time()
        torch_times.append(end - start)
    avg_time_torch = sum(torch_times) / len(torch_times)

    return avg_time_te, avg_time_torch

if __name__ == "__main__":
    in_feature_list = [1280, 5120, 10240, 20480]
    out_features_list = [1280, 5120, 10240, 20480]
    batch_size_list = [128, 256, 512]

    results = []
    for in_features in in_feature_list:
        for out_features in out_features_list:
            for batch_size in batch_size_list:
                print(f"=============Test in_features X out_features X batch_size {in_features}x{out_features}x{batch_size}")
                te_time, torch_time = run_once(in_features, out_features, batch_size)
                print(f"te linear average costime: {te_time*1000:.6f} ms")
                print(f"torch linear average costime: {torch_time*1000:.6f} ms")
                results.append([in_features, out_features, batch_size, te_time*1000, torch_time*1000])

   
    with open("results.csv", "w", newline='') as csvfile:
        writer = csv.writer(csvfile)
       
        writer.writerow(["in_features", "out_features", "batch_size", "te_time", "torch_time"])
        writer.writerows(results)

this is result

Image

cslvjt avatar Aug 28 '25 09:08 cslvjt

It looks like there are two issues:

  • Overheads from FP8 casts. When the data is small, the cost of casting to FP8 is more than the performance benefit from FP8 matmuls. Consider running with larger models or larger batch sizes. Also, consider using fused kernels that combine the FP8 cast with another operation, e.g. with te.LayerNormLinear, te.TransformerLayer, or the experimental operation fuser. To avoid repeatedly casting the weights to FP8, try storing the weights directly in FP8 by initializing within an fp8_model_init context.
  • CPU overhead. te.Linear is more complicated than torch.nn.Linear, and the extra logic and kernel launches result in situations where the runtime is limited by CPU rather than GPU. Your benchmark is a worst-case scenario for TE because it synchronizes the GPU at every step, when in practice the CPU overheads can be covered up by GPU compute from previous layers. That said, minimizing CPU overhead has been a continuous struggle and it is a major priority for future optimizations.

timmoon10 avatar Aug 30 '25 01:08 timmoon10

I found this issue to be similar with this one: https://github.com/NVIDIA/TransformerEngine/issues/2053

First of all, I agree with Tim that benchmark script is doing torch.cuda.synchronize() after every forward pass, this will enforce the synchronous execution between CPU and GPU, but in practice we can leverage the async nature of CPU/GPU execution and CPU side of time can be overlapped. So it's recommended to run the for loop, sync cuda, and then collect the end time. If the workload is training, then typically when the CPU is launching kernels for the forward pass, the GPU is running the kernels of previous backward pass, so CPU side time is not exposed E2E.

The FP8 cast overhead mentioned by Tim is also important. In our FP8 pretraining workload, we found M, K, N dimension (batch*seqlen, hidden, output ffn) >= 4K to be the sweet spot where FP8 gemms really bring huge benefit over BF16. The comparison is between fp8cast + fp8 gemm versus bf16 gemm.

I notice that the M dimension (batchsize * seqlen) you chose is among 128~ 512, which is quite small, is this the target workload for your model? TE was originally designed for larger scale pre-training, but we do realize that we also want to optimize it for other problem shapes for many more types of workloads.

Another small issue with the benchmark script is that we should use higher number of iterations to calculate the average time consumption, for example I would use at least iter=1000 as the default. I would also recommend delete the model and tensors at the end of run_once to avoid any impact of GPU memory & GC.

On the other hand, we do have recognized many places to accelerate the CPU side, and this is in our roadmap for sure. Thanks for raising this issue! As for immediate fix, we do recommend enabling cuda graph as much as possible to fully eliminate CPU side overhead, if your training framework is Nemo/Megatron, they already have flags to turn on cuda graph (but we do recommend newer version of TE, since there have been many fixes in the recent weeks). Inference engines should also have cuda graph integration, but I am not too familiar.

zhongbozhu avatar Aug 30 '25 03:08 zhongbozhu

Thank you for your reply. I now understand why the inference speed of FP8 is slower.

cslvjt avatar Sep 02 '25 09:09 cslvjt

I am currently using torch.profiler to measure the time taken by inference on both the CPU and GPU during code execution. The detailed results and code are shown below. From the results, it can be seen that the inference time on the GPU using te.linear with fp8 is the same as when using torch with fp16 on the GPU. I have three questions:

1、What is the specific function of each operator in te.linear? 2、What methods can I use to reduce the overhead on the CPU side? 3、What should I do to actually perform computations with fp8 precision? I would greatly appreciate it if you could provide sample code.

import torch
from torch.profiler import profile, record_function, ProfilerActivity
import transformer_engine.pytorch as te
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.common import recipe

def torch_profiler(input_shape, device, weight_type):
    print("torch profiler")
    batch_size, in_feature = input_shape
    out_feature = in_feature
    input_tensort = torch.randn(batch_size, in_feature).to(device, dtype=weight_type)
    model = torch.nn.Linear(in_feature, out_feature).to(device, dtype=weight_type)
    for _ in range(10):
        model(input_tensort)
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            for _ in range(100):
                model(input_tensort)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

def te_profiler(input_shape, device, weight_type):
    print("transformer engine profiler")
    batch_size, in_feature = input_shape
    out_feature = in_feature
    input_tensort = torch.randn(batch_size, in_feature).to(device, dtype=weight_type)
    model = te.Linear(in_feature, out_feature).to(device, dtype=weight_type)
    fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
    for _ in range(10):
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            model(input_tensort)
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            for _ in range(100):
                with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                    model(input_tensort)
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

if __name__ == "__main__":
    in_feature = 1024
    out_feature = 1024
    batch_size = 1024
    device = "cuda"
    weight_type = torch.float16
    torch_profiler((batch_size, in_feature), device, weight_type)
    te_profiler((batch_size, in_feature), device, weight_type)
    
Image

cslvjt avatar Sep 04 '25 11:09 cslvjt

I'd recommend benchmarking with a more representative workflow. A small linear layer by itself is unlikely to get full GPU utilization and it will miss out on important optimizations like kernel fusion. A good start would be this tutorial on finetuning a 7B Llama 2 model and this tutorial on inference with a 7B Gemma model.

1、What is the specific function of each operator in te.linear?

Based on your profile, it looks like the FP8 casts and matmuls happen in _Linear. I would guess that aten::cat and aten::reciprocal are used for computing scaling factors (see this tutorial on FP8 delayed scaling).

It's harder to interpret the CPU time in the remaining operations. They look like random tensor manipulations (e.g. for PyTorch autograd) or CUDA kernel launches.

2、What methods can I use to reduce the overhead on the CPU side?

Reducing CPU overhead is tricky and @zhongbozhu has already given a few ideas. For one, try to run with big enough problem sizes so that the GPU runtime is non-trivial and can cover up the CPU overhead. A more advanced technique is to use CUDA Graphs like in this Gemma inference tutorial. This can basically eliminate CPU overheads, but can be finicky and there are some operations that don't support it.

3、What should I do to actually perform computations with fp8 precision? I would greatly appreciate it if you could provide sample code.

The tutorials have sample code. At a high level, it should be sufficient to take a Transformer model, replace its Transformer layers with TE Transformer layers, and then wrap the forward pass in a fp8_autocast context.

timmoon10 avatar Oct 09 '25 00:10 timmoon10