NeMo icon indicating copy to clipboard operation
NeMo copied to clipboard

Add Gemma3 VL model

Open xiangxu-google opened this issue 7 months ago • 15 comments

[!IMPORTANT]
The Update branch button must only be pressed in very rare occassions. An outdated branch is never blocking the merge of a PR. Please reach out to the automation team before pressing that button.

What does this PR do ?

Add Gemma3 VL model implementation

Collection: [Note which collection this PR will affect]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

# Run gemma3 1B text only model generate
python scripts/llm/gemma3_generate.py

# Run gemma3 1B text only model pretrain
python scripts/llm/gemma3_pretrain.py

# Run gemma3 VL 4B model finetune
python scripts/vlm/gemma3vl_finetune.py

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR. To re-run CI remove and add the label again. To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • [ ] Make sure you read and followed Contributor guidelines
  • [ ] Did you write any new necessary tests?
  • [ ] Did you add or update any necessary documentation?
  • [ ] Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • [ ] Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • [ ] New Feature
  • [ ] Bugfix
  • [ ] Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed. Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

xiangxu-google avatar May 12 '25 00:05 xiangxu-google

Thanks @xiangxu-google for the PR. I will try to verify this on our end as well

suiyoubi avatar May 13 '25 13:05 suiyoubi

Hi @xiangxu-google , I wonder have you tested the inference and compare the results between NeMo's and HF's ? I observed that for 1B model at least, NeMo output is gibberish

suiyoubi avatar May 13 '25 14:05 suiyoubi

Hi the issue is fixed, verified by python scripts/llm/gemma3_generate.py

xiangxu-google avatar May 13 '25 21:05 xiangxu-google

Thanks @xiangxu-google , could you summarize what you have changed ? It is hard to see your change with the force merge. I also managed to review your change and made few modification, based on your original commit b4f29d7, which seems to resolve the issue as well:

Here is what I have changed:

  1. _is_local_attn_layer implementation is incorrect. (Always return True) fixed this
  2. Gemma3RotaryEmbedding implementation is incorrect. (rope_local and rope_global ) Always returning the same value (I think due to the cache)
  3. Remove Gemma3TERowParallelLinearLayerNorm, and reuse the class we have for Gemma2
  4. Based on the HF implementation rope_scaling_factor should not be performed in local attention. Fixed this
  5. During inference, the embedding is not updated with the Gemma3 implementation
  6. Some Doc Fix
  7. Made changes to work with the latest MCORE (to always add BOS for gemma3 inference)

I have my change here https://github.com/NVIDIA/NeMo/tree/pr-13536 since I cannot push to this PR.

Let me know what you think

suiyoubi avatar May 13 '25 21:05 suiyoubi

At this point, I have only verified the LLM part not the VLM model yet. But here is what I plan to do:

  1. Verify the VL model
  2. Add Unit test / Conversion test/
  3. Add our internal CI tests I am hoping to get this model in for the next container release. Would love to get your input on this.

suiyoubi avatar May 13 '25 21:05 suiyoubi

@suiyoubi Yeah the major two bugs that I fixed were:

  1. The local/global layer calculator was wrong
  2. The model.embedding = Gemma3LanguageModelEmbedding(...) was not executed because of no "pre_process" and "post_process" args passed in during inference.

I will also test other functionalities, thanks for your support to this PR!

xiangxu-google avatar May 13 '25 22:05 xiangxu-google

Thank you @xiangxu-google , I worked on the VLM a bit and found several issues with the importer&exporter and fixed them in the https://github.com/NVIDIA/NeMo/pull/13582 .

I added one more VL inference script , and currently it is working as expected: I was able to generate reasonable results with image input.

There are still quite a lot of tasks to do in order to have this feature merge, mostly relating to testings.

Moving forward let's discuss in that PR to keep us in the same page. I'll close this PR. Feel free to leave any comment/suggestion in the new PR. Thank you !

suiyoubi avatar May 14 '25 15:05 suiyoubi

