torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

[Bug] the number of embedddings in ManagedCollisionCollection must be a multiple of the number of devices

Open fangleigit opened this issue 1 year ago • 3 comments

when changing the number of embeddings to 4091, and mch_size to 1021 of the code below, it will throw the following exception

ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
Traceback (most recent call last):
  File "test2.py", line 143, in <module>
    spmd_sharing_simulation(ShardingType.ROW_WISE)
  File "test2.py", line 139, in spmd_sharing_simulation
    assert 0 == p.exitcode
AssertionError
import os
from typing import Dict, cast

import multiprocess
import torch
import torch.distributed as dist
import torchrec
from torchrec.distributed.mc_embeddingbag import ManagedCollisionEmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
    DistanceLFU_EvictionPolicy,
    ManagedCollisionCollection,
    ManagedCollisionModule,
    MCHManagedCollisionModule,
)


def preprocess_func(id: torch.Tensor, hash_size: int) -> torch.Tensor:
    return id % hash_size


os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

table_name = "sample"

tables = [
    torchrec.EmbeddingBagConfig(
        name=table_name,
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=[table_name],
        pooling=torchrec.PoolingType.SUM,
    )
]

mcc = ManagedCollisionCollection(
    managed_collision_modules={table_name: cast(
        ManagedCollisionModule,
        MCHManagedCollisionModule(
            zch_size=3070,
            mch_size=1026,
            device="meta",
            eviction_interval=1,
            eviction_policy=DistanceLFU_EvictionPolicy(),
            mch_hash_func=preprocess_func,
        ),
    )},
    embedding_configs=tables,
)

ebc: ManagedCollisionEmbeddingBagCollection = ManagedCollisionEmbeddingBagCollection(
    EmbeddingBagCollection(
        tables=tables,
        device='meta',
    ),
    mcc,
    return_remapped_features=False,
)


def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(
            rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module],
                     ManagedCollisionEmbeddingBagCollectionSharder())]
    plan = planner.collective_plan(module, sharders=None, pg=pg)

    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model


def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size=2,
):
    ctx = multiprocess.get_context("spawn")
    processes = []
    for rank in range(world_size):
        p = ctx.Process(
            target=single_rank_execution,
            args=(
                rank,
                world_size,
                {
                    table_name: ParameterConstraints(
                        sharding_types=[sharding_type.value],
                    )
                },
                ebc,
                "nccl"
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        assert 0 == p.exitcode


if __name__ == '__main__':
    spmd_sharing_simulation(ShardingType.ROW_WISE)

fangleigit avatar Dec 18 '23 02:12 fangleigit

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

henrylhtsang avatar Dec 18 '23 20:12 henrylhtsang

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

Thanks for your quick response, yes, I tried ManagedCollisionCollection on our data, the performance degraded when using ManagedCollisionCollection. The training time is also significant increased. Is there any guideline or document on how to set the hyper-parameters when using this module, e.g., eviction_interval, zch_size, mch_size, and which policy is better DistanceLFU_EvictionPolicy or LFU_EvictionPolicy under which scenario.

fangleigit avatar Dec 19 '23 08:12 fangleigit

@fangleigit Thanks. We are still actively developing MCH/ZCH, so we don't have a clear answer so far. Let us know if you have it figured out as well!

henrylhtsang avatar Dec 19 '23 19:12 henrylhtsang