TransformerEngine
TransformerEngine copied to clipboard
A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization i...
# Description UnfusedDotProductAttention in TE uses -10000 to fill in the attention mask, but the value is not small enough for some cases which leads to large diff between TE...
# Description Please include a brief summary of the changes, relevant motivation and context. Fixes # (issue) ## Type of change - [ ] Documentation change (change only to the...
# Description This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op. ## Type of change - [ ] Documentation change (change only to the documentation, either a...
# Description This is a continuation of the efforts in #2357. [FA3](https://github.com/Dao-AILab/flash-attention/blob/fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89/hopper/flash_attn_interface.py#L269) allows users to use the `num_splits` option to control the number of kernels launched for attention, which could...
# Description This PR is one of the many on-going grouped kernels for NVFP4 to reduce CPU overhead and reduce quantization cost. **This PR is ready for code review** Action...
# Description This MR enables one to specify the `cp_rank` to `get_batch_on_this_cp_rank` which lets one determine the batches for a specific rank without needing to provide full batches all ranks...
# Description Please include a brief summary of the changes, relevant motivation and context. Fixes # (issue) ## Type of change - [ ] Documentation change (change only to the...
# Description Based on single GPU profiling of the `GroupedLinear` module, implement some optimizations in order to reduce CPU overhead due to PyTorch. ## Type of change - [ ]...
# Description PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX. Fixes #2070 ``` BEFORE ================================================================================ TEST RUNTIME SUMMARY (grouped by function)...
# Description In certain cases the random sign mask and the normalization applied to the RHT matrix is not cached in TE/JAX, leading to a slight perf impact be these...