MNN icon indicating copy to clipboard operation
MNN copied to clipboard

对于控制流算子的支持

Open stricklandye opened this issue 1 year ago • 4 comments

Hi 我刚接触MNN,我有一些疑问还麻烦社区帮忙解答,如果有理解不到位的地方还麻烦各位指出。

  1. 为什么 MNN 不支持 tf.switch_case 的算子? 我尝试将下面的示例代码用MNN来进行推理,这只是一个演示 tf.switch_case 的样例:

     condition = tf.compat.v1.placeholder(dtype=tf.int32, name="input")
     def multiply():
         return tf.compat.v1.multiply(condition, 100)
     def add():
         return tf.compat.v1.add(condition, 10)
     res = tf.compat.v1.switch_case(condition, branch_fns={
         0: multiply,
         1: add,
     },default= None)
    

    在实践中,我们发现 tf.switch_case 效果比tf.case性能更好,遗憾的MNN提示:

    [17:32:45]/MNN/tools/converter/source/common/writeFb.cpp:105: These Op Not Support: Tensorflow::Case  
    
    
  2. MNN 如何判断模型是否包含子图? 对于包含tf.cond或者tf.case的模型,MNNConverter转换时会提示:

    The model has subgraphs, please use MNN::Express::Module to run it
    

    这是否意味着只用使用了这两个算子MNN就认为包含子图? 那子图的划分是怎么样的呢? 对于下面的 tf.case 样例:

    pred_fn_pairs = [
        (tf.equal(input_value, 2), lambda: tf.compat.v1.add(input_value, 10)),
        (tf.equal(input_value, 1), lambda: tf.compat.v1.add(input_value, 100)),
        (tf.equal(input_value, 5), lambda: tf.compat.v1.add(input_value, 20)),
    ]
    default_fn = lambda: tf.compat.v1.add(input_value, 30)
    output = tf.compat.v1.case(pred_fn_pairs, default_fn, exclusive=True)
    

    tools/converter/source/common/writeFb.cpp里面的MNN_DUMP_SUBGRAPH 解注,会出现8个subgrah,为什么是8个? 为什么这个宏现在没用了?

  3. 能否以手动加载子图的形式实现tf.switch_case的等价功能? 这个是我的拍脑袋想法,我不清楚这样的伪代码是否能够实现:

    switch (value) {
    case 1: load_subgraph_1(); break;
    case 2: load_subgraph_2(); break;
    }
    

stricklandye avatar Dec 27 '23 09:12 stricklandye

  1. switch_case 一般可以用 if 替代,你可以上传一个简单模型,我们排期支持下
  2. 存在控制流算子 (if / while) 时,都会产生子图
  3. 这个你可以把各分支分别导出 pb 并转 mnn 模型来实现。不过建议还是换成 if ,加载1个mnn比较方便

jxt1234 avatar Dec 27 '23 11:12 jxt1234

  1. switch_case 一般可以用 if 替代,你可以上传一个简单模型,我们排期支持下
  2. 存在控制流算子 (if / while) 时,都会产生子图
  3. 这个你可以把各分支分别导出 pb 并转 mnn 模型来实现。不过建议还是换成 if ,加载1个mnn比较方便

if else在推理的时候不会很慢吗,如果是一定批量输入的话

johnjim0816 avatar Dec 28 '23 07:12 johnjim0816

批量输入建议不要用控制流的方式实现,可以用 select 类似的算子替代

jxt1234 avatar Jan 06 '24 12:01 jxt1234

批量输入建议不要用控制流的方式实现,可以用 select 类似的算子替代

那如果是MoE这类网络呢,有什么比较好的方式,select 类似的算子在MoE场景中似乎也会回到tf.case的问题上

johnjim0816 avatar Jan 09 '24 08:01 johnjim0816

@jxt1234 大佬,我们这边仍然有switch_case的需求,如果可以的话可以在后续版本加上。感谢!

就以前面提到的测试代码为例:

tensorflow 版本: 1.1.5

import tensorflow as tf

condition = tf.placeholder(dtype=tf.int32, name="input")
def multiply():
    return tf.multiply(condition, 100
def add():
    return tf.add(condition, 10)

res = tf.switch_case(condition, branch_fns={
    0: multiply,
    1: add,
}, default=None)

with tf.Session() as sess:
    output = sess.run(res,feed_dict={condition:0})
    print(output)
    nodes = [node.name for node in tf.get_default_graph().get_operations()]
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_node_names=nodes
    )
    tf.train.write_graph(output_graph_def, "./", "switch_case.pb", as_text=False)

使用MNNConvert转换模型提示:

The device support i8sdot:0, support fp16:0, support i8mm: 0
Start to Convert Other Model Format To MNN Model..., target version: 2.8
Start to Optimize the MNN Net...
[10:59:47] :105: These Op Not Support: Tensorflow::Case 
Converted Failed!

以及用上面代码保存的测试模型的zip: switch_case.zip

stricklandye avatar Mar 05 '24 03:03 stricklandye