lightning-thunder
lightning-thunder copied to clipboard
[distributed][Tensor Parallelism] Implement early transform for Column-wise Parallel
this implements a trace transform that converts one or more linear layers into column-wise tensor parallel ones by (1) sharding their weight and bias and (2) inserting needed communication after the modified linear layers.
example
class Model(nn.Module):
def __init__(self, n_in: int, n_hidden: int, n_out: int) -> None:
self.l1 = nn.Linear(n_in, n_hidden)
self.l2 = nn.Linear(n_hidden, n_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.l2(F.gelu(self.l1(x))
device = torch.device(f"cuda:{rank}")
model = Model().to(device)
jitted_model = thunder.jit(model)
colwise_jitted_model = thunder.distributed.convert_module_to_columnwise_parallel(jitted_model, ("l1", "l2"))
x = torch.randn(..., device=device)
y = colwise_jitted_model(x)
assert y.size(1) == n_out
cc @borda @apaz-cli @carmocca @awaelchli @crcrpar
The failures as of https://github.com/Lightning-AI/lightning-thunder/commit/f724a886639bc616c93879435efcbbeb4c8ac2fe look related to https://github.com/Lightning-AI/lightning-thunder/issues/432.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB