[Bug] Add wrong `__restrict__` modifiers to pointers
import torch
import tilelang
from tilelang import language as T
@tilelang.jit()
def get_buggy_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'float'],
y: T.Tensor[(num_tokens, ), 'float']):
with T.Kernel(num_tokens, threads=32) as pid:
y[pid] = x[pid] + 1
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel()
print(kernel.get_kernel_source())
x = torch.randn((256, ), dtype=torch.float, device='cuda')
kernel(x[32:96], x[64:128])
The CUDA kernel after compilation:
extern "C" __global__ void buggy_kernel_kernel(float* __restrict__ x, float* __restrict__ y, int num_tokens);
extern "C" __global__ void __launch_bounds__(32, 1) buggy_kernel_kernel(float* __restrict__ x, float* __restrict__ y, int num_tokens) {
y[((int)blockIdx.x)] = (x[((int)blockIdx.x)] + 0x1p+0f/*1.000000e+00*/);
}
As x and y are overlapped, the __restrict__ modifiers are wrong.
It's likely that restrict is mainly beneficial for kernels that directly read from and write to global memory. This pattern is relatively rare in efficient AI workloads, so it might be reasonable to remove all __restrict__ annotations from kernel parameters.
However, __restrict__ may also play an important role in the performance of utility kernels. For instance, cutlass example kernels such as permute, gather, and scatter — make extensive use of this annotation. We should further discuss this topic and determine the best way to incorporate it into our design.
also cc @Hzfengsy @chengyupku @Rachmanino ?
One possible solution is introducing an annotation to manually annotate restrict buffers:
import torch
import tilelang
from tilelang import language as T
@tilelang.jit()
def get_buggy_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'float'],
y: T.Tensor[(num_tokens, ), 'float']):
with T.Kernel(num_tokens, threads=32) as pid:
T.annotate_restrict_buffers(x, y)
y[pid] = x[pid] + 1
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel()
print(kernel.get_kernel_source())
x = torch.randn((256, ), dtype=torch.float, device='cuda')
kernel(x[32:96], x[64:128])
I prefer to annotate non-restrict buffers as it's unusual. The syntax can be:
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'float', restrict=False],
y: T.Tensor[(num_tokens, ), 'float', restrict=False]):
with T.Kernel(num_tokens, threads=32) as pid:
T.annotate_restrict_buffers(x, y)
y[pid] = x[pid] + 1
One discussable point is that if we need to have a runtime check for the restrict modifier. If all tensors are contiguous, with start addresses and known shapes (static or dynamic), somehow we can check it from the host side before launching the kernel.