Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] loading gemma-7b-it runs out of memory

Open qwenzo opened this issue 10 months ago • 2 comments

Describe the bug When loading gemma-7b-it on a machine with 2 H100, the loading still fails on a jupyter notebook.

Code example

n_devices=2
model = HookedTransformer.from_pretrained_no_processing(
    "google/gemma-7b-it",
    center_writing_weights=False, 
    fold_value_biases=False,
    fold_ln=False,
    center_unembed=False,
    default_padding_side="left",
    default_prepend_bos=False,
    device="cuda",
    dtype=torch.bfloat16,
    n_devices=n_devices
)

System Info Describe the characteristic of your environment:

  • Describe how transformer_lens was installed: conda
  • What OS are you using: Linux
  • Python version: 3.11.5
torch==2.4.1
torch-scatter==2.1.2
torchaudio==2.4.1
torchmetrics==1.2.0
transformer-lens==2.13.0
transformers==4.45.1

Additional context

  • I tried just using from_pretrained instead of from_pretrained_no_processing.
  • I tried passing the model as hf_model to from_pretrained.

Checklist

  • [ x] I have checked that there is no similar issue in the repo (required)

qwenzo avatar Feb 06 '25 15:02 qwenzo

The current implementation of multi GPU is not very reliable. There is a new version coming up with this entirely overhauled now, but I have not specifically tested Gemma models, and it may still error out depending on the model architecture. If you have a moment to test this again after the next update is up, that would be very helpful. Fingers crossed it will work now.

bryce13950 avatar Feb 13 '25 00:02 bryce13950

@qwenzo the more robust support has been added. I did not specifically test Gemma on multi-GPU before the update was put up, so if you do try it, let me know if there are issues. I will make sure to test the model family in the upcoming week.

bryce13950 avatar Feb 15 '25 19:02 bryce13950