torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

TorchRec Sharding Composability

Open YLGH opened this issue 3 years ago • 2 comments

Desired behavior

We want to make TorchRec sharding composable w/ other sharding/parallelism techniques. This practically means that after applying TorchRec sharding model characteristics remain the same (e.g. state_dict() doesn’t change) and we don’t effect non-sharded parts of a model, e.g.

m = Model(device="meta") orig_keys = list(m.state_dict().keys()) m.ebc = torchrec.shard_embedding_bag_collection(m.ebc, ....) sharded_keys = list(m.ebc.state_dict().keys())

assert orig_keys == sharded(keys) # at all ranks, even for table-wise sharded

Observation for nn.Module

  • state_dict() - recursive call via state_dict()
  • load_state_dict() : recursive call via local load() call, which calls _load_from_state_dict
  • named_parameters()/named_buffers() : calls named_modules
  • named_modules() : recursive call

As we don’t want to modify state_dict/named_parameters etc impls of non-sharded modules (including top-level module) we have to ensure that _modules is consistent

Motivations

  • Current implementation (of named_parameters()/named_buffers()/named_modules()) is consistent with nn.module defaults currently
    • This logic is tricky to get right
  • Removes need for custom code (e.g tricks in DMP and ShardedModules to keep module FQN consistent)
  • Currently state_dict does not contain state of all tensors
    • if a table is sharded table_wise, then only that rank’s state_dict sees this
  • easier integration with other distributed solutions (e.g FSDP for DHEN).
    • We need to share a similar semantics of what is returned by nn.Module APIs, and this will make things consistent.
      • We can also do this within DMP nn.Module overrides, but that route is much heavier and bug-prone

Proposed APIs

Use the trick employed by torch.fx.GraphModule of registering empty nn.Modules for handling state and have implementation of forward() call based on a different module hierarchy.

new APIs

def shard_embedding_modules(model, sharders, plan: EmbeddingShardingPlan):
   # This module will replace all child ebcs/fused_ebcs modules with sharded
   # variants based on the plans and init meta tensors.
   
   # it returns the unsharded modules and their replacements
   # this useful for identifying which modules are already sharded
   return sharded_parameter_names

def dlrm_parallelize(model, 
   sharders=[EmbeddingBagCollectionSharder, FusedEmbeddingBagCollectionSharder, ...], 
   embedding_plan: Optional[EmbeddingShardingPlan]):
  
   # This style of sharding will shard embedding modules with torchrec sharder
   # and everything else in a distributed data parallel fashion.
   
   if embedding_plan is None
      embedding_plan = EmbeddingPlanner(sharders, topology, ...)
      
   sharded_embedding_modules = shard_embedding_modules(model, plan)
   
   # DDP can be done in two ways
   # with DDP wrapper:
   # This will rely on DDP being composable
   
   model = DistributedDataParallel(model, 
               params_and_buffers_to_ignore=[get_params_and_buffers(sharded_embedding_modules)]
           )      
   modules = model.modules()
   
   # or replace underlying module tensors with DataParallelTensor
   for name, param in model.named_parameters():
       if param not in get_params_and_buffers(sharded_embedding_modules):
          model.register_parameter(name, DistributedParallelTensor(param))

   
# similar wrapper can be done with FSDP, or just do it in model definition

class ShardedEmbeddingBagCollection()
    def __init__(self):
        # modules to handle state
        self.embedding_bags: nn.ModuleDict = nn.ModuleDict() 
        # these are sharded modules, Dict is not registered as sub-moduels
        self._lookups: Dict[str, nn.Module] ={} 
        # Add parameters/buffers from self._lookups to self.embedding_bags 
        
    def forward(self, KJT):
        # unchanged impl      

Model Authoring

embedding_bag_configs = [
   EmbeddingBagConfig(
       name="t1", embedding_dim=4, num_embeddings=10, feature_names=["f1"]
   ),
   EmbeddingBagConfig(
       name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
   ),
   EmbeddingBagConfig(
       name="t3", embedding_dim=4, num_embeddings=10, feature_names=["f3"]
   )
]
model = DLRMModel(ebc=EmbeddingBagCollection(embedding_bag_configs))
model = fuse_embedding_optimizer(model, ...) # recursively replaces model.ebc with FusedEBC

# using wrappers to parallelize
dlrm_parallelize(model)

# or explicitly do it
shard_embedding_modules(model, 
   plan={"embedding_bag_configs": 
    {"t1": ParameterSharding(sharding_type=ROW_WISE, placement=DEVICE, ranks=...)},
    {"t2": ParameterSharding(sharding_type=ROW_WISE, placement=UVM_CACHING, ranks=...)},
   })
   

print(model.named_parameters())
>>> [overarch..., linear,..., embedding_bag_configs.t1.weights: ShardedTensor]

opt = torch.optim.SGD(model.parameters(), lr=.02)

# if we do not use FusedEBC, then we expect that these ShardedTensors have 
# a grad field (or maybe these are ShardedParameters), and optimizers will
# naturally work on top of them
# If we have FusedEBC, their grads will be None, and opt.step will be a no-op

print(model.state_dict()["embedding_bag_configs.t1.weight"])
>>> ShardedTensor(rank0: (tensor, size, offset), rank1: (tensor, size, offset))

# If t1 is placed on rank 1, rank 0's state dict will still see the ShardedTensor
# but it's local shards will be empty.

YLGH avatar Jul 14 '22 17:07 YLGH

I have a question about the code snippet at the beginning of the previous comment.

m = Model(device="meta")
orig_keys = list(m.state_dict().keys())
m.ebc = torchrec.shard_embedding_bag_collection(m.ebc, ....)
sharded_keys = list(m.ebc.state_dict().keys())

assert orig_keys == sharded(keys) # at all ranks, even for table-wise sharded

What is m.ebc? Why are we supposed to compare m with m.ebc?

Do we mean the following instead?

m = Model(device="meta)
n = torchrec.shard_embeddings(m)
assert m.state_dict().keys() == n.state_dict().keys()

cc. @YLGH , @colin2328

wangkuiyi avatar Aug 01 '22 20:08 wangkuiyi

@wangkuiyi sorry for the delay :)

I think the snippet is a bit confusing, but the core API as landed is shard_embedding_modules https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/shard_embedding_modules.py#L24

This will replace (module swap) the embedding modules and return a module that you can compare like you indicated

assert m.state_dict().keys() == n.state_dict().keys()

See https://github.com/pytorch/torchrec/blob/main/examples/golden_training/train_dlrm.py#L95 for an example in action

cc @YLGH

colin2328 avatar Sep 13 '22 17:09 colin2328

This is landed in master, and will be going out in the next stable release

colin2328 avatar Jan 03 '23 23:01 colin2328