Issue: Adam optimizer iteration counter in FBGEMM cannot be reset with set_optimizer_step()
Description I've identified a bug in the latest version of FBGEMM where the set_optimizer_step() function no longer resets the Adam optimizer iteration counter, breaking previous functionality that worked in earlier versions. Context We're using NVIDIA's DynamicEmb module (from https://github.com/NVIDIA/recsys-examples), which is a TorchRec embedding plugin providing GPU-accelerated dynamic embeddings. Our workflow involves constructing "twin modules" where we need to transfer embedding data between TorchRec and DynamicEmb modules. Steps to Reproduce Our process for constructing twin modules:
- Create a TorchRec module with embedding collections
- Create a DynamicEmb module with identical configuration
- Use forward lookup to retrieve embedding indices/values from TorchRec (which increments Adam's iteration by 1)
- Insert these indices/values into the DynamicEmb module
- Reset Adam iteration in TorchRec using
torchrec_model.fused_optimizer.set_optimizer_step(0)
Expected Behavior
The Adam iteration counter should reset to 0 after calling set_optimizer_step(0), as it did in previous versions.
Actual Behavior
The iteration counter is not reset, regardless of the value passed to set_optimizer_step().
Root Cause Analysis
After investigating the source code, I found that in the latest FBGEMM version, a new tensor self.iter_cpu has been added to track Adam iterations:
https://github.com/pytorch/FBGEMM/blob/v1.2.0/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py#L2070
The issue is that set_optimizer_step() only modifies self.iter, but during the Adam optimizer step, the code uses self.iter_cpu instead, which has no setter method accessible from outside.