onnx-simplifier icon indicating copy to clipboard operation
onnx-simplifier copied to clipboard

[BUG] decoder 结构中的q,k,v 三个matmul会自动concatenate成一个没有名字的matmul算子

Open 1826133674 opened this issue 5 months ago • 0 comments

Describe the bug 在使用onnxsim优化qwen1.5-1.8B的onnx模型的时候,我们遇到了一个问题,即decoder结构中的q,k,v三个matmul算子会被concatenate成一个算子,然后再使用split分开成三个。此时得到的新的融合版的matmul和split算子都没有名字.

Model

依赖库版本 onnxsim-0.4.36 torch 2.3.1

稳定复现代码


import torch
import torch.nn as nn

class MultiMatMulAddModel(nn.Module):
    def __init__(self):
        super(MultiMatMulAddModel, self).__init__()

        # 初始化三个2048x2048的权重矩阵和三个2048维的偏置向量
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(2048, 2048)) for _ in range(3)])
        self.biases = nn.ParameterList([nn.Parameter(torch.randn(2048)) for _ in range(3)])

    def forward(self, x):
        outputs = []
        for i in range(3):
            # 对输入进行矩阵乘法
            matmul_result = torch.matmul(x.squeeze(0), self.weights[i])

            # 执行加法操作
            add_result = matmul_result + self.biases[i]

            # 将结果添加到输出列表中
            outputs.append(add_result)

        # 返回所有三个输出
        return tuple(outputs)

# 创建模型实例
model = MultiMatMulAddModel()
import torch.onnx

# 创建输入张量
input_tensor = torch.randn(1, 1024, 2048)

# 设置输出文件名
output_file = "multi_matmul_add_model.onnx"

# 导出模型到ONNX格式
torch.onnx.export(model,
                  input_tensor,
                  output_file,
                  export_params=True,        # 存储训练好的参数权重
                  opset_version=11,         # ONNX版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=['input'],    # 输入节点名称
                  output_names=['output1', 'output2', 'output3'],  # 输出节点名称
                  dynamic_axes={'input': {0: 'batch_size'},  # 可变轴信息
                                'output1': {0: 'batch_size'},
                                'output2': {0: 'batch_size'},
                                'output3': {0: 'batch_size'}})

print(f"Model has been exported to {output_file}")

1826133674 avatar Aug 28 '24 07:08 1826133674