FLASHQuad_pytorch icon indicating copy to clipboard operation
FLASHQuad_pytorch copied to clipboard

关于A = square(relu(qk / seq_len + bias))

Open ShomyLiu opened this issue 2 years ago • 34 comments

你好,非常感谢您用PyTorch复现Flash-Quad,我对这个模型也很感兴趣, 有几个小问题,想讨论下:

  • A = square(relu(qk / seq_len + bias)), 这里的seq_len是不是用当前batch的length更合适,代码中https://github.com/JunnYu/FLASHQuad_pytorch/blob/main/flash/gau.py#L117 用的是预设的max_length(如512 ). 不同batch 的序列长度可能是不同的。
  • 您有在不同任务上对比过GAU与Transformer的性能吗 我这边试了几个序列建模任务,发现性能会下降,可能训练超参数差异?

谢谢

ShomyLiu avatar Mar 13 '22 03:03 ShomyLiu

你可以改成batch里面的seqlen试试。我之前试过然后发现模型的输出结果不对劲,于是就改成了max length。我在预训练的时候seqlen基本都是512,也就是说模型只见过512这一个长度,而如果做别的短句子的任务时候,可能seqlen为几十或者一百多,模型都没见过,然后效果不知道为啥不咋行

JunnYu avatar Mar 13 '22 03:03 JunnYu

还有我发现我这个预训练的small权重效果不太行,不知道原论文在embedding处用了dropout没,用了layernorm还是scalenorm。最主要的一个疑惑也就是你提出的那个部分,我不太清楚模型的这个细节有没有实现错

JunnYu avatar Mar 13 '22 04:03 JunnYu

改成seqlen的输出:

pytorch: 天气预报说今天的天[台+0.2037||的+0.0798||定+0.0446||好+0.0422||以+0.0386]很好,那么我[大+0.1093||的+0.0697||本+0.0629||以+0.0559||一+0.0518]一起去公园玩吧!

使用max_length的输出:

pytorch: 天气预报说今天的天[气+0.9948||空+0.0011||色+0.0007||候+0.0004||势+0.0003]很好,那么我[就+0.4915||们+0.4186||也+0.0753||还+0.0021||都+0.0016]一起去公园玩吧!

JunnYu avatar Mar 13 '22 05:03 JunnYu

改成seqlen的输出:

pytorch: 天气预报说今天的天[台+0.2037||的+0.0798||定+0.0446||好+0.0422||以+0.0386]很好,那么我[大+0.1093||的+0.0697||本+0.0629||以+0.0559||一+0.0518]一起去公园玩吧!

使用max_length的输出:

pytorch: 天气预报说今天的天[气+0.9948||空+0.0011||色+0.0007||候+0.0004||势+0.0003]很好,那么我[就+0.4915||们+0.4186||也+0.0753||还+0.0021||都+0.0016]一起去公园玩吧!

谢谢分享,从结果上看max_length的效果更好!

aoom avatar Mar 13 '22 06:03 aoom

谢谢回复分享结果,看结果貌似max的确合理很多。 比较奇怪~我也再测试下,不同设置下的结果,到时候贴出来看看。

我这边主要是用GAU部分 来替代self-attention 和 FNN来做序列建模任务,并不是语料预训练的MLM任务, 目前基本上没有提升。 而且发现,学习率对模型效果影响蛮大的, 波动很明显,可能也是因为数据集的原因。

ShomyLiu avatar Mar 13 '22 08:03 ShomyLiu

Hi, 还有个小问题,关于文章中的RoPE要用到GAU单元内部呢,一般位置向量不是直接融合到最开始的Embedding模块吗? 这里有什么原因吗?

ShomyLiu avatar Mar 15 '22 06:03 ShomyLiu

那直接让qk / attention_mask.sum(-1)[:, None, None]是不是思路上更合适一些

JaheimLee avatar Mar 15 '22 07:03 JaheimLee

Hi, 还有个小问题,关于文章中的RoPE要用到GAU单元内部呢,一般位置向量不是直接融合到最开始的Embedding模块吗? 这里有什么原因吗?

又回头看了看Rope大概知道了~ 通过在q,k中用RoPE 能够体现相对位置编码

ShomyLiu avatar Mar 15 '22 08:03 ShomyLiu

看了苏神的代码,他的l确实是从mask那来的,而且放在了激活函数的外部. https://github.com/bojone/bert4keras/blob/8bf47989488009c2b8f68c20a97000fb96e07f9b/bert4keras/layers.py#L583

JaheimLee avatar Mar 16 '22 13:03 JaheimLee

原论文是这样实现的,苏神后来改了一下,修改了一下缩放的地方

JunnYu avatar Mar 16 '22 13:03 JunnYu

