transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[MusicGen] SDPA gives nans/infs during sampling

Open sanchit-gandhi opened this issue 3 months ago • 3 comments

System Info

  • transformers version: 4.40.0.dev0
  • Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.22.1
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    • distributed_type: NO
    • mixed_precision: bf16
    • use_cpu: False
    • debug: False
    • num_processes: 1
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: 0
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): 2.13.1 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.2 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Following #29939, running the following gives an overflow error:

from transformers import MusicgenForConditionalGeneration

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", attn_implementation="sdpa")

unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)

Traceback

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 3
      1 unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
----> 3 audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)

File [~/hf/lib/python3.8/site-packages/torch/utils/_contextlib.py:115](http://localhost:4000/hf/lib/python3.8/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/transformers/src/transformers/models/musicgen/modeling_musicgen.py:2822](http://localhost:4000/transformers/src/transformers/models/musicgen/modeling_musicgen.py#line=2821), in MusicgenForConditionalGeneration.generate(self, inputs, generation_config, logits_processor, stopping_criteria, synced_gpus, streamer, **kwargs)
   2814     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2815         input_ids=input_ids,
   2816         expand_size=generation_config.num_return_sequences,
   2817         is_encoder_decoder=self.config.is_encoder_decoder,
   2818         **model_kwargs,
   2819     )
   2821     # 12. run sample
-> 2822     outputs = self._sample(
   2823         input_ids,
   2824         logits_processor=logits_processor,
   2825         logits_warper=logits_warper,
   2826         stopping_criteria=stopping_criteria,
   2827         pad_token_id=generation_config.pad_token_id,
   2828         eos_token_id=generation_config.eos_token_id,
   2829         output_scores=generation_config.output_scores,
   2830         return_dict_in_generate=generation_config.return_dict_in_generate,
   2831         synced_gpus=synced_gpus,
   2832         streamer=streamer,
   2833         **model_kwargs,
   2834     )
   2836 else:
   2837     raise ValueError(
   2838         "Got incompatible mode for generation, should be one of greedy or sampling. "
   2839         "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
   2840     )

File [~/transformers/src/transformers/generation/utils.py:2771](http://localhost:4000/transformers/src/transformers/generation/utils.py#line=2770), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2769 # sample
   2770 probs = nn.functional.softmax(next_token_scores, dim=-1)
-> 2771 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
   2773 # finished sentences should have their next token be a padding token
   2774 if eos_token_id is not None:

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Expected behavior

With eager, the code functions as expected:

from transformers import MusicgenForConditionalGeneration

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", attn_implementation="eager")

unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)

Could you have a quick look to see if there's a bug in the sdpa implementation @ylacombe? We could also add an integration test that confirms we get sensible outputs with the checkpoint "facebook/musicgen-small".

sanchit-gandhi avatar Apr 03 '24 16:04 sanchit-gandhi

Hey @sanchit-gandhi, thanks for opening the issue! It's working on my environment, but it might be explained by the torch version I'm using (2.2). Nonetheless, before I dive deeper, could you verify that you still get nans/infs when using a GPU and/or when using torch.dtype=torch.float16 ?

ylacombe avatar Apr 03 '24 17:04 ylacombe

@ylacombe Is there a known issue with GPU + float16 and SDPA? I was searching and could not find anything, yet I'm having issues with other models (mistral, mixtral) sampling with SDPA. Happy to make a separate issue if it has not been reported.

cjekel avatar Apr 03 '24 19:04 cjekel

hey @cjekel, not that I'm aware of! The current issue is without GPU and with fp32! Feel free to open an issue for the other models with a reproducing script (and to tag me as well) !

ylacombe avatar Apr 04 '24 08:04 ylacombe