lightning-thunder
lightning-thunder copied to clipboard
Support accessing the module reference for the process group
What does this PR do?
Supports
import os
import thunder
import torch
import torch.distributed as torch_dist
from thunder import ThunderModule
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
torch_dist.init_process_group(backend="nccl")
pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)
def fwd_loss(m, x):
return m(x).sum()
fwd_loss = thunder.jit(fwd_loss)
model = thunder.distributed.ddp(model)
# notice how we cannot do `model.no_sync()` because it's not a ThunderModule
with ThunderModule.no_sync(model):
out = fwd_loss(model, x)