VAR icon indicating copy to clipboard operation
VAR copied to clipboard

flash-atten相关问题

Open Silentssss opened this issue 10 months ago • 2 comments

你好,有个flash-atten的问题想请教下,当我想使能flash-attn时,我发现以下图1的逻辑根本走不进去,为此我打印了self.using_flash、attn_bias、qkv.dtype,最后发现attn_bias一直不是None(图2) 图1: image 图2: image

于是我将代码修改成以下逻辑: using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32 修改为 using_flash = self.using_flash and qkv.dtype != torch.float32

assert attn_bias is None and qkv.dtype != torch.float32 修改为 assert qkv.dtype != torch.float32 但最后报了如图3的错误 图3: image

于是我继续打印输入的q、k、v的dtype(如图4) 图4: image 最后在代码中添加以下逻辑后功能才OK image 请问这是已知bug吗,麻烦请检查下呢,或者是我哪里操作不对吗,请指导下,最后是我的运行命令 torchrun --nproc_per_node=8 --nnodes=8 --node_rank=1 train.py --depth=16 --bs=384 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --afuse=False

Silentssss avatar Apr 19 '24 08:04 Silentssss

参数写错了,以下为执行命令 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 train.py --depth=16 --bs=384 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --afuse=False

Silentssss avatar Apr 19 '24 08:04 Silentssss

hi @Silentssss 这个地方不能使用flash-attn,因为存在attention mask(attn_bias)不能忽视,而flash-attn不支持自定义mask

图1的逻辑在训练是走不进去的,只有测试时attn_bias为None时,才能走进去

keyu-tian avatar Apr 19 '24 10:04 keyu-tian

但是训练时如果走xformers的memory_efficient_attention也是报错的,q、k的dtype为float32,v为float16 image

Silentssss avatar Apr 22 '24 02:04 Silentssss

Thanks @Silentssss, i fixed this xformers type error in the latest commit.

keyu-tian avatar Apr 23 '24 11:04 keyu-tian