torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Numerical Mismatch between Torchtune and HuggingFace models

Open athms opened this issue 6 months ago • 9 comments

This is potentially related to Issue #2809 .

There seem to be numerical mismatches between various models implemented in torchtune and HuggingFace, as demonstrated by this code:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from torchtune.models.qwen3 import (
    qwen3_1_7b_base,
    qwen3_4b_base,
    qwen3_8b_base,
    qwen3_tokenizer,
)
from torchtune.models.qwen3._convert_weights import qwen3_hf_to_tune
from torchtune.models.gemma import gemma_2b, gemma_7b, gemma_tokenizer
from torchtune.models.phi3 import phi3_mini, phi3_mini_tokenizer, phi3_hf_to_tune
from torchtune.models.llama3 import llama3_8b, llama3_tokenizer

MODELS = {
    "gemma": {
        "variants": {
            "2b": (gemma_2b, "google/gemma-2b"),
            "7b": (gemma_7b, "google/gemma-7b"),
        },
        "tokenizer": gemma_tokenizer,
        "convert": None,
    },
    "phi": {
        "variants": {
            "mini": (phi3_mini, "microsoft/Phi-3-mini-4k-instruct"),
        },
        "tokenizer": phi3_mini_tokenizer,
        "convert": phi3_hf_to_tune,
    },
    "llama": {
        "variants": {
            "8b": (llama3_8b, "meta-llama/Meta-Llama-3-8B"),
        },
        "tokenizer": llama3_tokenizer,
        "convert": None,
    },
    "qwen": {
        "variants": {
            "1.7b": (qwen3_1_7b_base, "Qwen/Qwen3-1.7B-Base"),
            "4b": (qwen3_4b_base, "Qwen/Qwen3-4B-Base"),
            "8b": (qwen3_8b_base, "Qwen/Qwen3-8B-Base"),
        },
        "tokenizer": qwen3_tokenizer,
        "convert": qwen3_hf_to_tune,
    },
}


def compare_logits(model_name, variant_key, builder, hf_name, convert_fn):
    print(f"Comparing {model_name} {variant_key}")
    hf_tokenizer = AutoTokenizer.from_pretrained(hf_name)
    hf_model = AutoModelForCausalLM.from_pretrained(hf_name)

    # Create tune model
    tune_model = builder()
    if convert_fn is not None:
        if 'phi' in model_name:
            converted_sd = convert_fn(
                hf_model.state_dict(),
                num_heads=hf_model.config.num_attention_heads,
                num_kv_heads=hf_model.config.num_key_value_heads,
                dim=hf_model.config.hidden_size
            )
        elif 'qwen' in model_name and variant_key!='8b':
            converted_sd = convert_fn(
                hf_model.state_dict(),
                tie_word_embeddings=True
            )
        else:
            converted_sd = convert_fn(hf_model.state_dict())
        tune_model.load_state_dict(converted_sd, strict=True)

    # Use the HF tokenizer for simplicity
    input_str = "Hello world"
    tokens = hf_tokenizer(input_str, return_tensors="pt").input_ids
    with torch.no_grad():
        hf_logits = hf_model(tokens).logits
        tune_logits = tune_model(tokens)
    diff = (hf_logits - tune_logits).abs().max().item()
    print(f"Max logit diff: {diff:.6f}\n")


def main():
    for name, info in MODELS.items():
        for variant, (builder, hf_name) in info["variants"].items():
            compare_logits(
                name,
                variant,
                builder,
                hf_name,
                info["convert"],
            )


if __name__ == "__main__":
    main()

athms avatar Jun 11 '25 20:06 athms

Hey! Thanks for the issue. I still need to understand how critical is this, but definitely it needs to be fixed.

krammnic avatar Jun 11 '25 20:06 krammnic

@pytorchbot label bug

krammnic avatar Jun 11 '25 20:06 krammnic

Thank you for such a thorough analysis.

In general you will see small numerical differences between torchtune and HF models as our RoPE implementation is implemented differently. They’re mathematically equivalent but numerically different. You can also run into situations like this if you separate linear multiply and bias addition or do them together, or if you fuse QKV or not. We’ve tested these and these differences do not affect training or inference.

Now the issue with Qwen3 is different and being addressed. Can you share some of your difference values for the different models? I can try to see if they’re within the expected tolerances.

pbontrager avatar Jun 11 '25 21:06 pbontrager

Thank you for your response @pbontrager. See below for the differences resulting from the script above:

Comparing gemma 2b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]
Max logit diff: 2079.876953

Comparing gemma 7b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.65it/s]
Max logit diff: 3198.977051

Comparing phi mini
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.50it/s]
Max logit diff: 0.000000

Comparing llama 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.57it/s]
Max logit diff: 15.436760

Comparing qwen 1.7b
Max logit diff: 2.295973

Comparing qwen 4b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.53it/s]
Max logit diff: 1.490764

Comparing qwen 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.16it/s]
Max logit diff: 1.048040

athms avatar Jun 11 '25 21:06 athms

Thank you! I'm investigating and will get back to you when I've found the source of the issues. I've started with Gemma and I've found that some of the weights don't match inside of the script even though by the way you load them it looks like they should. I'm not sure yet if it's an issue with our convert_weights or with the test script somehow. The other ones look like they might be within more normal bounds but I'll investigate them after Gemma.

pbontrager avatar Jun 12 '25 21:06 pbontrager

I got back to this and found the bug. Your testing script does not load any hf_weights when "convert": None. You should use the default default convert function instead of None.

from torchtune.models.convert_weights import hf_to_tune

