torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Error when running a llama-2-70b-qkv-mlperf model

Open kailashg26 opened this issue 10 months ago • 10 comments

Hello,

I'm trying to run the llama-2-70b-qkvmodel on torchtune for the LoRA finetuning and using the script from 70B_lora.yaml in llama2 with the govet_scrolls dataset.

So, first I downloaded the model, then used the below script to convert the model to convert into pytorch bin:

Script:

from safetensors.torch import load_file
import torch
import os

def convert_safetensors_to_bin(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    safetensors_files = sorted([
        f for f in os.listdir(input_dir)
        if f.endswith(".safetensors") and "model" in f
    ])

    for idx, filename in enumerate(safetensors_files):
        full_path = os.path.join(input_dir, filename)
        print(f"Loading {full_path}...")

        # Load weights
        weights = load_file(full_path)

        # Output filename format like Hugging Face's: pytorch_model-00001-of-00015.bin
        bin_filename = f"pytorch_model-{idx+1:05d}-of-{len(safetensors_files):05d}.bin"
        torch.save(weights, os.path.join(output_dir, bin_filename))
        print(f"Saved {bin_filename}.")

# Example usage
convert_safetensors_to_bin("safetensors_model_dir", "pytorch_bin_model_dir")

It does downlaod the pytorch bin and then I go 70B_lora.yaml to change the max_filename to 00029

But then without even trying to change the dataset from alpaca to govt scroll, I get an error like this:

[rank4]: The above exception was the direct cause of the following exception:

[rank4]: Traceback (most recent call last):
[rank4]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank4]:     sys.exit(recipe_main())
[rank4]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank4]:     sys.exit(recipe_main(conf))
[rank4]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 927, in recipe_main
[rank4]:     recipe.setup(cfg=cfg)
[rank4]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 270, in setup
[rank4]:     checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
[rank4]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 204, in load_checkpoint
[rank4]:     checkpoint_dict = self._checkpointer.load_checkpoint()
[rank4]:   File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 667, in load_checkpoint
[rank4]:     converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune(
[rank4]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 152, in hf_to_tune
[rank4]:     new_key = get_mapped_key(key, _FROM_HF)
[rank4]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank4]:     raise Exception(
[rank4]: Exception: Error converting the state dict. Found unexpected key: "model.layers.0.self_attn.qkv_proj.weight". Please make sure you're loading a checkpoint with the right format.
[rank7]: Traceback (most recent call last):
[rank7]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 54, in get_mapped_key
[rank7]:     new_key = mapping_dict[abstract_key]
[rank7]: KeyError: 'model.layers.{}.self_attn.qkv_proj.weight'

[rank7]: The above exception was the direct cause of the following exception:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank7]:     sys.exit(recipe_main())
[rank7]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank7]:     sys.exit(recipe_main(conf))
[rank7]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 927, in recipe_main
[rank7]:     recipe.setup(cfg=cfg)
[rank7]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 270, in setup
[rank7]:     checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
[rank7]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 204, in load_checkpoint
[rank7]:     checkpoint_dict = self._checkpointer.load_checkpoint()
[rank7]:   File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 667, in load_checkpoint
[rank7]:     converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune(
[rank7]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 152, in hf_to_tune
[rank7]:     new_key = get_mapped_key(key, _FROM_HF)
[rank7]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank7]:     raise Exception(
[rank7]: Exception: Error converting the state dict. Found unexpected key: "model.layers.0.self_attn.qkv_proj.weight". Please make sure you're loading a checkpoint with the right format.

Can someone help me if I missed anything here?

kailashg26 avatar Apr 19 '25 00:04 kailashg26

Hey @kailashg26 - good question! The error here comes from the fact that the model utilizes a fused QKV while torchtune does not natively support this. You should take a look at how we convert weights in Phi3: https://github.com/pytorch/torchtune/blob/main/torchtune/models/phi3/_convert_weights.py.

To be even more precise, you should take the following steps:

  1. Write a function that looks very similar to the Phi3 one that converts the Llama model from it's original state dict to the torchtune-compatible state dict.
  2. If you're working from a fork of torchtune, simply modify the FullModelHFCheckpointer to use this conversion function when loading in a checkpoint for a LLAMA2_QKV model.
  3. Update your config to point to the original safetensor files and make sure that you specify it's a LLAMA2_QKV model

joecummings avatar Apr 19 '25 19:04 joecummings

Hi @joecummings, Do you think is this the correct way to do it?

First, I create a convert weights file in torchtune/models/llama2/llama2_qkv.py with

from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key

_LLAMA2_QKV = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",  # q_proj will be base
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "lm_head.weight": "output.weight",
}


