TensorRT
TensorRT copied to clipboard
Unfused Multihead attention TensorRT 9.2 is 2x slower than PyTorch 2.2 on GPU A100-SXM4-40GB
Description
When I'm comparing Multihead Attention between Torch2.2 and TensorRT 9.2 on A100-SXM4-40G, I found that for certain size the result engine does not use _gemm_mha_v2 tactics. When not fusing, TRT performance relative to Torch is significantly lower, by almost 2x.
Both are using FP16 precision, and TRT engine was built with optimization level 3.
Here are the data: Shape = [batch, head, seq, channel]
| Shape | Torch (ms) | TRT (ms) | Speedup (torch / trt) higher is better |
|---|---|---|---|
| [2x1024, 5, 25, 64] | 6.3 | 5.4 | 1.16 |
| [4x1024, 5, 25, 64] | 11.8 | 10.12 | 1.17 |
| [8x1024, 5, 25, 64] | 23.1 | 20.0 | 1.15 |
| [16x1024, 5, 25, 64] | 45.6 | 71.0 | *0.64 |
The engine built for batch size "2x1024", "4x1024" and "8x1024" uses mha_v2 fusion.
Name: __tran6719reshp__mye6715xpos__mye6711, LayerType: kgen, Inputs: [ { Name: __mye6755^deref, Dimensions: [153600,320], Format/Datatype: Half }], Out
puts: [ { Name: __tran6719, Dimensions: [25,6144,320], Format/Datatype: Half }], TacticName: __myl_bb1_2_MovResTra, StreamId: 0, Metadata: __tran6719res
hp__mye6715xpos__mye6711
Name: [MATRIX_MULTIPLY]_[linear]_[x2_matmul]_m ... LY]_[linear]_[x4_matmul]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye6761^deref, Dimensi
ons: [153600,320], Format/Datatype: Half }, { Name: __mye6643_dconst, Dimensions: [3,320,320], Format/Datatype: Half }, { Name: __mye716[MATRIX_MULTIPLY
]_[linear]_[x2_matmul]_matrix_multiply_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __mye717[MATRIX_MULTIPLY]_[linear]_[x2_matmul]_matrix_m
ultiply_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __mye6631, Dimensions: [3,153600,320], Format/Datatype: Half }], TacticName
: ampere_h16816gemm_128x128_ldg8_stages_32x5_nn_v1, StreamId: 0, Metadata: [MATRIX_MULTIPLY]_[linear]_[x2_matmul]_matrix_multiply[MATRIX_MULTIPLY]_[line
ar]_[x3_matmul]_matrix_multiply[MATRIX_MULTIPLY]_[linear]_[x4_matmul]_matrix_multiply
Name: [MATRIX_MULTIPLY]_[scaled_dot_product_at ... uct_attention]_[x11_qkv]_matrix_multiply, LayerType: kgen, Inputs: [ { Name: __mye6767^deref, Dimensi
ons: [30720,25,64], Format/Datatype: Half }, { Name: __mye6773^deref, Dimensions: [30720,64,25], Format/Datatype: Half }, { Name: __mye6779^deref, Dimen
sions: [30720,25,64], Format/Datatype: Half }], Outputs: [ { Name: shuffle_output_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qkv].1, Dimensio
ns: [30720,25,64], Format/Datatype: Half }], TacticName: __mye6660_gemm_mha_v2, StreamId: 0, Metadata: [MATRIX_MULTIPLY]_[scaled_dot_product_attention]_
[x11_qk]_matrix_multiplyshuffle_output_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_reshape[ELEMENTWISE]_[scaled_dot_product_attention]_[x1
1_qk_scaled][SOFTMAX]_[scaled_dot_product_attention]_[x11]shuffle_input_0_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qkv]_reshape[MATRIX_MULT
IPLY]_[scaled_dot_product_attention]_[x11_qkv]_matrix_multiply
Name: [MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye6785^deref, Dimensions: [153600,320], Format/Data
type: Half }, { Name: __mye6609_dconst, Dimensions: [320,320], Format/Datatype: Half }, { Name: __mye1320[MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_
multiply_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __mye1321[MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply_beta, Dimensions: [1
], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_5, Dimensions: [153600,320], Format/Datatype: Half }], TacticName: ampere_h16816gemm_
128x128_ldg8_stages_32x5_nn_v1, StreamId: 0, Metadata: [MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply
Name: __tran6704reshp__mye6700, LayerType: kgen, Inputs: [ { Name: __myln_k_arg__bb1_5, Dimensions: [153600,320], Format/Datatype: Half }], Outputs: [ {
Name: __mye6791^deref, Dimensions: [6144,25,320], Format/Datatype: Half }], TacticName: __myl_bb1_1_ResTra, StreamId: 0, Metadata: __tran6704reshp__mye
6700
Then unfused engine for batch size "16x1024" which is almost 2x slower than torch 2.2.
Name: entry^bb^signal^2, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: entry^bb^signal^2
Name: entry^bb^wait^2, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 2, Metadata: entry^bb^wait^2
Name: [MATRIX_MULTIPLY]_[linear]_[x2_matmul]_m ... LY]_[linear]_[x3_matmul]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye6997^deref, Dimensi
ons: [460800,320], Format/Datatype: Half }, { Name: __mye6787_dconst, Dimensions: [2,320,320], Format/Datatype: Half }, { Name: __mye716[MATRIX_MULTIPLY
]_[linear]_[x2_matmul]_matrix_multiply_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __mye717[MATRIX_MULTIPLY]_[linear]_[x2_matmul]_matrix_m
ultiply_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __mye6777, Dimensions: [2,460800,320], Format/Datatype: Half }], TacticName
: ampere_h16816gemm_128x128_ldg8_stages_32x5_nn_v1, StreamId: 0, Metadata: [MATRIX_MULTIPLY]_[linear]_[x2_matmul]_matrix_multiply[MATRIX_MULTIPLY]_[line
ar]_[x3_matmul]_matrix_multiply
Name: __mye6914, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6914
Name: shuffle_output_[SHUFFLE]_[reshape]_[x5] ... [SHUFFLE]_[permute]_[x6]_first_transpose, LayerType: kgen, Inputs: [ { Name: __mye7003^deref, Dimensi
ons: [18432,25,5,64], Format/Datatype: Half }], Outputs: [ { Name: shuffle_output_[SHUFFLE]_[reshape]_[x5] _ [SHUFFLE]_[permute]_[x6]_first_transpose_ou
tput.1, Dimensions: [18432,5,25,64], Format/Datatype: Half }], TacticName: __myl_bb1_5_Tra, StreamId: 0, Metadata: shuffle_output_[SHUFFLE]_[reshape]_[x
5] _ [SHUFFLE]_[permute]_[x6]_first_transpose
Name: __mye6922, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6922
Name: __mye6916, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6916
Name: shuffle_output_[SHUFFLE]_[reshape]_[x7] ... duct_attention]_[x11_qk]_first_transpose, LayerType: kgen, Inputs: [ { Name: __mye7009^deref, Dimensi
ons: [18432,25,5,64], Format/Datatype: Half }], Outputs: [ { Name: shuffle_input_1_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_first_trans
pose_output.1, Dimensions: [18432,5,64,25], Format/Datatype: Half }], TacticName: __myl_bb1_4_Tra, StreamId: 1, Metadata: shuffle_output_[SHUFFLE]_[resh
ape]_[x7] _ [SHUFFLE]_[permute]_[x8]_first_transposeshuffle_input_1_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_first_transpose
Name: __mye6918, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6918
Name: __mye6920, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6920
Name: __mye6723_unroll0, LayerType: gemm, Inputs: [ { Name: __mye7021^deref, Dimensions: [46080,25,64], Format/Datatype: Half }, { Name: __mye7027^deref
, Dimensions: [46080,64,25], Format/Datatype: Half }, { Name: __mye6649__new_fc___mye6647_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __my
e6650__new_fc___mye6647_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __mye7015^deref, Dimensions: [46080,25,25], Format/Datatype
: Half }], TacticName: sm80_xmma_gemm_f16f16_f16f16_f16_nn_n_tilesize32x32x64_stage6_warpsize2x2x1_tensor16x8x16_aligna2_alignc2, StreamId: 0, Metadata:
__mye6723_unroll0
Name: __mye6924, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6924
Name: [MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye7039^deref, Dimensions: [46080,
25,64], Format/Datatype: Half }, { Name: __mye7045^deref, Dimensions: [46080,64,25], Format/Datatype: Half }, { Name: __mye6649__new_fc___mye6647_alpha,
Dimensions: [1], Format/Datatype: Float }, { Name: __mye6650__new_fc___mye6647_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __m
ye7033^deref, Dimensions: [46080,25,25], Format/Datatype: Half }], TacticName: sm80_xmma_gemm_f16f16_f16f16_f16_nn_n_tilesize32x32x64_stage6_warpsize2x2
x1_tensor16x8x16_aligna2_alignc2, StreamId: 1, Metadata: [MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_matrix_multiply
Name: __mye6926, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6926
Name: __mye6928, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6928
Name: __mye6806__mye6806__mye6800__mye6804 ... _dot_product_attention]_[x11_qk]_reshape, LayerType: kgen, Inputs: [ { Name: __mye6733_ur1, Dimensions: [
92160,25,25], Format/Datatype: Half }], Outputs: [ { Name: [SOFTMAX]_[scaled_dot_product_attention]_[x11]_output'.1, Dimensions: [18432,5,25,25], Format
/Datatype: Half }], TacticName: __myl_bb1_3_ResMulMaxSubExpSumDivMul, StreamId: 0, Metadata: __mye6806__mye6806__mye6800__mye6804[ELEMENTWISE]_[scaled_d
ot_product_attention]_[x11_qk_scaled]__mye6792__mye6796shuffle_output_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qk]_reshape
Name: __mye6938, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6938
Name: [MATRIX_MULTIPLY]_[linear]_[x4_matmul]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye7051^deref, Dimensions: [460800,320], Format/Datat
ype: Half }, { Name: __mye6606_dconst, Dimensions: [320,320], Format/Datatype: Half }, { Name: __mye1068[MATRIX_MULTIPLY]_[linear]_[x4_matmul]_matrix_mu
ltiply_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __mye1069[MATRIX_MULTIPLY]_[linear]_[x4_matmul]_matrix_multiply_beta, Dimensions: [1],
Format/Datatype: Float }], Outputs: [ { Name: [MATRIX_MULTIPLY]_[linear]_[x4_matmul]_fold_matmul_output.1, Dimensions: [460800,320], Format/Datatype: Ha
lf }], TacticName: ampere_h16816gemm_128x128_ldg8_stages_32x5_nn_v1, StreamId: 1, Metadata: [MATRIX_MULTIPLY]_[linear]_[x4_matmul]_matrix_multiply
Name: __mye6930, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6930
Name: __mye6932, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 2, Metadata: __mye6932
Name: shuffle_output_[SHUFFLE]_[reshape]_[x9] ... SHUFFLE]_[permute]_[x10]_first_transpose, LayerType: kgen, Inputs: [ { Name: __mye7057^deref, Dimensi
ons: [18432,25,5,64], Format/Datatype: Half }], Outputs: [ { Name: shuffle_output_[SHUFFLE]_[reshape]_[x9] _ [SHUFFLE]_[permute]_[x10]_first_transpose_o
utput.1, Dimensions: [18432,5,25,64], Format/Datatype: Half }], TacticName: __myl_bb1_5_Tra, StreamId: 2, Metadata: shuffle_output_[SHUFFLE]_[reshape]_[
x9] _ [SHUFFLE]_[permute]_[x10]_first_transpose
Name: __mye6934, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 2, Metadata: __mye6934
Name: __mye6936, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6936
Name: __mye6753_unroll0, LayerType: gemm, Inputs: [ { Name: __mye7069^deref, Dimensions: [46080,25,25], Format/Datatype: Half }, { Name: __mye7075^deref
, Dimensions: [46080,25,64], Format/Datatype: Half }, { Name: __mye6690__new_fc___mye6688_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __my
e6691__new_fc___mye6688_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __mye7063^deref, Dimensions: [46080,25,64], Format/Datatype
: Half }], TacticName: sm80_xmma_gemm_f16f16_f16f16_f16_nn_n_tilesize32x32x64_stage6_warpsize2x2x1_tensor16x8x16_aligna2_alignc2, StreamId: 0, Metadata:
__mye6753_unroll0
Name: __mye6940, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6940
Name: __mye6942, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6942
Name: [MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qkv]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye7087^deref, Dimensions: [46080
,25,25], Format/Datatype: Half }, { Name: __mye7093^deref, Dimensions: [46080,25,64], Format/Datatype: Half }, { Name: __mye6690__new_fc___mye6688_alpha
, Dimensions: [1], Format/Datatype: Float }, { Name: __mye6691__new_fc___mye6688_beta, Dimensions: [1], Format/Datatype: Float }], Outputs: [ { Name: __
mye7081^deref, Dimensions: [46080,25,64], Format/Datatype: Half }], TacticName: sm80_xmma_gemm_f16f16_f16f16_f16_nn_n_tilesize32x32x64_stage6_warpsize2x
2x1_tensor16x8x16_aligna2_alignc2, StreamId: 1, Metadata: [MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qkv]_matrix_multiply
Name: __mye6944, LayerType: signal, Inputs: [], Outputs: [], TacticName: , StreamId: 1, Metadata: __mye6944
Name: __mye6946, LayerType: wait, Inputs: [], Outputs: [], TacticName: , StreamId: 0, Metadata: __mye6946
Name: [SHUFFLE]_[permute]_[x12] _ [SHUFFLE]_[r ... dot_product_attention]_[x11_qkv]_reshape, LayerType: kgen, Inputs: [ { Name: __mye6763_ur1, Dimension
s: [92160,25,64], Format/Datatype: Half }], Outputs: [ { Name: [SHUFFLE]_[permute]_[x12] _ [SHUFFLE]_[reshape]_[x14]_first_transpose_output.1, Dimension
s: [18432,25,5,64], Format/Datatype: Half }], TacticName: __myl_bb1_1_ResTra, StreamId: 0, Metadata: [SHUFFLE]_[permute]_[x12] _ [SHUFFLE]_[reshape]_[x1
4]_first_transposeshuffle_output_[MATRIX_MULTIPLY]_[scaled_dot_product_attention]_[x11_qkv]_reshape
Name: [MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply, LayerType: gemm, Inputs: [ { Name: __mye7105^deref, Dimensions: [460800,320], Format/Data
type: Half }, { Name: __mye6609_dconst, Dimensions: [320,320], Format/Datatype: Half }, { Name: __mye1320[MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_
multiply_alpha, Dimensions: [1], Format/Datatype: Float }, { Name: __mye1321[MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply_beta, Dimensions: [1
], Format/Datatype: Float }], Outputs: [ { Name: __mye7099^deref, Dimensions: [460800,320], Format/Datatype: Half }], TacticName: ampere_h16816gemm_128x
128_ldg8_stages_32x5_nn_v1, StreamId: 0, Metadata: [MATRIX_MULTIPLY]_[linear]_[x15_matmul]_matrix_multiply
Since the code to build the layer and convert to TRT engine is the same, it is purely on TRT's decision to not fuse the MHA for shape [16x1024, 5, 25, 64]. I expect performance to be at least better than eager pytorch.
Environment
TensorRT Version: 9.2
NVIDIA GPU: A100-SXM4-40GB
NVIDIA Driver Version: 535.154.05
CUDA Version: 12.2
CUDNN Version: 8.9
Operating System: Ubuntu 20.04
Python Version (if applicable): 3.8
Tensorflow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if so, version):
Steps To Reproduce
Code to build MHA layer. We use a custom converter to convert aten ops to TensorRT:
- linear -> matmul + elementwise
- permute, reshape -> shuffle
- scaled_dot_product_attention -> matmul + element wise + softmax + matmul
aten = torch.ops.aten
linear = aten.linear.default
reshape = aten.reshape.default
permute = aten.permute.default
sdp = aten.scaled_dot_product_attention.default
def mha(x, Wq, Wk, Wv, Wo, heads):
B, L, C = x.shape
C, H = C // heads, heads
q = linear(x, Wq, None)
k = linear(x, Wk, None)
v = linear(x, Wv, None)
def head_dim_to_batch(a):
a = reshape(a, (B, L, H, C))
a = permute(a, (0, 2, 1, 3))
return a
q, k, v = map(head_dim_to_batch, (q, k, v))
o = sdp(q, k, v)
def head_dim_to_channel(a):
a = permute(a, (0, 2, 1, 3))
a = reshape(a, (B, L, C * H))
return a
o = head_dim_to_channel(o)
return linear(o, Wo, None)
def test_torch():
b, l, h, c = 1024, 25, 5, 64
wq, wk, wv, wo = (torch.randn((c * h, c * h), dtype=torch.float16, device='cuda') for _ in range(4))
benchmark_helper.builder_config.builder_optimization_level = 3
for i in [4, 8, 16]:
x = torch.randn(b*i, l, h * c, dtype=torch.float16, device='cuda')
mha(x, wq, wk, wv, wo, h)
I have not tried ONNX converter or torch2trt converter but I think they will not generate better engine than what our manual converter does. If needed, I'm happy to share script which include the conversion and timing code. But I think this information is sufficient to pin point that TRT engine does not fuse MHA for certain size which lead to worse performance than PyTorch eager.
Commands or scripts:
Have you tried the latest release?: Yes
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):
We had internal bug 4503042 to track this. cc @nvpohanh
@haijieg Could you provide the ONNX model so that it's easier for us to repro? Thanks
import torch
from torch.nn import Parameter, Module
aten = torch.ops.aten
linear = aten.linear.default
reshape = aten.reshape.default
permute = aten.permute.default
sdp = aten.scaled_dot_product_attention.default
def mha(x, Wq, Wk, Wv, Wo, heads):
B, L, C = x.shape
C, H = C // heads, heads
q = linear(x, Wq, None)
k = linear(x, Wk, None)
v = linear(x, Wv, None)
def head_dim_to_batch(a):
a = reshape(a, (B, L, H, C))
a = permute(a, (0, 2, 1, 3))
a = reshape(a, (B * H, L, C))
return a
q, k, v = map(head_dim_to_batch, (q, k, v))
o = sdp(q, k, v)
def head_dim_to_channel(a):
a = reshape(a, (B, H, L, C))
a = permute(a, (0, 2, 1, 3))
a = reshape(a, (B, L, C * H))
return a
o = head_dim_to_channel(o)
return linear(o, Wo, None)
def test_export():
b, l, h, c = 1024, 25, 5, 64
class M(Module):
def __init__(self):
super().__init__()
wq, wk, wv, wo = (torch.randn((c * h, c * h)) for _ in range(4))
self.wq = Parameter(wq)
self.wk = Parameter(wk)
self.wv = Parameter(wv)
self.wo = Parameter(wo)
self.heads = h
def forward(self, x):
return mha(x, self.wq, self.wk, self.wv, self.wo, self.heads)
for i in (9, 18):
m = M().half().cuda()
x = torch.randn(b * i, l, h * c).to(torch.float16).cuda()
with open(f"mha_{i}k.onnx", 'wb') as f:
torch.onnx.export(m, (x,), f)
Running the script above will produce mha_9k.onnx and mha_18k.onnx each corresponding to batch size 9k and 18k. Then running each model through trtexec --onnx=mha_9k.onnx --dumpLayerInfo --fp16 --profilingVerbosity=detailed --onnx=mha_{9|18}k.onnx will show the layer difference and timing. Specifically, the GPU latency for 18k batch size input is 3x more than 9k batch size which indicates the inefficient engine compilation without MHA fusion. PyTorch performance scales linearly when batch size goes from 9k to 18k.