torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

fx_importer NotImplementedError: MultiheadAttention layer with NeedWeight = false

Open alaa-ali opened this issue 8 months ago • 0 comments

This issue explains a bug in torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py We found this while importing an exported model into MLIR. This occurs for an exported MultiheadAttention layer with "NeedWeight = false" which means weights are not going to be returned by the layer. So, the second output attn_output_weights will be None in this case.

The following error is raised: Python Error: NotImplementedError: OutputKind.USER_OUTPUT for <class 'torch.export.graph_signature.ConstantArgument'>: ConstantArgument(name='', value=None)

[Additionally, I couldn't visualize the exported model as .pt2 using a tool like https://netron.app/, However, I am able to import the exported model and visualize it when "NeedWeight = true", i.e. attn_output_weights will not be None in this case]

doc: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html parameters: need_weights: [bool] If specified, returns attn_output_weights outputs: attn_output_weights: Only returned when need_weights=True.

Source code to reproduce the exported model with attn_output_weights = None

import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomModel(nn.Module):
    def __init__(self, kwargs):
        super(CustomModel, self).__init__()
        self.kwargs = kwargs
        self.attn = nn.MultiheadAttention(embed_dim=kwargs['embedding_dim'], num_heads=kwargs['num_heads'], dropout=kwargs['dropout'], add_bias_kv=kwargs['add_bias_kv'], add_zero_attn=kwargs['add_zero_attn'], kdim=kwargs['kdim'], vdim=kwargs['vdim'], batch_first=kwargs['batch_first'])
    def forward(self, *args):
        query, key, value, attn_mask, kp_mask = args[0], args[1], args[2], args[3], args[4]
        return self.attn(query, key, value, attn_mask=attn_mask, key_padding_mask=kp_mask, need_weights=self.kwargs['need_weights'], average_attn_weights=self.kwargs['average_attn_weights'], is_causal=self.kwargs['is_causal'])

# Create model instance
model = CustomModel(kwargs = {
    'embedding_dim': 64,
    'num_heads': 1,
	'dropout': 0.1,
	'add_bias_kv': True,
    'add_zero_attn': False,
    'kdim': 16,
    'vdim': None, #used None inseatd of string(missing)
    'batch_first': True,
    'need_weights': False,
    'average_attn_weights': True,
    'is_causal': False
})

# Dummy input tensors
query = torch.rand(1, 50, 64)         # (batch, seq_len, embedding_dim)
key = torch.rand(1, 10, 16)
value = torch.rand(1, 10, 64)
attn_mask = torch.zeros(50, 10)       # (seq_len, seq_len)
key_padding_mask = torch.zeros(1, 10)  # (batch, seq_len)

# Export the model
exported_model = torch.export.export(
    model, args=(query, key, value, attn_mask, key_padding_mask))

# use exported_model.graph to inspect the TorchScript graph
print(exported_model)

The error occurs due to a missing case in lines # 661, 662 in the source code below (torch.export.graph_signature.ConstantArgument is not handled) torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py Image

Before, proposing code changes to solve this issue, we wanted to check the expected behavior and confirm whether the OutputSpec is intentionally handled this way in the source code or if it's an actual bug that needs to be fixed.

This is a snippet from the exported program

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_attn_q_proj_weight: "f32[64, 64]", p_attn_k_proj_weight: "f32[64, 16]", p_attn_v_proj_weight: "f32[64, 64]", p_attn_in_proj_bias: "f32[192]", p_attn_bias_k: "f32[1, 1, 64]", p_attn_bias_v: "f32[1, 1, 64]", p_attn_out_proj_weight: "f32[64, 64]", p_attn_out_proj_bias: "f32[64]", args_0: "f32[1, 50, 64]", args_1: "f32[1, 10, 16]", args_2: "f32[1, 10, 64]", args_3: "f32[50, 10]", args_4: "f32[1, 10]"):
             # 
            transpose: "f32[50, 1, 64]" = torch.ops.aten.transpose.int(args_0, 1, 0);  args_0 = None
            ....
            view_8: "f32[50, 1, 64]" = torch.ops.aten.view.default(linear_3, [50, 1, 64]);  linear_3 = None
            transpose_6: "f32[1, 50, 64]" = torch.ops.aten.transpose.int(view_8, 1, 0);  view_8 = None
            return (transpose_6, **None**)
            
Graph signature: ExportGraphSignature(
input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_attn_q_proj_weight'), target='attn.q_proj_weight', persistent=None), ...], 
output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='transpose_6'), target=None), 
							  OutputSpec(kind=<**OutputKind.USER_OUTPUT: 1>, arg=ConstantArgument(name='', value=None**), target=None)])

We noticed that OutputSpec has enum below while the source code handles only two types of the enum below (TensorArgument, and SymIntArgument) https://pytorch.org/docs/stable/export.html#torch.export.graph_signature.OutputSpec Image

alaa-ali avatar Apr 24 '25 06:04 alaa-ali