Tim Moon

Results 227 comments of Tim Moon

Merging with approval from @ptrendx and @ksivaman.

@binxuan This implementation matches [`torch.nn.LayerNorm`](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html): ```python x_2d = x.reshape(-1, prod(normalized_shape)) y_2d = layer_norm_2d(x_2d, weight.reshape(-1), bias.reshape(-1)) y = y_2d.reshape(x.size()) ```

The easiest approach is to use native PyTorch FP8 dtypes: ```python x = torch.randn(128, device="cuda", dtype=torch.float32) y = x.to(dtype=torch.float8_e4m3fn) # or torch.float8_e5m2 ``` You could also use [`transformer_engine.pytorch.Float8Tensor`](https://github.com/NVIDIA/TransformerEngine/blob/744624d004f4514ffbaa90ac83e214311c86c607/transformer_engine/pytorch/float8_tensor.py#L329) / [`float8_experimental.Float8Tensor`](https://github.com/pytorch-labs/float8_experimental/blob/57136bdaa4d181d3bb10e54537ad119551b55c11/float8_experimental/float8_tensor.py#L173):...

If you just want the performance benefit of FP8 matmuls, I recommend using Transformer Engine modules (like [`te.Linear`](https://github.com/NVIDIA/TransformerEngine/blob/744624d004f4514ffbaa90ac83e214311c86c607/transformer_engine/pytorch/module/linear.py#L646)) in your model (see this [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)). They will internally handle the...

It depends on the type of communication. For FP8 with delayed scaling: - Tensor-parallel communication: all-gather in FP8 (see [`_all_gather_fp8`](https://github.com/NVIDIA/TransformerEngine/blob/a7eeb28bd917a647abf7854fa22239b8ee85c2af/transformer_engine/pytorch/distributed.py#L844)), reduce-scatter in BF16 (see [`reduce_scatter_along_first_dim`](https://github.com/NVIDIA/TransformerEngine/blob/a7eeb28bd917a647abf7854fa22239b8ee85c2af/transformer_engine/pytorch/distributed.py#L821)) - PyTorch FSDP: param all-gather...

Do you need both DDP and FP8 params for your use-case? We haven't considered this combination so far since optimizing FP8 params tends to have poor convergence. There are a...