FBGEMM icon indicating copy to clipboard operation
FBGEMM copied to clipboard

Qwen2 support?

Open HadXu opened this issue 5 months ago • 0 comments

Follow https://huggingface.co/docs/transformers/en/quantization/fbgemm_fp8 and I run it successfully.

But when I run it with qwen2 model, with error "RuntimeError: Invalid datatype. input must be BF16".

img_v3_02f1_3f046408-68c8-465f-bb54-71c26a139eag


RuntimeError Traceback (most recent call last) Cell In[4], line 11 8 input_text = "What are we having for dinner?" 9 input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") ---> 11 output = quantized_model.generate(**input_ids, max_new_tokens=10) 12 print(tokenizer.decode(output[0], skip_special_tokens=True))

File /usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs) 113 @functools.wraps(func) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): --> 116 return func(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs) 2016 input_ids, model_kwargs = self._expand_inputs_for_generation( 2017 input_ids=input_ids, 2018 expand_size=generation_config.num_return_sequences, 2019 is_encoder_decoder=self.config.is_encoder_decoder, 2020 **model_kwargs, 2021 ) 2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False) -> 2024 result = self._sample( 2025 input_ids, 2026 logits_processor=prepared_logits_processor, 2027 logits_warper=prepared_logits_warper, 2028 stopping_criteria=prepared_stopping_criteria, 2029 generation_config=generation_config, 2030 synced_gpus=synced_gpus, 2031 streamer=streamer, 2032 **model_kwargs, 2033 ) 2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): 2036 # 11. prepare logits warper 2037 prepared_logits_warper = ( 2038 self._get_logits_warper(generation_config, device=input_ids.device) 2039 if generation_config.do_sample 2040 else None 2041 )

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs) 2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) 2981 # forward pass to get next token -> 2982 outputs = self(**model_inputs, return_dict=True) 2984 if synced_gpus and this_peer_finished: 2985 continue # don't waste resources running the code we don't need

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs) 1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1552 else: -> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs) 1557 # If we don't have any hooks, we want to skip the rest of the logic in 1558 # this function, and just call forward. 1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1560 or _global_backward_pre_hooks or _global_backward_hooks 1561 or _global_forward_hooks or _global_forward_pre_hooks): -> 1562 return forward_call(*args, **kwargs) 1564 try: 1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:1104, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) 1101 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1103 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1104 outputs = self.model( 1105 input_ids=input_ids, 1106 attention_mask=attention_mask, 1107 position_ids=position_ids, 1108 past_key_values=past_key_values, 1109 inputs_embeds=inputs_embeds, 1110 use_cache=use_cache, 1111 output_attentions=output_attentions, 1112 output_hidden_states=output_hidden_states, 1113 return_dict=return_dict, 1114 cache_position=cache_position, 1115 ) 1117 hidden_states = outputs[0] 1118 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs) 1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1552 else: -> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs) 1557 # If we don't have any hooks, we want to skip the rest of the logic in 1558 # this function, and just call forward. 1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1560 or _global_backward_pre_hooks or _global_backward_hooks 1561 or _global_forward_hooks or _global_forward_pre_hooks): -> 1562 return forward_call(*args, **kwargs) 1564 try: 1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:915, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) 904 layer_outputs = self._gradient_checkpointing_func( 905 decoder_layer.call, 906 hidden_states, (...) 912 cache_position, 913 ) 914 else: --> 915 layer_outputs = decoder_layer( 916 hidden_states, 917 attention_mask=causal_mask, 918 position_ids=position_ids, 919 past_key_value=past_key_values, 920 output_attentions=output_attentions, 921 use_cache=use_cache, 922 cache_position=cache_position, 923 ) 925 hidden_states = layer_outputs[0] 927 if use_cache:

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs) 1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1552 else: -> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs) 1557 # If we don't have any hooks, we want to skip the rest of the logic in 1558 # this function, and just call forward. 1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1560 or _global_backward_pre_hooks or _global_backward_hooks 1561 or _global_forward_hooks or _global_forward_pre_hooks): -> 1562 return forward_call(*args, **kwargs) 1564 try: 1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:655, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs) 652 hidden_states = self.input_layernorm(hidden_states) 654 # Self Attention --> 655 hidden_states, self_attn_weights, present_key_value = self.self_attn( 656 hidden_states=hidden_states, 657 attention_mask=attention_mask, 658 position_ids=position_ids, 659 past_key_value=past_key_value, 660 output_attentions=output_attentions, 661 use_cache=use_cache, 662 cache_position=cache_position, 663 ) 664 hidden_states = residual + hidden_states 666 # Fully Connected

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs) 1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1552 else: -> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs) 1557 # If we don't have any hooks, we want to skip the rest of the logic in 1558 # this function, and just call forward. 1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1560 or _global_backward_pre_hooks or _global_backward_hooks 1561 or _global_forward_hooks or _global_forward_pre_hooks): -> 1562 return forward_call(*args, **kwargs) 1564 try: 1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:592, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position) 589 attn_output = attn_output.transpose(1, 2).contiguous() 590 attn_output = attn_output.view(bsz, q_len, self.hidden_size) --> 592 attn_output = self.o_proj(attn_output) 594 return attn_output, None, past_key_value

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs) 1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1552 else: -> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs) 1557 # If we don't have any hooks, we want to skip the rest of the logic in 1558 # this function, and just call forward. 1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1560 or _global_backward_pre_hooks or _global_backward_hooks 1561 or _global_forward_hooks or _global_forward_pre_hooks): -> 1562 return forward_call(*args, **kwargs) 1564 try: 1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/integrations/fbgemm_fp8.py:50, in FbgemmFp8Linear.forward(self, x) 47 num_tokens = None 48 # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. 49 # FBGEMM/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu at e08af8539c391437f447173863df0f3f6f ---> 50 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( 51 x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub 52 ) 53 # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works 54 # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) 55 56 # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight 57 output = torch.ops.fbgemm.f8f8bf16_rowwise( 58 x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True 59 )

File /usr/local/lib/python3.11/dist-packages/torch/ops.py:1061, in OpOverloadPacket.call(self, *args, **kwargs) 1059 if self_._has_torchbind_op_overload and must_dispatch_in_python(args, kwargs): 1060 return call_overload_packet_from_python(self, args, kwargs) -> 1061 return self._op(*args, **(kwargs or {}))

RuntimeError: Invalid datatype. input must be BF16

But I compare qwen2 and llama3 8B, the dtype are all bf16.

img_v3_02f1_fac20cb7-6896-4763-9100-7fc92ff76f6g

HadXu avatar Sep 24 '24 03:09 HadXu