TransformerEngine
TransformerEngine copied to clipboard
A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization i...
Hello I want to install TE using pip: `pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable` But I got the following error during installation: ``` Collecting git+https://github.com/NVIDIA/TransformerEngine.git@stable Cloning https://github.com/NVIDIA/TransformerEngine.git (to revision stable) to /tmp/pip-req-build-c6l34itl Running...
It seems that there are some breaking API changes in the main branch of `cudnn-frontend`. This cause the compilation of TE's `main` branch to fail. Some of the error messages:...
I use cuda 12.1.1 to build TE form source, stable、main and v1.3 branch, all of them can install successfully, but flash-attention installed by TE doesn’t work at all. `import flash_attn_2_cuda...
Code: ``` #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" #include #include #include #include using namespace transformer_engine; void GetSelfFusedAttnForwardWorkspaceSizes( size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type...
Version: latest stable Currently, the version constraint for `flash-attn` is: https://github.com/NVIDIA/TransformerEngine/blob/b8eea8aaa94bb566c3a12384eda064bda8ac4fd7/setup.py#L269 So most likely `v2.4.2` is going to be installed. However, this version seems to have some issues when imported,...
Hi, What is the correct `fp8_group` when using FSDP and tensor parallelism together? Is it all gpus or between tensor parallel groups? Thanks.
Currently importing transformer_engine takes ~10s on my machine and it also starts a background process pool because of all the JIT initialization like [here](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/jit.py#L50-L54) . It would be better if...

Thanks for the awesome library! I'm wondering whether there are plans to provide ops support for `grouped_gemm` as in https://github.com/tgale96/grouped_gemm/tree/main As of more information, it seems that fp8 is supported...
I've noticed that FP8 training is slower when finetuning BERT-large model in large multi-node setting. I have tested this on MLPerf training benchmark. Could someone explain the underlying reasons behind...