TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Do tensorrt 9.2 support flash attention v2

Open linkedqueue opened this issue 1 year ago • 7 comments

pytorch is now support flash attention v2, which is 2 times faster than flash attention: https://pytorch.org/blog/pytorch2-2/ So I'm wondering if tensorrt 9.2 already support flash attention v2, or I have to add a flash v2 plugin to use that, or is there any plans. Thanks for answering.

linkedqueue avatar Feb 01 '24 04:02 linkedqueue

I believe it's supported OOTB, you can export the model to onnx and check.

zerollzeng avatar Feb 07 '24 09:02 zerollzeng

I believe it's supported OOTB, you can export the model to onnx and check.

Thanks for the reply. To make it more clear, is this right -- if scaled_dot_product_attention with flash attention v2 can export to onnx, it will be supported by tensorrt without performance loss.

linkedqueue avatar Feb 18 '24 03:02 linkedqueue

yes. even speed-up

zerollzeng avatar Feb 22 '24 14:02 zerollzeng

I believe it's supported OOTB, you can export the model to onnx and check.

Thanks for the reply. To make it more clear, is this right -- if scaled_dot_product_attention with flash attention v2 can export to onnx, it will be supported by tensorrt without performance loss.

any result? I am also doing such things, Looking forward to your reply; Once I finish testing, I will also synchronize my results here

Feynman1999 avatar Feb 23 '24 03:02 Feynman1999

I believe it's supported OOTB, you can export the model to onnx and check.

Thanks for the reply. To make it more clear, is this right -- if scaled_dot_product_attention with flash attention v2 can export to onnx, it will be supported by tensorrt without performance loss.

any result? I am also doing such things, Looking forward to your reply; Once I finish testing, I will also synchronize my results here

yes. even speed-up

thanks

linkedqueue avatar Feb 23 '24 10:02 linkedqueue

I believe it's supported OOTB, you can export the model to onnx and check.

Thanks for the reply. To make it more clear, is this right -- if scaled_dot_product_attention with flash attention v2 can export to onnx, it will be supported by tensorrt without performance loss.

any result? I am also doing such things, Looking forward to your reply; Once I finish testing, I will also synchronize my results here

In my result, tensorrt is faster than pytorch 2.2 for aboud 50% in a transformer network, I'm just curious about if trt support the newest flash attention v2 tech.

linkedqueue avatar Feb 23 '24 11:02 linkedqueue

OOTB should be supported. You can also use https://github.com/NVIDIA/TensorRT-LLM

zerollzeng avatar Feb 28 '24 04:02 zerollzeng

@zerollzeng Forgive me for being new to TRT. I think I am still confused by how it is supported OOTB? For example, I have a multi-head attention in TF, how would I leverage this after convert the TF to ONNX and try to use TRT to optimize the model for inference? Should I be able to see a FlashAttention block in the TRT engine svgs?

I also noticed that the MultiHeadFlashAttention plugin is actually something that was removed from prev versions of TRT. So do I instruct TRT to optimize a MHA to use Flash Attention?

yixzhou avatar Apr 05 '24 18:04 yixzhou

@nvpohanh I'm kind of not sure about what the best practice for such issue after TRT-LLM released, could you please shed some light here :-D

zerollzeng avatar Apr 12 '24 13:04 zerollzeng

Yes, the recommend path for non-LLM workloads (like Vision Transformers) is to export your trained model to ONNX and load that into TRT. TRT will recognize the MHA patterns and use FlashAttention v2 whenever possible.

nvpohanh avatar Apr 15 '24 02:04 nvpohanh

@nvpohanh Is there no way to use FlashAttentionv2 as it is in TensorRT within TRT-LLM? The reason I ask is the FMHA plugin in TRT-LLM does not support flash attention when the sequence length of query is not the same as that of key and value (say cross-attention. But Stable Diffusion in this repo supports flash attention for that kind. I am interested in having the entire model created in TensorRT (like trt-llm and not onnx->tensorrt route) but I would still like to use FlashAttention from TensorRT by manually specifying it in my tensorrt model definition. Please guide me on how i can achieve this

Ashwin-Ramesh2607 avatar Jul 16 '24 22:07 Ashwin-Ramesh2607

If you just construct the TRT INetwork following the ONNX pattern, you should be able to get FlashAttention kernel for cross-attention.

nvpohanh avatar Jul 17 '24 05:07 nvpohanh

I built the following module with FlashAttention and saved it as a pth model:

import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_func

class Flash_attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = flash_attn_func
        
    def forward(self, q, k, v, mask=None):
        return self.attn(q, k, v)

model = Flash_attn()
torch.save(model, 'model/flash_attn.pth')

But when I convert it to an onnx model, it reports the following warning:

TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(

But the converted onnx model can be converted to a trt engine normally, e.g. using trtexec, does this have any effect? I tested the converted onnx model and found that it has no way to read the input at all, does it mean that this onnx model is not used for inference at all? Should I write a plugin to support FlashAttention?

Beatlesso avatar Aug 19 '24 09:08 Beatlesso