maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Convert Orbax ckpt to HuggingFace

Open A9isha opened this issue 1 year ago • 5 comments

A9isha avatar Apr 09 '24 21:04 A9isha

Change LGTMs but without a test I'm skeptical. I don't think this needs to be tested as exhaustively but how could we test it somewhat?

@rwitten I have tested locally but were you thinking running this at a nightly cadence?

A9isha avatar Apr 11 '24 19:04 A9isha

@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?

thiagolaitz avatar May 06 '24 18:05 thiagolaitz

@A9isha Hello, do you by any chance have a script that does the opposite, converting HF to Orbax?

We have the script llama_or_mistral_ckpt.py to convert the original PyTorch Llama2 checkpoint that Meta provides into MaxText checkpoint.

You can see the usage here for Llama2-7b for e.g.

A9isha avatar May 06 '24 18:05 A9isha

Hi @A9isha , I found two bugs in your conversion code, and I have fixed it and validated the weights converted from maxtext version of llama3-8b with the HF one.

First one is the unpermute function is wrong, the original maxtext ckpt script used a step size of 2, the correct way is to stack the odd and even tensors and reshape it:

def unpermute_from_match_maxtext_rope(arr):
  """
  Function to get the RoPE values in correct ordering
  """
  split_size = arr.shape[-1] // 2  # Assuming half for evens, half for odds
  evens = arr[..., :split_size]
  odds = arr[..., split_size:]
  return jax.numpy.stack([evens, odds], axis=len(arr.shape)).reshape(arr.shape)

Second bug is related to Q and K, I understand it's easy to make mistakes here because both original LLaMA, LLaMA-HF and maxtext stored the tensor differently, the correct way is to do following by reversing first to original LLaMA weight then to HF weight:

    hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor(np.asarray(
        unpermute_from_match_maxtext_rope(
          reverse_scale(
            training_state.params['params']["decoder"]["layers"]["self_attention"]["query"]["kernel"][:, layer_int, :, :]
            ,head_dim
            )
        )),
        dtype=torch.bfloat16
    )
    hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"].view(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T.view(base_num_query_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1, base_num_query_heads * head_dim)

    hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(np.asarray(
        unpermute_from_match_maxtext_rope(
          training_state.params['params']["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :]
          )),
        dtype=torch.bfloat16
    )
    hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"].view(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T.reshape(base_num_kv_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1 ,base_num_query_heads * head_dim)

hxssgaa avatar May 10 '24 08:05 hxssgaa

I think this script is fine, and I have been using it quite a lot. It should be updated for Llama3.1 though (whenever that is merged). And maybe also the 70B models?

    """
    Load the model that we are interested in from HuggingFace
    """
    if model_size == "llama2-7b":
        model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
    elif model_size == "llama3-8b":
        model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
    elif model_size == "llama3.1-8b":
        model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
    elif model_size == "mistral-7b":
        model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
    else:
        raise NotImplementedError

    return model

Any chance this can be merged @A9isha ?

peregilk avatar Sep 08 '24 10:09 peregilk