transformers icon indicating copy to clipboard operation
transformers copied to clipboard

LlamaRMSNorm() Dtype Casting Error

Open Ritz111 opened this issue 10 months ago • 2 comments

System Info

transformers==4.37.2

Who can help?

@ArthurZucker @younesbelkada

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

@ArthurZucker @younesbelkada Hi~ I found a bug in the LlamaRMSNorm(nn.Module) (lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py)

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

On the last line, if the input_dtype is bfloat16, the return tensor will still be float32 because the self.weight has been initialized as float32. Thus the last line should be modified to:

return (self.weight * hidden_states).to(input_dtype)

Expected behavior

see above and looking forward to your reply~ Thank you

Ritz111 avatar Apr 13 '24 13:04 Ritz111

Hi @Ritz111 Thanks ! I think this is not a bug, see: https://github.com/huggingface/transformers/pull/23535 for more details

younesbelkada avatar Apr 16 '24 08:04 younesbelkada

why should class LlamaRMSNorm do ”hidden_states = hidden_states.to(torch.float32)“ ,why not flow the type promotion rules of PyToch ops

GuWei007 avatar May 11 '24 10:05 GuWei007

self.weight is bf16,hidden_states is fp32 I found that the dtype of these two methods are different. method 1: return (self.weight * hidden_states).to(input_dtype) # (bf16 * fp32).to(input_dtype) method 2: return self.weight * hidden_states.to(input_dtype) # bf16 * bf16

GuWei007 avatar May 13 '24 06:05 GuWei007

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 06 '24 08:06 github-actions[bot]