TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Is there a way to activate int8 MHA_v2 kernel when SeqLen > 512?

Open zhexinli opened this issue 1 year ago • 4 comments

Description

Hi, I notice form Issue that the int8 MHA_v2 kernel only supports SeqLen <= 512. I also try on my own diffusion model whose Q shape is (B, N, S, H) and S >> 512. I use pytorch_quantization to insert QDQ in the MHA and convert to TRT. As expected, it breaks into 3 kernels( int8->gemm->FP32->softmax->int8->gemm->int8) and runs slower than fp16 MHA_v2.

image

But I notice the TensorRT 9.3 oss introduces ammo to quantize SDXL, and by veiwing the code I assume the MHA is also quantized because there are codes dealing with the QKV QDQ fusion. So dose the demo SDXL manage to involke int8 MHA_v2 kernel? I think the SDXL SeqLen is also >> 512. How did the demo quantization manage to utilize int8 MHA? image

zhexinli avatar Mar 14 '24 08:03 zhexinli

@nvpohanh ^ ^

zerollzeng avatar Mar 16 '24 16:03 zerollzeng

I need to check internally, but I don't think SDXL INT8 quantizes the MHAs. IIRC, it only quantizes the Convs and Gemms. cc @rajeevsrao

nvpohanh avatar Mar 18 '24 06:03 nvpohanh

Thanks. Is there any possibility that we can surpass the MHA's seq_len limitation? Or does NV have plan to extend mha_v2 to larger dimention so that diffusion-based model can benefit from int8 mha_v2?

BTW, I'm also confused why mha has this limitation since the int8 multiplication and summation happens on the head_dim (AKA that last dimention) instead of seq_len and it would cause overflow if the M&A result is larger than int32.

zhexinli avatar Mar 19 '24 12:03 zhexinli

When I quantize wav2vec, I encounter the same issue.

zzdxfei avatar May 24 '24 07:05 zzdxfei