Enable DoMINO parallelization via ShardTensor
Description
This PR is the initial version of domain parallelism for DoMINO. Currently, the forward pass (inference) is supported though there are some numerical instabilities to track down with Conv3d and certain shapes.
Included in this PR are a number of other pending parallelization operations, many of which are required for DoMINO but not all. These include:
- Sequence parallel attention via ring attention
- normalization layers (group Norm, layer norm via DTensor)
- ConvTranspose
- Upsampling via torch interpolate
- MaxPooling and AvgPooling
The halo passing (and new ring algorithm) are also reorganized. This should help make the halo algorithm more readable and maintainable.
There are some remaining optimizations to be done:
- The ring message passing should be made into an async op
- The convolution and ball query wrappers should avoid "infer" for shard tensor creation to avoid blocking
To include testing of all the distributed algorithms would add several hundred distributed tests. With the fork-every-test set up, this is just not feasible in the CI. The tests currently exist in another repository for validating numerical correctness of external operations. PhysicsNeMo-implemented operations (BallQuery) have tests implemented here.
(After the release of 25.03 and renaming, I had to do some manual merging on many files. so there are changed files unrelated to this PR)
Checklist
- [x] I am familiar with the Contributing Guidelines.
- [x] New or existing tests cover these changes.
- [ ] The documentation is up to date with these changes.
- [ ] The CHANGELOG.md is up to date with these changes.
- [ ] An issue is linked to this pull request.
Dependencies
External testing repo, to be opened or perhaps merged.
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci