MPSX icon indicating copy to clipboard operation
MPSX copied to clipboard

unsupportedOperator: "Equal" (and "if"?)

Open ronyfadel opened this issue 11 months ago • 5 comments

I'm trying to run inference on Silero VAD using MPSX (https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx) but it's failing with:

result: failure(MPSX.OnnxError.unsupportedOperator("Equal"))

ronyfadel avatar Mar 13 '24 13:03 ronyfadel

please, look at this draft - https://github.com/prisma-ai/MPSX/pull/14 - use this as a reference. Maybe you will help and make a PR to MPSX :)

geor-kasapidi avatar Mar 13 '24 13:03 geor-kasapidi

Thanks for the guidance @geor-kasapidi ! I'm trying to implement if but tbh this is all too new to me:

import MetalPerformanceShadersGraph

extension MPSGraph {
  /// https://github.com/onnx/onnx/blob/main/docs/Operators.md#If
  func `if`(
    _ node: Onnx_NodeProto,
    _ tensors: [String: MPSGraphTensor]
  ) throws -> MPSGraphTensor {
    guard let cond = tensors(node.input(0)),
          let else_branch = node.attr("else_branch"),
          let then_branch = node.attr("then_branch")
    else { throw OnnxError.invalidInput(node.name) }

    var error: Error?

    self.if(cond, then: {
      // do something with then_branch
    }, else: {
      // do something with else_branch
    }, name: nil)

    if let error {
      throw error
    } else {
      // return something
    }
  }
}

@geor-kasapidi your input is appreciated!

ronyfadel avatar Mar 13 '24 15:03 ronyfadel

well, i took a look at the IF operator - this is kinda tricky one. Both branches, true and false, require subgraph creations - if i understand onnx spec correctly. This is not a one-line implementation and requires new logic with recursive graph creation. But you can do this: for both branches call onnx function in MPSX with local tables for tensors and return output tensors from onnx calls in MPSGraph if method closures:

self.onnx(node: <#T##Onnx_NodeProto#>,
                  optimizedForMPS: <#T##Bool#>,
                  tensorsDataType: <#T##MPSDataType#>,
                  tensors: &<#T##[String : MPSGraphTensor]#>,
                  constants: &<#T##[String : Onnx_TensorProto]#>)
self.if(<#T##predicateTensor: MPSGraphTensor##MPSGraphTensor#>, then: {
            // onnx result for true
        }, else: {
            // onnx result for false
        }, name: nil)

try this by yourself - and i will try to help more in case you will fail with this :)

geor-kasapidi avatar Mar 15 '24 06:03 geor-kasapidi

@ronyfadel any updates here?

geor-kasapidi avatar Mar 27 '24 11:03 geor-kasapidi

Hey @geor-kasapidi , I've paused working on this for a little while because of competing priorities

ronyfadel avatar Mar 28 '24 12:03 ronyfadel