TransformerEngine
TransformerEngine copied to clipboard
A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization i...
Hey, I'm using the `te_gemm` function defined in the PyTorch extensions [here](https://github.com/cli99/TransformerEngine/blob/6b21f606f2459d49c2113d69236d68d334edeb4c/transformer_engine/pytorch/csrc/extensions/gemm.cu#L10), and I'm trying to apply a scaling factor to the output. My gemm inputs are in fp8e4m3 and...
The CMAKE configuration failed with the following error. The same error is observed in both *stable* and *main* branch. ```text -- JAX support: OFF -- Configuring done CMake Error at...
According to #438 we should be able to use both BF16 and FP8 autocasts. In our specific setting our module consists of some linear layers that are `torch.nn.Linear` and some...
Support main_grad and fuse_wgrad_accumulation
Hi, I'm seeing higher losses using `te.Linear` over `nn.Linear` directly in transformer models such as Llama which I assume is expected due to the nature of FP8. However, I don't...
In recent PyTorch 2.2.0 release, they have deprecated NVFuser in torch script with this [warning](https://github.com/pytorch/pytorch/blob/v2.2.0/torch/csrc/jit/python/init.cpp#L759-L762). See this [commit](https://github.com/pytorch/pytorch/commit/e6b5e0ecc609c15bfee5b383fe5c55fbdfda68ff). We are running into tests failure on TransformerEngine when running the following...
Setting ub_overlap_rs_dgrad to True in megatron-LM will raise "Caught signal 8 (Floating point exception: integer divide by zero) "error, which was eventually found to be caused by a problem with...
TorchDynamo has known limitations for `autograd.Function` implementations and `autograd.graph` hooks. Activation recompute utilizes *both* of those mechanisms, so this PR disables TorchDynamo on `te.distributed.checkpoint()` via the `@no_torch_dynamo()` decorator.
Currently the number of GQA groups per partition for DotProductAttention is computed using a default value, rather than the actual value computed earlier in the initializer, causing errors when tensor...