torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[WIP] Gemma3 support.

Open krammnic opened this issue 8 months ago • 18 comments

Context

What is the purpose of this PR? Is it to

  • [X] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • #2484

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [X] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [X] update docstrings for any new or updated methods or classes
  • [X] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [ ] manually run any new or modified recipes with sufficient proof of correctness
  • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example

  • [ ] I did not change any public API
  • [ ] I have added an example to docs or docstrings

krammnic avatar Mar 12 '25 11:03 krammnic

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2485

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Mar 12 '25 11:03 pytorch-bot[bot]

  • [X] Recipes, recipe registry
  • [X] Component builders (with all changes according to the tech report)
  • [X] Model builders (with all correct params according to the tech report)
  • [x] SigLIP
  • [x] Convert weight
  • [ ] Multimodal model
  • [x] Tokenizer? (I assume that it is almost same as in Gemma2)
  • [x] Was able to run 1B without multimodal
  • [ ] Manual runs for multimodal versions

krammnic avatar Mar 12 '25 11:03 krammnic

Was able to run full on 1B (without multimodal)

krammnic avatar Mar 12 '25 17:03 krammnic

Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?

You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374

felipemello1 avatar Mar 12 '25 18:03 felipemello1

Hey Mark, great work! As a sanity check, do you think you could compare vs HF for correctness with a larger sentence?

You can follow this as a start: https://gist.github.com/felipemello1/e3f1b1c358e145c7a4d610cf44cca374

Hey Felipe! Yep, sure, it is still WIP until I will be confident (we will do multimodal runs) and some configs will be fixed by Gemma team

krammnic avatar Mar 12 '25 18:03 krammnic

Curious about

some configs will be fixed by Gemma team

any chance it refers to https://github.com/huggingface/transformers/issues/36683 ? That seems to have been patched in Transformers

bzz avatar Apr 04 '25 08:04 bzz

Curious about

some configs will be fixed by Gemma team

any chance it refers to huggingface/transformers#36683 ? That seems to have been patched in Transformers

Hey @bzz. Exactly, the issue is very similar. Unfortunately, we require more information on the conversion stage from the config, which is missing in 4b config.

krammnic avatar Apr 06 '25 07:04 krammnic

  1. o the actual builders are all text-only. I think that's fine if there's still stuff up in the air (viz. (2)), but wonder if we should do something similar to our other multimodal models: provide g
  1. Sure! I will push this changes (with EarlyFusion after we will decide on the structure of the 1B vs 4B+ and discuss the blocker)
  2. Speaking about blocker: https://huggingface.co/google/gemma-3-4b-it/discussions/14

krammnic avatar Apr 16 '25 17:04 krammnic

The config issue doesn’t seem likely to get resolved anytime soon and doesn't seem to be a big problem. How about ignoring it for now and moving forward?

pocca2048 avatar May 21 '25 01:05 pocca2048

I agree, let's go forward with this @ebsmothers

krammnic avatar May 28 '25 13:05 krammnic

Waiting for this, any chance it's going to be merged ? :)

tsvisab avatar Jun 04 '25 10:06 tsvisab

@ebsmothers we can merge without multimodality due to our business with our PRs.

krammnic avatar Jun 04 '25 13:06 krammnic

Waiting for this, any chance it's going to be merged ? :)

I assume we will merge it soon

krammnic avatar Jun 04 '25 13:06 krammnic

Any updated eta on when this may be merged? cc @ebsmothers @krammnic

bradhilton avatar Jun 25 '25 14:06 bradhilton

Interested in getting this merged, willing to help if needed.

rlrs avatar Jul 08 '25 11:07 rlrs

For gemma-3-27b-it, I could not get the logits of this implementation to match the HF version, so I begin to track down where they deviate. This leads me to a bug in HF-to-torchtune weight conversion here. According to the implementation of TansformerSelfAttentionLayer, mlp_norm is the norm before FF and mlp_scale is the norm after FF, but the assignment in the above weight conversion code is the opposite. To fix this, we should simply swap the two assignment lines: pre_feedforward_layernorm.weight should go to mlp_norm and post_feedforward_layernorm.weight should go to mlp_scale.

mistysesame avatar Jul 15 '25 20:07 mistysesame

Let's continue with this. Thanks for the reviews and attention to this PR!

krammnic avatar Jul 17 '25 11:07 krammnic

  1. gemma3/12B_lora_single_device.yaml & all the other gemma3 12B configs have a typo in checkpoint_files: model-00001-of-00002.safetensors

  2. tune run lora_finetune_single_device --config gemma3/12B_lora_single_device throws an exception

torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key raise Exception( Exception: Error converting the state dict. Found unexpected key: "language_model.model.embed_tokens.weight". Please make sure you're loading a checkpoint with the right format.

I haven't spent enough time on this library, so please let me know if there's something I'm missing. Happy to provide more information if needed.

rhn19 avatar Jul 18 '25 19:07 rhn19