fastmoe icon indicating copy to clipboard operation
fastmoe copied to clipboard

No overlapping observed when enabling Smart Scheduling

Open chenyu-jiang opened this issue 1 year ago • 8 comments

Describe the bug I am trying to create a minimal run-able example of Smart Scheduling proposed by the FasterMoE paper. However, when I profile the example using Nsight Systems, it seems that there is no overlapping between the all-to-all communication and expert computation.

Example of the profile result (one of the forward passes): image

By looking at the CUDA API stack trace, it seems that it is indeed running the smart schedule code path: Screenshot 2023-08-04 at 17 17 22

The code I used can be found below. Could you let me know if this is caused by my misusing FastMoE or other issues? Thanks.

To Reproduce The test is done on 2 nodes, each with 8 V100 GPUs. The code I used for the tests: (example.py)

import torch
import os

from fmoe import DistributedGroupedDataParallel as fmoeDDP
from fmoe.transformer import FMoETransformerMLP
from fmoe.gates import SwitchGate

class DummyMoEModel(torch.nn.Module):
    def __init__(self, world_size):
        super().__init__()
        self.non_moe = torch.nn.Sequential(
            torch.nn.Linear(1024, 8192),
            torch.nn.ReLU(),
            torch.nn.Linear(8192, 1024))
        self.moe = FMoETransformerMLP(
            num_expert=1,
            world_size=world_size,
            d_model=1024,
            d_hidden=4096,
            top_k = 1,
        )

    def forward(self, inp):
        torch.cuda.nvtx.range_push("Non-MoE")
        out = self.non_moe(inp)
        torch.cuda.nvtx.range_pop()
        torch.cuda.nvtx.range_push("FMoETransformerMLP")
        out = self.moe(out)
        torch.cuda.nvtx.range_pop()
        return torch.sum(out)

if __name__ == "__main__":
    torch.distributed.init_process_group(backend="nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = torch.distributed.get_world_size()
    torch.cuda.set_device(local_rank)
    model = DummyMoEModel(world_size).to(f"cuda:{local_rank}")
    model = fmoeDDP(model)
    opt = torch.optim.SGD(model.parameters(), lr=0.01)

    for i in range(20):
        inp = torch.randn(8192, 1024).to(f"cuda:{local_rank}")
        opt.zero_grad()
        if i == 10:
            torch.cuda.cudart().cudaProfilerStart()
        out = model(inp)
        if i == 15:
            torch.cuda.cudart().cudaProfilerStop()
        out.backward()

Steps to reproduce the behavior:

  1. Setup the docker environment using image pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel, install FastMoE.
  2. Run the above code with command FMOE_FASTER_SCHEDULE_ENABLE=1 torchrun --nnodes=2 --nproc-per-node=8 --rdzv-id=0 --rdzv-backend=c10d --rdzv-endpoint=xxx.xxx.xx.xx example.py

Expected behavior Overlapping expert computation and all-to-all.

Logs N/A

Platform

  • Device: NVIDIA V100
  • OS: Ubuntu 20.04.5 LTS
  • CUDA version: 11.7
  • NCCL version: 2.14.3-1
  • PyTorch version: 2.0.1

Additional context N/A

chenyu-jiang avatar Aug 04 '23 09:08 chenyu-jiang

I will check it out. However, it looks like you missed to set the FMOE_FASTER_GROUP_SIZE variable.

zms1999 avatar Aug 04 '23 09:08 zms1999

Thanks for the fast reply! I tried to set FMOE_FASTER_GROUP_SIZE=4, but still not seeing any overlap: image

chenyu-jiang avatar Aug 04 '23 09:08 chenyu-jiang

This issue is found to be caused by using default cuda stream which synchronizes all other streams. Simply using another stream in smgr for nccl can solve the problem. Credits to @Harry-Chen for finding the point. Looking forward to a pull request.

laekov avatar Aug 17 '23 00:08 laekov

Hi @chenyu-jiang , I finally found some bugs. I've fixed them in this branch; maybe you can retrace your program on it?

zms1999 avatar Aug 25 '23 05:08 zms1999

Hi @zms1999, extremely sorry for the (very) delayed response.. After the fix, now I can see overlapping in the example program. Thanks a lot for the fix! It is tremendously helpful.

chenyu-jiang avatar Sep 10 '23 12:09 chenyu-jiang

Sorry for bothering again, but I am still running into problems when running the above example code with SwitchGate (i.e., add gate=SwitchGate when initializing FMoETransformerMLP.

The error message is:

Traceback (most recent call last):
  File "example.py", line 47, in <module>
    out = model(inp)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/distributed.py", line 114, in forward
    return self.module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "example.py", line 29, in forward
    out = self.moe(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/transformer.py", line 65, in forward
    output = super().forward(inp)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/layers.py", line 228, in forward
    gate_top_k_idx, gate_score = self.gate(moe_inp)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/gates/switch_gate.py", line 49, in forward
    valid_idx = top1_idx[top1_idx > -1]
RuntimeError: CUDA error: an illegal memory access was encountered

While if the code is run with CUDA_LAUNCH_BLOCKING=1, the error is gone. Could there still be some issue with synchronization?

chenyu-jiang avatar Sep 10 '23 14:09 chenyu-jiang

I guess that you're right because I've already fixed some synchronization bugs. There could be more, I will check next week.

zms1999 avatar Sep 10 '23 15:09 zms1999

The switch gate problem seems to be caused by using then old problematic stream manager in the expert counting and balancing kernels. I put torch stream into smgr and replace the smgr streams in the other places in PR #173 . @zms1999 can u plz have a look?

laekov avatar Sep 11 '23 07:09 laekov