lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

[distributed][Tensor Parallelism] Implement early transform for Column-wise Parallel

Open crcrpar opened this issue 1 year ago • 1 comments

this implements a trace transform that converts one or more linear layers into column-wise tensor parallel ones by (1) sharding their weight and bias and (2) inserting needed communication after the modified linear layers.


example

class Model(nn.Module):
    def __init__(self, n_in: int, n_hidden: int, n_out: int) -> None:
        self.l1 = nn.Linear(n_in, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.l2(F.gelu(self.l1(x))


device = torch.device(f"cuda:{rank}")

model = Model().to(device)
jitted_model = thunder.jit(model)
colwise_jitted_model = thunder.distributed.convert_module_to_columnwise_parallel(jitted_model, ("l1", "l2"))

x = torch.randn(..., device=device)
y = colwise_jitted_model(x)
assert y.size(1) == n_out

cc @borda @apaz-cli @carmocca @awaelchli @crcrpar

crcrpar avatar May 13 '24 13:05 crcrpar

The failures as of https://github.com/Lightning-AI/lightning-thunder/commit/f724a886639bc616c93879435efcbbeb4c8ac2fe look related to https://github.com/Lightning-AI/lightning-thunder/issues/432.

crcrpar avatar May 17 '24 16:05 crcrpar

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB