torchrec
torchrec copied to clipboard
[Bug]: ShardedManagedCollisionEmbeddingCollection throws an IndexError when "return_remapped_features=True"
I utilize ManagedCollisionEmbeddingCollection with DistributedModelParallel to store hashID embeddings during distributed training.
An error occurs when setting return_remapped_features=True with a single embedding table configuration, but it resolves when a second configuration is added.
The expected behavior is that return_remapped_features=True should not throw errors regardless of the number of embedding table configurations.
Below is a minimal reproducible Python code example:
torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_bug_reproduce.py
Error message
[rank0]: Traceback (most recent call last):
[rank0]: File "~/mch_bug_reproduce.py", line 72, in <module>
[rank0]: emb_result, remapped_ids = dmp_mc_ec(mb)
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]: return self._dmp_wrapped_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 998, in forward
[rank0]: return self.compute_and_output_dist(ctx, dist_input)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 982, in compute_and_output_dist
[rank0]: return self.output_dist(ctx, output)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_embedding_modules.py", line 243, in output_dist
[rank0]: kjt_awaitable = self._managed_collision_collection.output_dist(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_modules.py", line 791, in output_dist
[rank0]: awaitables_per_sharding.append(odist(remapped_ids, sharding_ctx))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sequence_sharding.py", line 102, in forward
[rank0]: return self._dist(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 1469, in forward
[rank0]: embedding_dim=local_embs.shape[1],
[rank0]: ~~~~~~~~~~~~~~~~^^^
[rank0]: IndexError: tuple index out of range
My current environment
fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0