torchrec
torchrec copied to clipboard
[Bug]: ManagedCollisionEmbeddingCollection returns all-zero embeddings after applying apply_optimizer_in_backward with RowWiseAdagrad
Hi there, the ManagedCollisionEmbeddingCollection with multiple tables + shared features returns all-zero embeddings after applying apply_optimizer_in_backward with RowWiseAdagrad. This result is unexpected.
This bug likely relates to the initialization and updating of the RowWiseAdagrad state and associated embeddings during eviction events in ManagedCollisionEmbeddingCollection.
Below is a minimal reproducible Python code example:
torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_rowrisegrad_bug.py
Unexpected Result
[RANK0] emb_result key: item_tag, jt: JaggedTensor({
[[[-0.00694586057215929, 0.005635389592498541, 0.029554935172200203, -0.014213510788977146, 0.027853110805153847, 0.023257633671164513, 0.004495333414524794, -0.01736217364668846]]]
})
[RANK0] emb_result key: user_tag, jt: JaggedTensor({
[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]] # --> all-zero embeddings, bug
})
[RANK0] emb_result key: item_id, jt: JaggedTensor({
[[[0.026882518082857132, -0.008349019102752209, 0.025774799287319183, 0.010714510455727577, 0.022058645263314247, -0.02674921043217182, 0.029537828639149666, 0.007071810774505138]]]
})
[RANK0] remapped_ids: KeyedJaggedTensor({
"item_tag": [[997]],
"user_tag": [[998]],
"item_id": [[998]]
})
My Current Environment
fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0