torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Issue with the aten.matmul op in eager_mode config

Open vivekkhandelwal1 opened this issue 3 years ago • 2 comments

Hi, after some changes in the upstream PyTorch, the test case https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir_e2e_test/test_suite/matmul.py#L54-L70, which was earlier passing for eager_mode is now failing and was failing earlier for LTC but is passing now. The error with the eager_mode test is that the vector of shape [4] is somehow getting converted to the shape [1,4] resulting in an issue. An unsqueeze op gets added in the torch IR, before the linalg conversion and that creates this error.

vivekkhandelwal1 avatar Aug 04 '22 06:08 vivekkhandelwal1

@makslevental PTAL

silvasean avatar Aug 04 '22 23:08 silvasean

I don't know what the change was that caused this but it is true that it is due to upstream changes, i.e. this is being done somewhere in the dispatcher prior to interception by __torch_dispatch__.

MWE:

import torch

from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
from torch_mlir.eager_mode import torch_mlir_tensor

class MatmulVecMat(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.float32, True),
            ([-1, -1], torch.float32, True),
        ]
    )
    def forward(self, lhs, rhs):
        return torch.matmul(lhs, rhs)

module = MatmulVecMat()
t = torch_mlir_tensor.TorchMLIRTensor(torch.rand(4))
u = torch_mlir_tensor.TorchMLIRTensor(torch.rand(4, 5))
print(module.forward(t, u))

put a breakpoint at eager_mode/torch_mlir_tensor.py#L120 like so

    @classmethod
    def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
        import pdb
        pdb.set_trace()
        requires_grad = check_requires_grad(*args, **kwargs)
    ...

and then the first break is

/home/mlevental/dev_projects/bragghls/venv/bin/python /home/mlevental/dev_projects/bragghls/examples/demo.py 

> /home/mlevental/dev_projects/bragghls/venv/lib/python3.10/site-packages/torch_mlir/eager_mode/torch_mlir_tensor.py(122)__torch_dispatch__()
-> requires_grad = check_requires_grad(*args, **kwargs)
(Pdb) func
<OpOverload(op='aten.unsqueeze', overload='default')>
(Pdb) 

and so you see that that unsqueeze is the first op that our eager mode sees, prior to any intervention/munging/transformation by our code. Note you can also put a breakpoint at eager_mode/torch_mlir_tensor.py#L102 to verify that "numpy"ing the tensor in the constructor doesn't change the shape either.

The answer to the problem of the test divergence is that there's no way to enforce/require/depend on both these paths being structurally equivalent - the schema that torch_ods_gen.py implements against are at the Tracer dispatch key level[^1], while the ones observed in __torch_dispatch__ are at the Python key level, i.e. at a lower level of abstraction. So there are a lot of ops (and decompositions) that pop up in __torch_dispatch__ that will never be seen in a TS graph (let alone an fx graph). Cf. @Chillee's epic diagram of all of the paths to _convolution.

Hence, the solution is probably to maintain a separate set of tests for eager mode (at least for these divergences). Presumably the fact that the tests haven't diverged until now (and this divergence is trivial) lends credence to the idea that this path is valid, the code is correct, and the numerics can be relied upon (and thus being 100% in sync with the conventional path isn't necessary anymore).

cc @powderluv

[^1]: Jit ops are registered through the dispatcher, not statically https://github.com/pytorch/pytorch/blob/f50a248a5eacb9a9aa475a9e610486aea136e4f5/aten/src/ATen/core/dispatch/Dispatcher.cpp#L155-L175 ; https://github.com/pytorch/pytorch/blob/bfebf254dd92f3ed35154597166e7e71fb04f31b/tools/autograd/templates/TraceType.cpp#L34-L36 ;

makslevental avatar Aug 05 '22 02:08 makslevental

Hi @makslevental, should we close this issue now?

vivekkhandelwal1 avatar Dec 01 '22 12:12 vivekkhandelwal1

I don't know? did it get disappear? My earlier response is basically a "can't fix" type response.

makslevental avatar Dec 01 '22 20:12 makslevental

I don't know? did it get disappear? My earlier response is basically a "can't fix" type response.

No. it's still there in the eager mode xfail set: https://github.com/llvm/torch-mlir/blob/main/e2e_testing/xfail_sets.py#L21.

vivekkhandelwal1 avatar Dec 02 '22 06:12 vivekkhandelwal1

Closing this issue since the eager_mode support is now removed from the torch-mlir (https://github.com/llvm/torch-mlir/pull/1697).

vivekkhandelwal1 avatar Dec 09 '22 14:12 vivekkhandelwal1