FireRedASR icon indicating copy to clipboard operation
FireRedASR copied to clipboard

用torch原生的flash attention性能更好

Open xphh opened this issue 5 months ago • 2 comments

transformer_decoder.py里面可以替换torch原生的scaled_dot_product_attention函数

第247行:output = self.attention(q, k, v, mask=mask)

改成:output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask.bool())

整体性能大概可以提升10%

conformer_encoder.py里面应该也可以,但逻辑稍微有点不一样,我还不知道怎么改,麻烦作者可以看看

xphh avatar Jul 04 '25 07:07 xphh