import torch
import triton
import triton.language as tl
from torch import empty_strided
from triton import cdiv
@triton.autotune(
[
triton.Config(
{
"XBLOCK": 32,
"YBLOCK": 32,
# This segfaults:
"ZBLOCK": 1
# This works:
# "ZBLOCK": 2
}
)
],
key=["xnumel", "ynumel", "znumel"],
)
@triton.jit
def kernel4(
in_ptr0,
in_ptr1,
in_ptr2,
in_ptr3,
in_ptr4,
in_ptr5,
in_ptr6,
out_ptr0,
xnumel: tl.constexpr,
ynumel: tl.constexpr,
znumel: tl.constexpr,
XBLOCK: tl.constexpr,
YBLOCK: tl.constexpr,
ZBLOCK: tl.constexpr,
):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1, 1])
xmask = xindex < xnumel
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.reshape(tl.arange(0, YBLOCK), [1, YBLOCK, 1])
ymask = yindex < ynumel
zoffset = tl.program_id(2) * ZBLOCK
zindex = zoffset + tl.reshape(tl.arange(0, ZBLOCK), [1, 1, ZBLOCK])
zmask = zindex < znumel
x0 = xindex
y1 = yindex
z2 = zindex
tmp0 = tl.load(in_ptr0 + x0 + (32 * y1), xmask & ymask)
tmp2 = tl.load(in_ptr2 + z2 + (512 * y1), ymask & zmask)
tmp4 = tl.load(in_ptr3 + y1 + (22 * x0), xmask & ymask)
tmp8 = tl.load(in_ptr4 + x0 + (32 * y1), xmask & ymask)
tmp16 = tl.load(in_ptr5 + z2, zmask)
tmp18 = tl.load(in_ptr6 + z2, zmask)
tmp1 = tl.load(
in_ptr1 + z2 + (512 * tmp0) + tl.zeros([XBLOCK, YBLOCK, ZBLOCK], tl.int32),
xmask & ymask & zmask,
)
tmp3 = tmp1 + tmp2
tmp5 = 512.0
tmp6 = tmp4 / tmp5
tmp7 = tmp3 - tmp6
tmp9 = 512
tmp10 = tmp8 / tmp9
tmp11 = 1e-06
tmp12 = tmp10 + tmp11
tmp13 = tl.sqrt(tmp12)
tmp14 = 1 / tmp13
tmp15 = tmp7 * tmp14
tmp17 = tmp15 * tmp16
tmp19 = tmp17 + tmp18
tl.store(out_ptr0 + z2 + (512 * y1) + (11264 * x0), tmp19, xmask & ymask & zmask)
primals_113 = empty_strided((512,), (1,), device="cuda", dtype=torch.float32).zero_()
primals_114 = empty_strided((512,), (1,), device="cuda", dtype=torch.float32).zero_()
primals_187 = empty_strided(
(1, 200, 512), (102400, 512, 1), device="cuda", dtype=torch.float32
).zero_()
primals_188 = empty_strided(
(9521, 512), (512, 1), device="cuda", dtype=torch.float32
).zero_()
primals_190 = empty_strided((32, 22), (1, 32), device="cuda", dtype=torch.int64).zero_()
buf2 = empty_strided((32, 22), (22, 1), device="cuda", dtype=torch.float32).zero_()
buf1 = empty_strided((32, 22), (1, 32), device="cuda", dtype=torch.float32).zero_()
buf3 = empty_strided(
(32, 22, 512), (11264, 512, 1), device="cuda", dtype=torch.float32
).zero_()
def grid(xnumel, ynumel=None, znumel=None):
def grid_fn(meta):
result = [cdiv(xnumel, meta["XBLOCK"])]
if ynumel:
result.append(cdiv(ynumel, meta["YBLOCK"]))
if znumel:
result.append(cdiv(znumel, meta["ZBLOCK"]))
return result
return grid_fn
kernel4[grid(32, 22, 512)](
primals_190,
primals_188,
primals_187,
buf2,
buf1,
primals_114,
primals_113,
buf3,
32,
22,
512,
)
$ python repro.py
zsh: segmentation fault (core dumped) python repro.py