cache weight/optimizer tensor mappings for efficient sync()
Summary:
Every call to sync() was rebuilding dictionaries that map dtypes to tensor lists
by iterating through all embedding kernels and their weights.Not needed since mappings don't change between sync() calls.
This diff optimizes sync() by:
- Adding
weights_by_dtypeandoptimizer_tensors_by_dtypefields toDMPCollectionContextto store pre-computed tensor mappings - Adding
_cache_sync_tensors()method that populates these caches once after_group_sharded_modules()completes - Modifying
_sync()to use the cached mappings instead of rebuilding them
This reduces the per-sync() overhead, particularly for models with many embedding tables or frequent sync operations. Practically, speaking this does not matter for most models but it does make our code cleaner to understand and removes unneeded inefficiencies.
Added unit testing to bulletproof implementation from erroneous changes
Differential Revision: D88993014
@iamzainhuda has exported this pull request. If you are a Meta employee, you can view the originating Diff in D88993014.