[Bug] InternalError: Squeeze dimension check too strict compared to PyTorch behavior
Description
When converting a PyTorch model containing squeeze operation on a dimension that is not 1, TVM fails with an InternalError. PyTorch's squeeze operation silently ignores dimensions that are not 1, but TVM performs strict checking and requires the dimension to be exactly 1.
Expected behavior
The PyTorch model with squeeze on non-1 dimensions should be successfully converted to TVM Relax module, matching PyTorch's behavior of silently ignoring such dimensions.
Actual behavior
An InternalError occurs during from_exported_program conversion with the message Squeeze expects the input tensor shape values at the given axis positions to be all 1. However, the tensor shape at axis 1 is T.int64(10) which is not 1., indicating that TVM's squeeze implementation is stricter than PyTorch's.
Environment
- OS: Ubuntu 20.04.6 LTS
- TVM version: 0.23.dev0
- Python version: 3.11.14
Steps to reproduce
import torch
import torch.nn as nn
import tvm
from tvm import relax
class SqueezeModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.squeeze(1) # This works in PyTorch even if dim=1 != 1
model = SqueezeModel()
model.eval()
# Create tensor where dim=1 is not 1
x = torch.randn(32, 10, 5) # shape [32, 10, 5]
# PyTorch execution works (squeeze is ignored when dim != 1)
with torch.no_grad():
output = model(x)
print(f"PyTorch output shape: {output.shape}") # [32, 10, 5]
# PyTorch export works
exported_program = torch.export.export(model, (x,))
# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program) # InternalError here
Error Log
Traceback (most recent call last):
File "test.py", line 33, in <module>
mod = from_exported_program(exported_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
tvm.error.InternalError: Squeeze expects the input tensor shape values at the given axis positions to be all 1. However, the tensor shape at axis 1 is T.int64(10) which is not 1. If it is symbolic, please use MatchCast to cast it to 1 before doing Squeeze.
Triage
- needs-triage
- bug
- frontend: pytorch
I think this one have been resolved by #18478.
cc @tlopex