Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] Issues with PosEmbed device when used with accelerate
Describe the bug
PosEmbed uses incorrect device when used with accelerate library
Code example
The following code is a minimal training loop that trains the gpt2 model on randomly generated data.
import torch
from transformer_lens import HookedTransformer
from accelerate import Accelerator
from tqdm import tqdm
import sys
print(f"Python version: {sys.version}")
accelerator = Accelerator()
print(f"Running on device: {accelerator.device}")
model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name)
tokens = torch.randint(0, model.tokenizer.vocab_size, (64, 100))
random_lengths = torch.randint(90, 100, (64,))
attention_mask = torch.ones_like(tokens)
for i in range(64):
attention_mask[i, random_lengths[i]:] = 0
dataset = torch.utils.data.TensorDataset(tokens, attention_mask)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
for batch in tqdm(dataloader):
tokens, attention_mask = batch
loss = model.forward(tokens, attention_mask=attention_mask, return_type="loss")
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
avg_loss = accelerator.gather(loss).mean().item()
if accelerator.is_main_process:
print(f"Loss: {avg_loss}")
When this code is run with python minimal_example.py it works fine.
When it is run with CUDA_VISIBLE_DEVICES=1,2 accelerate launch minimal_example.py I get the following error.
(py310) guest@All:~/david_quarel/advint$ CUDA_VISIBLE_DEVICES=1,2 accelerate launch minimal_inference.py
The following values were not passed to `accelerate launch` and had defaults used instead:
`--num_processes` was set to a value of `2`
More than one GPU was found, enabling multi-GPU training.
If this was unintended please pass in `--num_processes=1`.
`--num_machines` was set to a value of `1`
`--mixed_precision` was set to a value of `'no'`
`--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:19:12) [GCC 13.3.0]
Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:19:12) [GCC 13.3.0]
Running on device: cuda:0
Running on device: cuda:1
Loaded pretrained model gpt2 into HookedTransformer
Moving model to device: cuda
Loaded pretrained model gpt2 into HookedTransformer
Moving model to device: cuda
0%| | 0/4 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/HOME/guest/david_quarel/advint/minimal_inference.py", line 32, in <module>
[rank1]: loss = model.forward(tokens, attention_mask=attention_mask, return_type="loss")
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank1]: else self._run_ddp_forward(*inputs, **kwargs)
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank1]: return self.module(*inputs, **kwargs) # type: ignore[index]
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 583, in forward
[rank1]: ) = self.input_to_embed(
[rank1]: File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 410, in input_to_embed
[rank1]: residual, shortformer_pos_embed = self.get_residual(
[rank1]: File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 302, in get_residual
[rank1]: self.pos_embed(tokens, pos_offset, attention_mask)
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/components/pos_embed.py", line 58, in forward
[rank1]: pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
[rank1]: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)
System Info Describe the characteristic of your environment:
-
transformer_lensversion 2.15.0 installed via pip - What OS are you using? Linux
- Python version = 3.10.17
Additional context I'm trying to train an SAE on a transformer_lens model with mutiple GPUs using accelerate which is how I ran into this bug.
Checklist
- [X] I have checked that there is no similar issue in the repo (required)