[RFC] Shape handling for `aten::_native_multi_head_attention`
Hi,
I have been trying to look into adding a decomposition for the aten::_native_multi_head_attention op. The issue I have run into is that for certain inputs (specifically need_weights=False) the op will return a None value, contradicting the op signature we are taking from PyTorch. I created an issue upstream but was told that there are no guarantees for ops with a leading underscore. The summary of the issue is included below.
import torch
embed_dim = 8
num_heads = 4
bs = 4
sl = 2
qkv = torch.nn.Linear(embed_dim, embed_dim * 3, dtype=torch.float32)
proj = torch.nn.Linear(embed_dim, embed_dim, dtype=torch.float32)
q = torch.randn(bs, sl, embed_dim) * 10
k = torch.randn(bs, sl, embed_dim) * 10
v = torch.randn(bs, sl, embed_dim) * 10
mha = torch.ops.aten._native_multi_head_attention(
q,
k,
v,
embed_dim,
num_heads,
qkv.weight,
qkv.bias,
proj.weight,
proj.bias,
need_weights=False,
average_attn_weights=False,
)
print(mha)
with example output
(tensor([[[ 1.6427, 2.0966, 2.4298, 1.6536, 2.9116, -0.6659, 0.0086,
4.0757],
[ 2.0386, 0.8152, -0.8718, 1.7295, 0.9999, -1.8865, -2.7697,
1.9216]],
[[ 4.0717, 0.0476, -0.6383, 3.1022, -2.5480, 2.0922, -4.1062,
-0.5034],
[ 2.3662, 0.3523, -1.0895, 1.9332, 0.3525, 0.4775, -2.1356,
0.4972]],
[[-5.0851, 3.8904, 2.9651, -3.1131, 6.5247, -2.5286, -1.4031,
1.0763],
[-2.5247, 1.5687, -1.5536, 1.0382, 4.8081, -2.2505, 1.6698,
2.1023]],
[[-1.7481, 1.0500, 2.4167, -1.5026, 5.5205, -3.3177, 3.3927,
4.1006],
[-3.4155, 2.5501, 4.6239, -8.3866, 4.6514, -2.5655, 5.8211,
2.1764]]], grad_fn=<NotImplemented>), None)
Signature in native_functions.yaml where we get the function signature (correct me if I'm wrong). Indicates a (Tensor, Tensor) output but None is found in the second value.
- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True) -> (Tensor, Tensor)
variants: function
dispatch:
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention
My question is as follows. My understanding is that with the way that shape functions are currently handled, torch-mlir can't reconcile the None in the returned tuple. Thus, is this something that should be handled by torch-mlir or needs to be pushed back to PyTorch to get the op and/or signature updated? I am relatively new to torch-mlir so please let me know if this is the wrong channel for something like this.
I commented in the upstream issue https://github.com/pytorch/pytorch/issues/80738 - I hope they can change the return type to Tensor?. You can temporarily hack around it by putting some special case logic here: https://github.com/llvm/torch-mlir/blob/874fdb7e429175b701602e08df027f756bdf6ba9/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py#L65