torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Which lightning strategy to use with torchrec optimizers?

Open JacobHelwig opened this issue 5 months ago • 0 comments

Hi, thank you for this great work. I would like to know which distributed strategy to use with lightning trainer. I see two potential avenues:

  1. DDP strategy: following this example, I verified that the updates are not sparse, ie, embeddings not used to compute the loss for the current batch were still updated when using Adam (due to momentum/weight decay)
  2. Custom strategy for DMP: when using DMP, I've verified the updates are sparse. However, AFAIK, there is not a DMP strategy for lightning, and so I would need to define a custom strategy.

Is it possible to make DDP work for sparse opt, and if not, is a custom strategy the best option?

MWE:


import argparse
import os
import sys
from typing import Any, cast, Dict, List, Union

from fbgemm_gpu.split_embedding_configs import EmbOptimType

import torch
from torch import distributed as dist, nn, optim 
import torch.utils.data as data_utils
from torch.nn.parallel import DistributedDataParallel as DDP

import torchrec
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ModuleSharder
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.embedding_configs import ShardingType
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig, PoolingType

class DataSet(torch.utils.data.IterableDataset):
    def __init__(
            self,
            max_id: int,
            max_seq_len: int
    ) -> None:
        self.max_seq_len = max_seq_len
        self.max_id = max_id

    def __iter__(self):
        while True:
            len_ = torch.randint(1, self.max_seq_len + 1, (1, )).item()
            yield torch.randint(0, self.max_id, (len_,))


class Model(torch.nn.Module):

    def __init__(
        self,
        max_id: int,
        emb_dim: int,
    ) -> None:
        super().__init__()
        self.emb_dim = emb_dim

        item_embedding_config = EmbeddingBagConfig(
            name="item_embedding",
            embedding_dim=emb_dim,
            num_embeddings=max_id,
            feature_names=["item"],
            weight_init_max=1.0,
            weight_init_min=-1.0,
            pooling=PoolingType.MEAN,
        )
        self.ebc = EmbeddingBagCollection(
            tables=[item_embedding_config],
        )
        self.head = nn.Linear(emb_dim, 1)

    def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
        out = self.ebc(x)["item"].to_dense()
        return self.head(out)

def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        type=str,
        default="ddp",
        help="dmp (distributed model parallel) or ddp (distributed data parallel)",
    )
    return parser.parse_args(argv)


def _to_kjt(seqs: torch.LongTensor, device: torch.device) -> KeyedJaggedTensor:
    seqs_list = list(seqs)
    lengths = torch.IntTensor([value.size(0) for value in seqs_list])
    values = torch.cat(seqs_list, dim=0)

    kjt = KeyedJaggedTensor.from_lengths_sync(
        keys=["item"], values=values, lengths=lengths
    ).to(device)
    return kjt

def get_embedding_weights(model: Union[DDP, DMP], x: List[torch.Tensor]):
    emb_weights = [v.data.clone() for k, v in model.named_parameters() if "embedding" in k]
    assert len(emb_weights) == 1
    emb_weights = emb_weights[0]
    x = torch.cat(x)
    ids = torch.arange(len(emb_weights)).type_as(x)
    used_mask = torch.isin(ids, x)
    return emb_weights[used_mask], emb_weights[~used_mask]

def _train_one_epoch(
    model: Union[DDP, DMP],
    loader: data_utils.DataLoader,
    device: torch.device,
    optimizer: optim.Adam,
) -> None:
    model.train()
    if torch.cuda.is_available():
        torch.cuda.set_device(dist.get_rank())
    i = 0
    NUM_ITER = 5
    for batch in loader:
        i += 1
        batch = [x.to(device) for x in batch]
        optimizer.zero_grad()
        kjt = _to_kjt(batch, device)
        loss = model(kjt).norm()
        used_embs_pre, unused_embs_pre = get_embedding_weights(model, batch)
        loss.backward()
        optimizer.step()
        used_embs_post, unused_embs_post = get_embedding_weights(model, batch)

        diffs_used = torch.norm(used_embs_post - used_embs_pre).item()
        diffs_unused = torch.norm(unused_embs_post - unused_embs_pre).item()

        print(f"Iter {i}, loss {loss.item():.2e}, ∆ hot: {diffs_used:.2e}, ∆ cold: {diffs_unused:.2e}")
        if i > NUM_ITER:
            break


