TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

How to implement attention when query and value have different hidden dims?

Open ChaseMonsterAway opened this issue 9 months ago • 2 comments

Hi, I'm trying to export an attention layer with ​different hidden dimensions for query and value to a TensorRT engine. Do you have any tips

ChaseMonsterAway avatar Mar 27 '25 10:03 ChaseMonsterAway

@ming-wei

Hi Ming, do you have any suggestion for this question?

Thanks June

juney-nvidia avatar Mar 27 '25 10:03 juney-nvidia

@ming-wei Hi, could you give me any suggestions? Thanks in advance.

ChaseMonsterAway avatar Mar 28 '25 08:03 ChaseMonsterAway

Did you mean multi query attention or group query attention, where each q head corresponds to multiple kv heads?

We have support for this use case already: https://github.com/NVIDIA/TensorRT-LLM/blob/794f61c99767fd2aa2d28709831c7a9e3501fd43/examples/llama/convert_checkpoint.py#L421

Just set num_attention_heads to the number of q heads, and set num_key_value_heads to the number of kv heads.

You can check out the README in examples/llama directory for more details.

Let me know if you have further questions.

Thanks, Ming

ming-wei avatar Mar 31 '25 02:03 ming-wei

Did you mean multi query attention or group query attention, where each q head corresponds to multiple kv heads?

We have support for this use case already:

TensorRT-LLM/examples/llama/convert_checkpoint.py

Line 421 in 794f61c

'num_attention_heads': args.n_head, Just set num_attention_heads to the number of q heads, and set num_key_value_heads to the number of kv heads.

You can check out the README in examples/llama directory for more details.

Let me know if you have further questions.

Thanks, Ming

Thanks for the feedback. However, my question is different with MQA and GQA. It's still a normal attention layer, but the projection layers of query, key, value (Q\K\V) is different.

Here is an example of my questions from UniMERNet.

ChaseMonsterAway avatar Mar 31 '25 05:03 ChaseMonsterAway

Thanks for the elaboration.

I don't think TRTLLM supports such use cases.

A workaround that you can try is to pad zeroes to q/k/v to make their dimension match, however you'll lose all the inference speedup/memory saving benefit from it.

ming-wei avatar Mar 31 '25 08:03 ming-wei