TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Improve CP P2P efficiency
Description
The original implementation saves each attention layer's whole K, V.
Instead, we can discard used K, V to ensure each GPU only holds the current and the next K, V.
Memory profiling results using the toy unit-test
Original:
This MR:
One can see that the unused K, V are immediately released.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [x] Code refractor