onnx
onnx copied to clipboard
Add Support for Multi-Head Attention Operator
New Operator
Self Attention
Describe the operator
Multi-headed attention is seeing prolific use in all transformers (mostly described in pytorch). Including native support for the op simplifies onnx graphs for networks with complex interconnections of self-attention blocks. Meta's recently released Llama-2 is a good example of this (here is the 7B model exported to onnx) llama-2-7b-hf.onnx.zip
Can this operator be constructed using existing onnx operators?
Yes, there is already multi-head attention support with basic ONNX operators.
Is this operator used by any model currently? Which one?
Llama and all llama variants, Vision Transformer, and variants
Are you willing to contribute it? (Y/N)
Yes happy to help however I can.
Hi @dfiru Your proposal is interesting but still vague. Are you speaking about a MHA op: eg https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html ?
Yes! Update the original description. Thanks!
A challenge for us to standardize it into an op I heard from previous discussions is the number of variations in attentions. For us to support PyTorch in particular, we are implementing decompositions as ONNX functions over in onnxscript. https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops
I m still wondering if MHA is not too high level for onnx. A DPA op would be a first good step.
Ref: https://github.com/microsoft/onnxscript/blob/c74bc6a8993436c6516af723d3622170e75b4a1a/onnxscript/function_libs/torch_lib/ops/nn.py#L1633
previous discussions is the number of variations in attentions
What variations are you referring to here? Attention is essentially matmul+softmax. Maybe the variations can be encoded in terms of attributes?
I agree with @WilliamTambellini that dot-product attention would be a good first step.
@WilliamTambellini can you explain what you mean by "MHA is not too high level for onnx"?
WIP https://github.com/onnx/onnx/pull/6501