torchtitan
torchtitan copied to clipboard
RoPE implementation differences
I've been working with the pretrained Llama 3 weights, and found out that the RoPE implementation here does not match the one found in other places. The difference is whether you treat sequential entries of the embeddings as (real, imaginary), or you treat the first half as real, and the second half as imaginary.
The current torchtitan implementation uses the former, while both Transformers and llama.cpp for example use the latter. This also seems to mean that loading weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B does not work. I've verified numerically that you need to use the latter RoPE implementation to get correct results with existing weights. I'm slightly worried that I'm doing something wrong, but perhaps someone else can verify? I can post some code if that helps.
Here's a small change to apply_rotary_emb
which can be used to make it match the cos/sin implementation numerically.
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# first half is real, second half is imaginary
xq_ = torch.complex(xq[..., :xq.shape[-1] // 2].float(), xq[..., xq.shape[-1] // 2:].float())
xk_ = torch.complex(xk[..., :xk.shape[-1] // 2].float(), xk[..., xk.shape[-1] // 2:].float())
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
# added this
xq_out = torch.cat([xq_out[..., ::2], xq_out[..., 1::2]], dim=-1)
xk_out = torch.cat([xk_out[..., ::2], xk_out[..., 1::2]], dim=-1)
return xq_out.type_as(xq), xk_out.type_as(xk)
Hi @rlrs ! Could you share the script to transform the weights from HF to dcp? Thanks!
Hi @rlrs, thanks for bringing up the concern!
We are using the same definition as in llama3 code https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65 Would you provide more details on how you verified the loaded weights to be wrong / correct?
Hi @rlrs ! Could you share the script to transform the weights from HF to dcp? Thanks!
I'm using a modified script based on gpt-fast, will paste it here.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from maester.models import models_config
@torch.inference_mode()
def convert_hf_checkpoint(
*,
model_name: str,
variant: str,
checkpoint_dir: Path,
output_dir: Path,
) -> None:
if model_name is None:
model_name = checkpoint_dir.name
config = models_config[model_name][variant]
print(f"Model config {config.__dict__}")
# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "model.safetensors.index.json"
assert model_map_json.is_file()
with open(model_map_json) as json_map:
bin_index = json.load(json_map)
weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
merged_result = {}
for file in sorted(bin_files):
with safe_open(file, framework="pt", device="cpu") as f:
for k in f.keys():
merged_result[k] = f.get_tensor(k)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]
final_result[new_key] = value
output_dir.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
DCP.save({"model": final_result},
storage_writer=storage_writer)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
parser.add_argument('--checkpoint', type=Path, required=True)
parser.add_argument('--output', type=Path, required=True)
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--variant', type=str, required=True)
args = parser.parse_args()
convert_hf_checkpoint(
checkpoint_dir=args.checkpoint,
output_dir=args.output,
model_name=args.model,
variant=args.variant,
)
Hi @rlrs, thanks for bringing up the concern!
We are using the same definition as in llama3 code https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65 Would you provide more details on how you verified the loaded weights to be wrong / correct?
After the conversion, I verified that all loaded weights match against HF transformers layer by layer. I also verified the input/output matches against HF transformers layer by layer (the only difference is in RoPE) and I manually checked that inference outputs match as well. Here's a snippet to compare attention layers after setting up and loading weights for both models:
# Compare attention layers
input_tensor = torch.randn(cfg.batch_size, cfg.seq_len, model_config.dim)
freqs_cis = model.freqs_cis[0:cfg.seq_len]
for i, (layer, hf_layer) in enumerate(zip(model.layers, hf_model.model.layers)):
attention_output = layer.attention(layer.attention_norm(input_tensor), freqs_cis)
hf_attention_output, _, _ = hf_layer.self_attn(
hf_layer.input_layernorm(input_tensor),
position_ids=torch.arange(cfg.seq_len, dtype=torch.long).unsqueeze(0).expand(cfg.batch_size, -1)
)
assert torch.allclose(attention_output, hf_attention_output, atol=1e-5), f"Attention layer {i} outputs do not match"
@rlrs It seems HF's llama implementation is different from the official llama's. We'll need to understand why that's the case.
HF: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L184 meta-llama: https://github.com/meta-llama/llama3/blob/main/llama/model.py#L65
asked here: https://github.com/huggingface/transformers/issues/30872
Thanks for asking over there. I didn't try to download the weights from anywhere other than HF, but I would be a bit surprised if there's some simple transformation you can do to the weights to change the rope implementation? Anyways let's await some information from their side.
As discussed in the HF issue, there is indeed a permutation of the weights that causes the two implementations to be equivalent. I don't believe anything needs to be done in the torchtitan repo, and if you agree, feel free to close this issue.
@rlrs
This also seems to mean that loading weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B does not work.
If I understand correctly, there are two ways you can download weights from HF. The first way is from the original
folder which gives the same weights as downloaded from meta llama website; the second way is through HF api transformers.pipeline
which probably does the conversion.
I think torchtitan at least should have code & tutorial to load the original weights. For the second, HF should support the conversion from HF transformer to llama weights.