Gemma 2: `9b` and `27b` versions
Hi there 👋
Fixes #1535
Google released the latest and greatest Gemma model - v2. This time it comes in three sizes:
- 2b (not yet released)
- 9b
- 27b
Based on the technical report and the official implementation here are the main changes that I've spotted:
- Embeddings scaler needs to be casted down before applying.
- Needs to be careful with attention scores scaler: it's not equal to
head_size, but rathern_embd/n_head. In case of Gemmahead_sizemight not be equal ton_embd/n_head. - Logits soft-caping for attention scores and for final logits. Soft-caping is needed only for training (looks like it's more important for a larger model) and not so much for inference. Since
flash attentiondoesn't support soft-caping, it needs to be disabled if not in training mode. - Sliding window attention is used instead of a global window on every odd (idx) layer with half the size.
- RMSNorm does downcasting right at the end. That was the behavior before I added support for Gemma v1.
- Transformer block now has two more normalization layers: right after attention layer (before residual connection) and right after MLP (also before residual connection). Previously we had
norm->attn->residual->norm->MLP->residual. Now:norm->attn->norm->residual->norm->MLP->norm->residual. - Both
9band27buse grouped query attention. In Gemma v17bhad a regular multi-head attention, while2bvariant had multi-query attention (single key-value pair is shared across all query heads).
Nice summary. I think this touches all the main points. The others (knowledge distillation for the small models; tied embeddings) would not affect the architecture, it's more of a pretraining method. So yeah, looks great! Many thanks for taking this on!
@Andrei-Aksionov Sliding window attention (an ugly one, but hey, it works)
Cool! We can also add that to the existing Mistral/Mixtral models then 😊
Cool! We can also add that to the existing Mistral/Mixtral models then
I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it.
But after this PR is merged, adding SLA would be just a matter of an additional line in a config.
I believe only Mistral v0.1 supported sliding window attention, all the subsequent models by Mistral.ai don't use it.
I think you are right.
But after this PR is merged, adding SLA would be just a matter of an additional line in a config.
Nice!
Gemma 2 9b/9b-it now has an initial support (with a lot of “scaffolding”).
Generation returns plausible results, but chat does a couple of strange things:
- OOM. Don't understand why if a regular generate script consumes ~20 GB. Update: It's not a Gemma specific problem #1558, so it's not a blocker.
- Had to use quantization (bnb.nf4) and the model was very restrictive, often didn't want to respond and instead asked to rephrase the question. I know that LlaMA 3, because of a very long training, has saturated bf16 dtype up to the very last digit and thus quantization affects more than other models. Maybe here we have the same (thanks to a “teacher”)? 🤷 Update # 1: if to use generate script with quantization, then I get a proper output. Something else is broken in chat script, besides a higher memory consumption. Update # 2: KV-cache needs to be change to support sliding window attention. Chat script pre-allocates too much memory (up to model.max_seq_length), so a layer with sliding window has a wrong kv-cache.
Anyway, there is a lot of work that needs to be done (besides what I've mentioned above) before I can open this PR for a review:
- [x] 1. Only
final_softcappingaffects tests. Need to make tests fail ifattention_logit_softcappingis messed up. - [x] 2. Deal with all TODOs. The code works, but is very ugly and non-performant.
- [x] 3. Use torch profiler to make sure that there are no shady device syncs happen in the background.
- [x] 4. Add code and a test for LitGPT --> HF format conversion.
- [x] 5. Does the
torch.compilework with softcapping? If not, a clear error message needs to be printed. - [x] 6. Do a short training as a sanity-check.
- [x] 7. Figure out what to do with CausalSelfAttention from adapter.py. Tests for adapter don't fail because of # 1.
- [x] 8. Add support for 27b variant.
One more thing. Due to time constraints, I didn't test Gemma v2 27b version. Tests are running fine, but it would be nice to check the generated output.
@rasbt could you do this?
One more thing. Due to time constraints, I didn't test Gemma v2 27b version. Tests are running fine, but it would be nice to check the generated output.
@rasbt could you do this?
Yes, I am happy to do this. The other thing is I will also generate config files for the smaller models
Works great!
Based on the config file run, the train and val loss look great. It's a surprisingly low MMLU though. There's nothing wrong with the finetuned model though and it works fine during chat:
(Not 100% sure, but maybe the MMLU scores in the README were created with --num_fewshot greater than 1.
Anyways, I think everything else seems to be fine though and good to merge now right?
Yep, let's merge.
Awesome, this is great! Thanks for this amazing PR!