lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Support accessing the module reference for the process group

Open carmocca opened this issue 1 year ago • 0 comments

What does this PR do?

Supports

import os
import thunder
import torch
import torch.distributed as torch_dist
from thunder import ThunderModule

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

def fwd_loss(m, x):
    return m(x).sum()

fwd_loss = thunder.jit(fwd_loss)
model = thunder.distributed.ddp(model)

# notice how we cannot do `model.no_sync()` because it's not a ThunderModule
with ThunderModule.no_sync(model):
    out = fwd_loss(model, x)

carmocca avatar Mar 28 '24 03:03 carmocca