torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

DMP doesn't broadcast DataParallel ShardingType embedding table from the process with rank 0 to all other processes

Open tiankongdeguiji opened this issue 11 months ago • 13 comments

DMP should broadcast DataParallel ShardingType embedding table param from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state.

However, it has been observed that DMP currently designates all sharded_parameter_names as part of the params_and_buffers_to_ignore list for Distributed Data Parallel (DDP) operations. This behavior leads to a situation where DMP omits the necessary synchronization of the DataParallel ShardingType embedding table parameters during the initialization phase. As a consequence, model replicas may start from different states, which could result in inconsistent training outcomes and potentially compromise model convergence.

class DefaultDataParallelWrapper(DataParallelWrapper):
    ...

    def wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
    ) -> None:
        if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance(
            dmp._dmp_wrapped_module, FullyShardedDataParallel
        ):
            return
        sharded_parameter_names = set(
            DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
        )
        self._ddp_wrap(dmp, env, device, sharded_parameter_names)

    ...

class DistributedModelParallel(nn.Module, FusedOptimizerModule):
    @staticmethod
    def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
        module = get_unwrapped_module(module)
        if isinstance(module, ShardedModule):
            yield from module.sharded_parameter_names(prefix)
        else:
            for name, child in module.named_children():
                yield from DistributedModelParallel._sharded_parameter_names(
                    child, append_prefix(prefix, name)
                )

tiankongdeguiji avatar Mar 01 '24 04:03 tiankongdeguiji

@tiankongdeguiji DDP tensors are replicated in the state_dict that should be loaded on each rank. So no comm_ops are necessary to broadcast them.

IvanKobzarev avatar Mar 04 '24 17:03 IvanKobzarev

@IvanKobzarev In TorchRec DMP, all parameters of ShardedModule (Including ShardingType==DataParallel) are added to the params_and_buffers_to_ignore list for DDP. This configuration prevents DDP from broadcasting these parameters. However, for proper functionality, it is crucial to ensure that parameters of the ShardedModule with ShardingType equal to DataParallel are indeed broadcasted.

// torchrec/distributed/model_parallel.py
class DefaultDataParallelWrapper(DataParallelWrapper):
    ...

    def _ddp_wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
        ddp_ignore_param_names: Set[str],
    ) -> None:
        pg = env.process_group
        if pg is None:
            raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv")
        all_parameter_names = set(dict(dmp.named_parameters()).keys())
        if len(all_parameter_names - ddp_ignore_param_names) == 0:
            return
        DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
            module=dmp._dmp_wrapped_module,
            params_and_buffers_to_ignore=ddp_ignore_param_names,
        )
        # initialize DDP
        dmp._dmp_wrapped_module = cast(
            nn.Module,
            DistributedDataParallel(
                module=dmp._dmp_wrapped_module.to(device),
                device_ids=None if device.type == "cpu" else [device],
                process_group=pg,
                gradient_as_bucket_view=True,
                broadcast_buffers=False,
                static_graph=self._static_graph,
                find_unused_parameters=self._find_unused_parameters,
                bucket_cap_mb=self._bucket_cap_mb,
            ),
        )
        if self._allreduce_comm_precision == "fp16":
            dmp._dmp_wrapped_module.register_comm_hook(
                None, ddp_default_hooks.fp16_compress_hook
            )
        elif self._allreduce_comm_precision == "bf16":
            dmp._dmp_wrapped_module.register_comm_hook(
                None, ddp_default_hooks.bf16_compress_hook
            )

    def wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
    ) -> None:
        if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance(
            dmp._dmp_wrapped_module, FullyShardedDataParallel
        ):
            return
        sharded_parameter_names = set(
            DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
        )
        self._ddp_wrap(dmp, env, device, sharded_parameter_names)

tiankongdeguiji avatar Mar 05 '24 01:03 tiankongdeguiji

Hi, @henrylhtsang @IvanKobzarev @joshuadeng

tiankongdeguiji avatar Mar 11 '24 03:03 tiankongdeguiji

@tiankongdeguiji

I suspect this isn't really a problem. I tested it with the NCCL model_parallel test_sharding_dp by printing the state_dict out, and found them to be the same.

I suspect it is working due to this hack https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embeddingbag.py?fbclid=IwAR0LKttb3ZOvOhIh1lAfnbO7hUR16YrPp9kPfZSM3WW4pHvv800z0G3a718#L499-L500

henrylhtsang avatar Mar 12 '24 18:03 henrylhtsang

@henrylhtsang yes, that is where DDP modules are set up (using actual DDP) to make these data_parallel tables call all_reduce to get the correct gradients. Why do you call this part a hack?

@tiankongdeguiji , is your concern that all_reduce won't be called during training? or is your concern that restoring from checkpoint will be incorrect?

colin2328 avatar Mar 13 '24 23:03 colin2328

