TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Minor optimizations to reduce CPU overheads in modules

Open timmoon10 opened this issue 1 year ago • 5 comments

Description

We have observed that TE modules experience non-trivial CPU overhead, which often becomes a performance bottleneck in the forward pass of small models. For example, measuring the CPU runtime for Megatron-core modules with BF16 compute and TP=1:

Unfortunately this overhead is distributed throughout the forward pass. Many basic PyTorch operations, e.g. getting attributes from torch.Tensor, involve O(1 us) overhead, so even basic checks to handle all of our advanced features will eventually add up to something non-trivial.

This PR makes a few minor optimizations:

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

I see a 1.22x speedup, with 115 us per forward pass.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [x] Code refractor

Changes

  • Avoid importing from te.pytorch.cpu_offload in every forward pass
  • Memoize NCCL process group properties
  • Avoid custom logic in torch.nn.Module.__setattr__ when possible
  • Avoid custom logic for accessing params in torch.nn.Module when possible
  • Avoid accessing tensor attrs more than necessary

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

timmoon10 avatar Sep 18 '24 23:09 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 25 '24 22:09 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 27 '24 00:09 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 27 '24 23:09 timmoon10

/te-ci pytorch

timmoon10 avatar Oct 01 '24 20:10 timmoon10

/te-ci pytorch

timmoon10 avatar Oct 02 '24 23:10 timmoon10