onnx2tflite icon indicating copy to clipboard operation
onnx2tflite copied to clipboard

pow出现维度没有转置的问题

Open SherryYu33 opened this issue 11 months ago • 1 comments

问题可通过如下代码复现:

from onnx2tflite import onnx_converter
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha = nn.Parameter(2*torch.ones((1, 1, 1, 8)), requires_grad=True)
        self.conv = nn.Conv2d(2, 2, (1, 1))

    def forward(self, x):
        y = self.conv(x)
        return y.pow(self.alpha)

model = Model()
x = torch.randn((1, 2, 5, 8))**2
model.eval()
torch.onnx.export(model,
                  (x),
                  './model.onnx',
                  input_names=['x'],
                  output_names=['y'],
                  opset_version=11,
                  verbose=False)
                  
res = onnx_converter(onnx_model_path = "./model.onnx",
                     need_simplify = True,
                     output_path = "./",
                     target_formats = ['tflite'])

会有如下错误:

ValueError: Exception encountered when calling layer "tf.math.pow_1" (type TFOpLambda).

Dimensions must be equal, but are 2 and 8 for '{{node tf.math.pow_1/Pow}} = Pow[T=DT_FLOAT](Placeholder, tf.math.pow_1/Pow/y)' with input shapes: [1,5,8,2], [1,1,1,8].

Call arguments received by layer "tf.math.pow_1" (type TFOpLambda):
  • x=tf.Tensor(shape=(1, 5, 8, 2), dtype=float32)
  • y=array([[[[2., 2., 2., 2., 2., 2., 2., 2.]]]], dtype=float32)
  • name=None

特别的是,self.alpha如果为nn.Parameter(torch.ones((1, 1, 1, 8)), requires_grad=True),就不会报错,其他数比如0.5、2都会报错

SherryYu33 avatar Dec 11 '24 02:12 SherryYu33

这倒是一个不常用的奇葩姿势哈。 不过应该也比较好解决,你可以尝试解决一下,提交一个PR~

MPolaris avatar Dec 12 '24 09:12 MPolaris