litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Gemma 2: `9b` and `27b` versions

Open Andrei-Aksionov opened this issue 1 year ago • 8 comments

Hi there 👋

Fixes #1535

Google released the latest and greatest Gemma model - v2. This time it comes in three sizes:

  • 2b (not yet released)
  • 9b
  • 27b

Based on the technical report and the official implementation here are the main changes that I've spotted:

  1. Embeddings scaler needs to be casted down before applying.
  2. Needs to be careful with attention scores scaler: it's not equal to head_size, but rather n_embd/n_head. In case of Gemma head_size might not be equal to n_embd/n_head.
  3. Logits soft-caping for attention scores and for final logits. Soft-caping is needed only for training (looks like it's more important for a larger model) and not so much for inference. Since flash attention doesn't support soft-caping, it needs to be disabled if not in training mode.
  4. Sliding window attention is used instead of a global window on every odd (idx) layer with half the size.
  5. RMSNorm does downcasting right at the end. That was the behavior before I added support for Gemma v1.
  6. Transformer block now has two more normalization layers: right after attention layer (before residual connection) and right after MLP (also before residual connection). Previously we had norm -> attn -> residual -> norm -> MLP -> residual. Now: norm -> attn -> norm -> residual -> norm -> MLP -> norm -> residual.
  7. Both 9b and 27b use grouped query attention. In Gemma v1 7b had a regular multi-head attention, while 2b variant had multi-query attention (single key-value pair is shared across all query heads).

Andrei-Aksionov avatar Jul 02 '24 16:07 Andrei-Aksionov

Nice summary. I think this touches all the main points. The others (knowledge distillation for the small models; tied embeddings) would not affect the architecture, it's more of a pretraining method. So yeah, looks great! Many thanks for taking this on!

rasbt avatar Jul 02 '24 16:07 rasbt

@Andrei-Aksionov Sliding window attention (an ugly one, but hey, it works)

Cool! We can also add that to the existing Mistral/Mixtral models then 😊

rasbt avatar Jul 05 '24 11:07 rasbt

Cool! We can also add that to the existing Mistral/Mixtral models then

I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it. But after this PR is merged, adding SLA would be just a matter of an additional line in a config.

Andrei-Aksionov avatar Jul 05 '24 11:07 Andrei-Aksionov

I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it.

I think you are right.

But after this PR is merged, adding SLA would be just a matter of an additional line in a config.

Nice!

rasbt avatar Jul 05 '24 11:07 rasbt

Gemma 2 9b/9b-it now has an initial support (with a lot of “scaffolding”).

Generation returns plausible results, but chat does a couple of strange things:

  1. OOM. Don't understand why if a regular generate script consumes ~20 GB. Update: It's not a Gemma specific problem #1558, so it's not a blocker.
  2. Had to use quantization (bnb.nf4) and the model was very restrictive, often didn't want to respond and instead asked to rephrase the question. I know that LlaMA 3, because of a very long training, has saturated bf16 dtype up to the very last digit and thus quantization affects more than other models. Maybe here we have the same (thanks to a “teacher”)? 🤷 Update # 1: if to use generate script with quantization, then I get a proper output. Something else is broken in chat script, besides a higher memory consumption. Update # 2: KV-cache needs to be change to support sliding window attention. Chat script pre-allocates too much memory (up to model.max_seq_length), so a layer with sliding window has a wrong kv-cache.

Anyway, there is a lot of work that needs to be done (besides what I've mentioned above) before I can open this PR for a review:

  • [x] 1. Only final_softcapping affects tests. Need to make tests fail if attention_logit_softcapping is messed up.
  • [x] 2. Deal with all TODOs. The code works, but is very ugly and non-performant.
  • [x] 3. Use torch profiler to make sure that there are no shady device syncs happen in the background.
  • [x] 4. Add code and a test for LitGPT --> HF format conversion.
  • [x] 5. Does the torch.compile work with softcapping? If not, a clear error message needs to be printed.
  • [x] 6. Do a short training as a sanity-check.
  • [x] 7. Figure out what to do with CausalSelfAttention from adapter.py. Tests for adapter don't fail because of # 1.
  • [x] 8. Add support for 27b variant.

Andrei-Aksionov avatar Jul 06 '24 17:07 Andrei-Aksionov

One more thing. Due to time constraints, I didn't test Gemma v2 27b version. Tests are running fine, but it would be nice to check the generated output.

@rasbt could you do this?

Andrei-Aksionov avatar Jul 19 '24 09:07 Andrei-Aksionov

One more thing. Due to time constraints, I didn't test Gemma v2 27b version. Tests are running fine, but it would be nice to check the generated output.

@rasbt could you do this?

Yes, I am happy to do this. The other thing is I will also generate config files for the smaller models

rasbt avatar Jul 19 '24 13:07 rasbt

Works great!

Screenshot 2024-07-19 at 10 04 03 AM

rasbt avatar Jul 19 '24 15:07 rasbt

Based on the config file run, the train and val loss look great. It's a surprisingly low MMLU though. There's nothing wrong with the finetuned model though and it works fine during chat:

Screenshot 2024-07-22 at 3 38 27 PM

(Not 100% sure, but maybe the MMLU scores in the README were created with --num_fewshot greater than 1.

Anyways, I think everything else seems to be fine though and good to merge now right?

rasbt avatar Jul 22 '24 20:07 rasbt

Yep, let's merge.

Andrei-Aksionov avatar Jul 22 '24 21:07 Andrei-Aksionov

Awesome, this is great! Thanks for this amazing PR!

rasbt avatar Jul 22 '24 22:07 rasbt