Convert Orbax ckpt to HuggingFace
Change LGTMs but without a test I'm skeptical. I don't think this needs to be tested as exhaustively but how could we test it somewhat?
@rwitten I have tested locally but were you thinking running this at a nightly cadence?
@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?
@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?
We have the script llama_or_mistral_ckpt.py to convert the original PyTorch Llama2 checkpoint that Meta provides into MaxText checkpoint.
You can see the usage here for Llama2-7b for e.g.
Hi @A9isha , I found two bugs in your conversion code, and I have fixed it and validated the weights converted from maxtext version of llama3-8b with the HF one.
First one is the unpermute function is wrong, the original maxtext ckpt script used a step size of 2, the correct way is to stack the odd and even tensors and reshape it:
def unpermute_from_match_maxtext_rope(arr):
"""
Function to get the RoPE values in correct ordering
"""
split_size = arr.shape[-1] // 2 # Assuming half for evens, half for odds
evens = arr[..., :split_size]
odds = arr[..., split_size:]
return jax.numpy.stack([evens, odds], axis=len(arr.shape)).reshape(arr.shape)
Second bug is related to Q and K, I understand it's easy to make mistakes here because both original LLaMA, LLaMA-HF and maxtext stored the tensor differently, the correct way is to do following by reversing first to original LLaMA weight then to HF weight:
hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor(np.asarray(
unpermute_from_match_maxtext_rope(
reverse_scale(
training_state.params['params']["decoder"]["layers"]["self_attention"]["query"]["kernel"][:, layer_int, :, :]
,head_dim
)
)),
dtype=torch.bfloat16
)
hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"].view(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T.view(base_num_query_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1, base_num_query_heads * head_dim)
hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray(
unpermute_from_match_maxtext_rope(
training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :]
)),
dtype=torch.bfloat16
)
hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"].view(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T.reshape(base_num_kv_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1 ,base_num_query_heads * head_dim)
I think this script is fine, and I have been using it quite a lot. It should be updated for Llama3.1 though (whenever that is merged). And maybe also the 70B models?
"""
Load the model that we are interested in from HuggingFace
"""
if model_size == "llama2-7b":
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
elif model_size == "llama3-8b":
model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
elif model_size == "llama3.1-8b":
model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
elif model_size == "mistral-7b":
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
else:
raise NotImplementedError
return model
Any chance this can be merged @A9isha ?