FireRedASR icon indicating copy to clipboard operation
FireRedASR copied to clipboard

Achieve Over 20% Speedup with PyTorch SDPA

Open wxwmd opened this issue 2 months ago • 1 comments

The attention computation is the most time-consuming part during inference. The attention implementation in this project is

class DecoderScaledDotProductAttention(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature
        self.INF = float("inf")

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
        if mask is not None:
            mask = mask.eq(0)
            attn = attn.masked_fill(mask, -self.INF)
            attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
        else:
            attn = torch.softmax(attn, dim=-1)
        output = torch.matmul(attn, v)
        return output

which can be accelerated with Pytorch's SDPA.

PyTorch's SDPA achieves significant performance acceleration when no mask is passed, as it can fully leverage flash_attn for acceleration (this library currently does not support attention computations with masks. see https://github.com/Dao-AILab/flash-attention/issues/352)

When batch_size=1, there is no padding, the attention mask can be removed, thereby allowing PyTorch's SDPA to fully accelerate.

Based on my testing, this change brings an average performance improvement of over 20%.

wxwmd avatar Oct 21 '25 03:10 wxwmd

Thanks for your PR, we will review.

kaituoxu avatar Nov 24 '25 05:11 kaituoxu