MNN icon indicating copy to clipboard operation
MNN copied to clipboard

[feature]: 希望社区能够增加 tf.switch_case的支持

Open stricklandye opened this issue 11 months ago • 5 comments

模型用到了 tf.switch_case,使用MNN转换模型时会提示:

 These Op Not Support: Tensorflow::Case 

虽然可以使用 tf.case 或者 if else 平替,但是当分支比较多的时候性能并不好,因此希望社区能够增加对这个算子的支持。或者,有实现算子的更加充实的文档吗,我也可以尝试自定义算子? 官方文档描述的比较粗略,无法上手。

Version

TensorFlow: 1.15

示例模型代码:

def multiply(inputs, idx):
    x = tf.switch_case(idx, branch_fns={
        0: lambda: tf.constant(10, dtype=tf.float32),
        1: lambda: tf.constant(100, dtype=tf.float32)
    }, default=lambda: tf.constant(1000, dtype=tf.float32))
    x = tf.multiply(inputs, x)
    return x


class CustomModel(tf.keras.Model):
    def __init__(self, idx):
        super(CustomModel, self).__init__()
        self.idx = idx

    def call(self, inputs, training=None, mask=None):
        output = multiply(inputs, self.idx)
        return output

以及用这个模型导出的 frozen model。 switch_case.zip

stricklandye avatar Mar 27 '24 11:03 stricklandye

Case算子的参数包含function array,当前MNN实现中并无此类数据结构的先例。

如果需要新增Case算子的自定义实现,可行的实现思路是什么?

juju812 avatar Mar 28 '24 12:03 juju812

这个模型貌似导出有问题,没看到 switch / case

jxt1234 avatar Apr 02 '24 13:04 jxt1234

@jxt1234 感谢大佬回复,明天上班的时候我再看下咋回事。

stricklandye avatar Apr 02 '24 13:04 stricklandye

@jxt1234 不好意思,先前的模型上传不对,已经重新上传。netron可视化以后是这样: image

stricklandye avatar Apr 03 '24 03:04 stricklandye

之前的测试代码看着不像是一个正常的网络。下面这个看起来正常一些:

import tensorflow as tf

class CustomModel(tf.keras.Model):
    def __init__(self, idx):
        super(CustomModel, self).__init__()
        self.fc = tf.keras.layers.Dense(10, activation='relu')
        self.idx = idx

    def call(self, inputs, training=None, mask=None):
        x = self.fc(inputs)
        x = self._fc(x)
        x = tf.nn.relu(x, name="ac_relu")
        return x

    def _fc(self, inputs):
        r = tf.switch_case(self.idx, branch_fns={
            0: lambda: tf.layers.dense(inputs, 2, name="case_0"),
            1: lambda: tf.layers.dense(inputs, 3, name="case_1"),
            2: lambda: tf.layers.dense(inputs, 4, name="case_2"),
        }, default=lambda: tf.layers.dense(inputs, 5, name="default_case"))
        return r

if __name__ == '__main__':
    with tf.Session() as sess:
        model = CustomModel(tf.constant(1))
        example_input = tf.random.uniform(shape=[1, 28 * 28])
        infer_res = model(example_input)
        sess.run(tf.global_variables_initializer())
        final_res = sess.run(infer_res)
        # 保存为 frozen model
        nodes = [node.name for node in tf.get_default_graph().get_operations()]
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, output_node_names=["custom_model/ac_relu"]
        )
        tf.train.write_graph(output_graph_def, "./", "switch_case_network.pb", as_text=False)

以及它所对应的pb: switch_case_network.zip

stricklandye avatar Apr 03 '24 04:04 stricklandye

Marking as stale. No activity in 60 days.

github-actions[bot] avatar Jun 02 '24 09:06 github-actions[bot]