def llama2_qkv_hf_to_tune(
    state_dict: Dict[str, torch.Tensor],
    num_heads: Optional[int],
    num_kv_heads: Optional[int],
    dim: Optional[int],
) -> Dict[str, torch.Tensor]:
    """
    Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
    """
    converted_state_dict = {}

    if dim is not None:
        if num_heads is None or num_kv_heads is None:
            raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
        q_dim = dim
        kv_dim = q_dim * num_kv_heads // num_heads
    else:
        q_dim = kv_dim = None

    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _LLAMA2_QKV)

        if "qkv" in key:
            if q_dim is not None:
                q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
            else:
                q, k, v = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v

        elif "gate_up_proj" in key:
            gate, up = value.chunk(2, dim=0)
            converted_state_dict[new_key] = gate
            converted_state_dict[new_key.replace("w1", "w3")] = up

        else:
            converted_state_dict[new_key] = value

    return converted_state_dict

In the torchtune/training/checkpointing/_checkpointer.py file, should I add this?


elif model_type == "LLAMA2_QKV":
    from torchtune.models.convert_weights.llama2_qkv import llama2_qkv_hf_to_tune
    state_dict = llama2_qkv_hf_to_tune(
        state_dict,
        num_heads=config["num_attention_heads"],
        num_kv_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
        dim=config["hidden_size"],
    )

In the config.yaml file:

