Baichuan-13B icon indicating copy to clipboard operation
Baichuan-13B copied to clipboard

关于softmax在fp32和bf16下的精度误差问题

Open shiqingzhangCSU opened this issue 2 years ago • 6 comments

在推理baichuan13B时,由于增加了alibi mask导致模型在计算softmax时候精度溢出,bf16和fp32的结果是不一样的。请问baichuan训练中,有没有考虑到了这个精度的问题呢?

参考llama中解决softmax精度问题的实现: 在llama代码中会把softmax中的计算转为32位:

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

我们还测试fp32和bf16的误差,代码如下:

import math

import numpy as np
import torch


def self_attention_reference_softmax32(q, k, v, attention_mask):
    bsz, q_len, num_heads, head_dim = q.shape
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

    # upcast to 32
    attn_weights = attn_weights.to(dtype=torch.float32)
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype)

    attn_output = torch.matmul(attn_weights, v)
    return attn_output

def self_attention_reference_bf16(q, k, v, attention_mask):
    bsz, q_len, num_heads, head_dim = q.shape
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
    attn_output = torch.matmul(attn_weights, v)
    return attn_output


def _fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)

if __name__ == '__main__':
    size = 4096
    bsz, q_len, num_heads, head_dim = 1, size, 40, 64
    qkv_shape = (bsz, q_len, num_heads, head_dim)
    device = "cuda"
    dtype = torch.bfloat16

    q = torch.rand(qkv_shape, device=device, dtype=dtype)
    k = torch.rand(qkv_shape, device=device, dtype=dtype)
    v = torch.rand(qkv_shape, device=device, dtype=dtype)
    attention_mask = 20 * torch.rand(1, num_heads, q_len, q_len, device=device, dtype=dtype)

    mask2 = attention_mask + torch.triu(_fill_with_neg_inf(torch.zeros([size, size], dtype=dtype)), 1).unsqueeze(0).to("cuda")

    result_reference = self_attention_reference_softmax32(q, k, v, mask2)
    result_bf16= self_attention_reference_bf16(q, k, v, mask2)


    result_reference_numpy = result_reference.float().cpu().detach().numpy()
    result_bf16_numpy = result_bf16.float().cpu().detach().numpy()

    if not np.allclose(result_reference_numpy, result_bf16_numpy, 1e-4, 1e-4):
        diff = np.abs(result_reference_numpy - result_bf16_numpy)
        print("model output diff larger than 1e-4")
        print("max diff:  %f" % np.max(diff))
        index = np.argmax(diff)
        print("index: %d     value: %f  %f" % (index, result_reference_numpy.flatten()[index], result_bf16_numpy.flatten()[index]))
        print("")
        for i in range(20):
            idx = index - 10 + i 
            print("index: %d     value: %f  %f" % (idx, result_reference_numpy.flatten()[idx], result_bf16_numpy.flatten()[idx]))

shiqingzhangCSU avatar Aug 15 '23 06:08 shiqingzhangCSU

mark

KelleyYin avatar Aug 17 '23 02:08 KelleyYin

mark

老哥有什么高见?说一说

shiqingzhangCSU avatar Aug 17 '23 08:08 shiqingzhangCSU

mark,同样遇到了精度问题

NicholasYoungAI avatar Aug 31 '23 03:08 NicholasYoungAI

mark,同样遇到了精度问题

我这边解决了,通过用一个mask2 - max(mask2),这个操作是等价的,你可以试试。或者直接把softmax这部分计算全部还原为32。

shiqingzhangCSU avatar Aug 31 '23 03:08 shiqingzhangCSU

mark,同样遇到了精度问题

我这边解决了,通过用一个mask2 - max(mask2),这个操作是等价的,你可以试试。或者直接把softmax这部分计算全部还原为32。

我试过把softmax部分还原为fp32,仍然没有解决问题。

mask2 - max(mask2) 是指什么,我没有get到。

谢谢

NicholasYoungAI avatar Aug 31 '23 03:08 NicholasYoungAI

mark,同样遇到了精度问题

我这边解决了,通过用一个mask2 - max(mask2),这个操作是等价的,你可以试试。或者直接把softmax这部分计算全部还原为32。

我试过把softmax部分还原为fp32,仍然没有解决问题。

mask2 - max(mask2) 是指什么,我没有get到。

谢谢 1.softmax的相关的所有全部转为fp32试试。 2.就是你进行softmax的操作 attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) 这里的attn_weights减去自己的最大值max(attn_weights),这个操作对softmax是等价的,然后可以减少精度溢出。

shiqingzhangCSU avatar Aug 31 '23 03:08 shiqingzhangCSU