[FP8][H100] training performance when te layers are mixed with torch.nn layers
Hi, I am training a model which has conv layers in addition to attention, linear layers. for conv layers,since we can't use layer modules from transformer engine, I added torch.nn layers. does this adversely affect the FP8 training performance on H100?
- Are there are any best practices for these scenarios.?
- How does the FP8 precision work for
torch.nnlayers.? As I understand, FP8 is not supported fortorch.nnlayers? do they default to default dtype? Does this mix and match affect loss?
You can freely mix and match the TE layers and the regular pyTorch layers.
When it comes to training performance - it generally depends on the size of the model (since the main speedup from FP8 comes from the Linear layers, the bigger they are the more speedup you should expect). The general rule is that the more "high level" API from TE you use, you should expect better performance, as this enables us to fuse e.g. the casts to and from FP8 to other operations, like LayerNorm. This performance difference is mostly visible in the smaller models where the ratio of GEMM to non-GEMM time is smaller and so overheads matter more.
FP8 is not supported for regular torch.nn modules so they will work as usual - TE modules output tensors in the default precision (FP32 or other if Automatic Mixed Precision or explicit cast was used) and those values are provided to the other modules.
@ptrendx is there a plan to support nn.Conv2d layers (layers apart from linear) in FP8?