[Question] How to load a model in smaller precision?
Question
Is doing
model = HookedTransformer.from_pretrained_no_processing(
model_name="google/gemma-2-2b-it",
device=device,
dtype=torch.bfloat16,
default_padding_side="left"
)
enough to load a model in bfloat16? The model loads fine directly from huggingface but not through transformer lens.
What error did you receive when you ran this in TransformerLens?
@bryce13950
I received the standard Cuda OutofMemory error
For context I have an RTX 3060 Laptop.
I am running this script
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformer_lens import HookedTransformer
model_name = "google/gemma-2-2b-it"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = HookedTransformer.from_pretrained_no_processing(
model_name="google/gemma-2-2b-it",
device=device,
dtype=torch.bfloat16,
default_padding_side="left"
)
# Example input text
input_text = "Once upon a time"
# Tokenize the input
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Generate output
output_sequences = model.generate(
input_ids=inputs,
max_length=50,
temperature=0.7,
top_k=50,
top_p=0.95,
do_sample=True,
num_return_sequences=1
)
# Decode and print the generated text
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(generated_text)
The only change I made is using hooked transformer over a standard transformer
Here is a comparison of the outputs
You should be able to load that model without issues. If you are getting memory errors, then my best guess is that there is something else running that is causing TransformerLens to not have enough memory. Looking at your code, I am curious, is there a reason you are creating your own tokenizer? That shouldn't account for everything, but it is definitely adding more overhead than needed. Also, make sure you are using the most recent version of TransformerLens. There was a pretty large memory issue fixed a few months back, so if you have an older version of TransformerLens, that could also be part of the issue.
I had an external tokenizer just to showcase that I was only changing that one line.
I was using 2.9.0