VAR
VAR copied to clipboard
flash-atten相关问题
你好,有个flash-atten的问题想请教下,当我想使能flash-attn时,我发现以下图1的逻辑根本走不进去,为此我打印了self.using_flash、attn_bias、qkv.dtype,最后发现attn_bias一直不是None(图2)
图1:
图2:
于是我将代码修改成以下逻辑:
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
修改为
using_flash = self.using_flash and qkv.dtype != torch.float32
assert attn_bias is None and qkv.dtype != torch.float32
修改为
assert qkv.dtype != torch.float32
但最后报了如图3的错误
图3:
于是我继续打印输入的q、k、v的dtype(如图4)
图4:
最后在代码中添加以下逻辑后功能才OK
请问这是已知bug吗,麻烦请检查下呢,或者是我哪里操作不对吗,请指导下,最后是我的运行命令
torchrun --nproc_per_node=8 --nnodes=8 --node_rank=1 train.py --depth=16 --bs=384 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --afuse=False
参数写错了,以下为执行命令
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 train.py --depth=16 --bs=384 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --afuse=False
hi @Silentssss 这个地方不能使用flash-attn,因为存在attention mask(attn_bias)不能忽视,而flash-attn不支持自定义mask
图1的逻辑在训练是走不进去的,只有测试时attn_bias为None时,才能走进去
但是训练时如果走xformers的memory_efficient_attention也是报错的,q、k的dtype为float32,v为float16
Thanks @Silentssss, i fixed this xformers type error in the latest commit.