Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] Issues loading Llama-2
Describe the bug When loading Llama-2, throws a "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" error.
Code example I try to load Llama-2 with the following code:
import torch
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
and I get the error
File ~/miniconda3/envs/arena-env/lib/python3.11/site-packages/transformer_lens/HookedTransformer.py:1668, in HookedTransformer.fold_value_biases(self, state_dict)
1666 # [d_model]
1667 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"]
-> 1668 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
1670 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O
1671 if self.cfg.n_key_value_heads is None:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I can lazily get around this by forcing the device to be cpu, loading the model, and then afterwards switching device to the gpu and copying the model over, but I should be able to load the model directly into the gpu.
import torch
from transformer_lens import HookedTransformer
device = torch.device('cpu')
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf", device=device)
device = torch.device('cuda')
model.to(device)
but this is just bypassing the problem.
System Info Describe the characteristic of your environment:
-
transformer_lensinstalled viapip - Ubuntu 22.04.2 LTS "Jammy Jellyfish"
- Python version 3.11 (hopefully that shouldn't be causing it?)
Checklist
- [X] I have checked that there is no similar issue in the repo (required)
I load an instructed version of llama2-7b into a gpu, it seemed to take up 50+GB, is that normal?