feat: add gemma7b support
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
This PR adds support for gemma 7b: #969
Changelog
Minimal changes, simply adding gemma7b configs.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
- [ ] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [ ] add unit tests for any new functionality
- [ ] update docstrings for any new or updated methods or classes
- [ ] run unit tests via
pytest tests - [ ] run recipe tests via
pytest tests -m integration_test - [ ] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/971
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
:white_check_mark: No Failures
As of commit 0dc0c23bf06c903e49e931df48baa50fc141d5c0 with merge base dc2b9911cf5a793a8b4f9cd4df321c5271745b4a ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
[WIP] I need to run the training on alpaca and check that everything is working but my GPU is busy at the moment.
@joecummings should I add unit tests for this PR ?
I am new to using pre-commit. When running pre-commit run --all-files I see
trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed don't commit to branch...................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Failed
hook id: ufmt files were modified by this hook Formatted /mnt/datasets/mytorchtune/torchtune/torchtune/models/gemma/init.py ✨ 1 file formatted, 191 files already formatted ✨
Could anyone tell me what ufmt format is?
@joecummings should I add unit tests for this PR ?
Whoops, I keep overwriting instead of quote and reply. Let's just start with W&B run first.
I am new to using pre-commit. When running pre-commit run --all-files I see
trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed don't commit to branch...................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Failed
hook id: ufmt files were modified by this hook Formatted /mnt/datasets/mytorchtune/torchtune/torchtune/models/gemma/init.py ✨ 1 file formatted, 191 files already formatted ✨
Could anyone tell me what ufmt format is?
UFMT formats your code for you. So the above "error" just means that the current code was formatted and you'll have to use git add to add the newly formatted files to your commit.
Thanks for hopping on this so quickly!!
@joecummings I have a working version. Things have been slightly more complicated than expected because there was a silent bug in the gemma architecture.
I ran the qlora single gpu pipeline for 1 epoch, please find attached the logs. log_1715681624.txt
I could not run the full training pipeline because of OOM on my GPU.
Please let me know what is left to be done to check that everything works as expected!
Hi @Optimox can you elaborate on the silent bug? I think @kartikayk mentioned that for Gemma 7B it may be the case that embed_dim != head_dim * num_heads, is it related to that?
@joecummings @ebsmothers would any one of you accept to review this PR? Let me know if I need to add something! Thanks!
@joecummings @ebsmothers would any one of you accept to review this PR? Let me know if I need to add something! Thanks!
Yep, looking today
Hey, is there any update on this? It would be great if this was added soon.
@1Krypt0 I'll let @joecummings review the PR, but I was generally curious about why you need Gemma 7B? Looking at benchmarks seems like Mistral 7B and Llama3 are very competitive and have better community support (inference etc). So was curious about the use of Gemma 7B. Would you be able to say more about your use case? Benchmarks are definitely not comprehensive and so it would be nice to learn a bit more about the kind of use cases where Gemma shines.
@1Krypt0 I'll let @joecummings review the PR, but I was generally curious about why you need Gemma 7B? Looking at benchmarks seems like Mistral 7B and Llama3 are very competitive and have better community support (inference etc). So was curious about the use of Gemma 7B. Would you be able to say more about your use case? Benchmarks are definitely not comprehensive and so it would be nice to learn a bit more about the kind of use cases where Gemma shines.
@kartikayk Yeah, of course! It's for my Master's dissertation, I am assessing the capability of multiple models on "large" document summarization (around the 8k token limit). It just so happens that 3 of the models I would like to test are Mistral-7B, Llama3-8B, and Gemma-7B, both in the "base" form and fine-tuned, to see if there is any improvement, as they seem to be some of the more popular and capable "small" models, and are at the limit of what my GPU can handle as well.
@1Krypt0 sounds like a very interesting dissertation! We'd love to learn more when you have some results :) In the mean time we'll review ASAP
@ebsmothers I've just pushed a new version which takes all your points into consideration. I think the only thing missing is the run of the other pipelines (only ran qlora single device)
@ebsmothers I updated the download command as asked! Thank you for your careful review!
Thank you @ebsmothers for the last commit!