onnx2tflite
onnx2tflite copied to clipboard
pow出现维度没有转置的问题
问题可通过如下代码复现:
from onnx2tflite import onnx_converter
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(2*torch.ones((1, 1, 1, 8)), requires_grad=True)
self.conv = nn.Conv2d(2, 2, (1, 1))
def forward(self, x):
y = self.conv(x)
return y.pow(self.alpha)
model = Model()
x = torch.randn((1, 2, 5, 8))**2
model.eval()
torch.onnx.export(model,
(x),
'./model.onnx',
input_names=['x'],
output_names=['y'],
opset_version=11,
verbose=False)
res = onnx_converter(onnx_model_path = "./model.onnx",
need_simplify = True,
output_path = "./",
target_formats = ['tflite'])
会有如下错误:
ValueError: Exception encountered when calling layer "tf.math.pow_1" (type TFOpLambda).
Dimensions must be equal, but are 2 and 8 for '{{node tf.math.pow_1/Pow}} = Pow[T=DT_FLOAT](Placeholder, tf.math.pow_1/Pow/y)' with input shapes: [1,5,8,2], [1,1,1,8].
Call arguments received by layer "tf.math.pow_1" (type TFOpLambda):
• x=tf.Tensor(shape=(1, 5, 8, 2), dtype=float32)
• y=array([[[[2., 2., 2., 2., 2., 2., 2., 2.]]]], dtype=float32)
• name=None
特别的是,self.alpha如果为nn.Parameter(torch.ones((1, 1, 1, 8)), requires_grad=True),就不会报错,其他数比如0.5、2都会报错
这倒是一个不常用的奇葩姿势哈。 不过应该也比较好解决,你可以尝试解决一下,提交一个PR~