Numerical Mismatch between Torchtune and HuggingFace models
This is potentially related to Issue #2809 .
There seem to be numerical mismatches between various models implemented in torchtune and HuggingFace, as demonstrated by this code:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchtune.models.qwen3 import (
qwen3_1_7b_base,
qwen3_4b_base,
qwen3_8b_base,
qwen3_tokenizer,
)
from torchtune.models.qwen3._convert_weights import qwen3_hf_to_tune
from torchtune.models.gemma import gemma_2b, gemma_7b, gemma_tokenizer
from torchtune.models.phi3 import phi3_mini, phi3_mini_tokenizer, phi3_hf_to_tune
from torchtune.models.llama3 import llama3_8b, llama3_tokenizer
MODELS = {
"gemma": {
"variants": {
"2b": (gemma_2b, "google/gemma-2b"),
"7b": (gemma_7b, "google/gemma-7b"),
},
"tokenizer": gemma_tokenizer,
"convert": None,
},
"phi": {
"variants": {
"mini": (phi3_mini, "microsoft/Phi-3-mini-4k-instruct"),
},
"tokenizer": phi3_mini_tokenizer,
"convert": phi3_hf_to_tune,
},
"llama": {
"variants": {
"8b": (llama3_8b, "meta-llama/Meta-Llama-3-8B"),
},
"tokenizer": llama3_tokenizer,
"convert": None,
},
"qwen": {
"variants": {
"1.7b": (qwen3_1_7b_base, "Qwen/Qwen3-1.7B-Base"),
"4b": (qwen3_4b_base, "Qwen/Qwen3-4B-Base"),
"8b": (qwen3_8b_base, "Qwen/Qwen3-8B-Base"),
},
"tokenizer": qwen3_tokenizer,
"convert": qwen3_hf_to_tune,
},
}
def compare_logits(model_name, variant_key, builder, hf_name, convert_fn):
print(f"Comparing {model_name} {variant_key}")
hf_tokenizer = AutoTokenizer.from_pretrained(hf_name)
hf_model = AutoModelForCausalLM.from_pretrained(hf_name)
# Create tune model
tune_model = builder()
if convert_fn is not None:
if 'phi' in model_name:
converted_sd = convert_fn(
hf_model.state_dict(),
num_heads=hf_model.config.num_attention_heads,
num_kv_heads=hf_model.config.num_key_value_heads,
dim=hf_model.config.hidden_size
)
elif 'qwen' in model_name and variant_key!='8b':
converted_sd = convert_fn(
hf_model.state_dict(),
tie_word_embeddings=True
)
else:
converted_sd = convert_fn(hf_model.state_dict())
tune_model.load_state_dict(converted_sd, strict=True)
# Use the HF tokenizer for simplicity
input_str = "Hello world"
tokens = hf_tokenizer(input_str, return_tensors="pt").input_ids
with torch.no_grad():
hf_logits = hf_model(tokens).logits
tune_logits = tune_model(tokens)
diff = (hf_logits - tune_logits).abs().max().item()
print(f"Max logit diff: {diff:.6f}\n")
def main():
for name, info in MODELS.items():
for variant, (builder, hf_name) in info["variants"].items():
compare_logits(
name,
variant,
builder,
hf_name,
info["convert"],
)
if __name__ == "__main__":
main()
Hey! Thanks for the issue. I still need to understand how critical is this, but definitely it needs to be fixed.
@pytorchbot label bug
Thank you for such a thorough analysis.
In general you will see small numerical differences between torchtune and HF models as our RoPE implementation is implemented differently. They’re mathematically equivalent but numerically different. You can also run into situations like this if you separate linear multiply and bias addition or do them together, or if you fuse QKV or not. We’ve tested these and these differences do not affect training or inference.
Now the issue with Qwen3 is different and being addressed. Can you share some of your difference values for the different models? I can try to see if they’re within the expected tolerances.
Thank you for your response @pbontrager. See below for the differences resulting from the script above:
Comparing gemma 2b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 3.91it/s]
Max logit diff: 2079.876953
Comparing gemma 7b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.65it/s]
Max logit diff: 3198.977051
Comparing phi mini
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.50it/s]
Max logit diff: 0.000000
Comparing llama 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.57it/s]
Max logit diff: 15.436760
Comparing qwen 1.7b
Max logit diff: 2.295973
Comparing qwen 4b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.53it/s]
Max logit diff: 1.490764
Comparing qwen 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 3.16it/s]
Max logit diff: 1.048040
Thank you! I'm investigating and will get back to you when I've found the source of the issues. I've started with Gemma and I've found that some of the weights don't match inside of the script even though by the way you load them it looks like they should. I'm not sure yet if it's an issue with our convert_weights or with the test script somehow. The other ones look like they might be within more normal bounds but I'll investigate them after Gemma.
I got back to this and found the bug. Your testing script does not load any hf_weights when "convert": None. You should use the default default convert function instead of None.
from torchtune.models.convert_weights import hf_to_tune
converted_sd = convert_fn(
hf_model.state_dict(),
num_heads=hf_model.config.num_attention_heads,
num_kv_heads=hf_model.config.num_key_value_heads,
dim=hf_model.config.hidden_size,
head_dim=hf_model.config.head_dim,
)
if hf_model.config.tie_word_embeddings:
converted_sd.pop("output.weight")
Now rerunning your script I get
Comparing gemma 2b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.10it/s]
Max logit diff: 0.068563
Comparing gemma 7b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.34it/s]
Max logit diff: 0.554550
Comparing phi mini
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.02s/it]
Max logit diff: 0.000000
Comparing llama 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.44it/s]
Max logit diff: 0.000016
Comparing qwen 1.7b
Max logit diff: 2.295973
Comparing qwen 4b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.01it/s]
Max logit diff: 1.490764
Comparing qwen 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 5/5 [00:03<00:00, 1.61it/s]
Max logit diff: 1.048040
This is much more in the expected range. I'm confident that Llama and Phi are fine. We'll need to rerun for Qwen once this fix lands. Gemma is within a possibly okay range but I'm suspicious of it and will keep investigating to make sure the current difference is the same as in the original PR.
thank you @pbontrager for tracking this down! if helpful, I am happy to help with this?
Note that this script compares the models in float32 precision on CPU. When I modify the script to compare the models on GPU in bf16 precision, I suddenly see a difference with models that had ~0 difference when compared in float32 on CPU or GPU.
Modified script
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchtune.models.qwen3 import (
qwen3_8b_base,
qwen3_tokenizer,
)
from torchtune.models.llama3 import llama3_8b, llama3_tokenizer
from torchtune.models.qwen3._convert_weights import qwen3_hf_to_tune
from torchtune import config, generation, training, utils
from torchtune.models.convert_weights import hf_to_tune
def llama_hf_to_tune(hf_model):
return hf_to_tune(
hf_model.state_dict(),
num_heads=hf_model.config.num_attention_heads,
num_kv_heads=hf_model.config.num_key_value_heads,
dim=hf_model.config.hidden_size,
head_dim=hf_model.config.head_dim,
)
MODELS = {
"qwen": {
"variants": {
"8b": (qwen3_8b_base, "Qwen/Qwen3-8B-Base"),
},
"tokenizer": qwen3_tokenizer,
"convert": qwen3_hf_to_tune,
},
"llama": {
"variants": {
"8b": (llama3_8b, "meta-llama/Meta-Llama-3-8B"),
},
"tokenizer": llama3_tokenizer,
"convert": llama_hf_to_tune,
},
}
def compare_logits(model_name, variant_key, builder, hf_name, convert_fn):
print(f"Comparing {model_name} {variant_key}")
hf_tokenizer = AutoTokenizer.from_pretrained(hf_name)
hf_model = AutoModelForCausalLM.from_pretrained(hf_name)
dtype = torch.bfloat16
print("Using dtype", dtype)
_device = utils.get_device(device="cuda")
with training.set_default_dtype(dtype), _device:
# Create tune model
tune_model = builder()
if convert_fn is not None:
if model_name == "llama":
converted_sd = convert_fn(hf_model)
else:
converted_sd = convert_fn(hf_model.state_dict())
tune_model.load_state_dict(converted_sd, strict=True)
# Use the HF tokenizer for simplicity
input_str = "Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive"
tokens = hf_tokenizer(input_str, return_tensors="pt").input_ids
tokens_cuda = tokens.to(_device)
tokens_cuda_1 = tokens.to("cuda:1")
with torch.no_grad():
tune_logits = tune_model(tokens_cuda).to("cpu")
hf_model = hf_model.to("cuda:1", dtype)
hf_logits = hf_model(tokens_cuda_1).logits.to("cpu")
diff = (hf_logits - tune_logits).abs().max().item()
print(f"Max logit diff: {diff:.6f}\n")
def main():
for name, info in MODELS.items():
for variant, (builder, hf_name) in info["variants"].items():
compare_logits(
name,
variant,
builder,
hf_name,
info["convert"],
)
if __name__ == "__main__":
main()
Llama3-8B comparison
python compare_logits.py
Comparing llama 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.14it/s]
Using dtype torch.float32
Max logit diff: 0.000051
python compare_logits.py
Comparing llama 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.17it/s]
Using dtype torch.float16
Max logit diff: 0.025391
python compare_logits.py
Comparing llama 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]
Using dtype torch.bfloat16
Max logit diff: 0.203125
Qwen3-8B comparison after issue 2809 fix
python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.55it/s]
Using dtype torch.float32
Max logit diff: 0.000048
python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.58it/s]
Using dtype torch.float16
Max logit diff: 0.039062
python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00, 2.47it/s]
Using dtype torch.bfloat16
Max logit diff: 0.375000
Note that this script compares the models in float32 precision on CPU. When I modify the script to compare the models on GPU in bf16 precision, I suddenly see a difference with models that had ~0 difference when compared in float32 on CPU or GPU.
Modified script Llama3-8B comparison Qwen3-8B comparison after issue 2809 fix
Some numerical differences at lower precision (eg {bf,fp}16) are to be expected because floating point arithmetic is not associative.