大概说一下,在我们的序列建模任务上,Flash-Quad的效果总是比Transformer还是低1-2个点。调了很长时间,一直上不去,而且收敛也慢。

ShomyLiu avatar Mar 17 '22 01:03 ShomyLiu

我也感觉当前实现的效果不太行,因此还是要等官方代码放出来才知道他里面的一些细节到底怎么处理的,比如A = square(relu(qk / seq_len + bias))这个部分的代码。

JunnYu avatar Mar 17 '22 05:03 JunnYu

是呀,我这边也是测试了不少模块,从最开始的严格按照论文和伪代码,到自己改动改动,最终结果都还是比不上Transformer,可能是GAU的通用性没有那么强。

ShomyLiu avatar Mar 17 '22 07:03 ShomyLiu

是呀,我这边也是测试了不少模块,从最开始的严格按照论文和伪代码,到自己改动改动,最终结果都还是比不上Transformer,可能是GAU的通用性没有那么强。

有尝试把仿射变换改回全连接吗,总感觉这个操作有点神奇

JaheimLee avatar Mar 17 '22 08:03 JaheimLee

这个还没,我测试一下,而且比较奇怪的地方是,文章也没有用dropout;

ShomyLiu avatar Mar 17 '22 08:03 ShomyLiu

我发现给的伪代码中这个rel_pos_bias好像也有点问题,下面这个是原始的实现方式。

import torch

max_position_embeddings = 512
w = torch.arange(2 * max_position_embeddings - 1).float()
print(w.long())
def rel_pos_bias(seq_len, w):
    # Construct Toeplitz matrix directly when the sequence length is less than 512
    t = torch.nn.functional.pad(w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
    t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
    r = (2 * seq_len - 1) // 2
    t = t[..., r:-r]
    return t
#############
seqlen = 4
rel_pos_bias(seqlen, w)
# tensor([[[3., 4., 5., 6.],
#          [2., 3., 4., 5.],
#          [1., 2., 3., 4.],
#          [0., 1., 2., 3.]]])
#############
seqlen = 8
rel_pos_bias(seqlen, w)
# tensor([   0,    1,    2,  ..., 1020, 1021, 1022])
# tensor([[[ 7.,  8.,  9., 10., 11., 12., 13., 14.],
#          [ 6.,  7.,  8.,  9., 10., 11., 12., 13.],
#          [ 5.,  6.,  7.,  8.,  9., 10., 11., 12.],
#          [ 4.,  5.,  6.,  7.,  8.,  9., 10., 11.],
#          [ 3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
#          [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
#          [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
#          [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.]]])

感觉这样不对劲,于是我改成了下面的这种形式,感觉下面这种形式才对。

seqlen = 4
rel_pos_bias(512, w)[:,:seqlen,:seqlen]
# tensor([[[511., 512., 513., 514.],
#          [510., 511., 512., 513.],
#          [509., 510., 511., 512.],
#          [508., 509., 510., 511.]]])
seqlen = 8
rel_pos_bias(512, w)[:,:seqlen,:seqlen]
# tensor([[[511., 512., 513., 514., 515., 516., 517., 518.],
#          [510., 511., 512., 513., 514., 515., 516., 517.],
#          [509., 510., 511., 512., 513., 514., 515., 516.],
#          [508., 509., 510., 511., 512., 513., 514., 515.],
#          [507., 508., 509., 510., 511., 512., 513., 514.],
#          [506., 507., 508., 509., 510., 511., 512., 513.],
#          [505., 506., 507., 508., 509., 510., 511., 512.],
#          [504., 505., 506., 507., 508., 509., 510., 511.]]])

JunnYu avatar Mar 17 '22 09:03 JunnYu

Hi, 前几天有其他事情耽搁了,最近开始接着看GAU了。 发现您的复现关于attention_mask的位置好像不太对,不过不确定这个是不是导致GAU性能不好的原因; https://github.com/JunnYu/FLASHQuad_pytorch/blob/main/flash/gau.py#L116-L124

kernel = torch.square(torch.relu(
            qk / self.max_position_embeddings + bias))
        # attention_mask
if attention_mask is not None:
    assert attention_mask.ndim == 2
            attn_mask = (
                attention_mask[:, None, :] * attention_mask[:, :, None]
            ).type_as(x)
    kernel *= attn_mask

这里应该 先mask,再计算归一化: relu(qk)**2

我也在测试一些性能

ShomyLiu avatar Apr 01 '22 10:04 ShomyLiu

感觉影响应该不大吧,先mask掉就qk的部分值成了0,对0进行relu,square操作还是0,主要区别是多加了bias部分。 还有我mask的时候把矩阵padding位置的行和列都进行mask了。 还有正常来说mask一般不是都施加给attetnion注意力得分的吗?

我现在正在使用 https://github.com/lucidrains/FLASH-pytorch 的代码训练small的模型。 也差不多快训练完了,你之后可以试试 https://wandb.ai/junyu/huggingface/runs/1jg2jlgt?workspace=user-junyu

JunnYu avatar Apr 01 '22 10:04 JunnYu

  1. 一般mask 应该是在注意力得分的归一化之前呀, 这样后续进行归一化比如softmax 才有意义,正常位置的注意力得分和为1。如果先softmax再做mask,那正常的位置的得分之和是没有归一的。 只不过这里的归一化函数变成了 relu**2
  2. 嗯嗯 感谢老哥提供预训练的权重。 我也去测试下Lucidrains的flash复现,对比看看下结果~

ShomyLiu avatar Apr 01 '22 10:04 ShomyLiu

我先上传个19W步数的给你试试

JunnYu avatar Apr 01 '22 10:04 JunnYu

import torch
from flash_pytorch import FLASHTransformer
from transformers import BertTokenizerFast
model = FLASHTransformer(
    num_tokens=12000,          # number of tokens
    dim=768,                   # model dimension
    depth=12,                  # depth
    causal=False,              # autoregressive or not
    group_size=256,            # size of the groups
    query_key_dim=128,         # dimension of queries / keys
    expansion_factor=2.,       # hidden dimension = dim * expansion_factor
    # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    norm_type='scalenorm',
    shift_tokens=False
)
tokenizer = BertTokenizerFast.from_pretrained("junnyu/roformer_chinese_char_base")
model.load_state_dict(torch.load("flash.pt", map_location="cpu"))
model.eval()
text = "中国的首都是[MASK]京。"
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512) #这里必须是512,不然结果不对。

