transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[generate] fix eos/pad id check on mps devices

Open sanchit-gandhi opened this issue 6 days ago • 1 comments

What does this PR do?

Generation currently fails on main for mps devices:

from transformers.models.gemma2 import Gemma2ForCausalLM, Gemma2Config
import torch

config = Gemma2Config(num_hidden_layers=1, vocab_size=128, hidden_size=16, intermediate_size=32, num_attention_heads=1, num_key_value_heads=1)
model = Gemma2ForCausalLM(config).to("mps")

input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
model.generate(input_ids, attention_mask=input_ids)
Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[2], line 8
      5 model = Gemma2ForCausalLM(config).to("mps")
      7 input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
----> 8 model.generate(input_ids, attention_mask=input_ids)

File ~/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/transformers/src/transformers/generation/utils.py:1664, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1661 batch_size = inputs_tensor.shape[0]
   1663 device = inputs_tensor.device
-> 1664 self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
   1666 # decoder-only models must use left-padding for batched generation.
   1667 if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
   1668     # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
   1669     # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.

File ~/transformers/src/transformers/generation/utils.py:1513, in GenerationMixin._prepare_special_tokens(self, generation_config, kwargs_has_attention_mask, device)
   1510     logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
   1512 # we can't infer attn mask if pad token is set to be eos token in model's generation config
-> 1513 if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
   1514     if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
   1515         logger.warning_once(
   1516             "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
   1517             "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
   1518             "to obtain reliable results."
   1519         )

NotImplementedError: The operator 'aten::isin.Tensor_Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

=> this is due to the torch.isin operator not being implemented on torch mps. This PR removes the torch.isin operator from the main body of generation, while keeping compatibility with the eos/pad checks added in #31254.

Following this PR, Gemma-2 (and other generate-compatible models) can be run on mps.

sanchit-gandhi avatar Jun 28 '24 13:06 sanchit-gandhi