Segmentation fault in triton==3.0.0
from typing import Optional
import torch
import triton
import triton.language as tl
@triton.jit
def _distance_bias_(diagram: tl.tensor,
lower_bound: tl.tensor,
upper_bound: tl.tensor,
weight: tl.tensor,
bias: tl.tensor,
c_ss: int,
):
lower_d_mask = diagram[:, :, None] > lower_bound[None, None, :]
upper_d_mask = diagram[:, :, None] < upper_bound[None, None, :]
d_mask = lower_d_mask * upper_d_mask
d_mask = d_mask.to(diagram.dtype)
tl.static_print(d_mask)
tl.static_print(weight)
# uncomment the following two line, work!
# o = d_mask[:, :, :, None] * weight[None, None, :, :]
# o = tl.sum(o, axis=2)
# the following line will cause segmentation fault.
o = tl.dot(d_mask, weight[None, :, :])
tl.static_print(o)
o = o * bias[None, None, :]
o = tl.sum(o, axis=2)
o = o / c_ss
return o
@triton.jit
def distance_bias_fwd_triton(output_ptr,
diagram_ptr,
lower_bound_ptr,
upper_bound_ptr,
weight_ptr,
bias_ptr,
diagram_stride,
weight_stride,
num_rows, num_cols, num_bins, c_ss,
distance_default,
BLOCK_ROW_SIZE: tl.constexpr,
BLOCK_COL_SIZE: tl.constexpr,
BLOCK_BIN_SIZE: tl.constexpr,
BLOCK_SS_SIZE: tl.constexpr,
INF,
):
pid = tl.program_id(axis=0)
block_row = tl.arange(0, BLOCK_ROW_SIZE)
block_col = tl.arange(0, BLOCK_COL_SIZE)
diagram_offset = block_row[:, None] * diagram_stride + block_col[None, :]
diagram_mask = (block_row[:, None] < num_rows) & (block_col[None, :] < num_cols)
diagram = tl.load(diagram_ptr + diagram_offset, mask=diagram_mask, other=INF)
block_bins = tl.arange(0, BLOCK_BIN_SIZE)
lower_bound = tl.load(lower_bound_ptr + block_bins, mask=(block_bins < num_bins), other=INF)
upper_bound = tl.load(upper_bound_ptr + block_bins, mask=(block_bins < num_bins), other=INF)
block_ss = tl.arange(0, BLOCK_SS_SIZE)
weight_offset = block_bins[:, None] * weight_stride + block_ss[None, :]
weight_mask = (block_bins[:, None] < num_bins) & (block_ss[None, :] < c_ss)
weight = tl.load(weight_ptr + weight_offset, mask=weight_mask, other=0.0)
bias = tl.load(bias_ptr + block_ss, mask=(block_ss < c_ss), other=0.0)
# ou = _distance_bias_(diagram, lower_bound, upper_bound, weight, bias, c_ss, BLOCK_ROW_SIZE, BLOCK_COL_SIZE)
ou = _distance_bias_(diagram, lower_bound, upper_bound, weight, bias, c_ss)
tl.store(output_ptr + diagram_offset, ou, mask=diagram_mask)
def distance_bias_fwd(diagram: torch.Tensor,
lower_bound: torch.Tensor,
upper_bound: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
INF: float,
no_bins: Optional[int] = None,
c_ss: Optional[int] = None,
) -> torch.Tensor:
assert diagram.is_cuda and lower_bound.is_cuda and upper_bound.is_cuda
assert weight.is_cuda and bias.is_cuda
num_rows, num_cols = diagram.shape
if no_bins is None:
no_bins = weight.shape[0]
if c_ss is None:
c_ss = weight.shape[1]
distance_default = 10
device, dtype = diagram.device, diagram.dtype
o = torch.zeros_like(diagram, device=device, dtype=dtype)
BLOCK_ROW_SIZE = triton.next_power_of_2(num_rows)
BLOCK_COL_SIZE = triton.next_power_of_2(num_cols)
grid = lambda meta: (triton.cdiv(num_rows, BLOCK_ROW_SIZE), triton.cdiv(num_cols, BLOCK_COL_SIZE))
# grid = lambda meta: (1,)
distance_bias_fwd_triton[grid](o,
diagram,
lower_bound,
upper_bound,
weight,
bias,
diagram.stride(0),
weight.stride(0),
num_rows, num_cols, no_bins, c_ss, distance_default,
BLOCK_ROW_SIZE,
BLOCK_COL_SIZE,
triton.next_power_of_2(no_bins),
triton.next_power_of_2(c_ss),
INF,
num_stages=4,
)
return o
if __name__ == '__main__':
torch.manual_seed(42)
dtype, device = torch.float32, 'cuda'
inf = 1e8
min_bin = 3.25
max_bin = 20.75
no_bins = 29
c_ss = 16
w = torch.randn((no_bins, c_ss), dtype=dtype, device=device, requires_grad=True)
b = torch.randn((c_ss,), dtype=dtype, device=device, requires_grad=True)
num_rows, num_cols = 16, 16
# num_rows, num_cols = 8, 8
bins = torch.linspace(min_bin, max_bin, no_bins,
dtype=dtype, device=device, requires_grad=False)
squared_bins = bins ** 2
upper_bins = torch.cat([squared_bins[1:], squared_bins.new_tensor([inf])], dim=-1)
d = 500 * torch.rand((num_rows, num_cols), dtype=dtype, device=device, requires_grad=False)
oo = distance_bias_fwd(d, squared_bins, upper_bins, w, b, inf, no_bins, c_ss)
# print(oo)
# print(oo_w)
# print(w)
# max_diff = torch.max(torch.abs(oo_w - w))
# print(max_diff)
pass
Running the above code lead to following error: """ fp32[constexpr[8], constexpr[16], constexpr[64]] fp32[constexpr[64], constexpr[64]] fp32[constexpr[8], constexpr[16], constexpr[64]] Segmentation fault (core dumped) """ It seem that tl.dot is a problem since using the following two line to replace the tl.dot make the code work!!! o = d_mask[:, :, :, None] * weight[None, None, :, :] o = tl.sum(o, axis=2) I am curious what happened and why is that? Thanks very much and appreciate your comments.
Looks like the insert_element in MMA16816SmemLoader::loadX4 is trying to insert at index 32 when it only has a vector of 4 elements when lowering the following:
%72 = triton_gpu.local_load %68 : !tt.memdesc<16x16x32xf32, #shared> -> tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> loc(#loc54)
%73 = triton_gpu.local_load %71 : !tt.memdesc<1x32x16xf32, #shared1> -> tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> loc(#loc56)
%74 = tt.dot %72, %73, %cst, inputPrecision = tf32 : tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> * tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> -> tensor<16x16x16xf32, #mma> loc(#loc56)
Crash is happening here:
https://github.com/triton-lang/triton/blob/main/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp#L415-L420
It is crashing because the canonWidth is 32 which goes out of the bounds of the retElems SmallVector that contains the 4 elements for the loadX4. I think if you were using bf16 it would probably be taking the ldmatrix path instead.
I am not sure how to fix this one.