[BUG] segfault with `cute.gemm`
Hi! I'm trying to use cute.gemm with wgmma, and I'm getting a segfault that I don't quite know how to debug. I have the following simple code:
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack
M, N, K = 128, 128, 64
@cute.kernel
def gemm_kernel(sA_layout: cute.ComposedLayout,
sBT_layout: cute.ComposedLayout,
sC_layout: cute.Layout,
mma: cute.TiledMma):
tidx, _, _ = cute.arch.thread_idx()
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(cutlass.BFloat16, sA_layout, 16)
sBT = smem.allocate_tensor(cutlass.BFloat16, sBT_layout, 16)
sC = smem.allocate_tensor(cutlass.Float32, sC_layout, 16)
thr_mma = mma.get_slice(tidx)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sBT)
tCsC = thr_mma.partition_shape_C(sC.shape)
tCrA = thr_mma.make_fragment_A(tCsA)
tCrB = thr_mma.make_fragment_B(tCsB)
tCrC = thr_mma.make_fragment_C(tCsC)
print(tCsA)
print(tCsB)
print(tCrC)
cute.arch.barrier()
cute.gemm(mma, tCrC, tCrA, tCrB, tCrC)
@cute.jit
def launch_gemm():
layout_mn_sw128_atom = cute.make_composed_layout(
inner=cute.make_swizzle(3, 4, 3),
offset=0,
outer=cute.make_layout((64, 8), stride=(1, 64))
)
layout_k_sw128_atom = cute.make_composed_layout(
inner=cute.make_swizzle(3, 4, 3),
offset=0,
outer=cute.make_layout((8, 64), stride=(64, 1))
)
print(layout_mn_sw128_atom)
print(layout_k_sw128_atom)
sA_layout = cute.tile_to_shape(layout_k_sw128_atom, (M, K), (0, 1))
sBT_layout = cute.tile_to_shape(layout_mn_sw128_atom, (N, K), (0, 1))
sC_layout = cute.make_layout((M, N), stride=(N, 1))
print(sA_layout)
print(sBT_layout)
smem_size = (cute.size_in_bytes(cutlass.BFloat16, sA_layout) +
cute.size_in_bytes(cutlass.BFloat16, sBT_layout) +
cute.size_in_bytes(cutlass.Float32, sC_layout))
op = cute.nvgpu.warpgroup.MmaF16BF16Op(
ab_dtype=cutlass.BFloat16,
acc_dtype=cutlass.Float32,
instruction_shape=(64, 64, 16),
a_src=cute.nvgpu.warpgroup.OperandSource.SMEM,
a_major_mode=cute.nvgpu.warpgroup.OperandMajorMode.K,
b_major_mode=cute.nvgpu.warpgroup.OperandMajorMode.MN
)
tC = cute.make_layout((1, 1, 1), stride=(1, 1, 1))
tiled_mma = cute.make_tiled_mma(op, tC)
gemm_kernel(sA_layout, sBT_layout, sC_layout, tiled_mma
).launch(
grid=(1, 1, 1),
block=(128, 1, 1),
smem=smem_size
)
launch_gemm()
And it prints the following:
S<3,4,3> o 0 o (64,8):(1,64)
S<3,4,3> o 0 o (8,64):(64,1)
S<3,4,3> o 0 o ((8,16),(64,1)):((64,512),(1,0))
S<3,4,3> o 0 o ((64,2),(8,8)):((1,512),(64,1024))
tensor<ptr<bf16, smem, align<1024>> o S<3,4,3> o 0 o ((64,16),2,4):((64,1),4096,16)>
tensor<ptr<bf16, smem, align<1024>> o S<3,4,3> o 0 o ((64,(8,2)),2,4):((1,(64,1024)),512,2048)>
tensor<ptr<f32, rmem, align<128>> o ((2,2,8),2,2):((1,2,4),32,64)>
Segmentation fault
Basically everything up to the segfault seems like the values that I want. Not really sure what I'm doing wrong here. Any ideas?
Edit: a further note is that the segfault seems to happen in the CuTe compiler, not in the kernel (based on gdb)
The sA and sBT's layout can't be composed layout, try to add below codes after getting the sA and sB:
sA = cute.make_tensor(cute.recast_ptr(sA.iterator, sA_layout.inner, dtype=sA.element_type), sA_layout.outer)
sBT = cute.make_tensor(cute.recast_ptr(sBT.iterator, sBT_layout.inner, dtype=sBT.element_type), sBT_layout.outer)
You can also refer the how the _MemRangeData gets the smem tensor.
BTW, you need to initialize the cuda context before calling the kernel function
cutlass.cuda.initialize_cuda_context()
Thanks for the reply! I appreciate it :) Follow up questions:
- What is the idea of composed layouts if you can't actually use them to initialize smem tensors? Where should I be using them?
- Sorry, I don't see any example code that uses _MemRangeData. Could you please point me to what you are referring to?
- Do you think it would be possible to add an error message instead of a segfault? It's very hard to debug this as a user.
@axelfeldmann agreed that this is very non-intuitive and at the very least the compiler should not crash here. We will triage and fix. I too would expect composed layouts here to just work (tm)
@axelfeldmann
- As far as I know, you can use composed layout as you want except for the smem of umma and gmma which only supports layout and expects the swizzle info is in the ptr. We may support composed layout in the future release, still under discussion.
- There's no codes to use
_MemRangeDatadirectly, it's exposed by theget_tensorfunction, see https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm.py#L624 - Yes, there're many issues related to the error info, we are keeping improving it. Thanks!
I see, thanks. Is Hopper generally unsupported for the CuTe DSL?
CuTe DSL supports Hopper in deed, the hopper gemm example is coming soon.
For your codes, I think another solution is:
sA = smem.allocate_tensor(cutlass.BFloat16, sA_layout.outer, 16, sA_layout.inner)
sBT = smem.allocate_tensor(cutlass.BFloat16, sBT_layout.outer, 16, sBT_layout.inner)
see the allocate_tensor's detail
@axelfeldmann there is a subtle difference here between two different kinds of swizzle layouts. The Ampere era MMAs use position independent swizzle layouts (PISL) whereas Hopper and Blackwell MMAs use position dependent swizzle layouts (PDSL). The former is represented with a true composed layout and the latter is represented with an affine layout but a swizzle pointer. In the case of the Hopper kernel. the smem layout should be affine but the pointer in the tensor will have a swizzle instead.
Thanks @thakkarV
Is there anywhere I can read more about this distinction and how this all works?
Closest we have this: https://github.com/NVIDIA/cutlass/blob/main/include/cute/pointer_swizzle.hpp#L43
The most we can do to make this easier on people is provide utilities for picking the right smem tensor / layout based on each arch. we are already working on improving the error message here.
FYI - This is fixed in the latest release. Feel free to let us know if you see any other issue :-)
Thanks! I think there may be a similar issue with cute.nvgpu.cpasync.tma_partition? If I try to use the composed layout, then the gemm works fine but the tma_partition breaks. Maybe I'm doing something wrong though?
@axelfeldmann Both tma_partition of cute c++ and dsl don't support composed layout.
If I try to use the composed layout, then the gemm works fine but the tma_partition breaks. Maybe I'm doing something wrong though? Just want to check, did you get useful error message or the kernel segfault with the composed layout? thanks.
I think we can close this issue for now, please feel free to report back if you see any other issues. @Junkai-Wu can you help close this issue?
The fix is in the latest v4.1 release. Closed as @brandon-yujie-sun suggested.