mergekit icon indicating copy to clipboard operation
mergekit copied to clipboard

question about 'hidden' gate method for mergekit-moe

Open ZeyuTeng96 opened this issue 5 months ago • 2 comments

Hi there,

a questions about the process of merging different llms into moe.

So, for mergekit-moe, if we use 'hidden' gate method, we have to provide at least one positive prompt for each expert. How those hidden states can make the right routing dicision? Can you provide some explanations for it.

ZeyuTeng96 avatar Jan 17 '24 08:01 ZeyuTeng96

Another question we cannot guarantee all experts have similar hidden state result as the base model. How come the base model's hidden states can be used for gate routing?

ZeyuTeng96 avatar Jan 17 '24 08:01 ZeyuTeng96

The way the positive/negative prompts are used is pretty simple. Routing in Mixtral assigns scores to each expert with a single matrix multiplication, essentially doing a dot product of a vector with the model's hidden state for each expert (for each layer) and using the two with the largest values. The script aims to come up with vectors that have maximal dot products with the hidden states associated with your positive prompts.

The way it actually does this is embarrassingly simple - it just averages the hidden states for all positive prompts given (minus the hidden states for the negative prompts).

As for why the base model's hidden states can be used for gate routing? In part because the embedding, LM head, and all self-attention parameters come from the base model, I made the assumption that the latent space of the base model would be a good approximation of the latent space of the final model. This obviously isn't perfect. Given how compatible the latent spaces of fine tunes of a common base tend to be, though (see the script mergekit-layershuffle for a fun party trick), it's a decent assumption in practice. You could probably do something iterative to get a better approximation but that seems like overkill.

Hope this helps!

cg123 avatar Jan 25 '24 05:01 cg123