torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

row-wise alltoall error when some embeddings use mean pooling and others use sum pooling

Open tiankongdeguiji opened this issue 8 months ago • 2 comments

There is an "alltoall" error when using row-wise sharding, where some embeddingbags utilize mean pooling while others use sum pooling. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=49941 --nnodes=1 --nproc-per-node=2 test_row_wise_pooling.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124.

test_row_wise_pooling.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
    )
    for i in range(large_table_cnt)
]
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
    )
    for i in range(small_table_cnt)
]


def gen_constraints(
    sharding_type: ShardingType = ShardingType.ROW_WISE,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()

model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.ROW_WISE)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    sharders=sharders,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])

batch_size = 64
lengths_large = torch.randint(0, 10, (batch_size * large_table_cnt,))
lengths_small = torch.randint(0, 10, (batch_size * small_table_cnt,))
kjt = KeyedJaggedTensor(
    keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
    + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
    values=torch.cat([
        torch.randint(0, 4096, (torch.sum(lengths_large),))
        , torch.randint(0, 1023, (torch.sum(lengths_small),))]
    ),
    lengths=torch.cat([lengths_large, lengths_small]),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()

error info:

[rank0]: Traceback (most recent call last):
[rank0]:   File "test_row_wise_pooling.py", line 124, in <module>
[rank0]:     losses = sharded_model.forward(kjt)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]:     return self._dmp_wrapped_module(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "test_row_wise_pooling.py", line 73, in forward
[rank0]:     emb = self.ebc(kjt)
[rank0]:           ^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 997, in forward
[rank0]:     dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]:     ret: W = self._wait_impl()
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/embedding_sharding.py", line 745, in _wait_impl
[rank0]:     tensors_awaitables.append(w.wait())
[rank0]:                               ^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]:     ret: W = self._wait_impl()
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 530, in _wait_impl
[rank0]:     return KJTAllToAllTensorsAwaitable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 398, in __init__
[rank0]:     awaitable = dist.all_to_all_single(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4388, in all_to_all_single
[rank0]:     work = group.alltoall_base(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Split sizes doesn't match total dim 0 size

tiankongdeguiji avatar Mar 12 '25 09:03 tiankongdeguiji