torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

cache weight/optimizer tensor mappings for efficient sync()

Open iamzainhuda opened this issue 1 week ago • 1 comments

Summary: Every call to sync() was rebuilding dictionaries that map dtypes to tensor lists by iterating through all embedding kernels and their weights.Not needed since mappings don't change between sync() calls.

This diff optimizes sync() by:

  1. Adding weights_by_dtype and optimizer_tensors_by_dtype fields to DMPCollectionContext to store pre-computed tensor mappings
  2. Adding _cache_sync_tensors() method that populates these caches once after _group_sharded_modules() completes
  3. Modifying _sync() to use the cached mappings instead of rebuilding them

This reduces the per-sync() overhead, particularly for models with many embedding tables or frequent sync operations. Practically, speaking this does not matter for most models but it does make our code cleaner to understand and removes unneeded inefficiencies.

Added unit testing to bulletproof implementation from erroneous changes

Differential Revision: D88993014

iamzainhuda avatar Dec 11 '25 23:12 iamzainhuda

@iamzainhuda has exported this pull request. If you are a Meta employee, you can view the originating Diff in D88993014.

meta-codesync[bot] avatar Dec 11 '25 23:12 meta-codesync[bot]