maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Convert Gemma 2 to HuggingFace

Open peregilk opened this issue 1 year ago • 14 comments

Are there any scripts for converting Gemma-2 models to HuggingFace? I see there are Llama and Mistral scripts.

peregilk avatar Feb 28 '25 18:02 peregilk

You can check my commit here for the conversion script here

hxssgaa avatar Mar 01 '25 16:03 hxssgaa

That is fantastic. Thanks.

peregilk avatar Mar 01 '25 17:03 peregilk

@hxssgaa I made a quick test of the script, trying to convert a 2B Gemma2 model. However, I am seeing this error: ValueError: Requested shape: (2048,) is not compatible with the stored shape: (2304,). Truncating/padding is disabled by setting of strict=True. When using standard Orbax APIs, this behavior can be modified by specifying strict=FalseinArrayRestoreArgs for any array in which padding/truncation is desired.

peregilk avatar Mar 06 '25 18:03 peregilk

@hxssgaa I understand this is because it uses the settings from the base.yml-file. However, it was not obvious to my how to get the script to either rely on the structure from the loaded model, og on the model-yml files.

I also see the script refers to convert_maxtext_to_hf.py. Is that a helper file?

peregilk avatar Mar 07 '25 10:03 peregilk

@hxssgaa I made a quick test of the script, trying to convert a 2B Gemma2 model. However, I am seeing this error: ValueError: Requested shape: (2048,) is not compatible with the stored shape: (2304,). Truncating/padding is disabled by setting of strict=True. When using standard Orbax APIs, this behavior can be modified by specifying strict=FalseinArrayRestoreArgs for any array in which padding/truncation is desired.

Hi @peregilk , I just did another test for conversion script of gemma2-2b, and didn't find the issue you are getting. The converted checkpoint exactly matches with official huggingface gemma2-2b-it. Please use the correct yml setting for conversion, your script should look like:

JAX_PLATFORMS=cpu python MaxText/gemma2_orbax_to_hf.py MaxText/configs/base.yml \
        base_output_directory=/tmp/output \
        load_parameters_path=/path/to/maxtext/checkpoint \
        model_name='gemma2-2b' \
        hf_model_path=/path/to/save/hf_model.bin \
        model_size=2b

@hxssgaa I understand this is because it uses the settings from the base.yml-file. However, it was not obvious to my how to get the script to either rely on the structure from the loaded model, og on the model-yml files.

I also see the script refers to convert_maxtext_to_hf.py. Is that a helper file?

It's a typo, I already fixed it in the latest commit, it should be gemma2_orbax_to_hf.py instead.

hxssgaa avatar Mar 07 '25 10:03 hxssgaa

@hxssgaa Thanks for answering me, and sorry for posing stupid questions here. Do you first save/convert the checkpoint locally to disk first?

Or can /path/to/maxtext/checkpoint be the bucket where the trained checkpoints are stored, ie 'gs://mybucket/gemma2-2B-instruct-myfinetunedmodel1/checkpoints/0/items'.

I still dont think the example command is exactly correct, but if this is stored locally and does not require a specific yml-file, this is probably just a typo.

peregilk avatar Mar 07 '25 11:03 peregilk

@peregilk, no need to save the ckpt locally, you can just point the maxtext_checkpoint to the google bucket checkpoint location. Sorry tor the confusion here, I think I have changed the ckpt conversion format to be similar as llama_or_mistral_orbax_to_huggingface.py, the correct conversion script should be:

JAX_PLATFORMS=cpu python MaxText/gemma2_orbax_to_hf.py MaxText/configs/base.yml
base_output_directory=/tmp/output
load_parameters_path=/path/to/maxtext/checkpoint
model_name='gemma2-27b'
hf_model_path=/path/to/save/hf_model.bin
model_size=27b

hxssgaa avatar Mar 07 '25 12:03 hxssgaa

Awesome @hxssgaa. I actually tried something similar but I think there was a small typo in my script earlier forcing it to not pick up the correct yaml.

However, now it works. I can also confirm that I have tried one model "all the way". I can confirm that I get exactly the same MMLU scores on the original google/gemma2-2b-it that I get when I run the test on a model that is converted from Kaggle/Flax, stored as checkpoint in MaxText and converted to HF with the gemma2_orbax_to_hf.py-script.

peregilk avatar Mar 07 '25 16:03 peregilk

Has anyone yet tried converting gemma 3 (4b) to huggingface?

I have now done gemma 3 model from Kaggle --> maxtext format (orbax) --> continued pretraining --> (Would now like to convert to hf but seems there is now script available and I am trying to do it myself but no luck yet)

R4ZZ3 avatar Apr 12 '25 09:04 R4ZZ3

@hxssgaa any chance to develop similar code to convert Gemma 3 checkpoint to HF ?

salrowili avatar Apr 25 '25 10:04 salrowili

I created such conversion script based on this https://github.com/AI-Hypercomputer/maxtext/blob/f6ebc1662cb944bd7748fb350bba164b13479b68/MaxText/gemma2_orbax_to_hf.py and bunch of trial and error with gemini 2.5 pro in Cursor.

I was able to then run some benchmarks with the converted model + tested that the model would start GRPO finetuning with Unsloth. I can share the script once maybe today evening when I am finished with work

R4ZZ3 avatar Apr 25 '25 11:04 R4ZZ3

Great @R4ZZ3 I will also test the code once you share it and get back to you with my findings.

salrowili avatar Apr 25 '25 12:04 salrowili

Hi @salrowili

The file can now be found here: https://github.com/R4ZZ3/gemma_3_orbax_to_hf/blob/main/convert_gemma_3_orbax_to_hf.py

R4ZZ3 avatar Apr 25 '25 19:04 R4ZZ3

@gagika can you please take a look

shralex avatar May 01 '25 13:05 shralex

This has been completed by @YixuanWang-99 please see https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh

shralex avatar Aug 23 '25 18:08 shralex