transformers
transformers copied to clipboard
[generate] fix eos/pad id check on mps devices
What does this PR do?
Generation currently fails on main
for mps devices:
from transformers.models.gemma2 import Gemma2ForCausalLM, Gemma2Config
import torch
config = Gemma2Config(num_hidden_layers=1, vocab_size=128, hidden_size=16, intermediate_size=32, num_attention_heads=1, num_key_value_heads=1)
model = Gemma2ForCausalLM(config).to("mps")
input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
model.generate(input_ids, attention_mask=input_ids)
Traceback
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[2], line 8
5 model = Gemma2ForCausalLM(config).to("mps")
7 input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
----> 8 model.generate(input_ids, attention_mask=input_ids)
File ~/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/transformers/src/transformers/generation/utils.py:1664, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
1661 batch_size = inputs_tensor.shape[0]
1663 device = inputs_tensor.device
-> 1664 self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
1666 # decoder-only models must use left-padding for batched generation.
1667 if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
1668 # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
1669 # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
File ~/transformers/src/transformers/generation/utils.py:1513, in GenerationMixin._prepare_special_tokens(self, generation_config, kwargs_has_attention_mask, device)
1510 logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
1512 # we can't infer attn mask if pad token is set to be eos token in model's generation config
-> 1513 if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
1514 if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
1515 logger.warning_once(
1516 "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
1517 "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
1518 "to obtain reliable results."
1519 )
NotImplementedError: The operator 'aten::isin.Tensor_Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
=> this is due to the torch.isin
operator not being implemented on torch mps. This PR removes the torch.isin
operator from the main body of generation, while keeping compatibility with the eos/pad checks added in #31254.
Following this PR, Gemma-2 (and other generate-compatible models) can be run on mps.