vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Model] Add OLMoE

Open Muennighoff opened this issue 1 year ago • 13 comments

There is still some issue which we have been struggling to track down, maybe @huybery can add more details here.

Muennighoff avatar Aug 27 '24 18:08 Muennighoff

👋 Hi! Thank you for contributing to the vLLM project. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

github-actions[bot] avatar Aug 27 '24 18:08 github-actions[bot]

Hi @Muennighoff are you working on this implementation? It would be great to support OLMoE, especially to test our recently added MoE quantization methods!

mgoin avatar Sep 04 '24 14:09 mgoin

@huybery & I worked on this for quite a bit but did not get it to work well. There is some issue when generating multiple tokens. Maybe you could have a look?

@huybery maybe you can provide more details on what we looked at?

Muennighoff avatar Sep 04 '24 15:09 Muennighoff

I tried to align vllm with the inference of huggingface and found some diffs, the debugging code is as follows.

test file:

from vllm import LLM
llm = LLM(model="OLMoE/OLMoE-1B-7B-0824", enforce_eager=True, tensor_parallel_size=1)
print(llm.generate("Bitcoin is"))

vllm/vllm/model_executor/models/olmoe.py:

from transformers import OlmoeForCausalLM
model = OlmoeForCausalLM.from_pretrained("OLMoE/OLMoE-1B-7B-0824", torch_dtype=torch.bfloat16)

class OlmoeDecoderLayer(nn.Module):

    def __init__(...)


    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)

        print("#-"*20)
        print(positions)
        # print(hidden_states)
        print("#-"*20)

        self_attn = model.model.layers[self.layer_idx].self_attn.cuda()
        hidden_states_old = self_attn(
            hidden_states=hidden_states.unsqueeze(0),
            position_ids=positions.unsqueeze(0)
        )[0][0]
        print("old:", hidden_states_old)


        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        print("new: ", hidden_states)

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual

The log is here: log.txt

I found that vllm's attention is perfectly aligned when decoding first few tokens, but misaligned when generating a new token.

image

I have no idea of a solution at the moment and would appreciate your help! @mgoin

huybery avatar Sep 05 '24 06:09 huybery

@huybery could it be due to possible differences of how silu/swiglu is implemented in Olmoe and the existing FusedMode module?

janimo avatar Sep 05 '24 10:09 janimo

@janimo 🤔 Could you have some debugging info to verify this difference? Contributions are very welcome!

huybery avatar Sep 05 '24 14:09 huybery

@mgoin Would be great to get this in!

Muennighoff avatar Sep 09 '24 17:09 Muennighoff

@Muennighoff Thank you for the PR! Do you mind updating this to rebase off of main?

dsikka avatar Sep 09 '24 19:09 dsikka

@dsikka sure done; do you maybe have bandwidth to check why the model is not working?

Muennighoff avatar Sep 09 '24 19:09 Muennighoff

I ran a GSM8k eval using transformers for a reference and it's still quite a bit lower than the reported accuracy, so I'm not confident the transformers implementation is a strong reference

lm_eval --model vllm --model_args pretrained=allenai/OLMoE-1B-7B-0924-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto      

huggingface (pretrained=allenai/OLMoE-1B-7B-0924-Instruct), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3419|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.3283|±  |0.0129|

mgoin avatar Sep 09 '24 20:09 mgoin

do you mean the transformers version does not match results? it should match results as we used that one for eval; note that for GSM8k, we also use CoT & it is 8-shot which may explain the diff

Muennighoff avatar Sep 09 '24 20:09 Muennighoff

@mgoin Thanks for taking the time to help us debug! But unfortunately at the moment the simplest prompt still don't respond correctly in vllm, so reproducing scores from benchmarks is a much more difficult task. I'll continue to work on it and hopefully more contributors will help! @dsikka @janimo

huybery avatar Sep 10 '24 03:09 huybery

The RMSNorm outputs differ. Fixing that will correct at least some of the differences between the two model attention outputs.

It can be seen by switching forward_native with forward_cuda in vllm/model_executor/layers/layernorm.py so that even in the CUDA case the non-optimized version will run, and it will result in outputs that match the transformers library.

After qkv.split() the k and q tensors are non-contiguous as they are views on qkv. The built-in optimized RMSNorm operator in csrc/layernorm_kernels.cu assumes contiguous tensors, so the first entry in the output, corresponding to the first token is correct (as seen in the screenshot, new and old first lines match), but after that for q it probably calculates the norm based on a slice from k, and for k from a slice of v.

Changing this line q, k = self.q_norm(q), self.k_norm(k) to q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())

will make the outputs match even with the default CUDA RMSNorm operator.

So this should fix at least on of the issues, there are probably more as the final output is still not ok.

janimo avatar Sep 13 '24 13:09 janimo

The weight_loader should be passed name not weight_name, otherwise it silently fails to load the weights in the MoE layer and its output is all zeros.

This is the diff against main. The output text is coherent but still not the same as from transformers, there is still a bug somewhere when calculating attention in the decoding (not prefill) phase.

diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index 6667c8d7..5e6664df 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -32,12 +32,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization.base_config import (
     QuantizationConfig)
 from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.model_executor.layers.sampler import Sampler
+from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors, SamplerOutput
+from vllm.sequence import IntermediateTensors
 from vllm.utils import print_warning_once
 
 
@@ -166,7 +166,7 @@ class OlmoeAttention(nn.Module):
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-        q, k = self.q_norm(q), self.k_norm(k)
+        q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
@@ -387,7 +387,7 @@ class OlmoeForCausalLM(nn.Module):
                     weight_loader = param.weight_loader
                     weight_loader(param,
                                   loaded_weight,
-                                  weight_name,
+                                  name,
                                   shard_id=shard_id,
                                   expert_id=expert_id)
                     break

janimo avatar Sep 16 '24 21:09 janimo

Thanks for the help @janimo! With your patch, I see a match in gsm8k with transformers (which is still lower than expected)

lm_eval --model vllm --model_args pretrained=allenai/OLMoE-1B-7B-0924-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto

vllm (pretrained=allenai/OLMoE-1B-7B-0924-Instruct), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3457|±  |0.0131|
|     |       |strict-match    |     5|exact_match|↑  |0.3313|±  |0.0130|

The output from the model does reasonable for simple chat:

>>> from vllm import LLM
>>> model = LLM("allenai/OLMoE-1B-7B-0924-Instruct")
>>> model.chat([{"role": "user", "content": "What is the capital of Japan?"}])
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.11it/s, est. speed input: 222.40 toks/s, output: 177.89 toks/s]
[RequestOutput(request_id=0, prompt='<|endoftext|><|user|>\nWhat is the capital of Japan?\n<|assistant|>\n', prompt_token_ids=[50279, 29, 93, 4537, 49651, 187, 1276, 310, 253, 5347, 273, 4047, 32, 187, 29, 93, 515, 5567, 49651, 187], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='The capital of Japan is Tokyo, which is also the largest city in Japan.', token_ids=(510, 5347, 273, 4047, 310, 17413, 13, 534, 310, 671, 253, 6253, 2846, 275, 4047, 15), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1726780518.816808, last_token_time=1726780518.816808, first_scheduled_time=1726780518.8316748, first_token_time=1726780518.8722644, time_in_queue=0.014866828918457031, finished_time=1726780518.919081, scheduler_time=0.001896335743367672, model_forward_time=None, model_execute_time=None), lora_request=None)]

Maybe this is good enough? I would need to reproduce an actual reported eval to be sure.

mgoin avatar Sep 19 '24 21:09 mgoin

MMLU matches (and actually slightly exceeds) the reported 51.4! I'm considering this done. Thanks again @janimo

lm_eval --model vllm --model_args pretrained=allenai/OLMoE-1B-7B-0924-Instruct --tasks mmlu --num_fewshot 0 --batch_size auto

