TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Training the 1B model on H800 resulted in a decrease in throughput

Open forevergj opened this issue 9 months ago • 3 comments

Using FP8 to train a 1B model on H800 resulted in a significant decrease in throughput performance compared to FP16. However, upon examining the pytorch profiler, there is a significant performance gap in Linear. What is the reason for this performance MSAMP+TransformerEngine image

forevergj avatar May 07 '24 02:05 forevergj

Adding @timmoon10 fir visibility. The slowdown comes from the CPU overheads as you can see in the timeline that you posted. As you can see about 30% of the CPU time is spent in MSAMP routines, so we can't really help with that, but there is still significant time spent in TE-specific layers. We are working on fixing those, see e.g. PR #820 that is a first step of that effort.

ptrendx avatar May 09 '24 18:05 ptrendx

Adding @timmoon10 fir visibility. The slowdown comes from the CPU overheads as you can see in the timeline that you posted. As you can see about 30% of the CPU time is spent in MSAMP routines, so we can't really help with that, but there is still significant time spent in TE-specific layers. We are working on fixing those, see e.g. PR #820 that is a first step of that effort.

If MSAMP is not used, the time of te.Liner is still greater than that of te.Linear, is it because of the larger shape that fp8 has performance benefits? te.Linear and nn.Linea ,What are the differences between these two APIs?Does the te.Linear operation involve the process of quantization and de-quantization during forward and backward propagation?

forevergj avatar May 13 '24 12:05 forevergj

torch.nn.Linear's forward pass just requires a single cuBLAS GEMM kernel. te.Linear is more complicated since it also launches kernels for the FP8 cast, caching FP8 scales for the backward pass, and updating FP8 scales. These kernels are fast, so they usually take longer to schedule (on CPU) than to actually execute (on GPU). We've also found that PyTorch operations can add significant CPU overhead, especially when using its Python API.

Solutions:

  • Increase the model size or data size until the GEMM kernels cover up the CPU kernel launches. This isn't feasible if the model is fixed or if you're hitting memory constraints.
  • Capture a CUDA graph with te.make_graphed_callables. When the graph is replayed, it will launch all the kernels with none of the CPU overhead. This requires static compute graphs and fixed data sizes, and I am not sure how well MSAMP supports it.
  • Consider using compound TE modules like te.LayerNormLinear, te.LayerNormMLP, or te.TransformerLayer. These modules fuse the FP8 casts with other operations like LayerNorm, so the number of extra kernel launches is smaller.
  • As @ptrendx mentioned, we are slowly chipping away at other CPU overheads in TE. Unfortunately a lot of this is just the cost of using PyTorch.

timmoon10 avatar May 14 '24 21:05 timmoon10