transformers
transformers copied to clipboard
LlamaRMSNorm() Dtype Casting Error
System Info
transformers==4.37.2
Who can help?
@ArthurZucker @younesbelkada
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
@ArthurZucker @younesbelkada Hi~ I found a bug in the LlamaRMSNorm(nn.Module) (lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
On the last line, if the input_dtype is bfloat16, the return tensor will still be float32 because the self.weight has been initialized as float32. Thus the last line should be modified to:
return (self.weight * hidden_states).to(input_dtype)
Expected behavior
see above and looking forward to your reply~ Thank you
Hi @Ritz111 Thanks ! I think this is not a bug, see: https://github.com/huggingface/transformers/pull/23535 for more details
why should class LlamaRMSNorm do ”hidden_states = hidden_states.to(torch.float32)“ ,why not flow the type promotion rules of PyToch ops
self.weight is bf16,hidden_states is fp32 I found that the dtype of these two methods are different. method 1: return (self.weight * hidden_states).to(input_dtype) # (bf16 * fp32).to(input_dtype) method 2: return self.weight * hidden_states.to(input_dtype) # bf16 * bf16
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.