torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[RFC] MOE design in Torchtune

Open acisseJZhong opened this issue 1 year ago • 7 comments

[RFC] MOE design in Torchtune

Background

This RFC proposes adding the MOE support in Torchtune. We want to design in a general way so that components can be easily swapped when implementing different MOE models. An MOE layer directly replaces the dense FFN layer in the transformer decoder layer and has two main components: router and experts.

Expert

An expert is essentially an FFN layer similar to the original dense FFN layer in the transformer decoder layer. There are two kinds of experts: routed experts and shared experts. Each expert in the routed experts specializes in learning certain patterns/aspects, and only part of the routed experts will be activated. On the other hand, shared experts are always activated, aiming at capturing and consolidating common knowledge across varying contexts.

Here's the proposed Experts design in torchtune:

class Experts(nn.Module):
    def __init__(self, dim_in, dim_out, num_experts=1, swiglu=True, nonlinearity=None):
        self.gate_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
        self.down_proj = nn.Parameter(torch.empty(num_experts, dim_out, dim_in))
        if swiglu:
            self.up_proj = nn.Parameter(torch.empty(num_experts, dim_in, dim_out))
            self.act_fn = F.silu()
        else:
            self.up_proj = None
            self.act_fn = nonlinearity

    def forward(self, x, num_local_tokens_per_expert=None):
        '''
        inputs:
            x: input tokens
                shape [bs*slen*experts_per_token, hidden_dim] for TC forward
                shape [num_experts*tokens_per_expert, hidden_dim] for EC forward
            num_local_tokens_per_expert: number of tokens for each expert, only used for TC forward
        outputs:
            out: output tokens
                shape [bs*slen*experts_per_token, hidden_dim] for TC forward
                shape [num_experts*tokens_per_expert, hidden_dim] for EC forward
        '''
        # TC forward
        if num_local_tokens_per_expert is not None:
            # TODO: use cutlass groupGEMM instead of torch.matmul() to optimize performance
            # x shape [bs*slen*experts_per_token, hidden_dim]
            # x_expert_splits shape [num_experts, tokens_per_expert(varying), hidden_dim]
            x_expert_splits = torch.split(x, split_size_or_sections=num_local_tokens_per_expert, dim=0)
            out_expert_splits = []
            for expert_index, x_expert_split in enumerate(x_expert_splits):
                gate_proj = self.gate_proj[expert_index]
                down_proj = self.down_proj[expert_index]
                up_proj = None
                if self.up_proj is not None:
                    up_proj = self.up_proj[expert_index]

                h = self.act_fn(torch.matmul(x_expert_split, gate_proj))
                if up_proj is not None:
                    h = h * torch.matmul(x_expert_split, up_proj)
                # [tokens_per_expert, hidden_dim]
                h = torch.matmul(h, down_proj)

                out_expert_splits.append(h)
            # shape [num_experts * tokens_per_expert(varying), hidden_dim] = [bs*slen*experts_per_token, hidden_dim]
            out = torch.cat(out_expert_splits, dim=0)
        # EC forward
        else:
            # x shape [num_experts, tokens_per_expert, hidden_dim]
            x = x.view(num_experts, -1, dim_in)
            h = self.act_fn(torch.bmm(x, self.gate_proj))
            if self.up_proj is not None:
                h = h * torch.bmm(x, self.up_proj)
            out = torch.bmm(h, self.down_proj).view(-1, dim_in)
        return out

# Expert builder for routed experts
def moe_experts(hidden_dim, model_dim, num_experts, swiglu=True, nonlinearity=None):
    return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=num_experts, swiglu=swiglu, nonlinearity=nonlinearity)

# Single expert / shared expert
def moe_expert(hidden_dim, model_dim, swiglu=True, nonlinearity=None):
    return Experts(dim_in=hidden_dim, dim_out=model_dim, num_experts=1, swiglu=swiglu, nonlinearity=nonlinearity)

Router

Router is a gating network that calculates router scores and learns token-to-expert affinity. There are two types of routing: token choice routing and expert choice routing.

Mixtral uses token choice topK routing, which means each token will select its topK experts. The router is implemented through a learnable gate function, whose outputs will go through softmax and topK. The router then defines how tokens select experts based on router scores.

Here's the proposed Token Choice Routing design in torchtune:

class TokenChoiceTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, experts_per_token):
        self.gate = nn.Linear(hidden_dim, num_experts)
        self.experts_per_token = experts_per_token

    def forward(self, x, use_sigmoid=False):
        '''
        input:
            x: input tokens
                shape [bs*slen, hidden_dim]
        outputs:
            routed_input: tokens gather by selected experts
                shape [bs*slen*experts_per_token, hidden_dim]
            token_indices: token indices sorted by selected experts indices
            num_local_tokens_per_expert: number of tokens assigned to each expert
                shape [num_experts,]
        '''
        # scores shape [bs*slen, num_experts]
        scores = self.gate(x)
        if use_sigmoid:
            scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype)
        else:
            scores = F.softmax(scores.to(softmax_dtype), dim=1).to(x.dtype)

        # TODO: implement load balancing auxiliary loss for token choice routing
        # https://github.com/NVIDIA/Megatron-LM/blob/f1f039224584f0bc6ba89c21ef4f491d7136e3ce/megatron/core/transformer/moe/router.py#L162

        # router scores/indices shape [bs*slen, experts_per_token]
        top_scores, selected_experts_indices = torch.topk(scores, k=self.experts_per_token, dim=1)
        top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)

        # shape [num_experts,]: how many tokens for each expert
        num_local_tokens_per_expert = torch.histc(selected_expert_indices.view(-1), bins=num_experts, min=0, max=num_experts)
        # shape [bs*slen*experts_per_token,]
        token_indices_experts_sorted = torch.argsort(selected_experts_indices.view(-1), stable=True)
        # top_scores shape [bs*slen*experts_per_token,]
        top_scores = top_scores.view(-1)[token_indices_experts_sorted]

        # token_indices shape [bs*slen*experts_per_token, hidden_dim]
        token_indices = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim)
        # routed_input shape [bs*slen*experts_per_token, hidden_dim]
        routed_input = torch.gather(x, dim=0, index=token_indices)
        routed_input = routed_input * top_scores

        return routed_input, token_indices, num_local_tokens_per_expert

However, token choice routing has several pitfalls according to the expert choice paper.

  1. Poor load balance. Experts can become under or over-specialized. Load imbalance can hurt step latency / inference time.
  2. Experts under specialization. Ideally the gating network will learn token-to-expert affinity such that similar or relevant tokens are routed to the same expert. However, a sub-optimal strategy can produce redundant experts and/or experts that are not sufficiently specialized.
  3. Same compute for each token. Token choice will allocate a fixed number of experts to each token regardless of the importance of different tokens. Ideally an MOE model should flexibly allocate compute resources based on the complexity of the input.

Compared to token choice, expert choice topK routing lets experts select its top-k tokens. The ExpertChoiceTopKRouter class routes input tokens to different experts based on the router scores.

Here's the proposed Expert Choice Routing design in torchtune:

class ExpertChoiceTopKRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts):
        self.gate = nn.Linear(hidden_dim, num_experts)
        self.tokens_per_expert = tokens_per_expert

    def forward(self, x, use_sigmoid=False):
        '''
        input:
            x: shape [bs*slen, hidden_dim]
        outputs:
            routed_input: selected tokens
                shape [num_experts*tokens_per_expert, hidden_dim]
            token_indices: selected token indices
            num_local_tokens_per_expert: None
        '''
        # scores shape [num_experts, bs*slen]
        scores = self.gate(x).transpose(0,1)
        if use_sigmoid:
            scores = torch.sigmoid(scores.to(sigmoid_dtype)).to(x.dtype)
        else:
            scores = F.softmax(scores.to(softmax_dtype), dim=0).to(x.dtype)
        # router scores/indices shape [num_experts, tokens_per_expert]
        top_scores, selected_token_indices = torch.topk(scores, k=self.tokens_per_expert, dim=1)

        # apply the token preprocess function and then run experts forward
        token_indices = selected_token_indices.reshape(-1, 1).expand(-1, D)
        # routed input shape [num_experts*tokens_per_expert, hidden_dim]
        routed_input = torch.gather(x, dim=0, index=token_indices)
        routed_input = routed_input * top_scores.reshape(-1, 1)
        return routed_input, token_indices, None,

Moe Layer

An MOE layer consists of experts and routers.

Here's the proposed MoeLayer design in torchtune:

class MoeLayer(nn.Module):
    def __init__(self, router="token_choice"):
        self.experts = moe_experts(hidden_dim, model_dim, num_experts=num_experts)
        self.shared_expert = moe_expert(hidden_dim, model_dim)
        if router == "token_choice":
            self.router = TokenChoiceTopKRouter(hidden_dim, num_experts, experts_per_token)
        elif router == "expert_choice":
            self.router = ExpertChoiceTopKRouter(hidden_dim, num_experts, tokens_per_expert)
        else:
            raise NotImplementedError("This router is not supported yet!")

    def forward(self, x, infernece=False):
        routed_input, token_indices, num_local_tokens_per_expert = self.router(x)
        # routed output shape [num_experts*tokens_per_expert, hidden_dim] for EC, [bs*slen*experts_per_token, hidden_dim] for TC
        routed_output = self.experts(routed_input, num_local_tokens_per_expert=num_local_tokens_per_expert)

        # shared expert
        if use_shared_expert:
            out = self.shared_expert(x)
        else:
            out = torch.zeros_like(x)

        # add experts output
        out.data = scatter_add_(
            out.data,
            routed_output,
            selected_indices,
        )
        return out

Model builder

