NeMo
NeMo copied to clipboard
[Draft][PyTorch] Add context parallel support for packed dataset in THD format
What does this PR do ?
This PR adds context parallel support for packed dataset in THD format in NeMo in response to this TE PR: https://github.com/NVIDIA/TransformerEngine/pull/641. Currently, the TE PR requires each individual sequence length is divisible by (2*context_parallel_size).
Changes
- Add support to split packed dataset across different CP ranks in a load balanced way
- Add necessary paddings to dataset during packing stage to make sure the individual sequence length is a multiple of 2*cp_size
PR Type:
- [x] New Feature
- [ ] Bugfix
- [ ] Documentation