MNN
MNN copied to clipboard
对于控制流算子的支持
Hi 我刚接触MNN,我有一些疑问还麻烦社区帮忙解答,如果有理解不到位的地方还麻烦各位指出。
-
为什么 MNN 不支持
tf.switch_case
的算子? 我尝试将下面的示例代码用MNN来进行推理,这只是一个演示tf.switch_case
的样例:condition = tf.compat.v1.placeholder(dtype=tf.int32, name="input") def multiply(): return tf.compat.v1.multiply(condition, 100) def add(): return tf.compat.v1.add(condition, 10) res = tf.compat.v1.switch_case(condition, branch_fns={ 0: multiply, 1: add, },default= None)
在实践中,我们发现
tf.switch_case
效果比tf.case
性能更好,遗憾的MNN提示:[17:32:45]/MNN/tools/converter/source/common/writeFb.cpp:105: These Op Not Support: Tensorflow::Case
-
MNN 如何判断模型是否包含子图? 对于包含
tf.cond
或者tf.case
的模型,MNNConverter转换时会提示:The model has subgraphs, please use MNN::Express::Module to run it
这是否意味着只用使用了这两个算子MNN就认为包含子图? 那子图的划分是怎么样的呢? 对于下面的
tf.case
样例:pred_fn_pairs = [ (tf.equal(input_value, 2), lambda: tf.compat.v1.add(input_value, 10)), (tf.equal(input_value, 1), lambda: tf.compat.v1.add(input_value, 100)), (tf.equal(input_value, 5), lambda: tf.compat.v1.add(input_value, 20)), ] default_fn = lambda: tf.compat.v1.add(input_value, 30) output = tf.compat.v1.case(pred_fn_pairs, default_fn, exclusive=True)
将
tools/converter/source/common/writeFb.cpp
里面的MNN_DUMP_SUBGRAPH
解注,会出现8个subgrah,为什么是8个? 为什么这个宏现在没用了? -
能否以手动加载子图的形式实现
tf.switch_case
的等价功能? 这个是我的拍脑袋想法,我不清楚这样的伪代码是否能够实现:switch (value) { case 1: load_subgraph_1(); break; case 2: load_subgraph_2(); break; }
- switch_case 一般可以用 if 替代,你可以上传一个简单模型,我们排期支持下
- 存在控制流算子 (if / while) 时,都会产生子图
- 这个你可以把各分支分别导出 pb 并转 mnn 模型来实现。不过建议还是换成 if ,加载1个mnn比较方便
- switch_case 一般可以用 if 替代,你可以上传一个简单模型,我们排期支持下
- 存在控制流算子 (if / while) 时,都会产生子图
- 这个你可以把各分支分别导出 pb 并转 mnn 模型来实现。不过建议还是换成 if ,加载1个mnn比较方便
if else在推理的时候不会很慢吗,如果是一定批量输入的话
批量输入建议不要用控制流的方式实现,可以用 select 类似的算子替代
批量输入建议不要用控制流的方式实现,可以用 select 类似的算子替代
那如果是MoE这类网络呢,有什么比较好的方式,select 类似的算子在MoE场景中似乎也会回到tf.case的问题上
@jxt1234 大佬,我们这边仍然有switch_case的需求,如果可以的话可以在后续版本加上。感谢!
就以前面提到的测试代码为例:
tensorflow 版本: 1.1.5
import tensorflow as tf
condition = tf.placeholder(dtype=tf.int32, name="input")
def multiply():
return tf.multiply(condition, 100
def add():
return tf.add(condition, 10)
res = tf.switch_case(condition, branch_fns={
0: multiply,
1: add,
}, default=None)
with tf.Session() as sess:
output = sess.run(res,feed_dict={condition:0})
print(output)
nodes = [node.name for node in tf.get_default_graph().get_operations()]
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_node_names=nodes
)
tf.train.write_graph(output_graph_def, "./", "switch_case.pb", as_text=False)
使用MNNConvert转换模型提示:
The device support i8sdot:0, support fp16:0, support i8mm: 0
Start to Convert Other Model Format To MNN Model..., target version: 2.8
Start to Optimize the MNN Net...
[10:59:47] :105: These Op Not Support: Tensorflow::Case
Converted Failed!
以及用上面代码保存的测试模型的zip: switch_case.zip