torchtune
torchtune copied to clipboard
[WIP] Gemma3 support.
Context
What is the purpose of this PR? Is it to
- [X] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
- #2484
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
- [X] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [ ] add unit tests for any new functionality
- [X] update docstrings for any new or updated methods or classes
- [X] run unit tests via
pytest tests - [ ] run recipe tests via
pytest tests -m integration_test - [ ] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example
- [ ] I did not change any public API
- [ ] I have added an example to docs or docstrings
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2485
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
- [X] Recipes, recipe registry
- [X] Component builders (with all changes according to the tech report)
- [X] Model builders (with all correct params according to the tech report)
- [x] SigLIP
- [x] Convert weight
- [ ] Multimodal model
- [x] Tokenizer? (I assume that it is almost same as in Gemma2)
- [x] Was able to run 1B without multimodal
- [ ] Manual runs for multimodal versions
Was able to run full on 1B (without multimodal)
Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?
You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374
Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?
You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374
Hey Felipe! Yep, sure, it is still WIP until I will be confident (we will do multimodal runs) and some configs will be fixed by Gemma team
Curious about
some configs will be fixed by Gemma team
any chance it refers to https://github.com/huggingface/transformers/issues/36683 ? That seems to have been patched in Transformers
Curious about
some configs will be fixed by Gemma team
any chance it refers to huggingface/transformers#36683 ? That seems to have been patched in Transformers
Hey @bzz. Exactly, the issue is very similar. Unfortunately, we require more information on the conversion stage from the config, which is missing in 4b config.
- o the actual builders are all text-only. I think that's fine if there's still stuff up in the air (viz. (2)), but wonder if we should do something similar to our other multimodal models: provide
g
- Sure! I will push this changes (with
EarlyFusionafter we will decide on the structure of the 1B vs 4B+ and discuss the blocker) - Speaking about blocker: https://huggingface.co/google/gemma-3-4b-it/discussions/14
The config issue doesn’t seem likely to get resolved anytime soon and doesn't seem to be a big problem. How about ignoring it for now and moving forward?
I agree, let's go forward with this @ebsmothers
Waiting for this, any chance it's going to be merged ? :)
@ebsmothers we can merge without multimodality due to our business with our PRs.
Waiting for this, any chance it's going to be merged ? :)
I assume we will merge it soon
Any updated eta on when this may be merged? cc @ebsmothers @krammnic
Interested in getting this merged, willing to help if needed.
For gemma-3-27b-it, I could not get the logits of this implementation to match the HF version, so I begin to track down where they deviate.
This leads me to a bug in HF-to-torchtune weight conversion here.
According to the implementation of TansformerSelfAttentionLayer, mlp_norm is the norm before FF and mlp_scale is the norm after FF, but the assignment in the above weight conversion code is the opposite.
To fix this, we should simply swap the two assignment lines: pre_feedforward_layernorm.weight should go to mlp_norm and post_feedforward_layernorm.weight should go to mlp_scale.
Let's continue with this. Thanks for the reviews and attention to this PR!
-
gemma3/12B_lora_single_device.yaml & all the other gemma3 12B configs have a typo in checkpoint_files: model-00001-of-00002.safetensors
-
tune run lora_finetune_single_device --config gemma3/12B_lora_single_device throws an exception
torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key raise Exception( Exception: Error converting the state dict. Found unexpected key: "language_model.model.embed_tokens.weight". Please make sure you're loading a checkpoint with the right format.
I haven't spent enough time on this library, so please let me know if there's something I'm missing. Happy to provide more information if needed.