onnx-simplifier
onnx-simplifier copied to clipboard
[BUG] decoder 结构中的q,k,v 三个matmul会自动concatenate成一个没有名字的matmul算子
Describe the bug 在使用onnxsim优化qwen1.5-1.8B的onnx模型的时候,我们遇到了一个问题,即decoder结构中的q,k,v三个matmul算子会被concatenate成一个算子,然后再使用split分开成三个。此时得到的新的融合版的matmul和split算子都没有名字.
Model
依赖库版本 onnxsim-0.4.36 torch 2.3.1
稳定复现代码
import torch
import torch.nn as nn
class MultiMatMulAddModel(nn.Module):
def __init__(self):
super(MultiMatMulAddModel, self).__init__()
# 初始化三个2048x2048的权重矩阵和三个2048维的偏置向量
self.weights = nn.ParameterList([nn.Parameter(torch.randn(2048, 2048)) for _ in range(3)])
self.biases = nn.ParameterList([nn.Parameter(torch.randn(2048)) for _ in range(3)])
def forward(self, x):
outputs = []
for i in range(3):
# 对输入进行矩阵乘法
matmul_result = torch.matmul(x.squeeze(0), self.weights[i])
# 执行加法操作
add_result = matmul_result + self.biases[i]
# 将结果添加到输出列表中
outputs.append(add_result)
# 返回所有三个输出
return tuple(outputs)
# 创建模型实例
model = MultiMatMulAddModel()
import torch.onnx
# 创建输入张量
input_tensor = torch.randn(1, 1024, 2048)
# 设置输出文件名
output_file = "multi_matmul_add_model.onnx"
# 导出模型到ONNX格式
torch.onnx.export(model,
input_tensor,
output_file,
export_params=True, # 存储训练好的参数权重
opset_version=11, # ONNX版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names=['input'], # 输入节点名称
output_names=['output1', 'output2', 'output3'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 可变轴信息
'output1': {0: 'batch_size'},
'output2': {0: 'batch_size'},
'output3': {0: 'batch_size'}})
print(f"Model has been exported to {output_file}")