[PyTorch] Sequential fuser
Currently, Transformer Engine exposes fused operations with custom modules like LayerNormLinear. These are highly tuned for certain workloads (especially GPT), but are not easy to generalize to other models. This approach is especially cumbersome when the forward and backward passes have different fusion opportunities (e.g. forward GEMM+bias+gelu and backward dgelu+dbias+cast+transpose).
This PR adds a new API for specifying Transformer Engine models. Instead of using large compound modules (e.g. LayerNormLinear), users can build up a Sequential module out of small FusableOperations (e.g. LayerNorm, Linear). The Sequential module (with a similar API as torch.nn.Sequential) will internally attempt to fuse operations together (possibly differently in the forward and backward passes).
Some of the more important components:
te.fuser.ops.FusableOperation: A neural network operation that can be processed by the fuser. They have forward and backward functions similar totorch.autograd.Function.te.fuser.ops.UnfusedOperation: A minimalFusableOperation. Their forward and backward functions must be implemented and they should hold the model state and parameters.te.fuser.ops.FusedOperation: AFusableOperationthat is interchangeable with multipleUnfusedOpeations. If it implements a forward or backward function, they must save the same context as theUnfusedOperations.te.fuser.Sequential: A container module with a similar API astorch.nn.Sequential.te.fuser.Fuser: A helper class that manages autograd, performs the operation fusions, and keeps track of correspondingUnfusedOperations andFusedOperations.
As a proof-of-concept, I've been able to fuse Linear and Bias operations, on a single GPU and with tensor parallelism. These modules have been implemented to support Float8Tensor, which simplifies the implementation and will be important for future work with e.g. FP8 attention. I've also added single-GPU and multi-GPU tests.
This work is heavily influenced by https://github.com/NVIDIA/TransformerEngine/pull/377 from @janekb04.
Remaining tasks:
- [x] FP8 scaling factor updates
- [ ] Checkpointing
- [x] Documentation
Future work:
- [ ] Operations: layer norm, activations, attention
- [ ] Fusions
- [ ] Possibly reimplementing the existing modules using this infrastructure
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
Merging with approval from @ksivaman, @sudhakarsingh27, @ptrendx. This feature is still experimental and incomplete.