torchrec
torchrec copied to clipboard
Implement per table sharding in DI sharder
Summary:
Overview
https://fburl.com/gdoc/efm9by7o
-
Generate table-wise sharding for LocationType.HBM and row-wise sharding for LocationType.DI_HOST
- So we can directly edit the DIShardingPass with extra arguments
- For table-wise sharding, we equally split them to available GPUs based on available memory (_split_n_list_min_difference)
-
Generate per parameter sharding and apply construct_module_sharding_plan to get a customized sharding plan per table
- Fix Torchrec sharding with RW TW hybrid
-
Update tag_rules for different RW / TS tbes
- di_sharding_spec can be used to tell which TBEs are DI (remote_ro, inter) and which are Torchrec (remote)
-
Device allocation for Torchrec sharding need to be remapped
- When di world size is 3, TBE=0-2 can be DI sharding TBEs, GPU TBE=4-5.
- So we need to generate a device_map for
generate_app_graphanddevice allocation str
-
Also need to keep an eye on _unwrap_kjt, as not all of them should go to remote. For the one going to DI, it is better to leave them in the local net split into remote, local, remote_ro, inter; app_graph, remote(_unwrap_kjt -> device), local(_unwrap_kjt->None) P1190912023 , P1190914045
- choose to change tag_rules with
_unwrap_kjt_4: remotefor now - In the long run, we can put all
_unwrap_kjtto local once RO can do rebatching for tensor at boundary (YazhiGao is working on this); then we can simply usedevice_mapto determine the target device.
- choose to change tag_rules with
This diff
Implement GPU sharding in DI sharding pass, so we have both GPU and DI sharding combined in one pass:
- Same as Torchrec create_infer_embedding_sharding (https://fburl.com/code/5r596mkz), created per table sharding based on Sharding schema.
- Allow ShardingEnv be a per sharding_type dictionary.
Differential Revision: D54570308
This pull request was exported from Phabricator. Differential Revision: D54570308
This pull request was exported from Phabricator. Differential Revision: D54570308
This pull request was exported from Phabricator. Differential Revision: D54570308