mergekit
mergekit copied to clipboard
Idea: Scaling the Down-Projection Matrix in 'Mixture of Experts' Models
Problem
In a Mixture of Experts (MoE) LLM, the gating network outputs a categorical distribution of $n$ values (chosen from $n_{max}$), which is then used to create a convex combination of the $n$ outputs of the chosen expert MLP blocks only (eg: $n$=2 and $n_{max}$ = 8 for Mixtral-8x7b
and Mixtral-8x22b
). If the model was trained to choose only the top $n$ experts and we want to change the chosen number of experts to $m$, how should we scale the down-projection matrix of the MLP to maintain the expected norm of the final output?
Solution
For simplicity, let's assume that the output of each expert is an i.i.d. random vector with a norm of $r$ and the gating network outputs a discrete uniform distribution where $g_i = \frac{1}{n}$ for all $i$. The final output is a convex combination of the expert outputs:
$$\vec{S_n} = \sum_{i=1}^n g_i \vec{v_i}$$
The expected norm of this output is:
$$E[|\vec{S_n}|] = rE\left[\left|\sum_{i=1}^n g_i \vec{u_i}\right|\right] = r\sqrt{\sum_{i=1}^n g_i^2} = \frac{r}{\sqrt{n}}$$
NOTE: The last equality holds only for a balanced distribution, where $g_i = \frac{1}{n}$ for all $i$.
If we change the number of experts to $m$, and the gating network outputs a balanced distribution over $m$ experts, the expected norm of the output becomes:
$$E[|\vec{S_m}|] = rE\left[\left|\sum_{i=1}^m g_i \vec{u_i}\right|\right] = r\sqrt{\sum_{i=1}^m g_i^2} = \frac{r}{\sqrt{m}}$$
To make the expected norm of the output with $m$ experts equal to the expected norm of the output with $n$ experts, we need to scale the down-projection matrix of the MLP by a factor of $\sqrt{\frac{n}{m}}$:
$$\vec{v_i}' = \sqrt{\frac{n}{m}} \vec{v_i}$$
With this scaling, the expected norm of the output with $m$ experts becomes:
$$E[|\vec{S_m}|] = rE\left[\left|\sum_{i=1}^m g_i \vec{v_i}'\right|\right] = r\sqrt{\frac{n}{m}}E\left[\left|\sum_{i=1}^m g_i \vec{u_i}\right|\right] = \frac{r}{\sqrt{n}}$$
Which is the same as the expected norm of the output with $n$ experts.
Scale Factor
The scale factor $\sqrt{\frac{n}{m}}$ depends only on the ratio of the original number of experts ($n$) to the new number of experts ($m$). It does not depend on the norm $r$ of the expert outputs (with the given assumptions...).
- When $m > n$, the scale factor $\sqrt{\frac{n}{m}}$ will be less than 1.
- When $m < n$, the scale factor $\sqrt{\frac{n}{m}}$ will be greater than 1.
- When $m = n$, the scale factor $\sqrt{\frac{n}{m}} = 1$.
(sorry for the AI generated text again - but it's so much easier than trying to write all that Latex!)
This all assumes I have correctly understood what the Mixtral-style MoE architecture is doing though (it's not 100% clear from the paper).
If this shows promise then the i.i.d. assumption and the discrete uniform distribution simplification can be removed by sampling the actual outputs of the expert MLPs / gating networks (the i.i.d. assumption can be improved on if we are happy to just guess values for $\rho$ [see the other thread for example], but to use a concrete categorical distribution we would need to sample from it I think).
I'm going to try this on Mixtral-8x7b-Instruct
now and see if it improves the perplexity vs pervious attempts:
https://rentry.org/HowtoMixtral https://old.reddit.com/r/LocalLLaMA/comments/18m6zjz/for_exllamav2_how_many_mixtral_experts_are/
@cg123 I see you already have a parameter called residual_scale
so for the mergekit-moe
merges it should be pretty easy to try scaling the models designed to not be in a MOE by $\frac{1}{\sqrt{m}}$ , etc.