transformers icon indicating copy to clipboard operation
transformers copied to clipboard

I have a question in the source code called modeling_llama.py

Open park1200656 opened this issue 2 years ago • 5 comments

System Info

@ArthurZucker @gante

path : "src/transformers/models/llama/modeling_llama.py"

Line 85 and 232 of this code contains float32 as a constant. I think, it looks like a bug. Or is there another reason?

Thanks.

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

class LlamaRMSNorm(nn.Module): def init(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().init() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps

def forward(self, hidden_states):
    input_dtype = hidden_states.dtype
    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

    return (self.weight * hidden_states).to(input_dtype)

======

class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.max_position_embeddings = config.max_position_embeddings

    if (self.head_dim * self.num_heads) != self.hidden_size:
        raise ValueError(
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )
    self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
    self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(
            attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
        )

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

Expected behavior

may be float32 --> dtype

park1200656 avatar Jun 27 '23 12:06 park1200656

Hey @park1200656 👋

Some operations degrade the quality of the outputs if not performed at a certain minimum precision. The softmax in the attention layer and the variance accumulation in RMSNorm performed in FP32 are two examples of that :) Related read: this issue


Following our issues guidelines, we reserve GitHub issues for bugs in the repository and/or feature requests. For any other matters, we'd like to invite you to use our forum 🤗 If this is your first issue with us, check this guide.

gante avatar Jun 27 '23 12:06 gante

I had the same question yesterday. Can we make it optional? At least softmax

BF16 is good enough. And by "good enough" I mean it "not crashes at long context at my laptop's 3080TI " and "return values are the same anyway, instability might be overstated"

Example. Making it optional:

diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 24231c3f7..230e5333c 100755
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -228,8 +228,12 @@ class LlamaAttention(nn.Module):
                 attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
             )
 
-        # upcast attention to fp32
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        # optionally upcast attention to fp32
+        if self.config.use_attn_upcast:
+            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        else:
+            attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)

Test script:

from transformers import AutoModelForCausalLM
import torch
import sys

model = AutoModelForCausalLM.from_pretrained("./models/open_llama_3b/", torch_dtype=torch.bfloat16).cuda()
model.config.use_attn_upcast = "--no-oom" not in sys.argv
print("Predict that OOM will happen: ", model.config.use_attn_upcast)

input_ids = torch.arange(20)[None].cuda()
print(model(input_ids).logits.mean(-1))

input_ids = torch.arange(1000)[None].cuda()
print(model(input_ids).logits.mean())

With upcast removed

$  python demo_py.py --no-oom

Predict that OOM will happen:  False
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
         -7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
         -7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
tensor(-6.9062, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

With upcast:

$ python demo_py.py 

Predict that OOM will happen:  True
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
         -7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
         -7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
Traceback (most recent call last):
  File "/home/fella/src/llama/text-generation-webui/demo_py.py", line 14, in <module>
    print(model(input_ids).logits.mean())
          ^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 690, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 580, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 295, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 232, in forward
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/functional.py", line 1845, in softmax
    ret = input.softmax(dim, dtype=dtype)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 15.74 GiB total capacity; 14.83 GiB already allocated; 134.38 MiB free; 15.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Maykeye avatar Jun 29 '23 04:06 Maykeye

@Maykeye we have other options to reduce the memory footprint at inference time -- have you tried playing with our support for 4-bit inference? On a 3080 TI you may be able to run the 7B LLaMA model this way :)

gante avatar Jun 29 '23 10:06 gante

Yes and quantized models produce noticeably different results.

Maykeye avatar Jun 29 '23 11:06 Maykeye

In general, lowering the precision of these operations will have a more significant impact on downstream performance (take it from the person that initially added the upcast at Meta).

Since we have other memory reduction strategies, we will not add the flag you're proposing. (Still, the code is open-source, feel free to fork transformers and keep your changes 🤗 )

gante avatar Jun 29 '23 12:06 gante

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 27 '23 15:07 github-actions[bot]

In general, lowering the precision of these operations will have a more significant impact on downstream performance (take it from the person that initially added the upcast at Meta).

Since we have other memory reduction strategies, we will not add the flag you're proposing. (Still, the code is open-source, feel free to fork transformers and keep your changes 🤗 )

I dont think this is true. I have experimented a lot with mistral and fuyu and removing/changing the fused softmax cast in both has very little to any impact compared to alternative memory saving approaches (in terms of acc/loss tracked).

Seems like something that should be allowed but warned about for models.

grahamannett avatar Feb 21 '24 19:02 grahamannett

@grahamannett the NaN question remains to be cleared (see original thread).

If you can come up with a strategy to reproduce the original NaNs in OPT and show that recent models are free from it, I'd be happy to include the change in more recent architectures 🤗 Otherwise, we'd rather play it safe, NaNs are very disruptive.

gante avatar Feb 26 '24 15:02 gante

@gante For sure. Also just wanted to post in case anyone finds this issue and needs to change it due to GPU limits.

grahamannett avatar Mar 01 '24 17:03 grahamannett