torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Can ConvertAten_IndexPutImplOp pattern handle the scene that indices is bool?

Open lingfengqiu opened this issue 1 year ago • 1 comments

`import torch

from typing import List import numpy as np import time import torch._dynamo as dynamo

class fw_graph_0(torch.nn.Module): def init(self): super().init()

def forward(self, arg1_1, bitwise_or, scalar_tensor):
    index_put = torch.ops.aten.index_put(arg1_1, [bitwise_or], scalar_tensor);  sub = scalar_tensor = None
    return index_put

device = “cuda"

mod = fw_graph_0().to(device)

optimized_mod = torch.compile(mod, backend="inductor")

optimized_mod = torch.compile(mod, backend="grace")

arg1_1 = torch.ones((5, 1), dtype=torch.int64).to(device) bitwise_or = torch.ones((5, 1), dtype=torch.bool).to(device) scalar_tensor = torch.tensor(0, dtype=torch.int64).to(device)

warmups, repetitions = 10, 100 infer_sum = warmups + repetitions

times_grace = np.zeros(infer_sum) times_golden = np.zeros(infer_sum)

for idx in range(infer_sum): T1 = time.perf_counter() res = optimized_mod(arg1_1, bitwise_or, scalar_tensor) if device == "musa" : torch.musa.synchronize() elif device == "cuda": torch.cuda.synchronize() T2 = time.perf_counter() times_grace[idx] = (T2 - T1) * 1000`

This case failed, did torch-mlir support index_put op with bool type indicies?

error information:
MLIRError: Failure while executing pass pipeline:

error: "/host/workspace/MUSA-Megatron-DeepSpeed/test_op/op_case/test_index_put_min.py":14:0: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 note: "/host/workspace/MUSA-Megatron-DeepSpeed/test_op/op_case/test_index_put_min.py":14:0: see current operation: %50 = "tm_tensor.scatter"(%44, %49, %0) <{operandSegmentSizes = array<i32: 2, 1>, unique_indices = false}> ({ ^bb0(%arg3: i64, %arg4: i64): "tm_tensor.yield"(%arg3) : (i64) -> () }) : (tensor<1xi64>, tensor<5x2xi32>, tensor<5x1xi64>) -> tensor<5x1xi64>

lingfengqiu avatar Feb 21 '24 09:02 lingfengqiu

I'm seeing a similar failure, filed https://github.com/llvm/torch-mlir/issues/3433 with some other details.

ScottTodd avatar Jun 07 '24 18:06 ScottTodd