triton icon indicating copy to clipboard operation
triton copied to clipboard

`tl.reshape` causes `[CUDA]: misaligned address`

Open jansel opened this issue 1 year ago • 6 comments

For adding an extra size=1 dimension: tl.reshape(..., [XBLOCK, 1]) causes and error while ...[:, None] works fine -- in the context of a larger program.

Repro:

import torch
from torch import empty
from torch._dynamo.testing import rand_strided
import triton
import triton.language as tl


@triton.jit
def triton_(
    in_ptr0,
    in_ptr1,
    in_ptr2,
    in_ptr3,
    in_ptr4,
    out_ptr0,
    out_ptr1,
    XBLOCK: tl.constexpr,
    VERSION: tl.constexpr,
):
    xnumel = 1024
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex % 16
    r2 = rindex // 16
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask).to(tl.int1)
    tmp1 = tl.load(in_ptr1 + (0))
    tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
    tmp4 = tl.load(
        in_ptr2 + (x0 + (1024 * r2)),
        rmask & xmask,
        eviction_policy="evict_last",
        other=0.0,
    ).to(tl.float32)
    tmp9 = tl.load(in_ptr3 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask, other=0.0)

    if VERSION == 0:
        # works loading with a 2D [XBLOCK, 1] block
        tmp10 = tl.load(in_ptr4 + (x0), xmask)
    elif VERSION == 1:
        # works loading with a 1D block and using [:, None] to expand
        index_1d = xoffset + tl.arange(0, XBLOCK)
        tmp10 = tl.load(in_ptr4 + index_1d, index_1d < xnumel)[:, None]
    elif VERSION == 2:
        # tl.reshape() causes a RuntimeError: Triton Error [CUDA]: misaligned address
        index_1d = xoffset + tl.arange(0, XBLOCK)
        tmp10 = tl.reshape(
            tl.load(in_ptr4 + index_1d, index_1d < xnumel),
            [XBLOCK, 1],
        )

    tmp3 = tmp2.to(tl.float32)
    tmp5 = 16.0
    tmp6 = tmp4 / tmp5
    tmp7 = tl.where(tmp0, tmp3, tmp6)
    tmp8 = tmp7.to(tl.float32)
    tmp11 = tmp9 - tmp10
    tmp12 = tmp8 * tmp11
    tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
    tmp19 = tl.where(rmask & xmask, tmp17, 0)
    tmp20 = tl.sum(tmp19, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp16, xmask)
    tl.store(out_ptr1 + (x0), tmp20, xmask)


arg0_1 = rand_strided(
    (8, 1024, 4, 4), (1024, 1, 0, 0), device="cuda:0", dtype=torch.float16
)
arg1_1 = rand_strided(
    (8, 1024, 4, 4), (16384, 16, 4, 1), device="cuda:0", dtype=torch.bool
)
arg2_1 = rand_strided((), (), device="cuda:0", dtype=torch.float32)
arg3_1 = rand_strided(
    (8, 1024, 4, 4), (16384, 16, 4, 1), device="cuda:0", dtype=torch.float32
)
arg4_1 = rand_strided(
    (1, 1024, 1, 1), (1024, 1, 1, 1), device="cuda:0", dtype=torch.float32
)
buf0 = empty((1024,), device="cuda", dtype=torch.float32)
buf1 = empty((1024,), device="cuda", dtype=torch.float32)
xblock = 1

for version in range(3):
    print(f"Running with VERSION={version}")
    triton_[(triton.cdiv(1024, xblock),)](
        arg1_1,
        arg2_1,
        arg0_1,
        arg3_1,
        arg4_1,
        buf0,
        buf1,
        XBLOCK=xblock,
        VERSION=version,
    )
    torch.cuda.synchronize()
    print("SUCCESS")

Output:

