TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

fp8_model_init doesn't work with DDP

Open MaciejBalaNV opened this issue 1 year ago • 3 comments

When I'm trying to use fp8_model_init feature, it doesn't seem compatible with DDP. It throws an error: RuntimeError: Modules with uninitialized parameters can't be used with "DistributedDataParallel". Run a dummy forward pass to correctly initialize the modules

Running a dummy forward pass doesn't help, using reset_parameters doesn't help either. Using a separate stream for DDP also does not fix this issue.

A simple reproducible case:

import os
import torch
import torch.nn as nn
import functools
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy
import transformer_engine as te

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12364"

    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = te.pytorch.Linear(1024, 1024)
        self.fc2 = te.pytorch.Linear(1024, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


def fsdp_main(rank, world_size):
    setup(rank, world_size)

    torch.cuda.set_device(rank)


    with te.pytorch.fp8.fp8_model_init(enabled=True):
        model = Net().to(rank)
    for i, m in enumerate(model.modules()):
        if hasattr(m, "reset_parameters"):
            print(f"resetting {i}")
            m.reset_parameters()
    input_data = torch.randn((16, 1024)).cuda()
    with torch.no_grad():
        model(input_data)
    torch.cuda.synchronize()
    model = DDP(model)
    torch.cuda.synchronize()

    dist.barrier()
    cleanup()


if __name__ == "__main__":
    WORLD_SIZE = 8
    mp.spawn(fsdp_main, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)

@denera

MaciejBalaNV avatar Aug 26 '24 09:08 MaciejBalaNV

Do you need both DDP and FP8 params for your use-case? We haven't considered this combination so far since optimizing FP8 params tends to have poor convergence. There are a few ways to proceed:

  • Initialize your weights in higher precision and rely on TE's automatic FP8 casting. If your training loop involves multiple grad accumulation steps, you can pass in is_first_microbatch=True/False to cache the FP8 weights.
  • Maintain a separate set of FP32 master params for the grad all-reduce and optimizer. This is generally how DDP is implemented in Megatron-LM and NeMo.
  • Use FSDP mixed precision support to store the sharded weights in FP32 and the gathered weights in FP8. This isn't supported with TE yet, but FSDP has some callback hooks where we can add FP8-related logic (see fsdp_pre_all_gather and fsdp_post_all_gather).
  • DDP doesn't actually need the param values, just the grads. We could debug this case and figure out a way to bypass this error.

timmoon10 avatar Aug 26 '24 18:08 timmoon10

@MaciejBalaNV Transformer Engine modules that are initialized under te.pytorch.fp8_model_init() still need to be executed under te.pytorch.fp8_autocast() with an FP8 recipe for operations that we have to perform in higher precision. Missing this context might be the reason why the model parameters were not correctly initialized in your case, and if so, we should definitely catch that and show a useful error message.

For reference, here's a modified version of your DDP example that works correctly on my end:

import os
import socket
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import transformer_engine as te


class BasicMLP(nn.Module):
    """Basic MLP block"""

    def __init__(self, hidden_size, ffn_hidden_size, **kwargs):
        super().__init__()
        tp_group = kwargs.pop("tp_group", None)
        parallel_mode = kwargs.pop("parallel_mode", None)
        fc1_parallel_mode = fc2_parallel_mode = parallel_mode
        if tp_group is not None:
            fc1_parallel_mode = "row"
            fc2_parallel_mode = "column"
        self.fc1 = te.pytorch.Linear(hidden_size, ffn_hidden_size,parallel_mode=fc1_parallel_mode,
                                     **kwargs)
        self.fc2 = te.pytorch.Linear(ffn_hidden_size, hidden_size, parallel_mode=fc2_parallel_mode,
                                     **kwargs)

    def forward(self, x):
        """Forward pass: FC2(act_fn(FC1(x)))"""
        return self.fc2(self.fc1(x))


def _ddp_main(rank, world_size, num_replicas):
    SEQ_LENGTH = 512
    BATCH_SIZE = 2
    HIDDEN_SIZE = 256
    FFN_HIDDEN_SIZE = 4 * HIDDEN_SIZE

    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = socket.gethostname()
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(rank)

    if num_replicas == 1:
        dp_group = None
        tp_group = dist.new_group()
    elif num_replicas == world_size:
        dp_group = dist.new_group()
        tp_group = None
    else:
        assert num_replicas > 0 and num_replicas < world_size and world_size % num_replicas == 0
        replica_size = world_size // num_replicas
        mesh_2d = dist.init_device_mesh("cuda", (num_replicas, replica_size))
        dp_group, tp_group = mesh_2d.get_all_groups()

    with te.pytorch.fp8.fp8_model_init(enabled=True):
        model = BasicMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, tp_group=tp_group)

    if dp_group is not None:
        model = DDP(model, process_group=dp_group)
    optim = torch.optim.Adam(model.parameters())

    for _ in range(10):
        optim.zero_grad()
        input_data = torch.randn((SEQ_LENGTH, BATCH_SIZE, HIDDEN_SIZE), device="cuda")
        with te.pytorch.fp8_autocast(enabled=True):
            output = model(input_data)
        loss = output.sum()
        loss.backward()
        optim.step()

    dist.destroy_process_group()


if __name__ == "__main__":
    NUM_REPLICAS = 2

    if "TORCHELASTIC_RUN_ID" in os.environ:
        # Using the `torchrun` utility
        WORLD_RANK = int(os.getenv("RANK", "0"))
        WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
        _ddp_main(WORLD_RANK, WORLD_SIZE, NUM_REPLICAS)

    else:
        WORLD_SIZE = 8
        mp.spawn(_ddp_main, args=(WORLD_SIZE, 2), nprocs=WORLD_SIZE, join=True)

denera avatar Aug 26 '24 20:08 denera

@timmoon10 The use case is to use TP/SP without FSDP (it's problematic for many reason, fp8_model_init not working there is one of them) for large model training, while still utilizing data parallel through DDP. Are you suggesting that there are better option to achieve data parallelism than DDP?

@denera I don't think the error is because of the lact of te.pytorch.fp8_autocast() - if we delete the forward pass before DDP wrapping, the error still happens at the wrapping stage. I only included this forward pass to try to answer the error message, which suggested running a dummy forward. Thanks for this example - I've played around with it and found that it only works on nightly build with PyTorch 2.5 and TE 1.10. It still breaks with the same error message on 24.07 PyTorch container (which has 2.4 PyTorch), even if I reinstall TE to 1.10 or 1.11 version. Seems like something changed in PyTorch very recently then, in which case I'm not sure if any fixes are necessary on TE side.

MaciejBalaNV avatar Aug 27 '24 13:08 MaciejBalaNV