torchrec
torchrec copied to clipboard
`ValueError: Tensors must be contiguous` when running a specific model in DistributedModelParallel with world size equaling 2
Description
I’m using torch.compile with DistributedModelParallel. Running the below code results in a ValueError: Tensors must be contiguous. This error seems to be specific to the model and the world size. I would expect to see no such errors, like when I run the code with other world sizes.
Enviroment:
python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.
Reproduction code:
import os
from typing import Callable, List, Union, Tuple
import multiprocessing
import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port
class TestModel(nn.Module):
def __init__(self):
super().__init__()
# define model parameters
self.dense_in_feature = 820
self.dense_out_feature = 784
self.table_params = [
[311, 108],
[739, 408],
]
self.weighted_table_params = [
[159, 96],
[69, 24],
[412, 564],
[940, 300],
]
self.over_out_feature = 61
# sparse layer
self.tables = [
EmbeddingBagConfig(
num_embeddings=self.table_params[i][0],
embedding_dim=self.table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(len(self.table_params))
]
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
# weighted sparse layer
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=self.weighted_table_params[i][0],
embedding_dim=self.weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(self.weighted_table_params))
]
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
# dense layer
self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
# over layer
in_features_concat = (
self.dense_out_feature
+ sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
+ sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
)
self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)
def forward(
self,
input: ModelInput,
print_intermediate_layer: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# dense, sparse, weighted sparse layer output
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
# concat dense, sparse, weighted sparse layer output
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [feature for table in self.tables for feature in table.feature_names]
_weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
# over layer output
over_r = self.over(ret_concat)
# sigmoid output
pred = torch.sigmoid(torch.mean(over_r, dim=1))
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def sharding_single_rank_test(
rank: int,
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
backend: str,
compiled = True,
) -> None:
with MultiProcessContext(rank, world_size, backend) as ctx:
if compiled:
model = torch.compile(model)
local_model = model.to(ctx.device)
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, ctx.device.type
),
)
plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_process_group(ctx.pg),
plan=plan,
sharders=sharders,
device=ctx.device,
)
# Run a single training step of the sharded model.
local_input = inputs[0][1][rank].to(ctx.device)
with torch.no_grad():
local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
# record the local prediction
all_local_pred = []
for _ in range(world_size):
all_local_pred.append(torch.empty_like(local_pred))
dist.all_gather(all_local_pred, local_pred, group=ctx.pg)
# record the local model's layer output
all_dense_r = []
for _ in range(world_size):
all_dense_r.append(torch.empty_like(dense_r))
dist.all_gather(all_dense_r, dense_r, group=ctx.pg)
sparse_r_dict = sparse_r.to_dict()
all_sparse_r_dict = {}
for key in sparse_r_dict:
all_sparse_r_dict[key] = []
for _ in range(world_size):
all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)
sparse_weighted_r_dict = sparse_weighted_r.to_dict()
all_sparse_weighted_r_dict = {}
for key in sparse_weighted_r_dict:
all_sparse_weighted_r_dict[key] = []
for _ in range(world_size):
all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)
all_over_r = []
for _ in range(world_size):
all_over_r.append(torch.empty_like(over_r))
dist.all_gather(all_over_r, over_r, group=ctx.pg)
def setUp():
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def run_multi_process_test(
callable: Callable[
...,
None,
],
world_size: int,
# pyre-ignore
**kwargs,
) -> None:
setUp()
ctx = multiprocessing.get_context("forkserver")
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
kwargs["world_size"] = world_size
p = ctx.Process(
target=callable,
kwargs=kwargs,
)
p.start()
processes.append(p)
for p in processes:
p.join()
def main_test(
sharders: List[ModuleSharder[nn.Module]],
backend: str,
world_size: int,
compiled: bool,
) -> None:
model = TestModel()
inputs = [ModelInput.generate(
batch_size=1200,
world_size=world_size,
num_float_features=model.dense_in_feature,
tables=model.tables,
weighted_tables=model.weighted_tables,
)]
run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
backend=backend,
compiled=compiled,
)
if __name__ == "__main__":
sharders = [create_test_sharder("embedding_bag_collection", "column_wise", "dense")]
backend = "nccl"
world_size = 2
main_test(
sharders = sharders,
backend = backend,
world_size = world_size,
compiled = True,
)
Log:
The error message is copied below.
Traceback (most recent call last):
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/mnt/tests/reproduce_nccl_tensor_must_be_contiguous.py", line 203, in sharding_single_rank_test
dist.all_gather(all_over_r, over_r, group=ctx.pg)
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2617, in all_gather
work = group.allgather([tensor_list], [tensor])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Tensors must be contiguous
cc @IvanKobzarev