Hi @suiyoubi nice to talk to you offline. As we discussed, I pushed a patching commit reusing your fix in #13582 for the VL model importer/exporter. This PR is ready for review now.

xiangxu-google avatar May 14 '25 20:05 xiangxu-google

Thanks @xiangxu-google , could you help fix the pylint & flake8 & copyright issue?

suiyoubi avatar May 14 '25 21:05 suiyoubi

Thanks @xiangxu-google for preparing this PR and fixing all the linter issue. I left few comments on the issues.

Also I wonder if the convergence of the pretraining for LLM/VLM is tested ?

suiyoubi avatar May 14 '25 22:05 suiyoubi

Hi @suiyoubi I've resolved your comments. I tested the pretraining convergence using a small subset of the Fineweb dataset, it worked well. I didn't test for VL because I don't have a VL dataset around. Our team will use this VL model for real training tasks in the future, so I can report back if got data points.

xiangxu-google avatar May 15 '25 00:05 xiangxu-google

Hi @xiangxu-google , thanks for addressing the comments. I have one question relating to Gemma model in general: Do you know why Gemma models are sensitive to the BOS token ? I found that without the BOS token, the generated results are prone to be very low quality. Do you have a clue for this ?

suiyoubi avatar May 15 '25 14:05 suiyoubi

I also did some testing with gemma3vl_finetune.py. I think there might be some edge case did not handle properly by _process_sequence_parallel.

The script would fail with TP=2 PP=2 SP=True: combined_embedding would be None for the second group of PP. Do you think we should change to something like :

        # https://github.com/NVIDIA/NeMo/blob/2e89c0553aa25baf4e5cf8e0d13bc577bc900770/nemo/collections/vlm/gemma3vl/model/base.py#L525
        # After doing CP, the shape needs to be (T / CP, B, D)
        # If not using CP, the shape needs to be (T, B, D).
        if combined_embedding is not None:
            combined_embedding = combined_embedding.transpose(1, 0).contiguous()

            if self.is_sequence_parallel and self.pre_process:
                # (T / (CP * TP), B, D)
                combined_embedding = scatter_to_sequence_parallel_region(combined_embedding)

suiyoubi avatar May 15 '25 15:05 suiyoubi

Hi @suiyoubi

For BOS, it's because Gemma3 was PT and IT trained with this leading token. In the tech report Section 3, it also indicates the BOS token is explicitly needed.

For Context Parallel, after fixing the issue you mentioned, it still fails due to:

[rank1]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention.py", line 5212, in forward
[rank1]:     output = attn_forward_func_with_cp(
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention.py", line 3572, in attn_forward_func_with_cp
[rank1]:     assert not sliding_window_attn or cp_comm_type in [
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: The context parallel running configs cannot support sliding window attetnion!

TE's attention doesn't support enabling sliding window and CP together. I planned to switch to megatron's SDPA when CP enabled, but it doesn't support sliding window or CP. An alternative is to enable CP for global layers and disable CP for local layers, but it requires non-trivial changes and I need to think how to do that and won't be implemented in this PR. So CP can't be enabled for this model for now.

I won't remove those CP related code in gemma3vl, because they might be useful in the future if TE supports sliding window + CP. I added a check in gemma3 instead to raise failure when CP>1.

xiangxu-google avatar May 15 '25 19:05 xiangxu-google

Thanks @xiangxu-google:

TE's attention doesn't support enabling sliding window and CP together. Thats fine. Adding the assertion for checking not using CP is good. Thanks.

Thanks for your explanation on the BOS.

The PR looks good to me now. Thanks for all the efforts.

In order to have this merged, we will need to have the CICD passed. I am going to run this (might take up to a day), and since we are not really writing any test at this point, I am assuming it should pass. I'll add the test in a separate PR.

suiyoubi avatar May 15 '25 19:05 suiyoubi

@suiyoubi @xiangxu-google I had to merge main into this PR to resolve CI failures during the first run. Looks like DCO issues still existing on some commits. Please resolve the DCO issue.

chtruong814 avatar May 17 '25 12:05 chtruong814