Error loading SAE with SAElens library
Hello LLAMASCOPE team, thank you for your work! I'm having trouble with loading the SAE using the SAElens library, I keep getting the error below:
File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/sae_lens/sae.py", line 616, in from_pretrained
cfg_dict, state_dict, log_sparsities = conversion_loader(
^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/sae_lens/toolkit/pretrained_sae_loaders.py", line 503, in llama_scope_sae_loader
state_dict_loaded = load_file(sae_path, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/safetensors/torch.py", line 313, in load_file
with safe_open(filename, framework="pt", device=device) as f:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
safetensors_rust.SafetensorError: device cuda is invalid
The loading script I use is
from sae_lens import SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(release="llama_scope_lxm_8x", sae_id="l16m_8x", device=device)
where device is cuda.
Such problem doesn't appear with the GemmaScope SAE. Do you have any clue about this problem? Thanks in advance!
Hi, we have been receiving such issue reports several times recently. We will have a look at this asap in 24 hrs.
I have replicated your code script in two different environments I have at hand and fail to reproduce this error. Could you please provide more details?
running pip list gives me:
sae-lens 4.4.5
safetensors 0.4.5
torch 2.5.1
... (Other packages omitted)
Hi, my environment is:
sae-lens 5.1.0
safetensors 0.4.5
torch 2.5.1