torchrec
torchrec copied to clipboard
support non contiguous sharding
Summary: More of an RFC diff. depends on https://github.com/pytorch/torchrec/pull/1837 (see diagram there)
The high level idea is we want to disagg the dense and sparse tower placement in rec model distributed training.
Let's say we have 2 DGX hosts with 16 GPUs.
Today:
We flat shard DMP/FSDP onto the 16 GPUs. A2A/AG/RS would be world size of 16. This poses challenge on scalability as the model would quickly be comm bound above 128 GPUs.
After:
We allow for logical segregated placement. E.g. for the same 16 GPUs, we can do 1:3 split and place sparse onto 4, and dense onto 12.
To leverage intra nvswitch connect, we can do
[
[0 | 1 2 3],
[4 | 5 6 7],
[
[8 | 9 10 11],
[12 | 13 14 15],
]
placement.
That way, the world size becomes 4 and 12 respectively. And across them we use P2P comm.
Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262
This pull request was exported from Phabricator. Differential Revision: D55577262