triton
triton copied to clipboard
[Question] make_block_ptr different offsets for each row/column
Hi. I would like to load on memory a block that has a different offset for each row.
For example load the elements in square parentheses in a 2x2 block:
1 2 [3 4] 5
6 [7 8] 9 10
I tried using the following toy code:
@triton.jit
def func(
x_ptr, offsets_ptr,
stride_xm, stride_xk,
stride_om, stride_ok,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_ptr = tl.make_block_ptr(base=x_ptr, shape=(10, 10), strides=(stride_xm, stride_xk),
offsets=(pid * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE),
order=(1, 0))
block_ptr_offsets = tl.make_block_ptr(base=offsets_ptr, shape=(2, 2), strides=(stride_om, stride_ok),
offsets=(pid * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE),
order=(1, 0))
offsets = tl.load(block_ptr_offsets)
x = tl.load(block_ptr + offsets)
def main():
x = torch.randn(10, 10, device="cuda")
offsets = torch.tensor([[0, 0], [1, 1]], device="cuda", dtype=torch.int32)
func[(1, )](x, offsets, x.stride(0), x.stride(1), offsets.stride(0), offsets.stride(1), 2)
if __name__ == "__main__":
main()
The last line fails giving the following error:
error: invalid tensor element type: 'tensor<2x2xf32>'
python: /source/llvm-project/mlir/include/mlir/IR/StorageUniquerSupport.h:181: static ConcreteT mlir::detail::StorageUserBase<mlir::RankedTensorType, mlir::TensorType, mlir::detail::RankedTensorTypeStorage, mlir::detail::TypeUniquer, Trait>::get(mlir::MLIRContext *, Args &&...) [ConcreteT = mlir::RankedTensorType, BaseT = mlir::TensorType, StorageT = mlir::detail::RankedTensorTypeStorage, UniquerT = mlir::detail::TypeUniquer, Traits = <Trait>, Args = <llvm::ArrayRef<long> &, mlir::Type &, mlir::Attribute &>]: Assertion `succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))' failed
Is it the right approach for what I have to do? Right now I have to load an offsets matrix that is as big as block_ptr even if I only have to adjust rows' offsets, does a better way exists? How can I solve the above issue?
Thanks.
same questions + 1
you can try: the output size is the same with input