Error when running a llama-2-70b-qkv-mlperf model
Hello,
I'm trying to run the llama-2-70b-qkvmodel on torchtune for the LoRA finetuning and using the script from 70B_lora.yaml in llama2 with the govet_scrolls dataset.
So, first I downloaded the model, then used the below script to convert the model to convert into pytorch bin:
Script:
from safetensors.torch import load_file
import torch
import os
def convert_safetensors_to_bin(input_dir, output_dir):
os.makedirs(output_dir, exist_ok=True)
safetensors_files = sorted([
f for f in os.listdir(input_dir)
if f.endswith(".safetensors") and "model" in f
])
for idx, filename in enumerate(safetensors_files):
full_path = os.path.join(input_dir, filename)
print(f"Loading {full_path}...")
# Load weights
weights = load_file(full_path)
# Output filename format like Hugging Face's: pytorch_model-00001-of-00015.bin
bin_filename = f"pytorch_model-{idx+1:05d}-of-{len(safetensors_files):05d}.bin"
torch.save(weights, os.path.join(output_dir, bin_filename))
print(f"Saved {bin_filename}.")
# Example usage
convert_safetensors_to_bin("safetensors_model_dir", "pytorch_bin_model_dir")
It does downlaod the pytorch bin and then I go 70B_lora.yaml to change the max_filename to 00029
But then without even trying to change the dataset from alpaca to govt scroll, I get an error like this:
[rank4]: The above exception was the direct cause of the following exception:
[rank4]: Traceback (most recent call last):
[rank4]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank4]: sys.exit(recipe_main())
[rank4]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank4]: sys.exit(recipe_main(conf))
[rank4]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 927, in recipe_main
[rank4]: recipe.setup(cfg=cfg)
[rank4]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 270, in setup
[rank4]: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
[rank4]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 204, in load_checkpoint
[rank4]: checkpoint_dict = self._checkpointer.load_checkpoint()
[rank4]: File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 667, in load_checkpoint
[rank4]: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune(
[rank4]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 152, in hf_to_tune
[rank4]: new_key = get_mapped_key(key, _FROM_HF)
[rank4]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank4]: raise Exception(
[rank4]: Exception: Error converting the state dict. Found unexpected key: "model.layers.0.self_attn.qkv_proj.weight". Please make sure you're loading a checkpoint with the right format.
[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 54, in get_mapped_key
[rank7]: new_key = mapping_dict[abstract_key]
[rank7]: KeyError: 'model.layers.{}.self_attn.qkv_proj.weight'
[rank7]: The above exception was the direct cause of the following exception:
[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank7]: sys.exit(recipe_main())
[rank7]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank7]: sys.exit(recipe_main(conf))
[rank7]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 927, in recipe_main
[rank7]: recipe.setup(cfg=cfg)
[rank7]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 270, in setup
[rank7]: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
[rank7]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 204, in load_checkpoint
[rank7]: checkpoint_dict = self._checkpointer.load_checkpoint()
[rank7]: File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 667, in load_checkpoint
[rank7]: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune(
[rank7]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 152, in hf_to_tune
[rank7]: new_key = get_mapped_key(key, _FROM_HF)
[rank7]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank7]: raise Exception(
[rank7]: Exception: Error converting the state dict. Found unexpected key: "model.layers.0.self_attn.qkv_proj.weight". Please make sure you're loading a checkpoint with the right format.
Can someone help me if I missed anything here?
Hey @kailashg26 - good question! The error here comes from the fact that the model utilizes a fused QKV while torchtune does not natively support this. You should take a look at how we convert weights in Phi3: https://github.com/pytorch/torchtune/blob/main/torchtune/models/phi3/_convert_weights.py.
To be even more precise, you should take the following steps:
- Write a function that looks very similar to the Phi3 one that converts the Llama model from it's original state dict to the torchtune-compatible state dict.
- If you're working from a fork of torchtune, simply modify the FullModelHFCheckpointer to use this conversion function when loading in a checkpoint for a LLAMA2_QKV model.
- Update your config to point to the original safetensor files and make sure that you specify it's a LLAMA2_QKV model
Hi @joecummings, Do you think is this the correct way to do it?
First, I create a convert weights file in torchtune/models/llama2/llama2_qkv.py with
from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key
_LLAMA2_QKV = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight", # q_proj will be base
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"lm_head.weight": "output.weight",
}
def llama2_qkv_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: Optional[int],
num_kv_heads: Optional[int],
dim: Optional[int],
) -> Dict[str, torch.Tensor]:
"""
Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
"""
converted_state_dict = {}
if dim is not None:
if num_heads is None or num_kv_heads is None:
raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
q_dim = dim
kv_dim = q_dim * num_kv_heads // num_heads
else:
q_dim = kv_dim = None
for key, value in state_dict.items():
new_key = get_mapped_key(key, _LLAMA2_QKV)
if "qkv" in key:
if q_dim is not None:
q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
else:
q, k, v = value.chunk(3, dim=0)
converted_state_dict[new_key] = q
converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
converted_state_dict[new_key.replace("q_proj", "v_proj")] = v
elif "gate_up_proj" in key:
gate, up = value.chunk(2, dim=0)
converted_state_dict[new_key] = gate
converted_state_dict[new_key.replace("w1", "w3")] = up
else:
converted_state_dict[new_key] = value
return converted_state_dict
In the torchtune/training/checkpointing/_checkpointer.py file, should I add this?
elif model_type == "LLAMA2_QKV":
from torchtune.models.convert_weights.llama2_qkv import llama2_qkv_hf_to_tune
state_dict = llama2_qkv_hf_to_tune(
state_dict,
num_heads=config["num_attention_heads"],
num_kv_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
dim=config["hidden_size"],
)
In the config.yaml file:
model_type: LLAMA2_QKV
checkpoint_path: /path/to/llama2-70b-fused-qkv-mlperf/*.safetensors
# Model config
num_attention_heads: 64
num_key_value_heads: 8
hidden_size: 8192
Please let me know if this is the correct way to do it? Thanks!
Also, @joecummings could you give some insights into how I should input the dataset, which has input IDs and labels?
Hi @joecummings, Do you think is this the correct way to do it?
First, I create a convert weights file in torchtune/models/llama2/llama2_qkv.py with
from typing import Dict, Optional import torch from torchtune.models.convert_weights import get_mapped_key _LLAMA2_QKV = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight", # q_proj will be base "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", "model.norm.weight": "norm.scale", "lm_head.weight": "output.weight", } def llama2_qkv_hf_to_tune( state_dict: Dict[str, torch.Tensor], num_heads: Optional[int], num_kv_heads: Optional[int], dim: Optional[int], ) -> Dict[str, torch.Tensor]: """ Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up. """ converted_state_dict = {} if dim is not None: if num_heads is None or num_kv_heads is None: raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.") q_dim = dim kv_dim = q_dim * num_kv_heads // num_heads else: q_dim = kv_dim = None for key, value in state_dict.items(): new_key = get_mapped_key(key, _LLAMA2_QKV) if "qkv" in key: if q_dim is not None: q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0) else: q, k, v = value.chunk(3, dim=0) converted_state_dict[new_key] = q converted_state_dict[new_key.replace("q_proj", "k_proj")] = k converted_state_dict[new_key.replace("q_proj", "v_proj")] = v elif "gate_up_proj" in key: gate, up = value.chunk(2, dim=0) converted_state_dict[new_key] = gate converted_state_dict[new_key.replace("w1", "w3")] = up else: converted_state_dict[new_key] = value return converted_state_dictIn the torchtune/training/checkpointing/_checkpointer.py file, should I add this?
elif model_type == "LLAMA2_QKV": from torchtune.models.convert_weights.llama2_qkv import llama2_qkv_hf_to_tune state_dict = llama2_qkv_hf_to_tune( state_dict, num_heads=config["num_attention_heads"], num_kv_heads=config.get("num_key_value_heads", config["num_attention_heads"]), dim=config["hidden_size"], )In the config.yaml file:
model_type: LLAMA2_QKV checkpoint_path: /path/to/llama2-70b-fused-qkv-mlperf/*.safetensors # Model config num_attention_heads: 64 num_key_value_heads: 8 hidden_size: 8192Please let me know if this is the correct way to do it? Thanks!
This looks right, but you don't need to specify the extra things like num_attention_heads in the config. Those will be picked up from the Hugging Face config automatically.
Also, @joecummings could you give some insights into how I should input the dataset, which has input IDs and labels?
Can you provide an example? By input IDs, do you mean they're already tokenized?
Hi @joecummings , this is the dataset https://huggingface.co/datasets/regisss/scrolls_gov_report_preprocessed_mlperf_2 I'm trying to use. Looks like this format is not supported right? Could you let me know how to use this datset with the llama2-qkv model?
Hi @joecummings , so I get some error when running llama-270B-qkv
llama2_qkv.py code in torchtune/models/llama2/
from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key
_LLAMA2_QKV = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"lm_head.weight": "output.weight",
}
def llama2_qkv_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: Optional[int],
num_kv_heads: Optional[int],
dim: Optional[int],
) -> Dict[str, torch.Tensor]:
"""
Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
"""
converted_state_dict = {}
if dim is not None:
if num_heads is None or num_kv_heads is None:
raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
q_dim = dim
kv_dim = q_dim * num_kv_heads // num_heads
else:
q_dim = kv_dim = None
for key, value in state_dict.items():
new_key = get_mapped_key(key, _LLAMA2_QKV)
if "qkv" in key:
if q_dim is not None:
q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
else:
q, k, v = value.chunk(3, dim=0)
converted_state_dict[new_key] = q
converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
converted_state_dict[new_key.replace("q_proj", "v_proj")] = v
elif "gate_up_proj" in key:
gate, up = value.chunk(2, dim=0)
converted_state_dict[new_key] = gate
converted_state_dict[new_key.replace("w1", "w3")] = up
else:
converted_state_dict[new_key] = value
return converted_state_dict
Error:
NFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Compiling loss with torch.compile...
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [05:57<00:00, 34.43s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 13.51 secs
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 57, in get_mapped_key
[rank0]: new_key = mapping_dict[key]
[rank0]: KeyError: 'tok_embeddings.weight'
[rank0]: The above exception was the direct cause of the following exception:
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank0]: sys.exit(recipe_main())
[rank0]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]: sys.exit(recipe_main(conf))
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank0]: recipe.train()
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank0]: self.save_checkpoint(epoch=curr_epoch)
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 737, in save_checkpoint
[rank0]: self._checkpointer.save_checkpoint(
[rank0]: File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 781, in save_checkpoint
[rank0]: state_dict = llama2_qkv_hf_to_tune(
[rank0]: File "/workspace/torchtune/torchtune/models/llama2/llama2_qkv.py", line 63, in llama2_qkv_hf_to_tune
[rank0]: new_key = get_mapped_key(key, _LLAMA2_QKV)
[rank0]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key
[rank0]: raise Exception(
[rank0]: Exception: Error converting the state dict. Found unexpected key: "tok_embeddings.weight". Please make sure you're loading a checkpoint with the right format.
1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [07:04<00:00, 42.48s/it]
[rank2]: Traceback (most recent call last):
[rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module>
[rank2]: sys.exit(recipe_main())
[rank2]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]: sys.exit(recipe_main(conf))
[rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main
[rank2]: recipe.train()
[rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train
[rank2]: self.save_checkpoint(epoch=curr_epoch)
[rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 745, in save_checkpoint
[rank2]: torch.distributed.barrier()
[rank2]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank2]: return func(*args, **kwargs)
[rank2]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4761, in barrier
[rank2]: work.wait()
Could you please let me know if I'm missing anything here?
Hi @joecummings , so I get some error when running llama-270B-qkv
llama2_qkv.py code in torchtune/models/llama2/
from typing import Dict, Optional import torch from torchtune.models.convert_weights import get_mapped_key _LLAMA2_QKV = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", "model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", "model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", "model.norm.weight": "norm.scale", "lm_head.weight": "output.weight", } def llama2_qkv_hf_to_tune( state_dict: Dict[str, torch.Tensor], num_heads: Optional[int], num_kv_heads: Optional[int], dim: Optional[int], ) -> Dict[str, torch.Tensor]: """ Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up. """ converted_state_dict = {} if dim is not None: if num_heads is None or num_kv_heads is None: raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.") q_dim = dim kv_dim = q_dim * num_kv_heads // num_heads else: q_dim = kv_dim = None for key, value in state_dict.items(): new_key = get_mapped_key(key, _LLAMA2_QKV) if "qkv" in key: if q_dim is not None: q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0) else: q, k, v = value.chunk(3, dim=0) converted_state_dict[new_key] = q converted_state_dict[new_key.replace("q_proj", "k_proj")] = k converted_state_dict[new_key.replace("q_proj", "v_proj")] = v elif "gate_up_proj" in key: gate, up = value.chunk(2, dim=0) converted_state_dict[new_key] = gate converted_state_dict[new_key.replace("w1", "w3")] = up else: converted_state_dict[new_key] = value return converted_state_dictError:
NFO:torchtune.utils._logging:Optimizer is initialized. INFO:torchtune.utils._logging:Compiling loss with torch.compile... INFO:torchtune.utils._logging:Loss is initialized. INFO:torchtune.utils._logging:Learning rate scheduler is initialized. WARNING:torchtune.utils._logging: Profiling disabled. INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False} 1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [05:57<00:00, 34.43s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict... INFO:torchtune.utils._logging:Getting full model state dict took 13.51 secs [rank0]: Traceback (most recent call last): [rank0]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 57, in get_mapped_key [rank0]: new_key = mapping_dict[key] [rank0]: KeyError: 'tok_embeddings.weight' [rank0]: The above exception was the direct cause of the following exception: [rank0]: Traceback (most recent call last): [rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module> [rank0]: sys.exit(recipe_main()) [rank0]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper [rank0]: sys.exit(recipe_main(conf)) [rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main [rank0]: recipe.train() [rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train [rank0]: self.save_checkpoint(epoch=curr_epoch) [rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 737, in save_checkpoint [rank0]: self._checkpointer.save_checkpoint( [rank0]: File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 781, in save_checkpoint [rank0]: state_dict = llama2_qkv_hf_to_tune( [rank0]: File "/workspace/torchtune/torchtune/models/llama2/llama2_qkv.py", line 63, in llama2_qkv_hf_to_tune [rank0]: new_key = get_mapped_key(key, _LLAMA2_QKV) [rank0]: File "/workspace/torchtune/torchtune/models/convert_weights.py", line 59, in get_mapped_key [rank0]: raise Exception( [rank0]: Exception: Error converting the state dict. Found unexpected key: "tok_embeddings.weight". Please make sure you're loading a checkpoint with the right format. 1|10|Loss: 4.511970043182373: 100%|██████████| 10/10 [07:04<00:00, 42.48s/it] [rank2]: Traceback (most recent call last): [rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 933, in <module> [rank2]: sys.exit(recipe_main()) [rank2]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper [rank2]: sys.exit(recipe_main(conf)) [rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 928, in recipe_main [rank2]: recipe.train() [rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 894, in train [rank2]: self.save_checkpoint(epoch=curr_epoch) [rank2]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 745, in save_checkpoint [rank2]: torch.distributed.barrier() [rank2]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper [rank2]: return func(*args, **kwargs) [rank2]: File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4761, in barrier [rank2]: work.wait()Could you please let me know if I'm missing anything here?
Ahh you'll want to also include a function for tune_to_hf that does the opposite mapping. This isn't strictly necessary, but it will allow you to use the model on the Hugging Face Hub as expected in the very same format as the Llama2 QKV.
Hi @joecummings , this is the dataset huggingface.co/datasets/regisss/scrolls_gov_report_preprocessed_mlperf_2 I'm trying to use. Looks like this format is not supported right? Could you let me know how to use this datset with the llama2-qkv model?
This is a good question - it looks like your data is already tokenized! We typically work with datasets that require tokenization, but never fear, this will still work with torchtune.
To do this, I would create a new simple dataset.
class PreTokenizedDataset(Dataset):
def __init__(self, source):
self._data = load_dataset(source)
def __len__(self):
return len(self._data)
def __getitem__(self, index):
sample = self._data[index]
return {"tokens": sample["input_ids"], "labels": sample["labels"]}
Then you can reference this PreTokenizedDataset from your config with the dataset you want!
Hi @joecummings @felipemello1 @ebsmothers
I get this error when I try the opposite mapping:
Error:
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 1005, in <module>
[rank0]: sys.exit(recipe_main())
[rank0]: File "/workspace/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]: sys.exit(recipe_main(conf))
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 1000, in recipe_main
[rank0]: recipe.train()
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 902, in train
[rank0]: self.save_checkpoint(epoch=curr_epoch)
[rank0]: File "/workspace/torchtune/recipes/lora_finetune_distributed.py", line 758, in save_checkpoint
[rank0]: self._checkpointer.save_checkpoint(
[rank0]: File "/workspace/torchtune/torchtune/training/checkpointing/_checkpointer.py", line 839, in save_checkpoint
[rank0]: cpt_idx = self._weight_map[key]
[rank0]: KeyError: 'model.layers.0.mlp.gate_up_proj.weight'
Code: llama2_qkv.py code in torchtune/models/llama2/
from typing import Dict, Optional
import torch
from torchtune.models.convert_weights import get_mapped_key
_LLAMA2_QKV = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.self_attn.qkv_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.mlp.gate_up_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"model.layers.0.mlp.gate_up_proj.weight": "layers.0.mlp.w1.weight",
"lm_head.weight": "output.weight",
}
def llama2_qkv_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: Optional[int],
num_kv_heads: Optional[int],
dim: Optional[int],
) -> Dict[str, torch.Tensor]:
"""
Convertor from HF state dict to Torchtune state dict for LLaMA2 models with fused QKV and gate_up.
"""
converted_state_dict = {}
if dim is not None:
if num_heads is None or num_kv_heads is None:
raise ValueError("LLaMA2 with GQA requires dim, num_heads and num_kv_heads.")
q_dim = dim
kv_dim = q_dim * num_kv_heads // num_heads
else:
q_dim = kv_dim = None
for key, value in state_dict.items():
new_key = get_mapped_key(key, _LLAMA2_QKV)
if "qkv" in key:
if q_dim is not None:
q, k, v = torch.split(value, [q_dim, kv_dim, kv_dim], dim=0)
else:
q, k, v = value.chunk(3, dim=0)
converted_state_dict[new_key] = q
converted_state_dict[new_key.replace("q_proj", "k_proj")] = k
converted_state_dict[new_key.replace("q_proj", "v_proj")] = v
elif "gate_up_proj" in key:
gate, up = value.chunk(2, dim=0)
converted_state_dict[new_key] = gate
converted_state_dict[new_key.replace("w1", "w3")] = up
else:
converted_state_dict[new_key] = value
return converted_state_dict
def tune_to_hf_llama2_qkv(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convertor from torchtune state dict to HF state dict. This handles:
- Fusing q,k and v matrix
- Fusing gate and up projection matrix
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _LLAMA2_QKV.items()}
for key, value in state_dict.items():
if "k_proj" in key or "v_proj" in key or "w3" in key:
# these keys are accounted for separately and should be skipped
continue
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
q = value
k = state_dict[key.replace("q_proj", "k_proj")]
v = state_dict[key.replace("q_proj", "v_proj")]
qkv = torch.cat([q, k, v], dim=0)
# q_proj maps to qkv_proj; no need to string replace
converted_state_dict[new_key] = qkv
elif "w1" in key:
gate_proj = value
up_proj = state_dict[key.replace("w1", "w3")]
gate_up_proj = torch.cat([gate_proj, up_proj], dim=0)
# w1 maps to gate_up_proj; no need to string replace
converted_state_dict[new_key] = gate_up_proj
else:
converted_state_dict[new_key] = value
return converted_state_dict