torchrec
torchrec copied to clipboard
TorchRec Sharding Composability
Desired behavior
We want to make TorchRec sharding composable w/ other sharding/parallelism techniques. This practically means that after applying TorchRec sharding model characteristics remain the same (e.g. state_dict() doesn’t change) and we don’t effect non-sharded parts of a model, e.g.
m = Model(device="meta") orig_keys = list(m.state_dict().keys()) m.ebc = torchrec.shard_embedding_bag_collection(m.ebc, ....) sharded_keys = list(m.ebc.state_dict().keys())
assert orig_keys == sharded(keys) # at all ranks, even for table-wise sharded
Observation for nn.Module
- state_dict() - recursive call via state_dict()
- load_state_dict() : recursive call via local load() call, which calls _load_from_state_dict
- named_parameters()/named_buffers() : calls named_modules
- named_modules() : recursive call
As we don’t want to modify state_dict/named_parameters etc impls of non-sharded modules (including top-level module) we have to ensure that _modules is consistent
Motivations
- Current implementation (of named_parameters()/named_buffers()/named_modules()) is consistent with nn.module defaults currently
- This logic is tricky to get right
- Removes need for custom code (e.g tricks in DMP and ShardedModules to keep module FQN consistent)
- Currently state_dict does not contain state of all tensors
- if a table is sharded table_wise, then only that rank’s state_dict sees this
- easier integration with other distributed solutions (e.g FSDP for DHEN).
- We need to share a similar semantics of what is returned by nn.Module APIs, and this will make things consistent.
- We can also do this within DMP nn.Module overrides, but that route is much heavier and bug-prone
- We need to share a similar semantics of what is returned by nn.Module APIs, and this will make things consistent.
Proposed APIs
Use the trick employed by torch.fx.GraphModule of registering empty nn.Modules for handling state and have implementation of forward() call based on a different module hierarchy.
new APIs
def shard_embedding_modules(model, sharders, plan: EmbeddingShardingPlan):
# This module will replace all child ebcs/fused_ebcs modules with sharded
# variants based on the plans and init meta tensors.
# it returns the unsharded modules and their replacements
# this useful for identifying which modules are already sharded
return sharded_parameter_names
def dlrm_parallelize(model,
sharders=[EmbeddingBagCollectionSharder, FusedEmbeddingBagCollectionSharder, ...],
embedding_plan: Optional[EmbeddingShardingPlan]):
# This style of sharding will shard embedding modules with torchrec sharder
# and everything else in a distributed data parallel fashion.
if embedding_plan is None
embedding_plan = EmbeddingPlanner(sharders, topology, ...)
sharded_embedding_modules = shard_embedding_modules(model, plan)
# DDP can be done in two ways
# with DDP wrapper:
# This will rely on DDP being composable
model = DistributedDataParallel(model,
params_and_buffers_to_ignore=[get_params_and_buffers(sharded_embedding_modules)]
)
modules = model.modules()
# or replace underlying module tensors with DataParallelTensor
for name, param in model.named_parameters():
if param not in get_params_and_buffers(sharded_embedding_modules):
model.register_parameter(name, DistributedParallelTensor(param))
# similar wrapper can be done with FSDP, or just do it in model definition
class ShardedEmbeddingBagCollection()
def __init__(self):
# modules to handle state
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
# these are sharded modules, Dict is not registered as sub-moduels
self._lookups: Dict[str, nn.Module] ={}
# Add parameters/buffers from self._lookups to self.embedding_bags
def forward(self, KJT):
# unchanged impl
Model Authoring
embedding_bag_configs = [
EmbeddingBagConfig(
name="t1", embedding_dim=4, num_embeddings=10, feature_names=["f1"]
),
EmbeddingBagConfig(
name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
),
EmbeddingBagConfig(
name="t3", embedding_dim=4, num_embeddings=10, feature_names=["f3"]
)
]
model = DLRMModel(ebc=EmbeddingBagCollection(embedding_bag_configs))
model = fuse_embedding_optimizer(model, ...) # recursively replaces model.ebc with FusedEBC
# using wrappers to parallelize
dlrm_parallelize(model)
# or explicitly do it
shard_embedding_modules(model,
plan={"embedding_bag_configs":
{"t1": ParameterSharding(sharding_type=ROW_WISE, placement=DEVICE, ranks=...)},
{"t2": ParameterSharding(sharding_type=ROW_WISE, placement=UVM_CACHING, ranks=...)},
})
print(model.named_parameters())
>>> [overarch..., linear,..., embedding_bag_configs.t1.weights: ShardedTensor]
opt = torch.optim.SGD(model.parameters(), lr=.02)
# if we do not use FusedEBC, then we expect that these ShardedTensors have
# a grad field (or maybe these are ShardedParameters), and optimizers will
# naturally work on top of them
# If we have FusedEBC, their grads will be None, and opt.step will be a no-op
print(model.state_dict()["embedding_bag_configs.t1.weight"])
>>> ShardedTensor(rank0: (tensor, size, offset), rank1: (tensor, size, offset))
# If t1 is placed on rank 1, rank 0's state dict will still see the ShardedTensor
# but it's local shards will be empty.
I have a question about the code snippet at the beginning of the previous comment.
m = Model(device="meta")
orig_keys = list(m.state_dict().keys())
m.ebc = torchrec.shard_embedding_bag_collection(m.ebc, ....)
sharded_keys = list(m.ebc.state_dict().keys())
assert orig_keys == sharded(keys) # at all ranks, even for table-wise sharded
What is m.ebc? Why are we supposed to compare m with m.ebc?
Do we mean the following instead?
m = Model(device="meta)
n = torchrec.shard_embeddings(m)
assert m.state_dict().keys() == n.state_dict().keys()
cc. @YLGH , @colin2328
@wangkuiyi sorry for the delay :)
I think the snippet is a bit confusing, but the core API as landed is shard_embedding_modules https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/shard_embedding_modules.py#L24
This will replace (module swap) the embedding modules and return a module that you can compare like you indicated
assert m.state_dict().keys() == n.state_dict().keys()
See https://github.com/pytorch/torchrec/blob/main/examples/golden_training/train_dlrm.py#L95 for an example in action
cc @YLGH
This is landed in master, and will be going out in the next stable release