triton icon indicating copy to clipboard operation
triton copied to clipboard

[Question] make_block_ptr different offsets for each row/column

Open Rikyf3 opened this issue 1 year ago • 2 comments

Hi. I would like to load on memory a block that has a different offset for each row.

For example load the elements in square parentheses in a 2x2 block:

1  2  [3 4]  5
6  [7 8]  9  10

I tried using the following toy code:

@triton.jit
def func(
        x_ptr, offsets_ptr,
        stride_xm, stride_xk,
        stride_om, stride_ok,
        BLOCK_SIZE: tl.constexpr,
        ):
    pid = tl.program_id(0)

    block_ptr = tl.make_block_ptr(base=x_ptr, shape=(10, 10), strides=(stride_xm, stride_xk),
                                  offsets=(pid * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE),
                                  order=(1, 0))

    block_ptr_offsets = tl.make_block_ptr(base=offsets_ptr, shape=(2, 2), strides=(stride_om, stride_ok),
                                          offsets=(pid * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE),
                                          order=(1, 0))

    offsets = tl.load(block_ptr_offsets)

    x = tl.load(block_ptr + offsets)


def main():
    x = torch.randn(10, 10, device="cuda")
    offsets = torch.tensor([[0, 0], [1, 1]], device="cuda", dtype=torch.int32)

    func[(1, )](x, offsets, x.stride(0), x.stride(1), offsets.stride(0), offsets.stride(1), 2)


if __name__ == "__main__":
    main()

The last line fails giving the following error:

error: invalid tensor element type: 'tensor<2x2xf32>'
python: /source/llvm-project/mlir/include/mlir/IR/StorageUniquerSupport.h:181: static ConcreteT mlir::detail::StorageUserBase<mlir::RankedTensorType, mlir::TensorType, mlir::detail::RankedTensorTypeStorage, mlir::detail::TypeUniquer, Trait>::get(mlir::MLIRContext *, Args &&...) [ConcreteT = mlir::RankedTensorType, BaseT = mlir::TensorType, StorageT = mlir::detail::RankedTensorTypeStorage, UniquerT = mlir::detail::TypeUniquer, Traits = <Trait>, Args = <llvm::ArrayRef<long> &, mlir::Type &, mlir::Attribute &>]: Assertion `succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))' failed

Is it the right approach for what I have to do? Right now I have to load an offsets matrix that is as big as block_ptr even if I only have to adjust rows' offsets, does a better way exists? How can I solve the above issue?

Thanks.

Rikyf3 avatar Feb 23 '24 08:02 Rikyf3

same questions + 1

xinji1 avatar Apr 09 '24 03:04 xinji1

you can try: the output size is the same with input

ihaterecursion avatar May 08 '24 01:05 ihaterecursion