@henrylhtsang

In the test_sharding_dp function, the state_dict of global_model is copied to local_model, and the global_model is initialized using torch.manual_seed(0). This ensures that all model replicas start from identical initial parameters.

https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/test_utils/test_sharding.py#L376

However, it is important to note that without performing this state_dict copying at the beginning of training, as is typically done, the DataParallel ShardingType embedding table parameters in different replicas would start from differing states.

tiankongdeguiji avatar Mar 18 '24 01:03 tiankongdeguiji

@tiankongdeguiji can you try to inspect the state dict of local model right after local_model = DistributedModelParallel(, ie before the copy_state_dict?

When I ran it, it showed those parameters are the same.

update: I just tested it again. Before DMP, the table weights are on meta device. After DMP, they are on cuda and are the same.

henrylhtsang avatar Mar 18 '24 15:03 henrylhtsang

@henrylhtsang

Can you run torchrun --master_addr=localhost --master_port=54926 --nnodes=1 --nproc-per-node=2 --node_rank=0 debug_dp_shard.py, and use the enviroment torchrec==0.6.0+cu121, torch==2.2.0+cu121, fbgemm-gpu==0.6.0+cu121?

debug_dp_shard.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(large_table_cnt)
]
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(small_table_cnt)
]


def gen_constraints(
    sharding_type: ShardingType = ShardingType.DATA_PARALLEL,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
print("world_size:", world_size)

model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.DATA_PARALLEL)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    sharders=sharders,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])
print(f"rank:{rank},sharding plan: {plan}")

batch_size = 64
kjt = KeyedJaggedTensor(
    keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
    + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
    values=torch.cat([
        torch.randint(0, 4096, (batch_size * 2,))
        , torch.randint(0, 1023, (batch_size * 2,))]
    ),
    lengths=torch.ones(batch_size * (small_table_cnt + large_table_cnt), dtype=torch.int32),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()

dist.barrier()
for k, v in sharded_model.named_parameters():
    if 'ebc' in k:
        t_list = [torch.zeros_like(v) for _ in range(world_size)]
        dist.all_gather(t_list, v)
        if rank == 0:
            print(k, t_list[0].equal(t_list[1]))

It will print the following log which indicates ebc parameters are not the same.

ebc.embedding_bags.large_table_0.weight False
ebc.embedding_bags.large_table_1.weight False
ebc.embedding_bags.small_table_0.weight False
ebc.embedding_bags.small_table_1.weight False

tiankongdeguiji avatar Mar 19 '24 02:03 tiankongdeguiji

@tiankongdeguiji Okay I think you are 100% right. Sorry I didn't understand your point on the torch seed part.

I looked into it. A few points:

  1. The problem isn't with training. It seems like they are different at the initialization.
  2. The "hack" is working perfectly. If you print self._lookups[index].state_dict() right after that point, they are the same.
  3. What I think the problem is is self._initialize_torch_state(), which initialize the tables separately. I tested putting torch.manual_seed(0) before the self._initialize_torch_state() line, and your code are printing 4 Trues. Please test this and report back to see if this works.

Followups:

  1. Fix the problem.
  2. Remove torch.manual_seed in tests

cc @PaulZhang12 @colin2328

henrylhtsang avatar Mar 22 '24 00:03 henrylhtsang

@henrylhtsang Yes, it works. However, I think it's better to broadcast these parameters on rank 0 to other ranks, like DDP.

tiankongdeguiji avatar Mar 22 '24 01:03 tiankongdeguiji

@tiankongdeguiji fyi I raised the issue to the team already. Probably will wait a bit.

On the other hand, any suggestions on how to fix this in a nice way? Maybe remove the names from _sharded_parameter_names?

henrylhtsang avatar Mar 22 '24 19:03 henrylhtsang

@henrylhtsang We are unable to remove the names from _sharded_parameter_names, as the dist.Reducer within DDP is incapable of managing the parameters associated with the DataParallel ShardingType embedding table. At present, I invoke dist._broadcast_coalesced for these parameters following DMP.

tiankongdeguiji avatar Mar 25 '24 06:03 tiankongdeguiji

@tiankongdeguiji fyi landed the fix https://github.com/pytorch/torchrec/commit/cc482f8a5f80fd8975de82ad22b65cda3348d872

Basically every time we call reset_parameters, we will also broadcast the re-initialized DP tables from rank 0 to all other ranks

Though be sure to not forget to apply optimizer to DP tables

henrylhtsang avatar Apr 11 '24 17:04 henrylhtsang

Thanks

@tiankongdeguiji fyi landed the fix cc482f8

Basically every time we call reset_parameters, we will also broadcast the re-initialized DP tables from rank 0 to all other ranks

Though be sure to not forget to apply optimizer to DP tables

tiankongdeguiji avatar Aug 02 '24 02:08 tiankongdeguiji