torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

error: operand types should have the same type as the list contained type

Open Tengxu-Sun opened this issue 2 years ago • 5 comments

We're trying to use torch.cat operation, but get unexpected error. Code below can reproduce my error.

import torch
from torch import nn
import torch_mlir

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
    
    def forward(self, input0, input1, input2, input3, input4):
        outputs = [input0, input1, input2, input3, input4]
        out = torch.cat(outputs, dim=1)
        return out


def main():
    model = SimpleModel()

    example_input0 = torch.rand((20,12), dtype=torch.float32)
    example_input1 = torch.rand((20,44), dtype=torch.float32)
    example_input2 = torch.rand((20,40), dtype=torch.float32)
    example_input3 = torch.rand((20,36), dtype=torch.float32)
    example_input4 = torch.rand((20,32), dtype=torch.float32)


    traced = torch.jit.trace(model, [example_input0, example_input1, example_input2, example_input3, example_input4])
    torch_on_tensors_mlir = torch_mlir.compile(traced, [example_input0, example_input1, example_input2, example_input3, example_input4], output_type=torch_mlir.OutputType.TORCH, use_tracing=True)
    print(torch_on_tensors_mlir)


if __name__ == "__main__":
    main()

While if all inputs have the same shape, the code works. For example:

traced = torch.jit.trace(model, [example_input0, example_input0, example_input0, example_input0, example_input0])
torch_on_tensors_mlir = torch_mlir.compile(traced, [example_input0, example_input0, example_input0, example_input0, example_input0],  output_type=torch_mlir.OutputType.TORCH, use_tracing=True)
print(torch_on_tensors_mlir)

Hopes for your reply, thanks!

Tengxu-Sun avatar Jul 20 '22 10:07 Tengxu-Sun

Hi @Tengxu-Sun,

I just wanted to let you know that we have torch-MLIR office hours every Thursday where I can show you my approach for debugging issues like this one. If the time doesn't work you for you, I'm happy to help through github issues. I just wanted to make sure you were aware of that option! 😄

ramiro050 avatar Jul 21 '22 03:07 ramiro050

Hi @Tengxu-Sun,

I just wanted to let you know that we have torch-MLIR office hours every Thursday where I can show you my approach for debugging issues like this one. If the time doesn't work you for you, I'm happy to help through github issues. I just wanted to make sure you were aware of that option! 😄

Thanks very much for your invitation. I will be there!

Tengxu-Sun avatar Jul 21 '22 05:07 Tengxu-Sun

Was there a solution to this? I'm getting the same error.

kkiningh avatar Jul 25 '22 15:07 kkiningh

Was there a solution to this? I'm getting the same error.

Hi @kkiningh,

I was not able to recreate this issue on my machine. Do you have this commit https://github.com/llvm/torch-mlir/pull/971/commits/a495be1905237df77bb5e0e6f59534348d9cb070 in your branch? If you do and you're still getting an error, can you post the error message here along with the /tmp MLIR file generated to recreate it?

ramiro050 avatar Jul 25 '22 16:07 ramiro050

Was there a solution to this? I'm getting the same error.

Upgrade your torch-mlir version and this error will be gone.

Tengxu-Sun avatar Jul 26 '22 08:07 Tengxu-Sun

Closing issue as a solution was found.

silvasean avatar Oct 07 '22 13:10 silvasean