TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Sequential fuser

Open timmoon10 opened this issue 1 year ago • 1 comments

Currently, Transformer Engine exposes fused operations with custom modules like LayerNormLinear. These are highly tuned for certain workloads (especially GPT), but are not easy to generalize to other models. This approach is especially cumbersome when the forward and backward passes have different fusion opportunities (e.g. forward GEMM+bias+gelu and backward dgelu+dbias+cast+transpose).

This PR adds a new API for specifying Transformer Engine models. Instead of using large compound modules (e.g. LayerNormLinear), users can build up a Sequential module out of small FusableOperations (e.g. LayerNorm, Linear). The Sequential module (with a similar API as torch.nn.Sequential) will internally attempt to fuse operations together (possibly differently in the forward and backward passes).

Some of the more important components:

  • te.fuser.ops.FusableOperation: A neural network operation that can be processed by the fuser. They have forward and backward functions similar to torch.autograd.Function.
  • te.fuser.ops.UnfusedOperation: A minimal FusableOperation. Their forward and backward functions must be implemented and they should hold the model state and parameters.
  • te.fuser.ops.FusedOperation: A FusableOperation that is interchangeable with multiple UnfusedOpeations. If it implements a forward or backward function, they must save the same context as the UnfusedOperations.
  • te.fuser.Sequential: A container module with a similar API as torch.nn.Sequential.
  • te.fuser.Fuser: A helper class that manages autograd, performs the operation fusions, and keeps track of corresponding UnfusedOperations and FusedOperations.

As a proof-of-concept, I've been able to fuse Linear and Bias operations, on a single GPU and with tensor parallelism. These modules have been implemented to support Float8Tensor, which simplifies the implementation and will be important for future work with e.g. FP8 attention. I've also added single-GPU and multi-GPU tests.

This work is heavily influenced by https://github.com/NVIDIA/TransformerEngine/pull/377 from @janekb04.

Remaining tasks:

  • [x] FP8 scaling factor updates
  • [ ] Checkpointing
  • [x] Documentation

Future work:

  • [ ] Operations: layer norm, activations, attention
  • [ ] Fusions
  • [ ] Possibly reimplementing the existing modules using this infrastructure

timmoon10 avatar Mar 09 '24 03:03 timmoon10

/te-ci pytorch

timmoon10 avatar May 13 '24 17:05 timmoon10

/te-ci pytorch

timmoon10 avatar May 30 '24 21:05 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 04 '24 23:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 08 '24 00:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 10 '24 21:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 11 '24 18:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 12 '24 17:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 13 '24 01:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 13 '24 18:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 15 '24 00:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jun 27 '24 19:06 timmoon10

/te-ci pytorch

timmoon10 avatar Jul 08 '24 20:07 timmoon10

Merging with approval from @ksivaman, @sudhakarsingh27, @ptrendx. This feature is still experimental and incomplete.

timmoon10 avatar Jul 09 '24 22:07 timmoon10