Providing complete support for VIT transformer model
Hello, I've been working on providing support for transformer model via torch-mlir, I just changed example like resnet18 to vit_b_16, I found that it has problem when lower it to linalg-on tensor via torch_mlir.compile.
error: 'torch.prim.If' op along control flow edge from Region #1 to parent results: source type #0 '!torch.tensor<*,f32>' should match input type #0 '!torch.tensor' if query is key: ^ /home/dj/new-torch-mlir/torch-mlir/mlir_venv/lib/python3.9/site-packages/torchvision/models/vision_transformer.py:298:12: note: called from x = self.encoder(x) ^ /home/dj/new-torch-mlir/torch-mlir/mlir_venv/lib/python3.9/site-packages/torch/nn/modules/activation.py:1156:16: note: see current operation: %512:3 = "torch.prim.If"(%511) ({ %513 = "torch.aten.transpose.int"(%76, %1, %19) : (!torch.tensor<*,f32>, !torch.int, !torch.int) -> !torch.tensor<*,f32> %514 = "torch.tensor_static_info_cast"(%513) : (!torch.tensor<*,f32>) -> !torch.tensor "torch.prim.If.yield"(%514, %514, %514) : (!torch.tensor, !torch.tensor, !torch.tensor) -> () }, { %513 = "torch.aten.transpose.int"(%76, %1, %19) : (!torch.tensor<*,f32>, !torch.int, !torch.int) -> !torch.tensor<*,f32> %514 = "torch.aten.transpose.int"(%76, %1, %19) : (!torch.tensor<*,f32>, !torch.int, !torch.int) -> !torch.tensor<*,f32> "torch.prim.If.yield"(%513, %514, %514) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.tensor<*,f32>) -> () }) : (!torch.bool) -> (!torch.tensor, !torch.tensor, !torch.tensor) if query is key: ^
How to solve this problem?
@powderluv on model coverage -- is this in our set? vit_b_16 sounds familiar.
Yes we are tracking VIT lowering. I think this is the last piece https://github.com/llvm/torch-mlir/issues/1390 required.
@JakopinA who is working on this