converted_sd = convert_fn(
            hf_model.state_dict(),
            num_heads=hf_model.config.num_attention_heads,
            num_kv_heads=hf_model.config.num_key_value_heads,
            dim=hf_model.config.hidden_size,
            head_dim=hf_model.config.head_dim,
        )
if hf_model.config.tie_word_embeddings:
            converted_sd.pop("output.weight")

Now rerunning your script I get

Comparing gemma 2b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.10it/s]
Max logit diff: 0.068563

Comparing gemma 7b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.34it/s]
Max logit diff: 0.554550

Comparing phi mini
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Max logit diff: 0.000000

Comparing llama 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.44it/s]
Max logit diff: 0.000016

Comparing qwen 1.7b
Max logit diff: 2.295973

Comparing qwen 4b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.01it/s]
Max logit diff: 1.490764

Comparing qwen 8b
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.61it/s]
Max logit diff: 1.048040

This is much more in the expected range. I'm confident that Llama and Phi are fine. We'll need to rerun for Qwen once this fix lands. Gemma is within a possibly okay range but I'm suspicious of it and will keep investigating to make sure the current difference is the same as in the original PR.

pbontrager avatar Jun 13 '25 16:06 pbontrager

thank you @pbontrager for tracking this down! if helpful, I am happy to help with this?

athms avatar Jun 13 '25 17:06 athms

Note that this script compares the models in float32 precision on CPU. When I modify the script to compare the models on GPU in bf16 precision, I suddenly see a difference with models that had ~0 difference when compared in float32 on CPU or GPU.

Modified script
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from torchtune.models.qwen3 import (
    qwen3_8b_base,
    qwen3_tokenizer,
)
from torchtune.models.llama3 import llama3_8b, llama3_tokenizer
from torchtune.models.qwen3._convert_weights import qwen3_hf_to_tune
from torchtune import config, generation, training, utils
from torchtune.models.convert_weights import hf_to_tune

def llama_hf_to_tune(hf_model):
    return hf_to_tune(
            hf_model.state_dict(),
            num_heads=hf_model.config.num_attention_heads,
            num_kv_heads=hf_model.config.num_key_value_heads,
            dim=hf_model.config.hidden_size,
            head_dim=hf_model.config.head_dim,
        )

MODELS = {
    "qwen": {
        "variants": {
            "8b": (qwen3_8b_base, "Qwen/Qwen3-8B-Base"),
        },
        "tokenizer": qwen3_tokenizer,
        "convert": qwen3_hf_to_tune,
    },
    "llama": {
        "variants": {
            "8b": (llama3_8b, "meta-llama/Meta-Llama-3-8B"),
        },
        "tokenizer": llama3_tokenizer,
        "convert": llama_hf_to_tune,
    },
}

def compare_logits(model_name, variant_key, builder, hf_name, convert_fn):
    print(f"Comparing {model_name} {variant_key}")
    hf_tokenizer = AutoTokenizer.from_pretrained(hf_name)
    hf_model = AutoModelForCausalLM.from_pretrained(hf_name)

    dtype = torch.bfloat16
    print("Using dtype", dtype)
    _device = utils.get_device(device="cuda")
    with training.set_default_dtype(dtype), _device:
        # Create tune model
        tune_model = builder()
    
    if convert_fn is not None:
        if model_name == "llama":
            converted_sd = convert_fn(hf_model)
        else:
            converted_sd = convert_fn(hf_model.state_dict())
        tune_model.load_state_dict(converted_sd, strict=True)

    # Use the HF tokenizer for simplicity
    input_str = "Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive"
    tokens = hf_tokenizer(input_str, return_tensors="pt").input_ids
    tokens_cuda = tokens.to(_device)
    tokens_cuda_1 = tokens.to("cuda:1")
    with torch.no_grad():
        tune_logits = tune_model(tokens_cuda).to("cpu")
        hf_model = hf_model.to("cuda:1", dtype)
        hf_logits = hf_model(tokens_cuda_1).logits.to("cpu")
        
    diff = (hf_logits - tune_logits).abs().max().item()
    print(f"Max logit diff: {diff:.6f}\n")


def main():
    for name, info in MODELS.items():
        for variant, (builder, hf_name) in info["variants"].items():
            compare_logits(
                name,
                variant,
                builder,
                hf_name,
                info["convert"],
            )


if __name__ == "__main__":
    main()  

Llama3-8B comparison
python compare_logits.py
Comparing llama 8b                                                                                                                                                    
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.14it/s]
Using dtype torch.float32                                                                                                                                             
Max logit diff: 0.000051              
                                         
python compare_logits.py                                                                                         
Comparing llama 8b   
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.17it/s]
Using dtype torch.float16
Max logit diff: 0.025391           

python compare_logits.py
Comparing llama 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.30it/s]
Using dtype torch.bfloat16
Max logit diff: 0.203125
Qwen3-8B comparison after issue 2809 fix
python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.55it/s]
Using dtype torch.float32
Max logit diff: 0.000048

python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  2.58it/s]
Using dtype torch.float16
Max logit diff: 0.039062

python compare_logits.py
Comparing qwen 8b
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.47it/s]
Using dtype torch.bfloat16
Max logit diff: 0.375000

intervitens avatar Jun 14 '25 23:06 intervitens

Note that this script compares the models in float32 precision on CPU. When I modify the script to compare the models on GPU in bf16 precision, I suddenly see a difference with models that had ~0 difference when compared in float32 on CPU or GPU.

Modified script Llama3-8B comparison Qwen3-8B comparison after issue 2809 fix

Some numerical differences at lower precision (eg {bf,fp}16) are to be expected because floating point arithmetic is not associative.

athms avatar Jun 17 '25 10:06 athms