def main(argv: List[str]) -> None:

    args = parse_args(argv)
    use_dmp = args.mode == "dmp"
    rank = int(os.environ["LOCAL_RANK"])
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{rank}")
        backend = "nccl"
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
        backend = "gloo"

    if not torch.distributed.is_initialized():
        dist.init_process_group(backend=backend)

    world_size = dist.get_world_size()

    MAX_ID = 1000
    MAX_SEQ_LEN = 10
    loader = torch.utils.data.DataLoader(
        DataSet(max_id=MAX_ID, max_seq_len=MAX_SEQ_LEN),
        batch_size=32,
        num_workers=0,
        collate_fn=lambda x: x
    )

    model = Model(
        max_id=MAX_ID,
        emb_dim=16,
    ).to(device)


    LR = 0.1

    if use_dmp:
        print("Using DMP")
        fused_params: Dict[str, Any] = {}
        fused_params["optimizer"] = EmbOptimType.ADAM
        fused_params["learning_rate"] = LR
        model = DMP(
            module=model,
            device=device,
            sharders=[
                cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder(fused_params))
            ],
        )
        dense_optimizer = KeyedOptimizerWrapper(
            dict(in_backward_optimizer_filter(model.named_parameters())),
            lambda params: optim.Adam(
                params, lr=LR
            ),
        )

        optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
    else:
        print("Using DDP")
        fused_params: Dict[str, Any] = {}
        fused_params["optimizer"] = EmbOptimType.ADAM
        fused_params["learning_rate"] = LR

        sharding_types = [ShardingType.DATA_PARALLEL.value]
        constraints = dict()
        constraints[
            "item_embedding"
        ] = torchrec.distributed.planner.ParameterConstraints(sharding_types=sharding_types)
        sharders = [
            cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder(fused_params))
        ]
        pg = dist.GroupMember.WORLD
        model = DMP(
            module=model,
            device=device,
            plan=torchrec.distributed.planner.EmbeddingShardingPlanner(
            topology=torchrec.distributed.planner.Topology(
                world_size=world_size,
                compute_device=device.type,
            ),
            constraints=constraints
        ).collective_plan(model, sharders, pg),
            sharders=sharders,
        )
        dense_optimizer = KeyedOptimizerWrapper(
                dict(in_backward_optimizer_filter(model.named_parameters())),
                lambda params: optim.Adam(
                    params, lr=LR
                ),
            )

        optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])

    _train_one_epoch(
        model=model,
        loader=loader,
        device=device,
        optimizer=optimizer,
    )

    dist.destroy_process_group()


if __name__ == "__main__":
    print(sys.argv)
    import torch, torchrec, fbgemm_gpu, torchvision, triton
    modulenames = set(sys.modules) & set(globals())
    allmodules = [sys.modules[name] for name in modulenames]
    print("-" * 56)
    for module in allmodules:
        if hasattr(module, "__version__"):
            print(module.__name__, module.__version__)        
    print("-" * 56)

    main(sys.argv[1:])

# torchrun --rdzv_endpoint=localhost:29400 --nnodes 1 train.py --mode ddp

Output with DDP:

['train.py', '--mode', 'dmp']
--------------------------------------------------------
fbgemm_gpu 1.2.0+cu128
torchvision 0.22.0.2+cu128
torch 2.7.0.2+cu128
torchrec 1.2.0+cu128
argparse 1.1
triton 3.1.0
--------------------------------------------------------
Using DMP
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:859: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`. 
  warnings.warn(
Iter 1, loss 1.14e+00, ∆ hot: 5.25e+00, ∆ cold: 0.00e+00
Iter 2, loss 8.24e-01, ∆ hot: 3.61e+00, ∆ cold: 3.23e+00
Iter 3, loss 1.00e+00, ∆ hot: 3.42e+00, ∆ cold: 3.36e+00
Iter 4, loss 1.03e+00, ∆ hot: 2.72e+00, ∆ cold: 3.64e+00
Iter 5, loss 7.13e-01, ∆ hot: 2.63e+00, ∆ cold: 3.54e+00
Iter 6, loss 7.17e-01, ∆ hot: 2.37e+00, ∆ cold: 3.48e+00

Output with DMP:

torchrun --rdzv_endpoint=localhost:29400 --nnodes 1 train.py --mode dmp
['train.py', '--mode', 'dmp']
--------------------------------------------------------
triton 3.1.0
torchvision 0.22.0.2+cu128
torch 2.7.0.2+cu128
fbgemm_gpu 1.2.0+cu128
argparse 1.1
torchrec 1.2.0+cu128
--------------------------------------------------------
Using DMP
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:859: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`. 
  warnings.warn(
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py:39: UserWarning: 
        [FBGEMM_GPU] NOTE: The training optimizer 'adam' is marked as
        EXPERIMENTAL and thus not optimized, in order to reduce code compilation
        times and build sizes!
        
  warnings.warn(
Iter 1, loss 1.86e+00, ∆ hot: 5.14e+00, ∆ cold: 0.00e+00
Iter 2, loss 1.39e+00, ∆ hot: 3.94e+00, ∆ cold: 0.00e+00
Iter 3, loss 9.56e-01, ∆ hot: 2.83e+00, ∆ cold: 0.00e+00
Iter 4, loss 6.22e-01, ∆ hot: 3.01e+00, ∆ cold: 0.00e+00
Iter 5, loss 6.56e-01, ∆ hot: 2.58e+00, ∆ cold: 0.00e+00
Iter 6, loss 4.84e-01, ∆ hot: 2.72e+00, ∆ cold: 0.00e+00

JacobHelwig avatar Jun 18 '25 19:06 JacobHelwig