torchrec
torchrec copied to clipboard
[BUG]: Cannot copy out of meta tensor; no data!
Hi dear torchrec developers. I found a fatal bug when using EmbeddingCollection. The full stack is
[rank0]: File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 44, in <module>
[rank0]: main()
[rank0]: File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 36, in main
[rank0]: dmp = DistributedModelParallel(module = ec,
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 278, in __init__
[rank0]: self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 343, in _init_dmp
[rank0]: return self._shard_modules_impl(module)
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 381, in _shard_modules_impl
[rank0]: module = self._sharder_map[sharder_key].shard(
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 1372, in shard
[rank0]: return ShardedEmbeddingCollection(
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 632, in __init__
[rank0]: self.load_state_dict(module.state_dict())
[rank0]: File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: Error(s) in loading state_dict for ShardedEmbeddingCollection:
[rank0]: While copying the parameter named "embeddings.t1.weight", whose dimensions in the model are torch.Size([625000, 16]) and whose dimensions in the checkpoint are torch.Size([625000, 16]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
To reproduce it, just run the following code snippet with command
torchrun --standalone --nnodes 1 --node_rank 0 --nproc_per_node 8 test_ebd.py
test_ebd.py attached below:
import os
import torch
from torchrec.modules.embedding_configs import EmbeddingConfig
from torchrec.modules.embedding_modules import EmbeddingCollection
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.model_parallel import DistributedModelParallel
import torch.distributed as dist
def main():
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
e1_config = EmbeddingConfig(
name="t1", embedding_dim=16, num_embeddings=5000000, feature_names=["f1"]
)
ec = EmbeddingCollection(
tables=[e1_config],
device="meta"
)
_pg = dist.GroupMember.WORLD
_sharder = [EmbeddingCollectionSharder()]
planner = EmbeddingShardingPlanner()
plan = planner.collective_plan(
module=ec,
sharders=_sharder,
pg=_pg
)
dmp = DistributedModelParallel(module = ec,
device = device,
plan = plan,
sharders = _sharder
)
if __name__ == "__main__":
dist.init_process_group("nccl")
main()
torch 2.6.0, fbgemm 1.1.0, torchrec 1.1.0
This bug is not caused by torchrec since I found many similar bugs from other repositories. The temporary workaround is to set the device to "cuda", but then you cannot train large embedding tables.