torchrec
torchrec copied to clipboard
Which lightning strategy to use with torchrec optimizers?
Hi, thank you for this great work. I would like to know which distributed strategy to use with lightning trainer. I see two potential avenues:
- DDP strategy: following this example, I verified that the updates are not sparse, ie, embeddings not used to compute the loss for the current batch were still updated when using Adam (due to momentum/weight decay)
- Custom strategy for DMP: when using DMP, I've verified the updates are sparse. However, AFAIK, there is not a DMP strategy for lightning, and so I would need to define a custom strategy.
Is it possible to make DDP work for sparse opt, and if not, is a custom strategy the best option?
MWE:
import argparse
import os
import sys
from typing import Any, cast, Dict, List, Union
from fbgemm_gpu.split_embedding_configs import EmbOptimType
import torch
from torch import distributed as dist, nn, optim
import torch.utils.data as data_utils
from torch.nn.parallel import DistributedDataParallel as DDP
import torchrec
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel as DMP
from torchrec.distributed.types import ModuleSharder
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.embedding_configs import ShardingType
from torchrec import EmbeddingBagCollection, EmbeddingBagConfig, PoolingType
class DataSet(torch.utils.data.IterableDataset):
def __init__(
self,
max_id: int,
max_seq_len: int
) -> None:
self.max_seq_len = max_seq_len
self.max_id = max_id
def __iter__(self):
while True:
len_ = torch.randint(1, self.max_seq_len + 1, (1, )).item()
yield torch.randint(0, self.max_id, (len_,))
class Model(torch.nn.Module):
def __init__(
self,
max_id: int,
emb_dim: int,
) -> None:
super().__init__()
self.emb_dim = emb_dim
item_embedding_config = EmbeddingBagConfig(
name="item_embedding",
embedding_dim=emb_dim,
num_embeddings=max_id,
feature_names=["item"],
weight_init_max=1.0,
weight_init_min=-1.0,
pooling=PoolingType.MEAN,
)
self.ebc = EmbeddingBagCollection(
tables=[item_embedding_config],
)
self.head = nn.Linear(emb_dim, 1)
def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
out = self.ebc(x)["item"].to_dense()
return self.head(out)
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
default="ddp",
help="dmp (distributed model parallel) or ddp (distributed data parallel)",
)
return parser.parse_args(argv)
def _to_kjt(seqs: torch.LongTensor, device: torch.device) -> KeyedJaggedTensor:
seqs_list = list(seqs)
lengths = torch.IntTensor([value.size(0) for value in seqs_list])
values = torch.cat(seqs_list, dim=0)
kjt = KeyedJaggedTensor.from_lengths_sync(
keys=["item"], values=values, lengths=lengths
).to(device)
return kjt
def get_embedding_weights(model: Union[DDP, DMP], x: List[torch.Tensor]):
emb_weights = [v.data.clone() for k, v in model.named_parameters() if "embedding" in k]
assert len(emb_weights) == 1
emb_weights = emb_weights[0]
x = torch.cat(x)
ids = torch.arange(len(emb_weights)).type_as(x)
used_mask = torch.isin(ids, x)
return emb_weights[used_mask], emb_weights[~used_mask]
def _train_one_epoch(
model: Union[DDP, DMP],
loader: data_utils.DataLoader,
device: torch.device,
optimizer: optim.Adam,
) -> None:
model.train()
if torch.cuda.is_available():
torch.cuda.set_device(dist.get_rank())
i = 0
NUM_ITER = 5
for batch in loader:
i += 1
batch = [x.to(device) for x in batch]
optimizer.zero_grad()
kjt = _to_kjt(batch, device)
loss = model(kjt).norm()
used_embs_pre, unused_embs_pre = get_embedding_weights(model, batch)
loss.backward()
optimizer.step()
used_embs_post, unused_embs_post = get_embedding_weights(model, batch)
diffs_used = torch.norm(used_embs_post - used_embs_pre).item()
diffs_unused = torch.norm(unused_embs_post - unused_embs_pre).item()
print(f"Iter {i}, loss {loss.item():.2e}, ∆ hot: {diffs_used:.2e}, ∆ cold: {diffs_unused:.2e}")
if i > NUM_ITER:
break
def main(argv: List[str]) -> None:
args = parse_args(argv)
use_dmp = args.mode == "dmp"
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
if not torch.distributed.is_initialized():
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
MAX_ID = 1000
MAX_SEQ_LEN = 10
loader = torch.utils.data.DataLoader(
DataSet(max_id=MAX_ID, max_seq_len=MAX_SEQ_LEN),
batch_size=32,
num_workers=0,
collate_fn=lambda x: x
)
model = Model(
max_id=MAX_ID,
emb_dim=16,
).to(device)
LR = 0.1
if use_dmp:
print("Using DMP")
fused_params: Dict[str, Any] = {}
fused_params["optimizer"] = EmbOptimType.ADAM
fused_params["learning_rate"] = LR
model = DMP(
module=model,
device=device,
sharders=[
cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder(fused_params))
],
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: optim.Adam(
params, lr=LR
),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
else:
print("Using DDP")
fused_params: Dict[str, Any] = {}
fused_params["optimizer"] = EmbOptimType.ADAM
fused_params["learning_rate"] = LR
sharding_types = [ShardingType.DATA_PARALLEL.value]
constraints = dict()
constraints[
"item_embedding"
] = torchrec.distributed.planner.ParameterConstraints(sharding_types=sharding_types)
sharders = [
cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder(fused_params))
]
pg = dist.GroupMember.WORLD
model = DMP(
module=model,
device=device,
plan=torchrec.distributed.planner.EmbeddingShardingPlanner(
topology=torchrec.distributed.planner.Topology(
world_size=world_size,
compute_device=device.type,
),
constraints=constraints
).collective_plan(model, sharders, pg),
sharders=sharders,
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: optim.Adam(
params, lr=LR
),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
_train_one_epoch(
model=model,
loader=loader,
device=device,
optimizer=optimizer,
)
dist.destroy_process_group()
if __name__ == "__main__":
print(sys.argv)
import torch, torchrec, fbgemm_gpu, torchvision, triton
modulenames = set(sys.modules) & set(globals())
allmodules = [sys.modules[name] for name in modulenames]
print("-" * 56)
for module in allmodules:
if hasattr(module, "__version__"):
print(module.__name__, module.__version__)
print("-" * 56)
main(sys.argv[1:])
# torchrun --rdzv_endpoint=localhost:29400 --nnodes 1 train.py --mode ddp
Output with DDP:
['train.py', '--mode', 'dmp']
--------------------------------------------------------
fbgemm_gpu 1.2.0+cu128
torchvision 0.22.0.2+cu128
torch 2.7.0.2+cu128
torchrec 1.2.0+cu128
argparse 1.1
triton 3.1.0
--------------------------------------------------------
Using DMP
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:859: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
warnings.warn(
Iter 1, loss 1.14e+00, ∆ hot: 5.25e+00, ∆ cold: 0.00e+00
Iter 2, loss 8.24e-01, ∆ hot: 3.61e+00, ∆ cold: 3.23e+00
Iter 3, loss 1.00e+00, ∆ hot: 3.42e+00, ∆ cold: 3.36e+00
Iter 4, loss 1.03e+00, ∆ hot: 2.72e+00, ∆ cold: 3.64e+00
Iter 5, loss 7.13e-01, ∆ hot: 2.63e+00, ∆ cold: 3.54e+00
Iter 6, loss 7.17e-01, ∆ hot: 2.37e+00, ∆ cold: 3.48e+00
Output with DMP:
torchrun --rdzv_endpoint=localhost:29400 --nnodes 1 train.py --mode dmp
['train.py', '--mode', 'dmp']
--------------------------------------------------------
triton 3.1.0
torchvision 0.22.0.2+cu128
torch 2.7.0.2+cu128
fbgemm_gpu 1.2.0+cu128
argparse 1.1
torchrec 1.2.0+cu128
--------------------------------------------------------
Using DMP
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:859: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
warnings.warn(
/workspace/model-building-framework/.venv/lib/python3.10/site-packages/fbgemm_gpu/split_embedding_codegen_lookup_invokers/lookup_adam.py:39: UserWarning:
[FBGEMM_GPU] NOTE: The training optimizer 'adam' is marked as
EXPERIMENTAL and thus not optimized, in order to reduce code compilation
times and build sizes!
warnings.warn(
Iter 1, loss 1.86e+00, ∆ hot: 5.14e+00, ∆ cold: 0.00e+00
Iter 2, loss 1.39e+00, ∆ hot: 3.94e+00, ∆ cold: 0.00e+00
Iter 3, loss 9.56e-01, ∆ hot: 2.83e+00, ∆ cold: 0.00e+00
Iter 4, loss 6.22e-01, ∆ hot: 3.01e+00, ∆ cold: 0.00e+00
Iter 5, loss 6.56e-01, ∆ hot: 2.58e+00, ∆ cold: 0.00e+00
Iter 6, loss 4.84e-01, ∆ hot: 2.72e+00, ∆ cold: 0.00e+00