with torch.no_grad():
    pt_outputs = model(inputs["input_ids"])[0]
    
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val,idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v,t in zip(val.cpu(),tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
print(pt_outputs_sentence)
# pytorch: 中国的首都是[北+0.8221||南+0.0787||东+0.0559||西+0.0055||中+0.0033]京。

JunnYu avatar Apr 01 '22 11:04 JunnYu

赞一个, 我也测试下。然后在另外的序列任务上也尝试下,看看是否有提升

ShomyLiu avatar Apr 01 '22 11:04 ShomyLiu

权重 https://wss1.cn/f/7z0orce18tp 复制链接到浏览器打开

JunnYu avatar Apr 01 '22 11:04 JunnYu

  • small版本 + 25W训练步数 + batch_size 128 + lr 1e-4 + 线性衰减学习率 + max_length 512
  • 最终训练集MLM准确率51%左右
  • 权重现已添加:https://huggingface.co/junnyu/flash_small_wwm_cluecorpussmall
  • 完整训练日志:https://wandb.ai/junyu/huggingface/runs/1jg2jlgt

JunnYu avatar Apr 02 '22 03:04 JunnYu

感谢!简单测试了这几个权重,感觉Flash模型在语言任务上效果是挺好的。 但是在我这这边另外一个非预训练的序列建模上,效果总是提不上去,很诡异,差transformer略多。 不过速度是真快。

ShomyLiu avatar Apr 02 '22 13:04 ShomyLiu

新权重padding到最大长度512的可能效果会好一点把。不知道你实验的适合有没有padding到最大长度。

  • 最终训练集MLM准确率51%左右
  • 权重现已添加:https://huggingface.co/junnyu/flash_small_wwm_cluecorpussmall
  • 完整训练日志:https://wandb.ai/junyu/huggingface/runs/1jg2jlgt

JunnYu avatar Apr 02 '22 15:04 JunnYu

您说的是所有序列都padding到512吗。这个还测试过, 我这边设置最长为512,不过tokenize的时候,根据batch内最长的序列作为当前batch的seq_len 。

ShomyLiu avatar Apr 02 '22 16:04 ShomyLiu

tokenizer(text, return_tensors="pt", max_length=512, padding="max_length")

JunnYu avatar Apr 02 '22 16:04 JunnYu

我测试这新的代码的时候,发现短的文本不padding到512,预测结果不大理想

JunnYu avatar Apr 02 '22 16:04 JunnYu

对,我刚刚改成这样了, 之前用的是padding='longest';改成padding到512的时候,

  • 在FLASH模型上,结果总算正常一些了,稍微接近transformer了。之前动态batch 长度(基本在一二百的长度)就是很不正常的差。
  • 不过在GAU上还是有点问题。我再看看哪里的问题。可能也是实现seqlen的问题。
  • 之前复现的GAU与新版的GAU 基本都是按照论文复现的,也都是短文本不太行。 感觉这里比较奇怪,如果是很长序列的话,都padding到512,1014,4096啥的,太浪费显存了呀。

您先早点休息哈

ShomyLiu avatar Apr 02 '22 16:04 ShomyLiu