Add Gemma3 VL model
[!IMPORTANT]
TheUpdate branchbutton must only be pressed in very rare occassions. An outdated branch is never blocking the merge of a PR. Please reach out to the automation team before pressing that button.
What does this PR do ?
Add Gemma3 VL model implementation
Collection: [Note which collection this PR will affect]
Changelog
- Add specific line by line info of high level changes in this PR.
Usage
# Run gemma3 1B text only model generate
python scripts/llm/gemma3_generate.py
# Run gemma3 1B text only model pretrain
python scripts/llm/gemma3_pretrain.py
# Run gemma3 VL 4B model finetune
python scripts/vlm/gemma3vl_finetune.py
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR. To re-run CI remove and add the label again. To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you add or update any necessary documentation?
- [ ] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
- [ ] Reviewer: Does the PR have correct import guards for all optional libraries?
PR Type:
- [ ] New Feature
- [ ] Bugfix
- [ ] Documentation
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed. Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information
- Related to # (issue)
Thanks @xiangxu-google for the PR. I will try to verify this on our end as well
Hi @xiangxu-google , I wonder have you tested the inference and compare the results between NeMo's and HF's ? I observed that for 1B model at least, NeMo output is gibberish
Hi the issue is fixed, verified by python scripts/llm/gemma3_generate.py
Thanks @xiangxu-google , could you summarize what you have changed ? It is hard to see your change with the force merge. I also managed to review your change and made few modification, based on your original commit b4f29d7, which seems to resolve the issue as well:
Here is what I have changed:
_is_local_attn_layerimplementation is incorrect. (Always return True) fixed thisGemma3RotaryEmbeddingimplementation is incorrect. (rope_localandrope_global) Always returning the same value (I think due to the cache)- Remove
Gemma3TERowParallelLinearLayerNorm, and reuse the class we have for Gemma2 - Based on the HF implementation
rope_scaling_factorshould not be performed in local attention. Fixed this - During inference, the embedding is not updated with the Gemma3 implementation
- Some Doc Fix
- Made changes to work with the latest MCORE (to always add BOS for gemma3 inference)
I have my change here https://github.com/NVIDIA/NeMo/tree/pr-13536 since I cannot push to this PR.
Let me know what you think
At this point, I have only verified the LLM part not the VLM model yet. But here is what I plan to do:
- Verify the VL model
- Add Unit test / Conversion test/
- Add our internal CI tests I am hoping to get this model in for the next container release. Would love to get your input on this.
@suiyoubi Yeah the major two bugs that I fixed were:
- The local/global layer calculator was wrong
- The
model.embedding = Gemma3LanguageModelEmbedding(...)was not executed because of no "pre_process" and "post_process" args passed in during inference.
I will also test other functionalities, thanks for your support to this PR!
Thank you @xiangxu-google , I worked on the VLM a bit and found several issues with the importer&exporter and fixed them in the https://github.com/NVIDIA/NeMo/pull/13582 .
I added one more VL inference script , and currently it is working as expected: I was able to generate reasonable results with image input.
There are still quite a lot of tasks to do in order to have this feature merge, mostly relating to testings.
Moving forward let's discuss in that PR to keep us in the same page. I'll close this PR. Feel free to leave any comment/suggestion in the new PR. Thank you !
Hi @suiyoubi nice to talk to you offline. As we discussed, I pushed a patching commit reusing your fix in #13582 for the VL model importer/exporter. This PR is ready for review now.
Thanks @xiangxu-google , could you help fix the pylint & flake8 & copyright issue?
Thanks @xiangxu-google for preparing this PR and fixing all the linter issue. I left few comments on the issues.
Also I wonder if the convergence of the pretraining for LLM/VLM is tested ?
Hi @suiyoubi I've resolved your comments. I tested the pretraining convergence using a small subset of the Fineweb dataset, it worked well. I didn't test for VL because I don't have a VL dataset around. Our team will use this VL model for real training tasks in the future, so I can report back if got data points.
Hi @xiangxu-google , thanks for addressing the comments. I have one question relating to Gemma model in general: Do you know why Gemma models are sensitive to the BOS token ? I found that without the BOS token, the generated results are prone to be very low quality. Do you have a clue for this ?
I also did some testing with gemma3vl_finetune.py. I think there might be some edge case did not handle properly by _process_sequence_parallel.
The script would fail with TP=2 PP=2 SP=True: combined_embedding would be None for the second group of PP. Do you think we should change to something like :
# https://github.com/NVIDIA/NeMo/blob/2e89c0553aa25baf4e5cf8e0d13bc577bc900770/nemo/collections/vlm/gemma3vl/model/base.py#L525
# After doing CP, the shape needs to be (T / CP, B, D)
# If not using CP, the shape needs to be (T, B, D).
if combined_embedding is not None:
combined_embedding = combined_embedding.transpose(1, 0).contiguous()
if self.is_sequence_parallel and self.pre_process:
# (T / (CP * TP), B, D)
combined_embedding = scatter_to_sequence_parallel_region(combined_embedding)
Hi @suiyoubi
For BOS, it's because Gemma3 was PT and IT trained with this leading token. In the tech report Section 3, it also indicates the BOS token is explicitly needed.
For Context Parallel, after fixing the issue you mentioned, it still fails due to:
[rank1]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention.py", line 5212, in forward
[rank1]: output = attn_forward_func_with_cp(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention.py", line 3572, in attn_forward_func_with_cp
[rank1]: assert not sliding_window_attn or cp_comm_type in [
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: The context parallel running configs cannot support sliding window attetnion!
TE's attention doesn't support enabling sliding window and CP together. I planned to switch to megatron's SDPA when CP enabled, but it doesn't support sliding window or CP. An alternative is to enable CP for global layers and disable CP for local layers, but it requires non-trivial changes and I need to think how to do that and won't be implemented in this PR. So CP can't be enabled for this model for now.
I won't remove those CP related code in gemma3vl, because they might be useful in the future if TE supports sliding window + CP. I added a check in gemma3 instead to raise failure when CP>1.
Thanks @xiangxu-google:
TE's attention doesn't support enabling sliding window and CP together. Thats fine. Adding the assertion for checking not using CP is good. Thanks.
Thanks for your explanation on the BOS.
The PR looks good to me now. Thanks for all the efforts.
In order to have this merged, we will need to have the CICD passed. I am going to run this (might take up to a day), and since we are not really writing any test at this point, I am assuming it should pass. I'll add the test in a separate PR.
@suiyoubi @xiangxu-google I had to merge main into this PR to resolve CI failures during the first run. Looks like DCO issues still existing on some commits. Please resolve the DCO issue.