llm-awq icon indicating copy to clipboard operation
llm-awq copied to clipboard

Nan or Infs when using llama-13B-chat

Open jamesdborin opened this issue 2 years ago • 6 comments

I ran through the conversion script for llama 13B Chat but when I run the mode on longer generations I sometimes get the following error:

RuntimeError: probability tensor contains either inf, nan or element < 0

I think this has to do with the attention being in float16, and there being some overflow or underflow from the softmax. If I edit the pytorch code in tinychat/modules/fused_attn.py to run attention in float32 and then convert back to half precision i haven't seen the error again:

with torch.autocast(device_type='cuda', dtype=torch.float32):
    attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)

del query_states, key_states, value_states

attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output.half())

What would be nice is to run the model in bfloat16 but the custom kernels don't support that. Is there any plans to make a bflat16 version of the kernels?

Thanks a lot!

jamesdborin avatar Aug 15 '23 17:08 jamesdborin

@jamesdborin: Can you try this model and check if the same problem occurs?

abhinavkulkarni avatar Aug 16 '23 15:08 abhinavkulkarni

I have seen this error before, but I'm not quite sure why it happens. Happened to me with the 7B model and 13B LLaMa 2 models as per my memory. However, this does not happen with MPT models.

casper-hansen avatar Aug 16 '23 21:08 casper-hansen

@jamesdborin: Can you try this model and check if the same problem occurs?

Thanks a lot, Ill give this a go later today and let you know if I get the same problem!

jamesdborin avatar Aug 17 '23 10:08 jamesdborin

Hi~I have seen the same error recently. Is there any conclusion now? @jamesdborin

moonlightian avatar Sep 26 '23 02:09 moonlightian

It seems that torch.half() is not enough to complete the inference process and a higher precision numerical format needs to be used, like torch.bf16 or sth else

moonlightian avatar Sep 26 '23 02:09 moonlightian

@moonlightian I experimented with a few things - I tried running attention in fp32 but I still found the error. In the end I found that removing the custom rotary embedding kernel fixed the issue. I swapped it with the hf implementation and sacrificed some performance and the problem has disappeared.

jamesdborin avatar Sep 26 '23 08:09 jamesdborin