`tl.reshape` causes `[CUDA]: misaligned address`
For adding an extra size=1 dimension: tl.reshape(..., [XBLOCK, 1]) causes and error while ...[:, None] works fine -- in the context of a larger program.
Repro:
import torch
from torch import empty
from torch._dynamo.testing import rand_strided
import triton
import triton.language as tl
@triton.jit
def triton_(
in_ptr0,
in_ptr1,
in_ptr2,
in_ptr3,
in_ptr4,
out_ptr0,
out_ptr1,
XBLOCK: tl.constexpr,
VERSION: tl.constexpr,
):
xnumel = 1024
rnumel = 128
RBLOCK: tl.constexpr = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
rmask = rindex < rnumel
r1 = rindex % 16
r2 = rindex // 16
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask).to(tl.int1)
tmp1 = tl.load(in_ptr1 + (0))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = tl.load(
in_ptr2 + (x0 + (1024 * r2)),
rmask & xmask,
eviction_policy="evict_last",
other=0.0,
).to(tl.float32)
tmp9 = tl.load(in_ptr3 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask, other=0.0)
if VERSION == 0:
# works loading with a 2D [XBLOCK, 1] block
tmp10 = tl.load(in_ptr4 + (x0), xmask)
elif VERSION == 1:
# works loading with a 1D block and using [:, None] to expand
index_1d = xoffset + tl.arange(0, XBLOCK)
tmp10 = tl.load(in_ptr4 + index_1d, index_1d < xnumel)[:, None]
elif VERSION == 2:
# tl.reshape() causes a RuntimeError: Triton Error [CUDA]: misaligned address
index_1d = xoffset + tl.arange(0, XBLOCK)
tmp10 = tl.reshape(
tl.load(in_ptr4 + index_1d, index_1d < xnumel),
[XBLOCK, 1],
)
tmp3 = tmp2.to(tl.float32)
tmp5 = 16.0
tmp6 = tmp4 / tmp5
tmp7 = tl.where(tmp0, tmp3, tmp6)
tmp8 = tmp7.to(tl.float32)
tmp11 = tmp9 - tmp10
tmp12 = tmp8 * tmp11
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
tmp15 = tl.where(rmask & xmask, tmp13, 0)
tmp16 = tl.sum(tmp15, 1)[:, None]
tmp17 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp19 = tl.where(rmask & xmask, tmp17, 0)
tmp20 = tl.sum(tmp19, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp16, xmask)
tl.store(out_ptr1 + (x0), tmp20, xmask)
arg0_1 = rand_strided(
(8, 1024, 4, 4), (1024, 1, 0, 0), device="cuda:0", dtype=torch.float16
)
arg1_1 = rand_strided(
(8, 1024, 4, 4), (16384, 16, 4, 1), device="cuda:0", dtype=torch.bool
)
arg2_1 = rand_strided((), (), device="cuda:0", dtype=torch.float32)
arg3_1 = rand_strided(
(8, 1024, 4, 4), (16384, 16, 4, 1), device="cuda:0", dtype=torch.float32
)
arg4_1 = rand_strided(
(1, 1024, 1, 1), (1024, 1, 1, 1), device="cuda:0", dtype=torch.float32
)
buf0 = empty((1024,), device="cuda", dtype=torch.float32)
buf1 = empty((1024,), device="cuda", dtype=torch.float32)
xblock = 1
for version in range(3):
print(f"Running with VERSION={version}")
triton_[(triton.cdiv(1024, xblock),)](
arg1_1,
arg2_1,
arg0_1,
arg3_1,
arg4_1,
buf0,
buf1,
XBLOCK=xblock,
VERSION=version,
)
torch.cuda.synchronize()
print("SUCCESS")
Output:
$ CUDA_LAUNCH_BLOCKING=1 python repro.py
Running with VERSION=0
SUCCESS
Running with VERSION=1
SUCCESS
Running with VERSION=2
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 93, in <module>
triton_[(triton.cdiv(1024, xblock),)](
File "/home/jansel/conda/envs/pytorch/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: misaligned address
This code is originally from TorchInductor compiling the following program:
def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
div = torch.ops.aten.div.Scalar(expand, 16)
where = torch.ops.aten.where.self(arg207_1, full, div)
convert_element_type_43 = torch.ops.prims.convert_element_type.default(
where, torch.float32
)
sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
return (sub_2,)
Which was generated by running the PyTorch minifier on on a 1500+ node graph coming from selecsls42b.
I think @plotfi is looking at this!
I have some findings on this.
1.
At https://github.com/openai/triton/blob/0327b9d32db6d1d63d207ccab722bd45e00a6678/python/src/llvm.cc#L173 Triton is enabling the SLPVectorizer with an empty target machine, in order to get wider vectors. This results in sequences of IR instructions like:
%i123 = load i8, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1, !dbg !19
%.not97 = icmp eq i8 %i123, 0, !dbg !19
%i124 = load i8, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 98), align 1, !dbg !19
to be turned into:
%1 = load <16 x i8>, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1
I tried getting the IR out of Triton before the SLPVectorizer and running it through LLVM opt with and without the Target (passing -mtriple=nvptx64-nvidia-cuda), and when the target is passed it wont produces these vector loads that ultimately result in the misaligned access.
By the way the problematic loads originate from Python line 31 column 80: tmp0 = tl.load(in_ptr0 + (r1 + (16 * x0) + (16384 * r2)), rmask & xmask).to(tl.int1)). Also, before these loads are materialized in the first place they come from triton_gpu.convert_layout as part of layout conversions to support the reshape operation.
2.
Later in the compilation pipeline in LLVM's SelectionDAG the load<16 x i8> gets turned into a dag node NVPTXISD::LoadV4 and then lowered into a MIR instruction of NVPTX::LDV_i32_v4_asi at https://github.com/llvm/llvm-project/blob/f07eb24bb003aea435da94569910529fe5e332a4/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp#L1134
This is the part that produces what ultimately becomes the PTX ld.shared.v4.u32 {%r676, %r677, %r678, %r679}, [global_smem+97];
Summary
- So in summary the tl.reshape requires a triton_gpu.convert_layout which results in a valid sequence of loads to addrspace 3 (shared memory).
- This sequence of loads is vectorized by the SLPVectorizer without regard to legalize vector lengths or alignment (as part of what seems to be an expedient path to better performance)
- Later in Machine lowering the vector load is materialized into an unaligned machine instruction load of 4xi32 at offset 97 that results in the misaligned PTX
Potential Solutions
Still brainstorming, but here are some ideas
- Introduce a separate legalization pass, since we arent providing target information to SLPVectorizer
- Add additional legalization to the NVPTX backend
- Introduce some additional rules around misaligned accesses to the SLPVectorizer without directly involving the target machine
@Artem-B is this a bug in the nvptx backend's legalization routine?
load <16 x i8>, ptr addrspace(3) getelementptr ([0 x i8], ptr addrspace(3) @global_smem, i64 0, i64 97), align 1
A vector aligned by 1 should not have been lowered into a vectorized load. Period. The only way to load it correctly is to split it into individual byte loads. https://godbolt.org/z/4PPovnxxb
Looks like a bug to me.
E.g. here's what we do with an array: https://godbolt.org/z/rszcqdqoc though in this case we fail to vectorize even if we align by 16.
delete my previous message which was incorrect, so this is originally an align1 load for i8 type. In this case LLVM IR is correct but then it should have been broken down here: https://github.com/llvm/llvm-project/blob/25e1916d88ebeef786956b678a4eb9a757e219d9/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L5730
I wonder why it didn't hit this condition. @plotfi it could be interesting to put a break point there and see why it is not taking this path.