torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[RFC] Shape handling for `aten::_native_multi_head_attention`

Open qedawkins opened this issue 3 years ago • 1 comments

Hi,

I have been trying to look into adding a decomposition for the aten::_native_multi_head_attention op. The issue I have run into is that for certain inputs (specifically need_weights=False) the op will return a None value, contradicting the op signature we are taking from PyTorch. I created an issue upstream but was told that there are no guarantees for ops with a leading underscore. The summary of the issue is included below.

import torch

embed_dim = 8
num_heads = 4
bs = 4
sl = 2
qkv = torch.nn.Linear(embed_dim, embed_dim * 3, dtype=torch.float32)
proj = torch.nn.Linear(embed_dim, embed_dim, dtype=torch.float32)
q = torch.randn(bs, sl, embed_dim) * 10
k = torch.randn(bs, sl, embed_dim) * 10
v = torch.randn(bs, sl, embed_dim) * 10

mha = torch.ops.aten._native_multi_head_attention(
    q,
    k,
    v,
    embed_dim,
    num_heads,
    qkv.weight,
    qkv.bias,
    proj.weight,
    proj.bias,
    need_weights=False,
    average_attn_weights=False,
)

print(mha)

with example output

(tensor([[[ 1.6427,  2.0966,  2.4298,  1.6536,  2.9116, -0.6659,  0.0086,
           4.0757],
         [ 2.0386,  0.8152, -0.8718,  1.7295,  0.9999, -1.8865, -2.7697,
           1.9216]],

        [[ 4.0717,  0.0476, -0.6383,  3.1022, -2.5480,  2.0922, -4.1062,
          -0.5034],
         [ 2.3662,  0.3523, -1.0895,  1.9332,  0.3525,  0.4775, -2.1356,
           0.4972]],

        [[-5.0851,  3.8904,  2.9651, -3.1131,  6.5247, -2.5286, -1.4031,
           1.0763],
         [-2.5247,  1.5687, -1.5536,  1.0382,  4.8081, -2.2505,  1.6698,
           2.1023]],

        [[-1.7481,  1.0500,  2.4167, -1.5026,  5.5205, -3.3177,  3.3927,
           4.1006],
         [-3.4155,  2.5501,  4.6239, -8.3866,  4.6514, -2.5655,  5.8211,
           2.1764]]], grad_fn=<NotImplemented>), None)

Signature in native_functions.yaml where we get the function signature (correct me if I'm wrong). Indicates a (Tensor, Tensor) output but None is found in the second value.

 - func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True) -> (Tensor, Tensor)
   variants: function
   dispatch:
      CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention

My question is as follows. My understanding is that with the way that shape functions are currently handled, torch-mlir can't reconcile the None in the returned tuple. Thus, is this something that should be handled by torch-mlir or needs to be pushed back to PyTorch to get the op and/or signature updated? I am relatively new to torch-mlir so please let me know if this is the wrong channel for something like this.

qedawkins avatar Jul 05 '22 19:07 qedawkins

I commented in the upstream issue https://github.com/pytorch/pytorch/issues/80738 - I hope they can change the return type to Tensor?. You can temporarily hack around it by putting some special case logic here: https://github.com/llvm/torch-mlir/blob/874fdb7e429175b701602e08df027f756bdf6ba9/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py#L65

silvasean avatar Jul 06 '22 23:07 silvasean