$ CUDA_LAUNCH_BLOCKING=1 python repro.py
Running with VERSION=0
SUCCESS
Running with VERSION=1
SUCCESS
Running with VERSION=2
Traceback (most recent call last):
  File "/home/jansel/pytorch/repro.py", line 93, in <module>
    triton_[(triton.cdiv(1024, xblock),)](
  File "/home/jansel/conda/envs/pytorch/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: misaligned address

This code is originally from TorchInductor compiling the following program:

def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
    div = torch.ops.aten.div.Scalar(expand, 16)
    where = torch.ops.aten.where.self(arg207_1, full, div)
    convert_element_type_43 = torch.ops.prims.convert_element_type.default(
        where, torch.float32
    )
    sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
    sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
    mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
    sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
    mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
    unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
    mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
    mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
    unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
    unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
    mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
    sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
    sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
    return (sub_2,)

Which was generated by running the PyTorch minifier on on a 1500+ node graph coming from selecsls42b.

jansel avatar Dec 22 '23 04:12 jansel

I think @plotfi is looking at this!

manman-ren avatar Jan 24 '24 01:01 manman-ren

I have some findings on this.

1.

At https://github.com/openai/triton/blob/0327b9d32db6d1d63d207ccab722bd45e00a6678/python/src/llvm.cc#L173 Triton is enabling the SLPVectorizer with an empty target machine, in order to get wider vectors. This results in sequences of IR instructions like:

  %i123 = load i8, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1, !dbg !19
  %.not97 = icmp eq i8 %i123, 0, !dbg !19
  %i124 = load i8, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 98), align 1, !dbg !19

to be turned into:

  %1 = load <16 x i8>, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1

I tried getting the IR out of Triton before the SLPVectorizer and running it through LLVM opt with and without the Target (passing -mtriple=nvptx64-nvidia-cuda), and when the target is passed it wont produces these vector loads that ultimately result in the misaligned access.

By the way the problematic loads originate from Python line 31 column 80: tmp0 = tl.load(in_ptr0 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask).to(tl.int1)). Also, before these loads are materialized in the first place they come from triton_gpu.convert_layout as part of layout conversions to support the reshape operation.

2.

Later in the compilation pipeline in LLVM's SelectionDAG the load<16 x i8> gets turned into a dag node NVPTXISD::LoadV4 and then lowered into a MIR instruction of NVPTX::LDV_i32_v4_asi at https://github.com/llvm/llvm-project/blob/f07eb24bb003aea435da94569910529fe5e332a4/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp#L1134

This is the part that produces what ultimately becomes the PTX ld.shared.v4.u32 {%r676, %r677, %r678, %r679}, [global_smem+97];

Summary

  • So in summary the tl.reshape requires a triton_gpu.convert_layout which results in a valid sequence of loads to addrspace 3 (shared memory).
  • This sequence of loads is vectorized by the SLPVectorizer without regard to legalize vector lengths or alignment (as part of what seems to be an expedient path to better performance)
  • Later in Machine lowering the vector load is materialized into an unaligned machine instruction load of 4xi32 at offset 97 that results in the misaligned PTX

Potential Solutions

Still brainstorming, but here are some ideas

  • Introduce a separate legalization pass, since we arent providing target information to SLPVectorizer
  • Add additional legalization to the NVPTX backend
  • Introduce some additional rules around misaligned accesses to the SLPVectorizer without directly involving the target machine

plotfi avatar Jan 26 '24 15:01 plotfi

@Artem-B is this a bug in the nvptx backend's legalization routine?

jlebar avatar Jan 26 '24 21:01 jlebar

load <16 x i8>, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1

A vector aligned by 1 should not have been lowered into a vectorized load. Period. The only way to load it correctly is to split it into individual byte loads. https://godbolt.org/z/4PPovnxxb

Looks like a bug to me.

Artem-B avatar Jan 27 '24 01:01 Artem-B

E.g. here's what we do with an array: https://godbolt.org/z/rszcqdqoc though in this case we fail to vectorize even if we align by 16.

Artem-B avatar Jan 27 '24 01:01 Artem-B

delete my previous message which was incorrect, so this is originally an align1 load for i8 type. In this case LLVM IR is correct but then it should have been broken down here: https://github.com/llvm/llvm-project/blob/25e1916d88ebeef786956b678a4eb9a757e219d9/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L5730

I wonder why it didn't hit this condition. @plotfi it could be interesting to put a break point there and see why it is not taking this path.

ThomasRaoux avatar Jan 27 '24 02:01 ThomasRaoux