Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] Gemma-2-2b-it output logit doesn't match with huggingface

Open yeutong opened this issue 1 year ago • 3 comments

Describe the bug The output logits from transformer_lens and huggingface are quite different using Gemma-2-2b-it model

Code example

import torch
import transformer_lens
from transformers import AutoTokenizer, AutoModelForCausalLM

device = 'cuda'
model_name = 'google/gemma-2-2b-it'
tl_model = transformer_lens.HookedTransformer.from_pretrained(model_name, device=device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

inputs = tokenizer('Hello world', return_tensors="pt").to(device)

logits_tl = tl_model(inputs.input_ids, return_type='logits', prepend_bos=False)
logits_hf = hf_model(**inputs).logits

print((logits_tl[0, -1] - logits_hf[0, -1]).mean()) # 0.1159
print((logits_hf[0, -1]).min(), (logits_hf[0, -1]).max()) # -19.6916 16.0789

System Info transformer_lens 2.3.0, transformers 4.43.2

Additional context The logit diff is quite large

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)

yeutong avatar Aug 02 '24 18:08 yeutong

TransformerLens centers the unembedding, which translates every logit by a fixed amount per token (the shift can vary over token). Can you do this again for log probs? Or try from_pretrained_no_processing? There are known accuracy issues, but I want to rule out trivial causes

On Fri, 2 Aug 2024, 7:53 pm Yeu-Tong Lau, @.***> wrote:

Describe the bug The output logits from transformer_lens and huggingface are quite different using Gemma-2-2b-it model

Code example

import torchimport transformer_lensfrom transformers import AutoTokenizer, AutoModelForCausalLM device = 'cuda'model_name = 'google/gemma-2-2b-it'tl_model = transformer_lens.HookedTransformer.from_pretrained(model_name, device=device) tokenizer = AutoTokenizer.from_pretrained(model_name)hf_model = AutoModelForCausalLM.from_pretrained(model_name).to(device) inputs = tokenizer('Hello world', return_tensors="pt").to(device) logits_tl = tl_model(inputs.input_ids, return_type='logits', prepend_bos=False)logits_hf = hf_model(**inputs).logits print((logits_tl[0, -1] - logits_hf[0, -1]).mean()) # 0.1159print((logits_hf[0, -1]).min(), (logits_hf[0, -1]).max()) # -19.6916 16.0789

System Info transformer_lens 2.3.0, transformers 4.43.2

Additional context The logit diff is quite large Checklist

— Reply to this email directly, view it on GitHub https://github.com/TransformerLensOrg/TransformerLens/issues/693, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKNUTFZ6X5ECWWOOHYDZPPIRNAVCNFSM6AAAAABL5BOXL6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGQ2DKNJXHA4TMNQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

neelnanda-io avatar Aug 02 '24 19:08 neelnanda-io

Tried from_pretrained_no_processing and got the same results. It is more than the unembedding centering, the differences exist and get larger in each layer model activations.

def forward_with_cache(model, layer, inputs):
    cache = None
    def hook(module, inputs, outputs):
        nonlocal cache
        cache = inputs[0]
        return outputs
    
    hook_handle = model.model.layers[layer].register_forward_hook(hook)
    _ = model(**inputs)
    hook_handle.remove()

    return cache

resid_pre_diffs = []

for layer in range(tl_model.cfg.n_layers):
    hf_cache = forward_with_cache(hf_model, layer, inputs)
    _, tl_cache = tl_model.run_with_cache(inputs.input_ids, prepend_bos=False, names_filter=[f'blocks.{layer}.hook_resid_pre'])
    tl_cache = tl_cache[f'blocks.{layer}.hook_resid_pre']
    resid_pre_diff = (hf_cache - tl_cache)[0, -1].norm().item()
    resid_pre_diffs.append(resid_pre_diff)

import plotly.express as px
px.line(resid_pre_diffs, markers=True, labels={'index': 'Layer', 'value': 'norm of resid pre diff'}, title='Difference in resid_pre between HF and TL')
image

yeutong avatar Aug 02 '24 19:08 yeutong

@yeutong the issue is caused by a different attention scale used (~14.96 vs 16). The HF implementation also disables the attention logits soft capping for inference, but that is less important

for b in tl_model.blocks:
    b.attn.attn_scale = 16
    b.attn.cfg.attn_scores_soft_cap = 0
resid_diff

There is still some difference in the activations, but this is on the order of 5e-4 on the last layer. This one is probably a deeper issue related to the use of einsum in the attention

mntss avatar Aug 07 '24 08:08 mntss

this also seems to be an issue with Llama3.2 1B Instruct just using the code above and replacing gemma with llama, I get the following output: image

Vedaant-J avatar Oct 27 '24 10:10 Vedaant-J

I think I was able to localize the error in Llama3.2-1B to the positional embedding, specifically the values of the cos and sin vectors got different values when comparing tlens to the HF implementation at LlamaRotaryEmbedding. When copying exactly the same values, it seems the difference falls down to roughly ~1e-5.

image

yanivnik avatar Oct 30 '24 17:10 yanivnik

@yanivnik What version of TransformerLens were you using? The latest version should have modified that, but that does not mean it is now 100% accurate. It would be helpful to know exactly what version this was generated with.

bryce13950 avatar Nov 03 '24 22:11 bryce13950

Hi! After all the einsum removals that were done over the past two weeks, I took a look at this again, and it seems like the logits match very well!

=== Logits Comparison ===
TransformerLens:
Mean: tensor(-7.1663, device='mps:0', grad_fn=<MeanBackward0>)
Std: tensor(4.3232, device='mps:0', grad_fn=<StdBackward0>)
HuggingFace:
Mean: tensor(-7.1663, device='mps:0', grad_fn=<MeanBackward0>)
Std: tensor(4.3232, device='mps:0', grad_fn=<StdBackward0>)
Max diff: tensor(6.6757e-05, device='mps:0', grad_fn=<MaxBackward1>)

degenfabian avatar Dec 13 '24 01:12 degenfabian

I think I was able to localize the error in Llama3.2-1B to the positional embedding, specifically the values of the cos and sin vectors got different values when comparing tlens to the HF implementation at LlamaRotaryEmbedding. When copying exactly the same values, it seems the difference falls down to roughly ~1e-5.

image

Hi, Thank you for providing this info! I met the exact same issue on Llama3.2-1B. Could you kindly share the details of how to fix this?

Thank you!

chengjiali avatar Oct 30 '25 17:10 chengjiali

Hey! How are you loading Llama in? When you load Llama with from_pretrained_no_processing the residual stream values should be very close.

degenfabian avatar Oct 30 '25 23:10 degenfabian