torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Non-value Tensor Type

Open ZihengJiang opened this issue 2 years ago • 5 comments

I tried to lower a model into torch dialect from torch script, but met this error:

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass
note: see current operation: %286 = "torch.copy.to_tensor"(%285) : (!torch.vtensor) -> !torch.tensor

Any idea?

ZihengJiang avatar Aug 03 '22 23:08 ZihengJiang

Hi @ZihengJiang,

This error means that the MaximizeValueSemantics pass was not able to successfully convert the IR graph into one that only uses !torch.vtensors (tensors with value semantics) and does not use !torch.tensors. This could be caused in general by a number of different issues, so it is hard to tell what is going wrong from just the error message. If you have a small working example that results in this error, I can help you debug further.

ramiro050 avatar Aug 04 '22 00:08 ramiro050

Hi @ZihengJiang,

This error means that the MaximizeValueSemantics pass was not able to successfully convert the IR graph into one that only uses !torch.vtensors (tensors with value semantics) and does not use !torch.tensors. This could be caused in general by a number of different issues, so it is hard to tell what is going wrong from just the error message. If you have a small working example that results in this error, I can help you debug further.

Hi @ramiro050 , we are suffering the same problem, and i post a small working example that results in this error below. Please have a look, thanks a lot!

import torch
from torch import nn
import torch_mlir

hidden_size = 768
max_ngram = 3
device = "cpu"

class MyPad1D(torch.nn.Module):
    def __init__(self, ngram_idx, device):
        super(MyPad1D, self).__init__()
        self.left = 1 if ngram_idx == 2 else 0
        self.right = 1 if ngram_idx >= 1 else 0
        self.device = device

    def forward(self, x):
        left = torch.zeros((x.shape[0], x.shape[1], self.left), device=self.device)
        right = torch.zeros((x.shape[0], x.shape[1], self.right), device=self.device)
        x = torch.cat((left, x, right), dim=-1)
        return x

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.q_convs = nn.ModuleList()
        self.d_convs = nn.ModuleList()
        for i in range(max_ngram):
            conv = nn.Sequential(
                #torch.nn.ConstantPad1d((0, i), 1),  
                MyPad1D(i, device),  # This version of TensorRT only supports constant 0 padding
                torch.nn.Conv1d(hidden_size, hidden_size, i+1),
                torch.nn.ReLU()
            )
            self.q_convs.append(conv)
            self.d_convs.append(conv)
    
    def forward(self, input0, input1):
        q_convs = []
        d_convs = []
        for q_conv, d_conv in zip(self.q_convs, self.d_convs):
            q_convs.append(q_conv(input0.transpose(1, 2)).transpose(1, 2))
            d_convs.append(d_conv(input1.transpose(1, 2)).transpose(1, 2))
        return q_convs, d_convs

def main():
    model = SimpleModel().eval()

    example_input0 = torch.rand((20, 12, 768), dtype=torch.float32)
    example_input1 = torch.rand((20, 50, 768), dtype=torch.float32)

    # print(model(example_input0, example_input1))

    traced = torch.jit.trace(model, [example_input0, example_input1])
    linalg_on_tensors_mlir = torch_mlir.compile(traced, [example_input0, example_input1],
                                                output_type=torch_mlir.OutputType.TORCH)
    print(linalg_on_tensors_mlir)


if __name__ == "__main__":
    main()

Update: If use return torch.cat(q_convs, dim=0), torch.cat(d_convs, dim=0) instead of return q_convs, d_convs, this error gone. So i wonder

  1. What's the difference between above two return types. According the MLIR info, return q_convs, d_convs will introduce torch.copy.to_tensor, while return torch.cat(q_convs, dim=0), torch.cat(d_convs, dim=0) not.
  2. If i really need return q_convs, d_convs as return type, how can i do it.

Tengxu-Sun avatar Aug 08 '22 06:08 Tengxu-Sun

Hi @ZihengJiang,

This error means that the MaximizeValueSemantics pass was not able to successfully convert the IR graph into one that only uses !torch.vtensors (tensors with value semantics) and does not use !torch.tensors. This could be caused in general by a number of different issues, so it is hard to tell what is going wrong from just the error message. If you have a small working example that results in this error, I can help you debug further.

Hi, @ramiro050 I got the same error of "found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass" when I use the latest torch-mlir.

But the same codes can generate mlir successfully by using torch-mlir snapshot snapshot-20220701.520

I paste my codes( export mode is :use_tracing=False) below, all of them can generate mlir files in the torch-mlir of snapshot-20220701 , but meet error in snapshot-20220807...

code1:

import torch
import torch.nn as nn
import torch_mlir


class CtrlFlow(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        if x > 0:
            return x + 1
        else:
            return x - 1


example_input = torch.ones((1, 4, 2), dtype=torch.float32)
torch_on_tensors_mlir = torch_mlir.compile(CtrlFlow(), example_input,
                                                output_type=torch_mlir.OutputType.TORCH, use_tracing=False)
print("CtrlFlow Mlir = ", torch_on_tensors_mlir)

code2:

import torch
import torch.nn as nn
import torch_mlir


class ScriptMode(nn.Module):
    def __init__(self):
        super(ScriptMode, self).__init__()
        self.linear = nn.Linear(4, 2)
    
    def forward(self, x):
        a = []
        a.append(self.linear(x[:, :, 0]).unsqueeze(-1))
        a.append(self.linear(x[:, :, 1]).unsqueeze(-1))
        return torch.cat(a, -1)


example_input = torch.ones((1, 4, 2), dtype=torch.float32)
torch_on_tensors_mlir = torch_mlir.compile(ScriptMode(), example_input,
                                                output_type=torch_mlir.OutputType.TORCH, use_tracing=False)
print("ScriptMode Mlir = ", torch_on_tensors_mlir)

KangHe000 avatar Aug 08 '22 12:08 KangHe000

What's the difference between above two return types

Hi @Tengxu-Sun, torch-mlir currently does not have support for returning lists. If you need to return several tensors, the only supported way at the moment is by using tuples

ramiro050 avatar Aug 08 '22 17:08 ramiro050

Hi @KangHe000,

Torch-mlir currently does not have support for data-dependent control flow, so the first code snippet failing is expected. The reason you were not getting an error with the older snapshot is likely due to the fact that in the older version there is no verification that the maximize-value-semantics pass succeeds, so even after failing, the pipeline finishes without an error. The code that verifies that maximize-value-semantics works was only recently added.

The second snippet fails because there is currently no support for append. If you get rid of the appends, the code should work

ramiro050 avatar Aug 09 '22 18:08 ramiro050

Is there anything further to do in this issue? Or can we close it?

silvasean avatar Oct 07 '22 13:10 silvasean

I found that the error is due to missing some op support in Torch dialect. After adding it, the error disappeared in my case.

ZihengJiang avatar Oct 11 '22 00:10 ZihengJiang