torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

force CW shards to be contiguous

Open iamzainhuda opened this issue 1 year ago • 1 comments

Summary: For easier concat of multiple shards when we call DT.full_tensor() with LocalShardsWrapper. The most important case is checkpointing with state_dict or any case where we need global tensor of a CW sharded table from DTensor. This helps us avoid any extra logic in rearranging the shards when we checkpoint, we can do a simple concat on each rank.

Also add callbacks to MemoryBalancedEmbeddingShardingPlanner and HeteroEmbeddingShardingPlanner

Differential Revision: D59134513

iamzainhuda avatar Jun 28 '24 01:06 iamzainhuda