TransformerEngine
TransformerEngine copied to clipboard
Improved performance of mxfp8 cast kernels
Description
Modified and tuned fused cast mxfp8 kernels for better performance. This CUDA kernel template implements optimized scaling along columnwise, rowwise, or both dimensions. It uses compile-time specialization and SIMD instructions to maximize compute performance. Data is transferred asynchronously via TMA into shared memory, processed in stages with prefetching overlap, and results are cached as needed.
Kernel Logic Overview
The kernel is a function template designed to handle multiple scaling scenarios, leveraging compile-time logic to generate specialized, optimized code paths. It consists of a common processing part shared across all scenarios, along with three specialized branches:
1. Columnwise scaling 2. Rowwise scaling 3. Both columnwise and rowwise scaling
Each branch introduces specific optimizations based on the template parameters provided at instantiation.
Since the kernel is compute-bound, performance is improved by favoring SIMD (vectorized) instructions over scalar ones. For example, operations like amax computation and element rescaling (multiplication) are performed on two elements simultaneously using a single instruction.
Common Part
The common part of the kernel performs the following tasks:
- Defines block and thread offsets.
- Manages memory barriers.
- Handles data transfers and updates.
Data transfer between global memory and shared memory (SHMEM) is performed using the TMA engine. SHMEM serves as the main workspace where input, output, and intermediate cached data (e.g., activations) reside. Due to the limited size of shared memory, the assigned data block (chunk) is partitioned into smaller tiles processed iteratively in stages.
Due to the asynchronous nature of the TMA engine, the kernel overlaps computation and memory transfer. At each iteration (except the last), data for the next stage is prefetched while the current stage's data is being processed.
dBias Computation
Thread-wise Reduction
The dBias value only needs to be computed once per input and is performed during the scaling phase:
- Columnwise scaling: dBias is computed in the columnwise phase.
- Rowwise scaling: dBias is computed in the rowwise phase.
- Both scalings: dBias is computed in the columnwise phase.
Block-wise Reduction
Depending on where thread-wise reduction occurred:
- When performed during the columnwise phase, accumulated values are stored directly into the workspace.
- When performed during the rowwise phase, elements must be un-swizzled to ensure proper indexing. Additionally, partial dBias results across waves are further reduced using an intermediate buffer that shares memory with the input buffer
Scaling Scenarios
1. Columnwise Scaling
Thread-to-data mapping: One thread processes one column (threads per block = number of columns). Processing: Each thread moves vertically down its assigned column, reading one element at a time into registers (REG) from SHMEM.
Computation:
- Compute the amax while reading.
- After reading 32 elements, calculate the scaling factor and rescale all cached elements.
- Store the rescaled elements back into SHMEM.
2. Rowwise Scaling
Thread-to-data mapping: One thread processes 32 consecutive (rowwise) elements (one MXFP8 block per thread, no inter-thread communication). E.g., for a chunk size of 128, four threads per row are needed. Processing: Each thread traverses its assigned block horizontally, reading and writing four elements at a time (PACK_SIZE = 4). The layout is treated as a closed ring: when a thread reaches the rightmost element of the block, it wraps around to the leftmost element to continue processing. This ring-based traversal pattern, along with swizzling, helps minimize shared memory bank conflicts.
NOTE: PACK_SIZE parameter can be fine tuned to slightly improve the performance in some scenarios, although it may have a negative impact on performance in other scenarios. E.g., with the --fast_math flag active, setting the PACK_SIZE=8 has the following performance impact vs. PACK_SIZE=4:
Computation:
- Compute amax during reads.
- Rescale cached elements (stored in registers).
- Store the rescaled elements back into SHMEM, applying an un-swizzling pattern to restore original indices.
A few examples of threads-to-data mapping between shared memory and registers (lower bar) while iterating through waves:
tid = 0 (tid_y = 0; tid_x = 0)
belongs to Group 0
initial offset = 0 elements
tid = 11 (tid_y = 2; tid_x = 3)
belongs to Group 2
initial offset = 8 elements
tid = 26 (tid_y = 6; tid_x = 2)
belongs to Group 6
initial offset = 24 elements
Threads-to-data mapping for a single warp during a single wave: This shows how a warp spans across the 32 banks of shared memory, avoiding bank conflicts when writing FP8 data during a single wave. Since four FP8 elements fit into a single 4-byte bank, vectorized writes (4×FP8) align naturally with the memory banks.
3. Both Columnwise and Rowwise Scaling (2x)
Scaling across both dimensions is performed sequentially in two phases:
1st Phase (Columnwise Scaling)
Same as the columnwise-only case, with one difference: if IS_DACT or IS_ACT template parameters are set, activations are computed and cached in a dedicated buffer within shared memory.
2nd Phase (Rowwise Scaling)
Instead of recomputing activations, the kernel reuses cached activations from the first phase. Only the amax computation and rescaling are performed; no dBias computation is required, as it was already completed in the first phase. The activations cache buffer shares memory with the input buffer of the corresponding stage.
Notes:
- Vectorization is key to performance, particularly for amax computation and rescaling.
- Swizzling and un-swizzling patterns are carefully used to minimize shared memory bank conflicts.
- Shared memory reuse (between activations cache, intermediate buffers, and input/output buffers) maximizes efficient memory utilization.
Benchmark scenario:
Input tensor 4096x13312, BF16
Output tensor 4096x13312, MXFP8 E4M3
Activation type (if used) GeLU/dGeLU
Kernel 1 - Fused Cast MXFP8
Kernel 2 - Reduce dBias
Performance was measured with NVIDIA Nsight Systems version 2025.2.1.130-252135690618v0 using the corresponding test suites. Runtime is measured in microseconds.
NOTE: CAST+DBIAS+DACT with --fast_math has relatively high speedup factors, because originally there was no such an option to compile cast.cu with this flag enabled, so this metric does not provide a direct performance comparison.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [x] Code refactoring
Changes
- Increased block size from
64x64to128x128(except for CAST_DBIAS) - Modified the scheme of how tensor elements are processed. ROWWISE: 32 elements per thread (i.e. single scaling factor)
- Added micro-optimizations (e.g. using
MUL2instructions; fusing MUL and CVT instructions directly in PTX)
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci
/te-ci
/te-ci
/te-ci
/te-ci L0 L1
/te-ci
/te-ci
Hi,
Regarding the recent changes to the JAX unit test, I think it's fine to proceed with the current approach for now so that we can upstream this PR soon.
Looking ahead, I would prefer that we align the JAX-based implementation with the TE kernel for the following reasons:
- We want to benchmark end-to-end workloads both with and without TE custom calls. If the TE kernel stores intermediate results in dtype while the JAX-based implementation uses FP32, the comparison will no longer be apples-to-apples.
- Since we have already agreed to store intermediate results in dtype going forward—and other JAX-based frameworks (such as MaxText) are also adopting this approach—I don’t see a compelling reason to do otherwise for our JAX-based implementation.
What do you think? @jberchtold-nvidia @ptrendx
- Since we have already agreed to store intermediate results in dtype going forward—and other JAX-based frameworks (such as MaxText) are also adopting this approach—I don’t see a compelling reason to do otherwise for our JAX-based implementation.
Sounds good, I agree
/te-ci
/te-ci
/te-ci
/te-ci