torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

`ValueError: Tensors must be contiguous` when running a specific model in DistributedModelParallel with world size equaling 2

Open jiannanWang opened this issue 10 months ago • 1 comments

Description

I’m using torch.compile with DistributedModelParallel. Running the below code results in a ValueError: Tensors must be contiguous. This error seems to be specific to the model and the world size. I would expect to see no such errors, like when I run the code with other world sizes.

Enviroment:

python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.

Reproduction code:

import os
from typing import Callable, List, Union, Tuple
import multiprocessing

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        # define model parameters
        self.dense_in_feature = 820
        self.dense_out_feature = 784
        self.table_params = [
            [311, 108],
            [739, 408],
        ]
        self.weighted_table_params = [
            [159, 96],
            [69, 24],
            [412, 564],
            [940, 300],
        ]
        self.over_out_feature = 61

        # sparse layer
        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=self.table_params[i][0],
                embedding_dim=self.table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
            )
            for i in range(len(self.table_params))
        ]
        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        # weighted sparse layer
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=self.weighted_table_params[i][0],
                embedding_dim=self.weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(self.weighted_table_params))
        ]
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )
        # dense layer
        self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
        # over layer
        in_features_concat = (
            self.dense_out_feature
            + sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
            + sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
        )
        self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)

    def forward(
        self,
        input: ModelInput,
        print_intermediate_layer: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # dense, sparse, weighted sparse layer output
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        # concat dense, sparse, weighted sparse layer output
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )
        _features = [feature for table in self.tables for feature in table.feature_names]
        _weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)
        # over layer output
        over_r = self.over(ret_concat)
        # sigmoid output
        pred = torch.sigmoid(torch.mean(over_r, dim=1))

        return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    rank: int,
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    compiled = True,
) -> None:

    with MultiProcessContext(rank, world_size, backend) as ctx:
        
        if compiled:
            model = torch.compile(model)
        local_model = model.to(ctx.device)

            
        planner = EmbeddingShardingPlanner(
            topology=Topology(
                world_size, ctx.device.type
            ),
        )
        plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)

        local_model = DistributedModelParallel(
            local_model,
            env=ShardingEnv.from_process_group(ctx.pg),
            plan=plan,
            sharders=sharders,
            device=ctx.device,
        )

        # Run a single training step of the sharded model.
        local_input = inputs[0][1][rank].to(ctx.device)

        with torch.no_grad():
            local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)

        # record the local prediction
        all_local_pred = []
        for _ in range(world_size):
            all_local_pred.append(torch.empty_like(local_pred))
        dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

        # record the local model's layer output
        all_dense_r = []
        for _ in range(world_size):
            all_dense_r.append(torch.empty_like(dense_r))
        dist.all_gather(all_dense_r, dense_r, group=ctx.pg)

        sparse_r_dict = sparse_r.to_dict()
        all_sparse_r_dict = {}
        for key in sparse_r_dict:
            all_sparse_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
            dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)

        sparse_weighted_r_dict = sparse_weighted_r.to_dict()
        all_sparse_weighted_r_dict = {}
        for key in sparse_weighted_r_dict:
            all_sparse_weighted_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
            dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)

        all_over_r = []
        for _ in range(world_size):
            all_over_r.append(torch.empty_like(over_r))
        dist.all_gather(all_over_r, over_r, group=ctx.pg)


def setUp():
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
    os.environ["NCCL_SOCKET_IFNAME"] = "lo"

    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def run_multi_process_test(
    callable: Callable[
        ...,
        None,
    ],
    world_size: int,
    # pyre-ignore
    **kwargs,
) -> None:
    setUp()
    ctx = multiprocessing.get_context("forkserver")
    processes = []
    for rank in range(world_size):
        kwargs["rank"] = rank
        kwargs["world_size"] = world_size
        p = ctx.Process(
            target=callable,
            kwargs=kwargs,
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()


def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    world_size: int,
    compiled: bool,
) -> None:
    model = TestModel()
    inputs = [ModelInput.generate(
        batch_size=1200,
        world_size=world_size,
        num_float_features=model.dense_in_feature,
        tables=model.tables,
        weighted_tables=model.weighted_tables,
    )]

    run_multi_process_test(
        callable=sharding_single_rank_test,
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        backend=backend,
        compiled=compiled,
    )


if __name__ == "__main__":
    sharders = [create_test_sharder("embedding_bag_collection", "column_wise", "dense")]
    backend = "nccl"
    world_size = 2
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = True,
    )

Log:

The error message is copied below.

  Traceback (most recent call last):
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
      self.run()
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
      self._target(*self._args, **self._kwargs)
  File "/mnt/tests/reproduce_nccl_tensor_must_be_contiguous.py", line 203, in sharding_single_rank_test
      dist.all_gather(all_over_r, over_r, group=ctx.pg)
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
      return func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2617, in all_gather
      work = group.allgather([tensor_list], [tensor])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ValueError: Tensors must be contiguous

jiannanWang avatar Apr 17 '24 18:04 jiannanWang

cc @IvanKobzarev

colin2328 avatar May 13 '24 23:05 colin2328