Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] Load model to mutilple devices
When I use the following code to load LLama2 and generate:
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
hf_model=hf_model,
device="cuda",
n_devices=4,
move_to_device=True,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
tokenizer=tokenizer)
model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)
I got an error:
mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)
I found that when I load model to multiple devices, the Attention mask matrix is always on cuda 0, which raised the abovementioned error. So, I have made the following change in the forward function of the attention module :
def forward(
self,
query_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
key_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
value_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
attention_mask is the attention mask for padded tokens. Defaults to None.
"""
# move the attention mask to the device that the attention block is on
if additive_attention_mask is not None and additive_attention_mask.device != self.rotary_sin.device:
additive_attention_mask = additive_attention_mask.to(self.rotary_sin.device)
if attention_mask is not None and attention_mask.device != self.rotary_sin.device:
attention_mask = attention_mask.to(self.rotary_sin.device)
and it works well.
Thanks - feel free to submit a PR for this!
Hi, I'm also using n_devices
> 1 and I've found this bug in many other parts of the code.
Some examples:
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L676
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L754-L755
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/utils.py#L696-L698
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/ActivationCache.py#L453
In general, it breaks every time it has to do an operation between two tensors that are stored in different GPUs.
It seems like a more structural issue than just fixing those cases with a .to(X.device)
. Or am I missing something?
Thanks!
It's a bit messy. In my opinion the crucial thing is that the model runs. So fixing bugs 1 and 2 seems important
I'm in general kinda fine with some utilities in the library assuming things all live on one device, which is what happened with bugs 3 and 4. Bug 3 is in utils.test_prompt and should be easy to fix by eg moving the values to the CPU. Bug 4 is a messier problem, because it's trying to stack activations across device. It would be easy to fix by adding something to move the activations to the same device but that might give you out of memory errors. One option is to have a method on the cache that moves all activations to the same device (or CPU, which should probably be the default)
Generally, multi device is a thing I want the library's core to support, but I'm fine with some features breaking on it, since it seems costly to need to write everything to be robust to multi device stuff. But if anything specific annoys you and is easy to fix, please send a PR!
On Wed, 25 Oct 2023, 3:02 pm Gerard I. Gállego, @.***> wrote:
Hi, I'm also using n_devices > 1 and I've found this bug in many other parts of the code.
Some examples:
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L676
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/components.py#L754-L755
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/utils.py#L696-L698
https://github.com/neelnanda-io/TransformerLens/blob/174209ea708fe3838ccf08b70f2f4f28e7397cb4/transformer_lens/ActivationCache.py#L453
In general, it breaks every time it has to do an operation between two tensors that are stored in different GPUs.
It seems like a more structural issue than just fixing those cases with a .to(X.device). Or am I missing something?
Thanks!
— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/439#issuecomment-1779350184, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKI4PPYKQI3DSHRL77LYBELYVAVCNFSM6AAAAAA6LPS7WCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZZGM2TAMJYGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
I've encountered the same issue. Instead of changing things at various places in the attention block, I added these four lines to HookedTransformer.forward
(roughly L546)
if attention_mask is not None:
attention_mask = attention_mask.to(
devices.get_device_for_block_index(i, self.cfg)
)
I personally think this is slightly more parsimonious.
I wanted to check and see if any progress/PRs have been made on this issue? I am running into this error as well.
Not yet I'm afraid.
There's a task involved here to remove most of the manual device setting throughout the codebase (e.g. to.(device=)
and tensor([], device=
), as torch handles most of this by default with e.g. model.to() . Then getting multi-gpu support to work everywhere should be v. easy.
I'll peel this off if no-one else does over the next few weeks - it's also a bottleneck for distributed SAE training.
There's a task involved here to remove most of the manual device setting throughout the codebase (e.g.
to.(device=)
andtensor([], device=
), as torch handles most of this by default with e.g. model.to() . Then getting multi-gpu support to work everywhere should be v. easy.
In this case, as I understand, it would not support distributing a model between several GPUs?
I had similar problems when I tried to load llama-13b on two of GTX1080. This worked with quantization (#486), but then I got similar issues as described above: tensors on different devices.
I made some fixes to move the tensors to the same device: https://github.com/coolvision/TransformerLens/commit/5b250b30abcbf1c4f1d482759082d897e2ef2843
With this fixes the inference does work, with the model distributed between 2 GPUs:
model_name = "meta-llama/Llama-2-13b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=inference_dtype,
device_map = "auto",
load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_tl = HookedTransformer.from_pretrained(model_name,
hf_model=model,
dtype=inference_dtype,
fold_ln=False,
n_devices=2,
fold_value_biases=False,
center_writing_weights=False,
center_unembed=False,
tokenizer=tokenizer)
model_tl.generate("The capital of Germany is", max_new_tokens=2, temperature=0)
>'The capital of Germany is Berlin.'
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 Off | N/A |
| 29% 44C P8 10W / 250W | 6136MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... Off | 00000000:02:00.0 Off | N/A |
| 25% 41C P8 9W / 250W | 5312MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
I can add this fixes to my quantization PR #486, or can make a separate PR as well.