transformers
transformers copied to clipboard
[MusicGen] SDPA gives nans/infs during sampling
System Info
-
transformers
version: 4.40.0.dev0 - Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.22.1
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: NO
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 1
- machine_rank: 0
- num_machines: 1
- gpu_ids: 0
- rdzv_backend: static
- same_network: True
- main_training_function: main
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- PyTorch version (GPU?): 2.2.1+cu121 (True)
- Tensorflow version (GPU?): 2.13.1 (False)
- Flax version (CPU?/GPU?/TPU?): 0.7.2 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Following #29939, running the following gives an overflow error:
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", attn_implementation="sdpa")
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
Traceback
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[11], line 3
1 unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
----> 3 audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
File [~/hf/lib/python3.8/site-packages/torch/utils/_contextlib.py:115](http://localhost:4000/hf/lib/python3.8/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File [~/transformers/src/transformers/models/musicgen/modeling_musicgen.py:2822](http://localhost:4000/transformers/src/transformers/models/musicgen/modeling_musicgen.py#line=2821), in MusicgenForConditionalGeneration.generate(self, inputs, generation_config, logits_processor, stopping_criteria, synced_gpus, streamer, **kwargs)
2814 input_ids, model_kwargs = self._expand_inputs_for_generation(
2815 input_ids=input_ids,
2816 expand_size=generation_config.num_return_sequences,
2817 is_encoder_decoder=self.config.is_encoder_decoder,
2818 **model_kwargs,
2819 )
2821 # 12. run sample
-> 2822 outputs = self._sample(
2823 input_ids,
2824 logits_processor=logits_processor,
2825 logits_warper=logits_warper,
2826 stopping_criteria=stopping_criteria,
2827 pad_token_id=generation_config.pad_token_id,
2828 eos_token_id=generation_config.eos_token_id,
2829 output_scores=generation_config.output_scores,
2830 return_dict_in_generate=generation_config.return_dict_in_generate,
2831 synced_gpus=synced_gpus,
2832 streamer=streamer,
2833 **model_kwargs,
2834 )
2836 else:
2837 raise ValueError(
2838 "Got incompatible mode for generation, should be one of greedy or sampling. "
2839 "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
2840 )
File [~/transformers/src/transformers/generation/utils.py:2771](http://localhost:4000/transformers/src/transformers/generation/utils.py#line=2770), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2769 # sample
2770 probs = nn.functional.softmax(next_token_scores, dim=-1)
-> 2771 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
2773 # finished sentences should have their next token be a padding token
2774 if eos_token_id is not None:
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Expected behavior
With eager, the code functions as expected:
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", attn_implementation="eager")
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
Could you have a quick look to see if there's a bug in the sdpa implementation @ylacombe? We could also add an integration test that confirms we get sensible outputs with the checkpoint "facebook/musicgen-small"
.
Hey @sanchit-gandhi, thanks for opening the issue!
It's working on my environment, but it might be explained by the torch version I'm using (2.2).
Nonetheless, before I dive deeper, could you verify that you still get nans/infs when using a GPU and/or when using torch.dtype=torch.float16
?
@ylacombe Is there a known issue with GPU + float16 and SDPA? I was searching and could not find anything, yet I'm having issues with other models (mistral, mixtral) sampling with SDPA. Happy to make a separate issue if it has not been reported.
hey @cjekel, not that I'm aware of! The current issue is without GPU and with fp32! Feel free to open an issue for the other models with a reproducing script (and to tag me as well) !