Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Question] How to load a model in smaller precision?

Open MittelmanDaniel opened this issue 1 year ago • 5 comments

Question

Is doing

model = HookedTransformer.from_pretrained_no_processing(
    model_name="google/gemma-2-2b-it",
    device=device,
    dtype=torch.bfloat16,
    default_padding_side="left"
)

enough to load a model in bfloat16? The model loads fine directly from huggingface but not through transformer lens.

MittelmanDaniel avatar Nov 18 '24 00:11 MittelmanDaniel

What error did you receive when you ran this in TransformerLens?

bryce13950 avatar Nov 19 '24 14:11 bryce13950

@bryce13950

I received the standard Cuda OutofMemory error

For context I have an RTX 3060 Laptop.

MittelmanDaniel avatar Nov 19 '24 22:11 MittelmanDaniel

I am running this script


from transformers import AutoTokenizer, AutoModelForCausalLM  
import torch
from transformer_lens import HookedTransformer

model_name = "google/gemma-2-2b-it"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
model = HookedTransformer.from_pretrained_no_processing(
    model_name="google/gemma-2-2b-it",
    device=device,
    dtype=torch.bfloat16,
    default_padding_side="left"
)



# Example input text
input_text = "Once upon a time"

# Tokenize the input
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)

# Generate output
output_sequences = model.generate(
    input_ids=inputs,
    max_length=50,       
    temperature=0.7,     
    top_k=50,            
    top_p=0.95,          
    do_sample=True,      
    num_return_sequences=1
)

# Decode and print the generated text
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(generated_text)

The only change I made is using hooked transformer over a standard transformer

Here is a comparison of the outputs

image

image

MittelmanDaniel avatar Nov 19 '24 23:11 MittelmanDaniel

You should be able to load that model without issues. If you are getting memory errors, then my best guess is that there is something else running that is causing TransformerLens to not have enough memory. Looking at your code, I am curious, is there a reason you are creating your own tokenizer? That shouldn't account for everything, but it is definitely adding more overhead than needed. Also, make sure you are using the most recent version of TransformerLens. There was a pretty large memory issue fixed a few months back, so if you have an older version of TransformerLens, that could also be part of the issue.

bryce13950 avatar Nov 21 '24 00:11 bryce13950

I had an external tokenizer just to showcase that I was only changing that one line.

I was using 2.9.0

MittelmanDaniel avatar Dec 03 '24 21:12 MittelmanDaniel