torchrec
torchrec copied to clipboard
row-wise alltoall error when some embeddings use mean pooling and others use sum pooling
There is an "alltoall" error when using row-wise sharding, where some embeddingbags utilize mean pooling while others use sum pooling. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=49941 --nnodes=1 --nproc-per-node=2 test_row_wise_pooling.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124.
test_row_wise_pooling.py
import os
from typing import Dict, cast
import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
large_table_cnt = 2
small_table_cnt = 2
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
)
for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
)
for i in range(small_table_cnt)
]
def gen_constraints(
sharding_type: ShardingType = ShardingType.ROW_WISE,
) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"large_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(large_table_cnt)
}
small_table_constraints = {
"small_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(small_table_cnt)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
class DebugModel(nn.Module):
def __init__(self):
super().__init__()
self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)
def forward(self, kjt: KeyedJaggedTensor):
emb = self.ebc(kjt)
return torch.mean(self.linear(emb.values()))
rank = int(os.environ["RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})
topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.ROW_WISE)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)
sharded_model = DistributedModelParallel(
model,
plan=plan,
sharders=sharders,
device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])
batch_size = 64
lengths_large = torch.randint(0, 10, (batch_size * large_table_cnt,))
lengths_small = torch.randint(0, 10, (batch_size * small_table_cnt,))
kjt = KeyedJaggedTensor(
keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
+ ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
values=torch.cat([
torch.randint(0, 4096, (torch.sum(lengths_large),))
, torch.randint(0, 1023, (torch.sum(lengths_small),))]
),
lengths=torch.cat([lengths_large, lengths_small]),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()
error info:
[rank0]: Traceback (most recent call last):
[rank0]: File "test_row_wise_pooling.py", line 124, in <module>
[rank0]: losses = sharded_model.forward(kjt)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]: return self._dmp_wrapped_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]: else self._run_ddp_forward(*inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "test_row_wise_pooling.py", line 73, in forward
[rank0]: emb = self.ebc(kjt)
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 997, in forward
[rank0]: dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]: ret: W = self._wait_impl()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/embedding_sharding.py", line 745, in _wait_impl
[rank0]: tensors_awaitables.append(w.wait())
[rank0]: ^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]: ret: W = self._wait_impl()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 530, in _wait_impl
[rank0]: return KJTAllToAllTensorsAwaitable(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 398, in __init__
[rank0]: awaitable = dist.all_to_all_single(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4388, in all_to_all_single
[rank0]: work = group.alltoall_base(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Split sizes doesn't match total dim 0 size