torch.cond operator not supported on a simple example
Issue :
Trying to implement in a neural network a logic that routes dynamically a sample based on some condition. I built a dummy example of how the network should look like and I would like to export this model to MLIR. When I try to do so using torch-mlir, I get an error. I would like to know if the operator torch.cond is not supported or if my implementation is just wrong.
Steps to reproduce :
Just run this code :
import torch
import torch.nn as nn
import copy
from torch_mlir.fx import export_and_import
class CondNetwork(nn.Module):
def __init__(self):
super(CondNetwork, self).__init__()
self.confidence_threshold = 2
self.linear1 = nn.Linear(3072, 3)
self.linear2 = nn.Linear(3072, 3)
def forward(self, x):
condition = torch.mean(x) > self.confidence_threshold
def true_fn():
feature = x.clone().flatten()
return self.linear1(feature)
def false_fn():
feature = x.clone().flatten()
return self.linear2(feature)
return torch.cond(condition, true_fn, false_fn)
def torch_mlir_model_export(model):
cond_model = copy.deepcopy(model)
with torch.no_grad():
cond_model.eval()
module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
open("torchmlir_condmodel.mlir", "w").write(str(module))
###-- Main
def main():
model = CondNetwork()
#model_export(model, "cpu")
torch_mlir_model_export(model)
if __name__ == '__main__':
main()
You should get this error :
module = export_and_import(cond_model, torch.ones(1, 3, 32, 32), output_type="torch")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/fx.py", line 111, in export_and_import
fx_importer.import_frozen_program(
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 901, in import_frozen_program
return self.import_stateless_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 947, in import_stateless_graph
node_importer.import_nodes(
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1462, in import_nodes
self._import_hop(loc, node, target)
File "/home/jhassani/miniconda3/lib/python3.12/site-packages/torch_mlir/extras/fx_importer.py", line 1566, in _import_hop
raise NotImplementedError(
NotImplementedError: Higher-order operation 'cond' not implemented in the FxImporter (tried '_import_hop_cond')
Additional informations
torch version : 2.7.0.dev20250210+cpu
torchvision version : torchvision-0.22.0.dev20250210+cpu
torch_mlir version : 20250127.357
I ran the code and got a similar error message. Correct me if I'm wrong: I think the problem is in how Torch‑MLIR’s FxImporter (and underlying TorchDynamo integration) currently handles higher‑order operations like torch.cond with nested (inline) functions. This means Torch‑MLIR currently isn’t set up to handle like the example above.
Yes, I think so, after working on this the best workaround I found is to split the model into two models and then write the "if" logic on the application level. This is not ideal for me because I want to compile the IR afterwards but it solves the issue.
Hello @JibAxelera and @amemov
I believe that under the hood, torch_mlir.fx relies on FX symbolic tracing of a nn.Module.
According to the documentation, FX only handles static control flow, it doesn’t preserve branching that depends on actual input data values. In other words, any Python "if" statements get “inlined” into the single path taken during tracing with your example inputs, dropping the other branch.
Workaround
One way to preserve both outputs in your exported IR is to use torch.where so that both potential outputs exist in the same dataflow expression. This ensures FX sees a single where op rather than Python-level branching. The downside is that both sub-networks compute every forward pass, even if you only need one in practice.
Potential workaround:
class CondNetwork(nn.Module):
def __init__(self):
super(CondNetwork, self).__init__()
self.confidence_threshold = 2
self.linear1 = nn.Linear(3072, 3)
self.linear2 = nn.Linear(3072, 3)
def forward(self, x):
condition = torch.mean(x) > self.confidence_threshold
feature = x.clone().flatten()
feature1 = self.linear1(feature)
feature2 = self.linear2(feature)
return torch.where(condition, feature1, feature2)
With this code, I can successfully run:
module = export_and_import(model, torch.ones(1, 3, 32, 32), output_type="torch")
Final caveat: if the sub-networks are large, computing both might be inefficient. In principle, TorchScript (torch.jit.script) can handle truly dynamic data-based branching, but I’m not sure if Torch-MLIR still supports direct ScriptModule importing. As far as I know, the FX-based path is the recommended or primary route at the moment.
Very clear response, thank you. Would be nice if we could route samples dynamicly and export. Is this a feature request that makes sense ? Or is this out of the scope of torch-mlir ?
@JibAxelera I think it should be in the scope of torch-mlir, since semantically your code above is acceptable for PyTorch.
I wouldn't mind implementing this feature myself.
I'm trying to implement this (my first attempt at MLIR stuff), generates MLIR for torch.cond (although not necessarily technically correct 😄). Tested it in IREE; it compiles and works in the runtime, for the given example in this issue.
FX graph and generated MLIR graph: https://gist.github.com/thomasverelst/3e4d564a0ad6aebadd227c8c5b8cadac modified FX importer https://github.com/thomasverelst/torch-mlir/commit/33d9faf52d2fb39917c27a21c25fe4370ce3933b changes:
- implemented
_import_hop_condthat inserts antorch.prim.Ifwith function calls operation. - changed the
import_frozen_program/import_programso that it imports both the main GraphModule and the nested GraphModules (the if & else graphmodules), each as a seperate MLIR function. (probably not done as it should be) - some changes in
ContextCache.value_info_to_typeto deal with the output arguments of the functions (tuple of one or more tensors)(definitely not sure if this is a correct approach)
Definitely open for feedback!