torch-mlir
torch-mlir copied to clipboard
E2E test framework does not detect when annotation does not match input tensors
Below is an example of a test that would pass in torch-mlir despite having an annotation that does not match the input tensors. Namely, the size and dtype of the tensors are different. This test is inspired by an actual test I saw in a recent PR.
class DivBugModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2], torch.float32, True),
([1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.ops.aten.div(lhs[0], rhs)
@register_test_case(module_factory=lambda: DivBugModule())
def DivBugModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, low=1, high=10), tu.randint(1, low=1, high=10))
However, in most cases, tests do actually fail when the annotation does not match the input. It would be nice to have the error message resulting from such tests to mention that such an error could be due to the annotation not matching input tensors.