[Bug Report] loading gemma-7b-it runs out of memory
Describe the bug When loading gemma-7b-it on a machine with 2 H100, the loading still fails on a jupyter notebook.
Code example
n_devices=2
model = HookedTransformer.from_pretrained_no_processing(
"google/gemma-7b-it",
center_writing_weights=False,
fold_value_biases=False,
fold_ln=False,
center_unembed=False,
default_padding_side="left",
default_prepend_bos=False,
device="cuda",
dtype=torch.bfloat16,
n_devices=n_devices
)
System Info Describe the characteristic of your environment:
- Describe how
transformer_lenswas installed: conda - What OS are you using: Linux
- Python version: 3.11.5
torch==2.4.1
torch-scatter==2.1.2
torchaudio==2.4.1
torchmetrics==1.2.0
transformer-lens==2.13.0
transformers==4.45.1
Additional context
- I tried just using
from_pretrainedinstead offrom_pretrained_no_processing. - I tried passing the model as
hf_modeltofrom_pretrained.
Checklist
- [ x] I have checked that there is no similar issue in the repo (required)
The current implementation of multi GPU is not very reliable. There is a new version coming up with this entirely overhauled now, but I have not specifically tested Gemma models, and it may still error out depending on the model architecture. If you have a moment to test this again after the next update is up, that would be very helpful. Fingers crossed it will work now.
@qwenzo the more robust support has been added. I did not specifically test Gemma on multi-GPU before the update was put up, so if you do try it, let me know if there are issues. I will make sure to test the model family in the upcoming week.