torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

[Bug]: ShardedManagedCollisionEmbeddingCollection throws an IndexError when "return_remapped_features=True"

Open rayhuang90 opened this issue 8 months ago • 1 comments

I utilize ManagedCollisionEmbeddingCollection with DistributedModelParallel to store hashID embeddings during distributed training.

An error occurs when setting return_remapped_features=True with a single embedding table configuration, but it resolves when a second configuration is added.

The expected behavior is that return_remapped_features=True should not throw errors regardless of the number of embedding table configurations.

Below is a minimal reproducible Python code example:

mch_bug_reproduce.py.txt

torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_bug_reproduce.py

Error message

[rank0]: Traceback (most recent call last):
[rank0]:   File "~/mch_bug_reproduce.py", line 72, in <module>
[rank0]:     emb_result, remapped_ids = dmp_mc_ec(mb)
[rank0]:                                ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]:     return self._dmp_wrapped_module(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 998, in forward
[rank0]:     return self.compute_and_output_dist(ctx, dist_input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 982, in compute_and_output_dist
[rank0]:     return self.output_dist(ctx, output)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_embedding_modules.py", line 243, in output_dist
[rank0]:     kjt_awaitable = self._managed_collision_collection.output_dist(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_modules.py", line 791, in output_dist
[rank0]:     awaitables_per_sharding.append(odist(remapped_ids, sharding_ctx))
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sequence_sharding.py", line 102, in forward
[rank0]:     return self._dist(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 1469, in forward
[rank0]:     embedding_dim=local_embs.shape[1],
[rank0]:                   ~~~~~~~~~~~~~~~~^^^
[rank0]: IndexError: tuple index out of range

My current environment

fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0

rayhuang90 avatar Mar 20 '25 07:03 rayhuang90