gtcrn
gtcrn copied to clipboard
导出onnx的stream模型时可以优化一点点的两个方法
- SFE模块的unfold可以用如下模块代替,可以减少很多算子
import torch
import torch.nn as nn
class Unfold(nn.Module):
def __init__(self):
super().__init__()
kernel = torch.eye(3)
kernel = kernel.view(3, 1, 1, 3)
kernel = nn.Parameter(kernel.repeat(8, 1, 1, 1))
self.conv = nn.Conv2d(8, 24, (1, 3), padding=(0, 1), groups=8, bias=False)
self.conv.weight = kernel
def forward(self, x):
out = self.conv(x)
return out
- onnxsim没办法把ConvTranspose和BN融合在一起,但是pnnx可以,可以节省算力,用如下方法导出若干文件
import pnnx
mod = torch.jit.trace(model_stream, [一堆变量])
mod.save("gtcrn.pt")
opt_net = pnnx.convert("gtcrn.pt", [一堆变量])
然后会在当前文件夹生成一个gtcrn_pnnx.py的文件,里面有一个export_onnx()的函数,可以按喜好修改输出形式,最后当然也可以用onnxsim再跑一次
export_onnx()
import onnx
from onnxsim import simplify
onnx_model = onnx.load('gtcrn.onnx')
onnx.checker.check_model(onnx_model)
model_simp, check = simplify(onnx_model)
onnx.save(model_simp, 'gtcrn_sim.onnx')