TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization i...

Results 414 TransformerEngine issues
Sort by recently updated
recently updated
newest added

# Description UnfusedDotProductAttention in TE uses -10000 to fill in the attention mask, but the value is not small enough for some cases which leads to large diff between TE...

# Description Please include a brief summary of the changes, relevant motivation and context. Fixes # (issue) ## Type of change - [ ] Documentation change (change only to the...

# Description This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op. ## Type of change - [ ] Documentation change (change only to the documentation, either a...

2.10.0

# Description This is a continuation of the efforts in #2357. [FA3](https://github.com/Dao-AILab/flash-attention/blob/fbf24f67cf7f6442c5cfb2c1057f4bfc57e72d89/hopper/flash_attn_interface.py#L269) allows users to use the `num_splits` option to control the number of kernels launched for attention, which could...

2.10.0

# Description This PR is one of the many on-going grouped kernels for NVFP4 to reduce CPU overhead and reduce quantization cost. **This PR is ready for code review** Action...

# Description This MR enables one to specify the `cp_rank` to `get_batch_on_this_cp_rank` which lets one determine the batches for a specific rank without needing to provide full batches all ranks...

# Description Please include a brief summary of the changes, relevant motivation and context. Fixes # (issue) ## Type of change - [ ] Documentation change (change only to the...

# Description Based on single GPU profiling of the `GroupedLinear` module, implement some optimizations in order to reduce CPU overhead due to PyTorch. ## Type of change - [ ]...

# Description PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX. Fixes #2070 ``` BEFORE ================================================================================ TEST RUNTIME SUMMARY (grouped by function)...

2.10.0
attention

# Description In certain cases the random sign mask and the normalization applied to the RHT matrix is not cached in TE/JAX, leading to a slight perf impact be these...