Besides the above components: experts, routers, and MOE layers, we would need a model builder to pull all pieces together to form the Transformer decoder layer and then Transformer decoder:

Here's the proposed MOE model builder design in torchtune:

def moe(...) -> TransformerDecoder:
    # Build the decoder associated with the moe model. This includes
    # - Token embeddings
    # - num_layers number of TransfomerDecoderLayer block
    # - RMS Norm layer applied to the ouput of the transfomer
    # - Final projection into the token space'
    token_embeddings = nn.Embedding(vocab_size, embed_dim)
    self_attn = MultiHeadAttention()
    moe_layer = MoeLayer(router="token_choice") # or MoeLayer(router="expert_choice")
    norm = RMSNorm(dim=embed_dim)
    layer = TransformerSelfAttentionLayer(attn=self_attn, mlp=moe_layer, sa_norm=norm, mlp_norm=norm)
    output_proj = nn.Linear(embed_dim, vocab_size)
    return TransformerDecoder(
        tok_embeddings=tok_embeddings,
        layers=layer,
        num_layers=num_layers,
        max_seq_len=max_seq_len,
        num_heads=num_heads,
        head_dim=head_dim,
        norm=RMSNorm(dim=embed_dim),
        output=output_proj,
    )

File changes for new modules/functions

torchtune/
    modules/
        moe/
            moe_layers.py
                TokenChoiceTopKRouter()
                ExpertChoiceTopKRouter()
                MoeLayer()
            experts.py
                Experts()
    models/
        moe/
            _component_builders.py
                moe()
                moe_expert()
                moe_experts()

acisseJZhong avatar Oct 24 '24 20:10 acisseJZhong

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1902

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit ed424d8bf5a7ccce3052f0fcb7a5cac6661f2fa2 with merge base dc0591c6c52bf937b701173834a1a4212e0fe89b (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 24 '24 20:10 pytorch-bot[bot]

Thanks for the awesome RFC, really appreciate the detail in the code implementations. One concern I have is that token choice vs expert choice is conceptually spread across both the router and the experts classes, making them pretty tightly coupled, i.e., if I use TokenChoiceRouter I have to make sure I set use_token_choice=False in the expert forward, otherwise I will get totally incorrect results. Also, if I ever want to use the Experts class with a new routing mechanism, I would need add a lot of if-else chunks and more parameters.

I would try to either make the experts entirely routing agnostic (not sure if this is possible, based on your code it seems to affect the forward quite significantly), make separate expert classes for token choice / expert choice, or just combine the expert forward logic and routing logic into one class.

but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts

This is a great point, wouldn't the inference behavior be entirely different anyway? How will we handle different inference logic and potential optimizations?

RdoubleA avatar Oct 27 '24 02:10 RdoubleA

I would try to either make the experts entirely routing agnostic (not sure if this is possible, based on your code it seems to affect the forward quite significantly), make separate expert classes for token choice / expert choice, or just combine the expert forward logic and routing logic into one class.

Thanks for the suggestion. It makes sense! unfortunately I think expert choice and token choice needs to have different forward impl. This is because for expert choice, each expert has tokens_per_expert tokens and it is fixed. However, for token choice, tokens_per_expert is different for each expert. This is also why we passed num_local_tokens_per_expert into token choice forward function.

I think making them into separate expert classes is reasonable. So we will have TokenChoiceExperts and ExpertChoiceExperts. I am hesitant on combining the expert forward logic and routing logic, as this makes things even more complicated and hard to understand.

but at inference time wouldn't you want to have separate parameters to reduce compute and potentially allow tricks like offloading or compressing unused experts? Maybe we could have a method to split/merge the experts

This is a great point, wouldn't the inference behavior be entirely different anyway? How will we handle different inference logic and potential optimizations?

Yeah also thanks for raising this question. I didn't keep inference in mind during the first draft design. I am discussing/consulting with Jie more about MOE inference.

acisseJZhong avatar Oct 28 '24 17:10 acisseJZhong

beautiful rfc, looking forward to this

silentlustre avatar Nov 14 '24 22:11 silentlustre

Has there been any progress on this? Would really like to fine-tune the new Qwen3 MoE models. Be happy to contribute in any way to take this forward.

prvnsmpth avatar May 03 '25 15:05 prvnsmpth

@prvnsmpth thanks for your interest! (I am just reviewing your Qwen3 PR now.) We have now landed most of these components in torchtune -- one slight difference from this RFC is that our MoE components use token choice routing, not expert choice routing. You can see the various components here. Happy to discuss any modifications we'd need to make in order to support the Qwen MoE models

ebsmothers avatar May 05 '25 14:05 ebsmothers

Thanks @ebsmothers ! I'll dig into the implementation and see how we can support the Qwen MoE models.

prvnsmpth avatar May 06 '25 03:05 prvnsmpth

Closing this now that all our MoE components have landed

ebsmothers avatar Jun 04 '25 16:06 ebsmothers