torchft icon indicating copy to clipboard operation
torchft copied to clipboard

LocalSGD / DiLoCo support

Open d4l3k opened this issue 11 months ago • 5 comments

This is a tracking issue for adding LocalSGD support into torchft. There's been interest in LocalSGD support and it's something we'd like to be able to support.

This should be fairly straightforward as we can use the Manager + quorum in an outer loop and then use an allreduce only periodically copy of the weights.

Something like:

manager = Manager(...)
model = ...

while True:
    for step in range(local_steps):
        inputs, labels = next(dataloader_iter)
        optimizer.zero_grad()
        criterion(model(inputs), labels).backwards()
        optimizer.step()

    # update quorum and PGs (could overlap with the optimizer steps above)
    manager.step()

    # free gradient memory to make room for averaged weights
    optimizer.zero_grad(set_to_none=True)

    # copy the model weights and start the allreduce mean
    # we need a temporary copy to gracefully handle failures
    params = {}
    for name, param in model.named_parameters():
        copy = param.detach().clone()
        manager.allreduce_grad(copy)
        params[name] = copy

    # this will wait for all transfers to complete succesfully
    if manager.should_commit():
        for name, param in model.named_parameters():
            param.copy_(params[name])
            del params[name]

DiLoCo should be a small modification of this algorithm to use a separate optimizer instead of just averaging the weights

For efficiency we should probably use the DDP reducer on the parameters directly and copy underlying Storage to make a backup copy

References:

  • LocalSGD: https://arxiv.org/abs/2311.08105
  • DiLoCo: https://arxiv.org/abs/2311.08105

d4l3k avatar Dec 13 '24 00:12 d4l3k