onnx icon indicating copy to clipboard operation
onnx copied to clipboard

Add Support for Multi-Head Attention Operator

Open dfiru opened this issue 2 years ago • 7 comments

New Operator

Self Attention

Describe the operator

Multi-headed attention is seeing prolific use in all transformers (mostly described in pytorch). Including native support for the op simplifies onnx graphs for networks with complex interconnections of self-attention blocks. Meta's recently released Llama-2 is a good example of this (here is the 7B model exported to onnx) llama-2-7b-hf.onnx.zip

Can this operator be constructed using existing onnx operators?

Yes, there is already multi-head attention support with basic ONNX operators.

Is this operator used by any model currently? Which one?

Llama and all llama variants, Vision Transformer, and variants

Are you willing to contribute it? (Y/N)

Yes happy to help however I can.

dfiru avatar Jul 27 '23 19:07 dfiru

Hi @dfiru Your proposal is interesting but still vague. Are you speaking about a MHA op: eg https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html ?

WilliamTambellini avatar Jul 28 '23 00:07 WilliamTambellini

Yes! Update the original description. Thanks!

dfiru avatar Jul 28 '23 14:07 dfiru

A challenge for us to standardize it into an op I heard from previous discussions is the number of variations in attentions. For us to support PyTorch in particular, we are implementing decompositions as ONNX functions over in onnxscript. https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops

justinchuby avatar Aug 01 '23 14:08 justinchuby

Details-of-multi-head-attention-building-blocks

I m still wondering if MHA is not too high level for onnx. A DPA op would be a first good step.

WilliamTambellini avatar Aug 01 '23 15:08 WilliamTambellini

Ref: https://github.com/microsoft/onnxscript/blob/c74bc6a8993436c6516af723d3622170e75b4a1a/onnxscript/function_libs/torch_lib/ops/nn.py#L1633

justinchuby avatar Oct 03 '23 22:10 justinchuby

previous discussions is the number of variations in attentions

What variations are you referring to here? Attention is essentially matmul+softmax. Maybe the variations can be encoded in terms of attributes?

I agree with @WilliamTambellini that dot-product attention would be a good first step.

nicholaiTukanov avatar Oct 06 '23 17:10 nicholaiTukanov

@WilliamTambellini can you explain what you mean by "MHA is not too high level for onnx"?

nicholaiTukanov avatar Oct 06 '23 17:10 nicholaiTukanov

WIP https://github.com/onnx/onnx/pull/6501

WilliamTambellini avatar Nov 05 '24 19:11 WilliamTambellini