torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

[BUG]: Cannot copy out of meta tensor; no data!

Open gouchangjiang opened this issue 6 months ago • 4 comments

Hi dear torchrec developers. I found a fatal bug when using EmbeddingCollection. The full stack is

[rank0]:   File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 44, in <module>
[rank0]:     main()
[rank0]:   File "/home/admin/hippo/worker/slave/aop_418921_aop_launcher_job_temp_m_20250528093245_6524584_job.worker_0_57_12/train/test_ebd.py", line 36, in main
[rank0]:     dmp = DistributedModelParallel(module = ec,
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 278, in __init__
[rank0]:     self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 343, in _init_dmp
[rank0]:     return self._shard_modules_impl(module)
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/model_parallel.py", line 381, in _shard_modules_impl
[rank0]:     module = self._sharder_map[sharder_key].shard(
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 1372, in shard
[rank0]:     return ShardedEmbeddingCollection(
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torchrec/distributed/embedding.py", line 632, in __init__
[rank0]:     self.load_state_dict(module.state_dict())
[rank0]:   File "/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Error(s) in loading state_dict for ShardedEmbeddingCollection:
[rank0]:  While copying the parameter named "embeddings.t1.weight", whose dimensions in the model are torch.Size([625000, 16]) and whose dimensions in the checkpoint are torch.Size([625000, 16]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).

To reproduce it, just run the following code snippet with command torchrun --standalone --nnodes 1 --node_rank 0 --nproc_per_node 8 test_ebd.py

test_ebd.py attached below:

import os

import torch
from torchrec.modules.embedding_configs import EmbeddingConfig
from torchrec.modules.embedding_modules import EmbeddingCollection
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.model_parallel import DistributedModelParallel

import torch.distributed as dist

def main():
    rank = int(os.environ["LOCAL_RANK"])
    if torch.cuda.is_available():
      device: torch.device = torch.device(f"cuda:{rank}")
      torch.cuda.set_device(device)

    e1_config = EmbeddingConfig(
        name="t1", embedding_dim=16, num_embeddings=5000000, feature_names=["f1"]
    )

    ec = EmbeddingCollection(
        tables=[e1_config],
        device="meta"
    )

    _pg = dist.GroupMember.WORLD
    _sharder = [EmbeddingCollectionSharder()]
    planner = EmbeddingShardingPlanner()
    plan = planner.collective_plan(
        module=ec,
        sharders=_sharder,
        pg=_pg
    )

    dmp = DistributedModelParallel(module = ec,
                                   device = device,
                                   plan = plan,
                                   sharders = _sharder
                                   )

if __name__ == "__main__":
    dist.init_process_group("nccl")
    main()

torch 2.6.0, fbgemm 1.1.0, torchrec 1.1.0

This bug is not caused by torchrec since I found many similar bugs from other repositories. The temporary workaround is to set the device to "cuda", but then you cannot train large embedding tables.

gouchangjiang avatar May 28 '25 03:05 gouchangjiang