torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[fx]add all_reduce test.

Open linuxlonelyeagle opened this issue 1 year ago • 5 comments

linuxlonelyeagle avatar Jan 22 '24 16:01 linuxlonelyeagle

@stellaraccident @ramiro050 Hi, happy to introduce the first PR to support communication operator. The CI failed because that nightly-build torch has difference signature from stable-build torch. The stable-build's signature is c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor), however the nightly-build's signature is _c10d_functional::all_reduce : (Tensor, str, str) -> (Tensor). In our practice, the stable-build's signature is correct.

Do you have any idea to fix this?

qingyunqu avatar Jan 23 '24 02:01 qingyunqu

I don't think I've ever seen this particular issue. We do have a place where we check the PyTorch version because of differences in ops supported:

https://github.com/llvm/torch-mlir/blob/77ae56337dbf95eb809f4a7d218a9fb3dc1f41b0/projects/pt1/python/torch_mlir/dynamo.py#L69

but the issue here is that the ODS for the ops is hard-coded. One simple workaround would be to modify the ODS generator to output two versions for that op and add a Stable or Nightly to the op names. Once things converge upstream, we can get rid of the workaround.

ramiro050 avatar Jan 24 '24 00:01 ramiro050

I don't think I've ever seen this particular issue. We do have a place where we check the PyTorch version because of differences in ops supported:

https://github.com/llvm/torch-mlir/blob/77ae56337dbf95eb809f4a7d218a9fb3dc1f41b0/projects/pt1/python/torch_mlir/dynamo.py#L69

but the issue here is that the ODS for the ops is hard-coded. One simple workaround would be to modify the ODS generator to output two versions for that op and add a Stable or Nightly to the op names. Once things converge upstream, we can get rid of the workaround.

In my local test, when I use torch==2.3.0.dev20240109+cpu and without modified GeneratedTorchOps.td, it will generate:

test_import_frozen_exported_program
-----------------------------------
module {
  func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> {
    %str = torch.constant.str "sum"
    %str_0 = torch.constant.str ""
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %int4 = torch.constant.int 4
    %1 = torch.operator "torch.c10d_functional.all_reduce"(%arg0, %str, %str_0, %0, %int4) : (!torch.vtensor<[4],f32>, !torch.str, !torch.str, !torch.list<int>, !torch.int) -> !torch.vtensor<[4],f32>
    %2 = torch.operator "torch.c10d_functional.wait_tensor"(%1) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
    return %2 : !torch.vtensor<[4],f32>
  }
}

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

So how about adding another td file CommunicationOps.td manually (don't use the torch_ods_gen.py) as workaround? And merge this td file into GeneratedTorchOps.td once things converge upstream.

qingyunqu avatar Jan 24 '24 02:01 qingyunqu

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

Interesting. Yeah, your proposed solution seems fine to me. No need to add an extra td file. I would just place it here next to this op:

https://github.com/llvm/torch-mlir/blob/ac8975ea1276e1f53f1bb3eaedc26f32252c91a1/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1085-L1087

ramiro050 avatar Jan 24 '24 18:01 ramiro050

It seems that torch.export always generate op with signature c10d_functional::all_reduce : (Tensor, str, str, int[], int) -> (Tensor) even on nightly-build torch.

Interesting. Yeah, your proposed solution seems fine to me. No need to add an extra td file. I would just place it here next to this op:

https://github.com/llvm/torch-mlir/blob/ac8975ea1276e1f53f1bb3eaedc26f32252c91a1/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1085-L1087

Ok, I place it at the end of TorchOps.td.

qingyunqu avatar Jan 25 '24 03:01 qingyunqu