vllm (pretrained=allenai/OLMoE-1B-7B-0924-Instruct), gen_kwargs: (None), limit: None, num_fewshot: 0, batch_size: auto
|                 Tasks                 |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|---------------------------------------|------:|------|-----:|------|---|-----:|---|-----:|
|mmlu                                   |      2|none  |      |acc   |↑  |0.5220|±  |0.0040|
| - humanities                          |      2|none  |      |acc   |↑  |0.4812|±  |0.0069|
|  - formal_logic                       |      1|none  |     0|acc   |↑  |0.3730|±  |0.0433|
|  - high_school_european_history       |      1|none  |     0|acc   |↑  |0.6727|±  |0.0366|
|  - high_school_us_history             |      1|none  |     0|acc   |↑  |0.6765|±  |0.0328|
|  - high_school_world_history          |      1|none  |     0|acc   |↑  |0.7257|±  |0.0290|
|  - international_law                  |      1|none  |     0|acc   |↑  |0.6033|±  |0.0447|
|  - jurisprudence                      |      1|none  |     0|acc   |↑  |0.5926|±  |0.0475|
|  - logical_fallacies                  |      1|none  |     0|acc   |↑  |0.6503|±  |0.0375|
|  - moral_disputes                     |      1|none  |     0|acc   |↑  |0.5723|±  |0.0266|
|  - moral_scenarios                    |      1|none  |     0|acc   |↑  |0.2425|±  |0.0143|
|  - philosophy                         |      1|none  |     0|acc   |↑  |0.6013|±  |0.0278|
|  - prehistory                         |      1|none  |     0|acc   |↑  |0.5895|±  |0.0274|
|  - professional_law                   |      1|none  |     0|acc   |↑  |0.4081|±  |0.0126|
|  - world_religions                    |      1|none  |     0|acc   |↑  |0.7836|±  |0.0316|
| - other                               |      2|none  |      |acc   |↑  |0.5826|±  |0.0085|
|  - business_ethics                    |      1|none  |     0|acc   |↑  |0.4900|±  |0.0502|
|  - clinical_knowledge                 |      1|none  |     0|acc   |↑  |0.5849|±  |0.0303|
|  - college_medicine                   |      1|none  |     0|acc   |↑  |0.4740|±  |0.0381|
|  - global_facts                       |      1|none  |     0|acc   |↑  |0.2600|±  |0.0441|
|  - human_aging                        |      1|none  |     0|acc   |↑  |0.5695|±  |0.0332|
|  - management                         |      1|none  |     0|acc   |↑  |0.6893|±  |0.0458|
|  - marketing                          |      1|none  |     0|acc   |↑  |0.7607|±  |0.0280|
|  - medical_genetics                   |      1|none  |     0|acc   |↑  |0.6800|±  |0.0469|
|  - miscellaneous                      |      1|none  |     0|acc   |↑  |0.7344|±  |0.0158|
|  - nutrition                          |      1|none  |     0|acc   |↑  |0.5686|±  |0.0284|
|  - professional_accounting            |      1|none  |     0|acc   |↑  |0.3121|±  |0.0276|
|  - professional_medicine              |      1|none  |     0|acc   |↑  |0.5147|±  |0.0304|
|  - virology                           |      1|none  |     0|acc   |↑  |0.4639|±  |0.0388|
| - social sciences                     |      2|none  |      |acc   |↑  |0.6003|±  |0.0086|
|  - econometrics                       |      1|none  |     0|acc   |↑  |0.2807|±  |0.0423|
|  - high_school_geography              |      1|none  |     0|acc   |↑  |0.6414|±  |0.0342|
|  - high_school_government_and_politics|      1|none  |     0|acc   |↑  |0.7409|±  |0.0316|
|  - high_school_macroeconomics         |      1|none  |     0|acc   |↑  |0.5000|±  |0.0254|
|  - high_school_microeconomics         |      1|none  |     0|acc   |↑  |0.5672|±  |0.0322|
|  - high_school_psychology             |      1|none  |     0|acc   |↑  |0.7138|±  |0.0194|
|  - human_sexuality                    |      1|none  |     0|acc   |↑  |0.6489|±  |0.0419|
|  - professional_psychology            |      1|none  |     0|acc   |↑  |0.5180|±  |0.0202|
|  - public_relations                   |      1|none  |     0|acc   |↑  |0.5182|±  |0.0479|
|  - security_studies                   |      1|none  |     0|acc   |↑  |0.5755|±  |0.0316|
|  - sociology                          |      1|none  |     0|acc   |↑  |0.7463|±  |0.0308|
|  - us_foreign_policy                  |      1|none  |     0|acc   |↑  |0.7600|±  |0.0429|
| - stem                                |      2|none  |      |acc   |↑  |0.4469|±  |0.0086|
|  - abstract_algebra                   |      1|none  |     0|acc   |↑  |0.3200|±  |0.0469|
|  - anatomy                            |      1|none  |     0|acc   |↑  |0.5185|±  |0.0432|
|  - astronomy                          |      1|none  |     0|acc   |↑  |0.6645|±  |0.0384|
|  - college_biology                    |      1|none  |     0|acc   |↑  |0.5903|±  |0.0411|
|  - college_chemistry                  |      1|none  |     0|acc   |↑  |0.4200|±  |0.0496|
|  - college_computer_science           |      1|none  |     0|acc   |↑  |0.4400|±  |0.0499|
|  - college_mathematics                |      1|none  |     0|acc   |↑  |0.3800|±  |0.0488|
|  - college_physics                    |      1|none  |     0|acc   |↑  |0.2059|±  |0.0402|
|  - computer_security                  |      1|none  |     0|acc   |↑  |0.6500|±  |0.0479|
|  - conceptual_physics                 |      1|none  |     0|acc   |↑  |0.4681|±  |0.0326|
|  - electrical_engineering             |      1|none  |     0|acc   |↑  |0.4483|±  |0.0414|
|  - elementary_mathematics             |      1|none  |     0|acc   |↑  |0.3333|±  |0.0243|
|  - high_school_biology                |      1|none  |     0|acc   |↑  |0.6548|±  |0.0270|
|  - high_school_chemistry              |      1|none  |     0|acc   |↑  |0.4335|±  |0.0349|
|  - high_school_computer_science       |      1|none  |     0|acc   |↑  |0.5000|±  |0.0503|
|  - high_school_mathematics            |      1|none  |     0|acc   |↑  |0.3222|±  |0.0285|
|  - high_school_physics                |      1|none  |     0|acc   |↑  |0.3179|±  |0.0380|
|  - high_school_statistics             |      1|none  |     0|acc   |↑  |0.4120|±  |0.0336|
|  - machine_learning                   |      1|none  |     0|acc   |↑  |0.4018|±  |0.0465|

|      Groups      |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------------------|------:|------|------|------|---|-----:|---|-----:|
|mmlu              |      2|none  |      |acc   |↑  |0.5220|±  |0.0040|
| - humanities     |      2|none  |      |acc   |↑  |0.4812|±  |0.0069|
| - other          |      2|none  |      |acc   |↑  |0.5826|±  |0.0085|
| - social sciences|      2|none  |      |acc   |↑  |0.6003|±  |0.0086|
| - stem           |      2|none  |      |acc   |↑  |0.4469|±  |0.0086|

mgoin avatar Sep 19 '24 21:09 mgoin