model_type: LLAMA2_QKV
checkpoint_path: /path/to/llama2-70b-fused-qkv-mlperf/*.safetensors

# Model config
num_attention_heads: 64
num_key_value_heads: 8
hidden_size: 8192

Please let me know if this is the correct way to do it? Thanks!

kailashg26 avatar Apr 19 '25 22:04 kailashg26

Also, @joecummings could you give some insights into how I should input the dataset, which has input IDs and labels?

kailashg26 avatar Apr 19 '25 22:04 kailashg26

Hi @joecummings, Do you think is this the correct way to do it?

First, I create a convert weights file in torchtune/models/llama2/llama2_qkv.py with

from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key

_LLAMA2_QKV = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",  # q_proj will be base
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "lm_head.weight": "output.weight",
}


def llama2_qkv_hf_to_tune(
    state_dict: Dict[str, torch.Tensor],
    num_heads: Optional[int],
    num_kv_heads: Optional[int],
    dim: Optional[int],
) -> Dict[str, torch.Tensor]:
    """
    Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
    """
    converted_state_dict = {}

    if dim is not None:
        if num_heads is None or num_kv_heads is None:
            raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
        q_dim = dim
        kv_dim = q_dim * num_kv_heads // num_heads
    else:
        q_dim = kv_dim = None

    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _LLAMA2_QKV)

        if "qkv" in key:
            if q_dim is not None:
                q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
            else:
                q, k, v = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v

        elif "gate_up_proj" in key:
            gate, up = value.chunk(2, dim=0)
            converted_state_dict[new_key] = gate
            converted_state_dict[new_key.replace("w1", "w3")] = up

        else:
            converted_state_dict[new_key] = value

    return converted_state_dict

In the torchtune/training/checkpointing/_checkpointer.py file, should I add this?


elif model_type == "LLAMA2_QKV":
    from torchtune.models.convert_weights.llama2_qkv import llama2_qkv_hf_to_tune
    state_dict = llama2_qkv_hf_to_tune(
        state_dict,
        num_heads=config["num_attention_heads"],
        num_kv_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
        dim=config["hidden_size"],
    )

In the config.yaml file:

model_type: LLAMA2_QKV
checkpoint_path: /path/to/llama2-70b-fused-qkv-mlperf/*.safetensors

# Model config
num_attention_heads: 64
num_key_value_heads: 8
hidden_size: 8192

Please let me know if this is the correct way to do it? Thanks!

This looks right, but you don't need to specify the extra things like num_attention_heads in the config. Those will be picked up from the Hugging Face config automatically.

joecummings avatar Apr 21 '25 15:04 joecummings

Also, @joecummings could you give some insights into how I should input the dataset, which has input IDs and labels?

Can you provide an example? By input IDs, do you mean they're already tokenized?

joecummings avatar Apr 21 '25 15:04 joecummings

Hi @joecummings , this is the dataset https://huggingface.co/datasets/regisss/scrolls_gov_report_preprocessed_mlperf_2 I'm trying to use. Looks like this format is not supported right? Could you let me know how to use this datset with the llama2-qkv model?

kailashg26 avatar Apr 21 '25 15:04 kailashg26

Hi @joecummings , so I get some error when running llama-270B-qkv

llama2_qkv.py code in torchtune/models/llama2/

from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key


_LLAMA2_QKV = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "lm_head.weight": "output.weight",
}



def llama2_qkv_hf_to_tune(
    state_dict: Dict[str, torch.Tensor],
    num_heads: Optional[int],
    num_kv_heads: Optional[int],
    dim: Optional[int],
) -> Dict[str, torch.Tensor]:
    """
    Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
    """
    converted_state_dict = {}

    if dim is not None:
        if num_heads is None or num_kv_heads is None:
            raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
        q_dim = dim
        kv_dim = q_dim * num_kv_heads // num_heads
    else:
        q_dim = kv_dim = None

    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _LLAMA2_QKV)

        if "qkv" in key:
            if q_dim is not None:
                q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
            else:
                q, k, v = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v

        elif "gate_up_proj" in key:
            gate, up = value.chunk(2, dim=0)
            converted_state_dict[new_key] = gate
            converted_state_dict[new_key.replace("w1", "w3")] = up

        else:
            converted_state_dict[new_key] = value

    return converted_state_dict

Error:

NFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Compiling loss with torch.compile...
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [05:57<00:00, 34.43s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 13.51 secs
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 57, in get_mapped_key
[rank0]:     new_key = mapping_dict[key]
[rank0]: KeyError: 'tok_embeddings.weight'

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank0]:     self.save_checkpoint(epoch=curr_epoch)
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 737, in save_checkpoint
[rank0]:     self._checkpointer.save_checkpoint(
[rank0]:   File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 781, in save_checkpoint
[rank0]:     state_dict = llama2_qkv_hf_to_tune(
[rank0]:   File "/workspace/torchtune/torchtune/models/llama2/llama2_qkv.py", line 63, in llama2_qkv_hf_to_tune
[rank0]:     new_key = get_mapped_key(key, _LLAMA2_QKV)
[rank0]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank0]:     raise Exception(
[rank0]: Exception: Error converting the state dict. Found unexpected key: "tok_embeddings.weight". Please make sure you're loading a checkpoint with the right format.
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [07:04<00:00, 42.48s/it]
[rank2]: Traceback (most recent call last):
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank2]:     self.save_checkpoint(epoch=curr_epoch)
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 745, in save_checkpoint
[rank2]:     torch.distributed.barrier()
[rank2]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4761, in barrier
[rank2]:     work.wait()

Could you please let me know if I'm missing anything here?

kailashg26 avatar Apr 21 '25 19:04 kailashg26

Hi @joecummings , so I get some error when running llama-270B-qkv

llama2_qkv.py code in torchtune/models/llama2/

from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key


_LLAMA2_QKV = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "lm_head.weight": "output.weight",
}



def llama2_qkv_hf_to_tune(
    state_dict: Dict[str, torch.Tensor],
    num_heads: Optional[int],
    num_kv_heads: Optional[int],
    dim: Optional[int],
) -> Dict[str, torch.Tensor]:
    """
    Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
    """
    converted_state_dict = {}

    if dim is not None:
        if num_heads is None or num_kv_heads is None:
            raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
        q_dim = dim
        kv_dim = q_dim * num_kv_heads // num_heads
    else:
        q_dim = kv_dim = None

    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _LLAMA2_QKV)

        if "qkv" in key:
            if q_dim is not None:
                q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
            else:
                q, k, v = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v

        elif "gate_up_proj" in key:
            gate, up = value.chunk(2, dim=0)
            converted_state_dict[new_key] = gate
            converted_state_dict[new_key.replace("w1", "w3")] = up

        else:
            converted_state_dict[new_key] = value

    return converted_state_dict

Error:

NFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Compiling loss with torch.compile...
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [05:57<00:00, 34.43s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 13.51 secs
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 57, in get_mapped_key
[rank0]:     new_key = mapping_dict[key]
[rank0]: KeyError: 'tok_embeddings.weight'

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank0]:     self.save_checkpoint(epoch=curr_epoch)
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 737, in save_checkpoint
[rank0]:     self._checkpointer.save_checkpoint(
[rank0]:   File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 781, in save_checkpoint
[rank0]:     state_dict = llama2_qkv_hf_to_tune(
[rank0]:   File "/workspace/torchtune/torchtune/models/llama2/llama2_qkv.py", line 63, in llama2_qkv_hf_to_tune
[rank0]:     new_key = get_mapped_key(key, _LLAMA2_QKV)
[rank0]:   File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank0]:     raise Exception(
[rank0]: Exception: Error converting the state dict. Found unexpected key: "tok_embeddings.weight". Please make sure you're loading a checkpoint with the right format.
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [07:04<00:00, 42.48s/it]
[rank2]: Traceback (most recent call last):
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank2]:     self.save_checkpoint(epoch=curr_epoch)
[rank2]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 745, in save_checkpoint
[rank2]:     torch.distributed.barrier()
[rank2]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4761, in barrier
[rank2]:     work.wait()

Could you please let me know if I'm missing anything here?

Ahh you'll want to also include a function for tune_to_hf that does the opposite mapping. This isn't strictly necessary, but it will allow you to use the model on the Hugging Face Hub as expected in the very same format as the Llama2 QKV.

joecummings avatar Apr 22 '25 16:04 joecummings

Hi @joecummings , this is the dataset huggingface.co/datasets/regisss/scrolls_gov_report_preprocessed_mlperf_2 I'm trying to use. Looks like this format is not supported right? Could you let me know how to use this datset with the llama2-qkv model?

This is a good question - it looks like your data is already tokenized! We typically work with datasets that require tokenization, but never fear, this will still work with torchtune.

To do this, I would create a new simple dataset.


class PreTokenizedDataset(Dataset):
    def __init__(self, source):
        self._data = load_dataset(source)

    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, index):
        sample = self._data[index]
        return {"tokens": sample["input_ids"], "labels": sample["labels"]}

Then you can reference this PreTokenizedDataset from your config with the dataset you want!

joecummings avatar Apr 22 '25 16:04 joecummings

Hi @joecummings @felipemello1 @ebsmothers

I get this error when I try the opposite mapping:

Error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 1005, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:   File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 1000, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 902, in train
[rank0]:     self.save_checkpoint(epoch=curr_epoch)
[rank0]:   File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 758, in save_checkpoint
[rank0]:     self._checkpointer.save_checkpoint(
[rank0]:   File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 839, in save_checkpoint
[rank0]:     cpt_idx = self._weight_map[key]
[rank0]: KeyError: 'model.layers.0.mlp.gate_up_proj.weight'

Code: llama2_qkv.py code in torchtune/models/llama2/

from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key


_LLAMA2_QKV = {
    "model.embed_tokens.weight": "tok_embeddings.weight",
    "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
    "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
    "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
    "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
    "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
    "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
    "model.norm.weight": "norm.scale",
    "model.layers.0.mlp.gate_up_proj.weight": "layers.0.mlp.w1.weight",
    "lm_head.weight": "output.weight",
}



def llama2_qkv_hf_to_tune(
    state_dict: Dict[str, torch.Tensor],
    num_heads: Optional[int],
    num_kv_heads: Optional[int],
    dim: Optional[int],
) -> Dict[str, torch.Tensor]:
    """
    Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
    """
    converted_state_dict = {}

    if dim is not None:
        if num_heads is None or num_kv_heads is None:
            raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
        q_dim = dim
        kv_dim = q_dim * num_kv_heads // num_heads
    else:
        q_dim = kv_dim = None

    for key, value in state_dict.items():
        new_key = get_mapped_key(key, _LLAMA2_QKV)

        if "qkv" in key:
            if q_dim is not None:
                q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
            else:
                q, k, v = value.chunk(3, dim=0)
            converted_state_dict[new_key] = q
            converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
            converted_state_dict[new_key.replace("q_proj", "v_proj")] = v

        elif "gate_up_proj" in key:
            gate, up = value.chunk(2, dim=0)
            converted_state_dict[new_key] = gate
            converted_state_dict[new_key.replace("w1", "w3")] = up

        else:
            converted_state_dict[new_key] = value

    return converted_state_dict

def tune_to_hf_llama2_qkv(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """
    Convertor from torchtune state dict to HF state dict. This handles:
    - Fusing q,k and v matrix
    - Fusing gate and up projection matrix
    """
    converted_state_dict = {}
    inverted_mapping_dict = {v: k for k, v in _LLAMA2_QKV.items()}

    for key, value in state_dict.items():
        if "k_proj" in key or "v_proj" in key or "w3" in key:
            # these keys are accounted for separately and should be skipped
            continue
        new_key = get_mapped_key(key, inverted_mapping_dict)

        if "q_proj" in key:
            q = value
            k = state_dict[key.replace("q_proj", "k_proj")]
            v = state_dict[key.replace("q_proj", "v_proj")]
            qkv = torch.cat([q, k, v], dim=0)
            # q_proj maps to qkv_proj; no need to string replace
            converted_state_dict[new_key] = qkv

        elif "w1" in key:
            gate_proj = value
            up_proj = state_dict[key.replace("w1", "w3")]
            gate_up_proj = torch.cat([gate_proj, up_proj], dim=0)
            # w1 maps to gate_up_proj; no need to string replace
            converted_state_dict[new_key] = gate_up_proj

        else:
            converted_state_dict[new_key] = value
    return converted_state_dict

kailashg26 avatar May 01 '25 01:05 kailashg26