(TorchToLinalg) Support for lowering torch.aten.as_strided
This was tracked previously in pull request https://github.com/llvm/torch-mlir/issues/1683 which was closed. That may be used to some degree as reference, but it's quite outdated so not likely a reliable starting point.
Torch versions 2.6.0+ emit this op frequently in model exports through dynamo, without option to decompose at fx graph construction. The best way to fix forward seems to be implementing the lowering for the op explicitly despite the difficulty of doing so.
The previous PR mentioned above supplies a test case for this op:
# ==============================================================================
class AsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)
@register_test_case(module_factory=lambda: AsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
x = torch.randn(25, 1, 1)
print(x)
print (torch.ops.aten.as_strided(x, (2, 2), (3,3), 1))
module.forward(x)
# ==============================================================================
Running the following script with torch >= 2.7.0 will show how relatively trivial ops like chunk end up relying on as_strided in torch dialect IR when exported via torch-mlir's fx export_and_import.
import torch
from torch_mlir import fx
N_CHUNK = 6
class ChunkModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.chunk(N_CHUNK, dim=1)
def sample_inputs(self):
return torch.rand([1,N_CHUNK,2048]) # dims here are arbitrary
module = ChunkModule().eval()
export_output = fx.export_and_import(module, module.sample_inputs(), output_type="torch")
export_output.dump()
Result torch dialect IR:
(show/hide)
module {
func.func @main(%arg0: !torch.vtensor<[1,6,2048],f32>) -> (!torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>) {
%int10240 = torch.constant.int 10240
%int8192 = torch.constant.int 8192
%int6144 = torch.constant.int 6144
%int4096 = torch.constant.int 4096
%int0 = torch.constant.int 0
%int12288 = torch.constant.int 12288
%int1 = torch.constant.int 1
%int2048 = torch.constant.int 2048
%0 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%3 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.as_strided %arg0, %3, %4, %int2048 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%6 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%7 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%8 = torch.aten.as_strided %arg0, %6, %7, %int4096 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%9 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%10 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%11 = torch.aten.as_strided %arg0, %9, %10, %int6144 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%12 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%13 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%14 = torch.aten.as_strided %arg0, %12, %13, %int8192 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
%15 = torch.prim.ListConstruct %int1, %int1, %int2048 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%16 = torch.prim.ListConstruct %int12288, %int2048, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%17 = torch.aten.as_strided %arg0, %15, %16, %int10240 : !torch.vtensor<[1,6,2048],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2048],f32>
return %2, %5, %8, %11, %14, %17 : !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>, !torch.vtensor<[1,1,2048],f32>
}
}
Hi! I'd like to work on implementing the lowering from torch.aten.as_strided to the linalg dialect. This would be my first contribution to the project. Would it be okay if I took this on?
By all means, I'd take a look at any PRs submitted for this. It is, however, not an entry-level task and I don't recommend it for a first-time contributor. @gmalasan
It may help to enumerate the cases where as_strided is used in pytorch lowerings and decompositions. Chunk (example in OP) is just one of many sources and is likely one of the simpler use cases.
By all means, I'd take a look at any PRs submitted for this. It is, however, not an entry-level task and I don't recommend it for a first-time contributor. @gmalasan
I agree with @monorimet. This is not a starter task, and I would also recommend that @gmalasan not pick this up. There are several other open issues, @gmalasan, you should pick any one of them.
Ok thanks for the feedback! I'll try to focus on finding first time contributor issues to contribute to. Thanks again!
And if you wouldn't mind, are there any issues that are available that you'd recommend I check out?