torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

(TorchToLinalg) Support for lowering torch.aten.as_strided

Open monorimet opened this issue 7 months ago • 6 comments

This was tracked previously in pull request https://github.com/llvm/torch-mlir/issues/1683 which was closed. That may be used to some degree as reference, but it's quite outdated so not likely a reliable starting point.

Torch versions 2.6.0+ emit this op frequently in model exports through dynamo, without option to decompose at fx graph construction. The best way to fix forward seems to be implementing the lowering for the op explicitly despite the difficulty of doing so.

The previous PR mentioned above supplies a test case for this op:

# ==============================================================================

class AsStridedModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([-1, -1], torch.float32, True),
    ])

    def forward(self, x):
        return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)

@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
    x = torch.randn(25, 1, 1)
    print(x)
    print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
    module.forward(x)

# ==============================================================================

Running the following script with torch >= 2.7.0 will show how relatively trivial ops like chunk end up relying on as_strided in torch dialect IR when exported via torch-mlir's fx export_and_import.

import torch

from torch_mlir import fx
N_CHUNK = 6

class ChunkModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.chunk(N_CHUNK, dim=1)
    
    def sample_inputs(self):
        return torch.rand([1,N_CHUNK,2048]) # dims here are arbitrary
    
module = ChunkModule().eval()
export_output = fx.export_and_import(module, module.sample_inputs(), output_type="torch")
export_output.dump()

Result torch dialect IR:

(show/hide)
module {
  func.func @main(%arg0: !torch.vtensor<[1,6,2048],f32>) -> (!torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>) {
    %int10240 = torch.constant.int 10240
    %int8192 = torch.constant.int 8192
    %int6144 = torch.constant.int 6144
    %int4096 = torch.constant.int 4096
    %int0 = torch.constant.int 0
    %int12288 = torch.constant.int 12288
    %int1 = torch.constant.int 1
    %int2048 = torch.constant.int 2048
    %0 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %3 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %4 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.as_strided %arg0, %3, %4, %int2048 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %6 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %7 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %8 = torch.aten.as_strided %arg0, %6, %7, %int4096 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %9 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %10 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %11 = torch.aten.as_strided %arg0, %9, %10, %int6144 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %12 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %13 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %14 = torch.aten.as_strided %arg0, %12, %13, %int8192 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    %15 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %16 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %17 = torch.aten.as_strided %arg0, %15, %16, %int10240 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
    return %2, %5, %8, %11, %14, %17 : !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>
  }
}

monorimet avatar May 19 '25 18:05 monorimet

Hi! I'd like to work on implementing the lowering from torch.aten.as_strided to the linalg dialect. This would be my first contribution to the project. Would it be okay if I took this on?

gmalasan avatar Jun 09 '25 02:06 gmalasan

By all means, I'd take a look at any PRs submitted for this. It is, however, not an entry-level task and I don't recommend it for a first-time contributor. @gmalasan

monorimet avatar Jun 10 '25 16:06 monorimet

It may help to enumerate the cases where as_strided is used in pytorch lowerings and decompositions. Chunk (example in OP) is just one of many sources and is likely one of the simpler use cases.

monorimet avatar Jun 10 '25 19:06 monorimet

By all means, I'd take a look at any PRs submitted for this. It is, however, not an entry-level task and I don't recommend it for a first-time contributor. @gmalasan

I agree with @monorimet. This is not a starter task, and I would also recommend that @gmalasan not pick this up. There are several other open issues, @gmalasan, you should pick any one of them.

vivekkhandelwal1 avatar Jun 11 '25 07:06 vivekkhandelwal1

Ok thanks for the feedback! I'll try to focus on finding first time contributor issues to contribute to. Thanks again!

gmalasan avatar Jun 11 '25 22:06 gmalasan

And if you wouldn't mind, are there any issues that are available that you'd recommend I check out?

gmalasan avatar Jun 11 '25 22:06 gmalasan