triton icon indicating copy to clipboard operation
triton copied to clipboard

Segfault from `ZBLOCK=1` in 3D kernel

Open jansel opened this issue 3 years ago • 0 comments

import torch
import triton
import triton.language as tl
from torch import empty_strided
from triton import cdiv


@triton.autotune(
    [
        triton.Config(
            {
                "XBLOCK": 32,
                "YBLOCK": 32,
                # This segfaults:
                "ZBLOCK": 1
                # This works:
                # "ZBLOCK": 2
            }
        )
    ],
    key=["xnumel", "ynumel", "znumel"],
)
@triton.jit
def kernel4(
    in_ptr0,
    in_ptr1,
    in_ptr2,
    in_ptr3,
    in_ptr4,
    in_ptr5,
    in_ptr6,
    out_ptr0,
    xnumel: tl.constexpr,
    ynumel: tl.constexpr,
    znumel: tl.constexpr,
    XBLOCK: tl.constexpr,
    YBLOCK: tl.constexpr,
    ZBLOCK: tl.constexpr,
):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1, 1])
    xmask = xindex < xnumel
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.reshape(tl.arange(0, YBLOCK), [1, YBLOCK, 1])
    ymask = yindex < ynumel
    zoffset = tl.program_id(2) * ZBLOCK
    zindex = zoffset + tl.reshape(tl.arange(0, ZBLOCK), [1, 1, ZBLOCK])
    zmask = zindex < znumel
    x0 = xindex
    y1 = yindex
    z2 = zindex
    tmp0 = tl.load(in_ptr0 + x0 + (32 * y1), xmask & ymask)
    tmp2 = tl.load(in_ptr2 + z2 + (512 * y1), ymask & zmask)
    tmp4 = tl.load(in_ptr3 + y1 + (22 * x0), xmask & ymask)
    tmp8 = tl.load(in_ptr4 + x0 + (32 * y1), xmask & ymask)
    tmp16 = tl.load(in_ptr5 + z2, zmask)
    tmp18 = tl.load(in_ptr6 + z2, zmask)
    tmp1 = tl.load(
        in_ptr1 + z2 + (512 * tmp0) + tl.zeros([XBLOCK, YBLOCK, ZBLOCK], tl.int32),
        xmask & ymask & zmask,
    )
    tmp3 = tmp1 + tmp2
    tmp5 = 512.0
    tmp6 = tmp4 / tmp5
    tmp7 = tmp3 - tmp6
    tmp9 = 512
    tmp10 = tmp8 / tmp9
    tmp11 = 1e-06
    tmp12 = tmp10 + tmp11
    tmp13 = tl.sqrt(tmp12)
    tmp14 = 1 / tmp13
    tmp15 = tmp7 * tmp14
    tmp17 = tmp15 * tmp16
    tmp19 = tmp17 + tmp18
    tl.store(out_ptr0 + z2 + (512 * y1) + (11264 * x0), tmp19, xmask & ymask & zmask)


primals_113 = empty_strided((512,), (1,), device="cuda", dtype=torch.float32).zero_()
primals_114 = empty_strided((512,), (1,), device="cuda", dtype=torch.float32).zero_()
primals_187 = empty_strided(
    (1, 200, 512), (102400, 512, 1), device="cuda", dtype=torch.float32
).zero_()
primals_188 = empty_strided(
    (9521, 512), (512, 1), device="cuda", dtype=torch.float32
).zero_()
primals_190 = empty_strided((32, 22), (1, 32), device="cuda", dtype=torch.int64).zero_()
buf2 = empty_strided((32, 22), (22, 1), device="cuda", dtype=torch.float32).zero_()
buf1 = empty_strided((32, 22), (1, 32), device="cuda", dtype=torch.float32).zero_()
buf3 = empty_strided(
    (32, 22, 512), (11264, 512, 1), device="cuda", dtype=torch.float32
).zero_()


def grid(xnumel, ynumel=None, znumel=None):
    def grid_fn(meta):
        result = [cdiv(xnumel, meta["XBLOCK"])]
        if ynumel:
            result.append(cdiv(ynumel, meta["YBLOCK"]))
            if znumel:
                result.append(cdiv(znumel, meta["ZBLOCK"]))
        return result

    return grid_fn


kernel4[grid(32, 22, 512)](
    primals_190,
    primals_188,
    primals_187,
    buf2,
    buf1,
    primals_114,
    primals_113,
    buf3,
    32,
    22,
    512,
)
$ python repro.py 
zsh: segmentation fault (core dumped)  python repro.py

jansel avatar Aug 12 '22 23:08 jansel