MNN
MNN copied to clipboard
[feature]: 希望社区能够增加 tf.switch_case的支持
模型用到了 tf.switch_case
,使用MNN转换模型时会提示:
These Op Not Support: Tensorflow::Case
虽然可以使用 tf.case
或者 if else 平替,但是当分支比较多的时候性能并不好,因此希望社区能够增加对这个算子的支持。或者,有实现算子的更加充实的文档吗,我也可以尝试自定义算子? 官方文档描述的比较粗略,无法上手。
Version
TensorFlow: 1.15
示例模型代码:
def multiply(inputs, idx):
x = tf.switch_case(idx, branch_fns={
0: lambda: tf.constant(10, dtype=tf.float32),
1: lambda: tf.constant(100, dtype=tf.float32)
}, default=lambda: tf.constant(1000, dtype=tf.float32))
x = tf.multiply(inputs, x)
return x
class CustomModel(tf.keras.Model):
def __init__(self, idx):
super(CustomModel, self).__init__()
self.idx = idx
def call(self, inputs, training=None, mask=None):
output = multiply(inputs, self.idx)
return output
以及用这个模型导出的 frozen model。 switch_case.zip
Case
算子的参数包含function array,当前MNN实现中并无此类数据结构的先例。
如果需要新增Case
算子的自定义实现,可行的实现思路是什么?
这个模型貌似导出有问题,没看到 switch / case
@jxt1234 感谢大佬回复,明天上班的时候我再看下咋回事。
@jxt1234 不好意思,先前的模型上传不对,已经重新上传。netron可视化以后是这样:
之前的测试代码看着不像是一个正常的网络。下面这个看起来正常一些:
import tensorflow as tf
class CustomModel(tf.keras.Model):
def __init__(self, idx):
super(CustomModel, self).__init__()
self.fc = tf.keras.layers.Dense(10, activation='relu')
self.idx = idx
def call(self, inputs, training=None, mask=None):
x = self.fc(inputs)
x = self._fc(x)
x = tf.nn.relu(x, name="ac_relu")
return x
def _fc(self, inputs):
r = tf.switch_case(self.idx, branch_fns={
0: lambda: tf.layers.dense(inputs, 2, name="case_0"),
1: lambda: tf.layers.dense(inputs, 3, name="case_1"),
2: lambda: tf.layers.dense(inputs, 4, name="case_2"),
}, default=lambda: tf.layers.dense(inputs, 5, name="default_case"))
return r
if __name__ == '__main__':
with tf.Session() as sess:
model = CustomModel(tf.constant(1))
example_input = tf.random.uniform(shape=[1, 28 * 28])
infer_res = model(example_input)
sess.run(tf.global_variables_initializer())
final_res = sess.run(infer_res)
# 保存为 frozen model
nodes = [node.name for node in tf.get_default_graph().get_operations()]
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_node_names=["custom_model/ac_relu"]
)
tf.train.write_graph(output_graph_def, "./", "switch_case_network.pb", as_text=False)
以及它所对应的pb: switch_case_network.zip
Marking as stale. No activity in 60 days.