cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Deadlock in producer consumer loop

Open axelfeldmann opened this issue 6 months ago • 7 comments

Hi!

I'm trying to understand how to use the PipelineTmaAsync utility, so I wrote a simple program where one warp group loads tiles and another one does nothing with them but "consumes" them nonetheless.

My simple program is as follows:

import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack

M = 64
BK = 128
nstages = 3

class ProducerConsumerKernel:

    def __init__(self):
        self.shared_storage = None

    @cute.kernel
    def kernel(self, tma_atom_a: cute.CopyAtom, mA_tma: cute.Tensor, sA_layout: cute.ComposedLayout):
        tidx, _, _ = cute.arch.thread_idx()

        smem = cutlass.utils.SmemAllocator()
        storage = smem.allocate(self.shared_storage)

        pipeline_barriers_ptr = storage.pipeline_barriers.data_ptr()

        producer_group = utils.CooperativeGroup(utils.Agent.Thread)
        consumer_group = utils.CooperativeGroup(utils.Agent.Thread)

        tma_copy_bytes = M * BK * 2

        pipeline = utils.PipelineTmaAsync.create(
            barrier_storage=pipeline_barriers_ptr,
            num_stages=nstages,
            producer_group=producer_group,
            consumer_group=consumer_group,
            tx_count=tma_copy_bytes, 
            cta_layout_vmnk=cute.make_layout((1, 1, 1))
        )

        sa = storage.sa.get_tensor(
            sA_layout.outer, swizzle=sA_layout.inner
        )

        ga = cute.local_tile(mA_tma, tiler=(M, BK), coord=(0, None))
        ga_for_tma_partition = cute.group_modes(ga, 0, 2)
        sa_for_tma_partition = cute.group_modes(sa, 0, 2)

        tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
            atom=tma_atom_a,
            cta_coord=(0,),
            cta_layout=cute.make_layout((1,)),
            smem_tensor=sa_for_tma_partition,
            gmem_tensor=ga_for_tma_partition
        )

        k_tile_cnt = cute.size(ga, mode=[2])

        cute.arch.sync_threads()

        if tidx < 128:
            # producer
            producer_state = utils.make_pipeline_state(
                utils.PipelineUserType.Producer, nstages
            )

            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                if tidx < 32:

                    if tidx == 0:
                        cute.printf("waiting for space: %d %d %d", k, producer_state.index, producer_state.count)

                    pipeline.producer_acquire(producer_state)

                    tAgA_k = tAgA[(None, producer_state.count)]
                    tAsA_pipe = tAsA[(None, producer_state.index)]

                    if tidx == 0:
                        cute.printf("actually loading: %d %d %d", k, producer_state.index, producer_state.count)

                    cute.copy(
                        tma_atom_a,
                        tAgA_k,
                        tAsA_pipe,
                        tma_bar_ptr=pipeline.producer_get_barrier(
                            producer_state
                        )
                    )
                    pipeline.producer_commit(producer_state)
                    producer_state.advance()

        else:
            # consumer
            read_state = utils.make_pipeline_state(
                utils.PipelineUserType.Consumer, nstages
            )
            release_state = utils.make_pipeline_state(
                utils.PipelineUserType.Consumer, nstages
            )

            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                
                if tidx == 128:
                    cute.printf("waiting for data: %d %d %d", k, read_state.index, read_state.count)
                
                pipeline.consumer_wait(read_state)
                
                if tidx == 128:
                    cute.printf("received: %d %d %d", k, read_state.index, read_state.count)
                pipeline.consumer_release(release_state)
                
                if tidx == 128:
                    cute.printf("released: %d %d %d", k, release_state.index, release_state.count)
                read_state.advance()
                release_state.advance()

    @cute.jit
    def __call__(self, mA: cute.Tensor):

        sw128_k_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
            kind=cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128,
            element_type=cutlass.BFloat16
        )

        sA_layout = cute.tile_to_shape(
            sw128_k_atom,
            (M, BK, nstages),
            order=(0, 1, 2)
        )

        basic_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
        tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tma_tile_atom(
            op=basic_tma_op,
            gmem_tensor=mA,
            smem_layout=cute.slice_(sA_layout, (None, None, 0)),
            cta_tiler=(M, BK)
        )

        @cute.struct
        class SharedStorage:
            pipeline_barriers: cute.struct.MemRange[cutlass.Int64, nstages * 2]
            sa: cute.struct.Align[
                cute.struct.MemRange[cutlass.BFloat16, cute.cosize(sA_layout)],
                1024
            ]
        self.shared_storage = SharedStorage

        self.kernel(
            tma_atom_a,
            tma_tensor_a,
            sA_layout
        ).launch(
            grid=(1, 1, 1),
            block=(256, 1, 1),
            smem=self.shared_storage.size_in_bytes()
        )

K = 4096
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
a_tensor = from_dlpack(a, assumed_align=16)

cutlass.cuda.initialize_cuda_context()
func = cute.compile(ProducerConsumerKernel(), a_tensor)
func(a)
torch.cuda.synchronize()

This program deadlocks. I get the following trace:

waiting for data: 0 0 0
waiting for space: 0 0 0
actually loading: 0 0 0
received: 0 0 0
waiting for space: 1 1 1
released: 0 0 0 <--------- release index 0
actually loading: 1 1 1
waiting for data: 1 1 1
waiting for space: 2 2 2
received: 1 1 1
actually loading: 2 2 2
released: 1 1 1
waiting for space: 3 0 3 <------ the producer never acquires index 0
waiting for data: 2 2 2
received: 2 2 2
released: 2 2 2
waiting for data: 3 0 3 <------- so the consumer never receives any data at index 0
<nothing more>

This is quite confusing to me, because index 0 clearly gets released by the consumer, yet the producer is never able to acquire it again.

I am sure that somehow I am doing this wrong, but I'm not really sure how. Does anyone have any ideas?

Thanks!

axelfeldmann avatar Jun 15 '25 00:06 axelfeldmann

You haven't set the size of CooperativeGroup. The default value is 1. From your program, the producer_group has size of 32 ('if tidx < 32') while the consumer_group has size of 128 ( 256 - 128).

In this case, doing pipeline.producer_commit(producer_state) by 32 threads or doing pipeline.consumer_release(release_state) by 128 threads will finally make phase remain the same as not doing these (or depends on the timing for the unstable state). So, the second round of require/wait will be struck at checking phase.

keithzzzzz avatar Jun 16 '25 01:06 keithzzzzz

This makes sense. I tried fixing it by doing the following:

        producer_group = utils.CooperativeGroup(utils.Agent.Thread, size=32, alignment=32)
        consumer_group = utils.CooperativeGroup(utils.Agent.Thread, size=128, alignment=128)

(with no other changes).

Unfortunately, this deadlocks in the exact same way. Do you know if I am doing something else wrong? Also, more generally, do you have tips on how to debug these issues?

axelfeldmann avatar Jun 16 '25 02:06 axelfeldmann

I think there're many codes can be improved. Before analyzing the deadlock, it's better to make the codes more precise first.

  • the producer arrive cnt should be 1 since only one thread will do the tma copy. The producer_commit of TmaAsyncPipeline does nothing. And the arrive behavior is done in the tma copy internally.
  • cta_layout_vmnk should be ((1,1,1,1))
  • there're missing parameter tidx of the utils.PipelineTmaAsync.create, in your case ,it should be 128
  • please remove the printf before the producer_acquire, consumer_wait, consumer_release, it will make the threads divergent. You can add it before the advance() to see the info.
  • one consumer_state is enough

Could you refine your codes to see if the deadlock disappears?

Jie-Fang avatar Jun 16 '25 03:06 Jie-Fang

the producer arrive cnt should be 1 since only one thread will do the tma copy. The producer_commit of TmaAsyncPipeline does nothing. And the arrive behavior is done in the tma copy internally.

I'm assuming you mean:

        producer_group = utils.CooperativeGroup(utils.Agent.Thread, size=1, alignment=1)
        consumer_group = utils.CooperativeGroup(utils.Agent.Thread, size=128, alignment=128)

Correct?

cta_layout_vmnk should be ((1,1,1,1))

Done.

there're missing parameter tidx of the utils.PipelineTmaAsync.create

I don't believe this parameter actually exists? I also saw this on the documentation but when I tried it, I get:

TypeError: PipelineTmaAsync.create() got an unexpected keyword argument 'tidx'

I just pip installed a fresh copy ~5 minutes ago, so I think my version is current.

please remove the printf before the producer_acquire, consumer_wait, consumer_release, it will make the threads divergent. You can add it before the advance() to see the info.

I've done this, but can you explain why this matters? Based on my understanding of cuda this should not matter at all. I'm curious what I am missing.

one consumer_state is enough

Done, but again, could you please explain why this matters? In the Hopper dense_gemm.py example, there are separate read and release states.

I tried all of these things, and my code still deadlocks. Here is an updated version of the kernel:

    @cute.kernel
    def kernel(self, tma_atom_a: cute.CopyAtom, mA_tma: cute.Tensor, sA_layout: cute.ComposedLayout):
        tidx, _, _ = cute.arch.thread_idx()

        smem = cutlass.utils.SmemAllocator()
        storage = smem.allocate(self.shared_storage)

        pipeline_barriers_ptr = storage.pipeline_barriers.data_ptr()

        producer_group = utils.CooperativeGroup(utils.Agent.Thread, size=1, alignment=1)
        consumer_group = utils.CooperativeGroup(utils.Agent.Thread, size=128, alignment=128)

        tma_copy_bytes = M * BK * 2

        pipeline = utils.PipelineTmaAsync.create(
            barrier_storage=pipeline_barriers_ptr,
            num_stages=nstages,
            producer_group=producer_group,
            consumer_group=consumer_group,
            tx_count=tma_copy_bytes,
            cta_layout_vmnk=cute.make_layout((1, 1, 1, 1)),
        )

        sa = storage.sa.get_tensor(
            sA_layout.outer, swizzle=sA_layout.inner
        )

        ga = cute.local_tile(mA_tma, tiler=(M, BK), coord=(0, None))
        ga_for_tma_partition = cute.group_modes(ga, 0, 2)
        sa_for_tma_partition = cute.group_modes(sa, 0, 2)

        tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
            atom=tma_atom_a,
            cta_coord=(0,),
            cta_layout=cute.make_layout((1,)),
            smem_tensor=sa_for_tma_partition,
            gmem_tensor=ga_for_tma_partition
        )

        k_tile_cnt = cute.size(ga, mode=[2])

        cute.arch.sync_threads()

        if tidx < 128:
            # producer
            producer_state = utils.make_pipeline_state(
                utils.PipelineUserType.Producer, nstages
            )

            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                if tidx < 32:
                    pipeline.producer_acquire(producer_state)

                    tAgA_k = tAgA[(None, producer_state.count)]
                    tAsA_pipe = tAsA[(None, producer_state.index)]

                    cute.copy(
                        tma_atom_a,
                        tAgA_k,
                        tAsA_pipe,
                        tma_bar_ptr=pipeline.producer_get_barrier(
                            producer_state
                        )
                    )
                    producer_state.advance()
        else:
            # consumer
            read_state = utils.make_pipeline_state(
                utils.PipelineUserType.Consumer, nstages
            )
            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                pipeline.consumer_wait(read_state)
                pipeline.consumer_release(read_state)
                read_state.advance()

axelfeldmann avatar Jun 16 '25 03:06 axelfeldmann

TypeError: PipelineTmaAsync.create() got an unexpected keyword argument 'tidx'

Sorry, I don't notice I was looking at the latest codes internally. For the pip installed packages, there's no tidx parameter. You can take a look at the consumer_release of PipelineTmaAsync, there's a signalling thread to issue the consumer_release. But it's always the threadIdx.x in the public codes which results in the is_signalling_thread = tidx < cute.size(cluster_shape_mnk) becomes False in your case (tidx >= 128). In order to make your codes work, you can swap the producer warps(tidx >= 128) and consumer warps (tidx < 128). And make sure the arrive cnt should be 1.

producer_group = utils.CooperativeGroup(utils.Agent.Thread, size=1, alignment=1)
consumer_group = utils.CooperativeGroup(utils.Agent.Thread, size=1, alignment=1)

but can you explain why this matters?

I just want to eliminate this impact to have a clean code to do analysis.

In the Hopper dense_gemm.py example, there are separate read and release states.

It's used for the high performance. In your case, there's no need to use two states.

Jie-Fang avatar Jun 16 '25 05:06 Jie-Fang

Thanks so much! This fixes the problem. Two final questions:

  1. Is this considered a bug or an intended behavior? For example, with the public version, is it possible to have two consumer warp groups? Also, am I correct in thinking that cluster_shape_mnk is only for using CGAs?
  2. One thing I find a bit confusing about python CuTe is the "internal if statements." (another example is here) It's often not clear to me how many threads need to call something and which threads are "gated" inside of the library call. Is there any way to know these conditions? Are they documented anywhere?

axelfeldmann avatar Jun 16 '25 12:06 axelfeldmann

Is this considered a bug or an intended behavior?

Yes, this is a bug. We want to fix it in the future release. Thanks for finding this bug!

with the public version, is it possible to have two consumer warp groups?

It's possible. But we need to remember there's only one thread of a warp can issue the consumer release, we need extra work to guarantee that the two consumer warp groups complete their work. Or you can hack the pipeline.py to let all threads in the consumer warps can issue the consumer release.

am I correct in thinking that cluster_shape_mnk is only for using CGAs?

yes.

Are they documented anywhere?

sorry about that. We don't have documentation on these internal if statement. Yes, it's natural for user to use if tidx == 0 if the instruction is issued by the one thread, if we don't have any documentation about that, it's really confusing. At this point, you can refer the blackwell or hopper examples to see how to use these APIs if you have questions. Seems it's better for us update the docstring of some APIs.

Jie-Fang avatar Jun 16 '25 13:06 Jie-Fang

Follow up question: suppose that we want to have 2 warpgroups in the consumer group. Is it possible to specify that with a single pipeline object now? More precisely, is it possible to run a loop like this correctly with:

if warp_idx == 8:
  # producer
elif warp_idx  < 8:
  # consumer

I am trying this and I don't know what consumer_group to define. Should it be:

consumer_group = utils.CooperativeGroup(utils.Agent.Thread, 2)

(that hangs forever) or

consumer_group = utils.CooperativeGroup(utils.Agent.Thread, 1)

(I think one of the warpgroups does not wait for the whole MMA transfer properly?

axelfeldmann avatar Jun 20 '25 15:06 axelfeldmann

Seems your pipeline isn't TmaAsync, so it can have different behaviors on the corresponding APIs. You should take a look at these implementations.

Jie-Fang avatar Jun 21 '25 03:06 Jie-Fang

Sorry, I'm confused. I'm using PipelineTmaAsync from the python DSL (same as in the original post) and I don't think I can look at the API? It's not very documented and it's not open source. I'm quite confused

axelfeldmann avatar Jun 21 '25 03:06 axelfeldmann

I'm using PipelineTmaAsync from the python DSL (same as in the original post)

Got it.

I don't think I can look at the API?

I mean you can look at the pipeline.py which should locate at the installation folder of python packages.

Could you provide your codes so that I can have a look?

Jie-Fang avatar Jun 21 '25 03:06 Jie-Fang

Unfortunately I don't have a small and concise reproducer, I only have my entire gemm code that I am working on.

import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack

torch.manual_seed(0)

class GemmKernel:
    def __init__(self, BM: int, BN: int, BK: int, nstages: int = 3):
        self.BM = BM
        self.BN = BN
        self.BK = BK
        self.nstages = nstages
        self.GROUP_SIZE_M = 8

    @cute.kernel
    def kernel(self,
               tma_atom_a: cute.CopyAtom,
               mA_tma: cute.Tensor,
               tma_atom_bT: cute.CopyAtom,
               mBT_tma: cute.Tensor,
               sA_layout: cute.ComposedLayout,
               sBT_layout: cute.ComposedLayout,
               mC: cute.Tensor,
               mma: cute.TiledMma):
        
        tidx, _, _ = cute.arch.thread_idx()
        pid, _, _ = cute.arch.block_idx()

        warp_idx = cute.arch.warp_idx()
        warp_idx = cute.arch.make_warp_uniform(warp_idx)

        M, N = mC.shape

        num_pid_m, num_pid_n = M // self.BM, N // self.BN
        num_pid_in_group = self.GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * self.GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, self.GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid %num_pid_in_group) // group_size_m

        bidy, bidx = pid_m, pid_n

        smem = cutlass.utils.SmemAllocator()
        storage = smem.allocate(self.shared_storage)

        pipeline_barriers_ptr = storage.pipeline_barriers.data_ptr()

        consumer_group = utils.CooperativeGroup(utils.Agent.Thread)
        producer_group = utils.CooperativeGroup(utils.Agent.Thread)

        tma_copy_bytes = self.BM * self.BK * 2 + self.BN * self.BK * 2

        pipeline = utils.PipelineTmaAsync.create(
            barrier_storage=pipeline_barriers_ptr,
            num_stages=self.nstages,
            producer_group=producer_group,
            consumer_group=consumer_group,
            tx_count=tma_copy_bytes,
            cta_layout_vmnk=cute.make_layout((1, 1, 1, 1))
        )

        sA = storage.sA.get_tensor(sA_layout.outer, swizzle=sA_layout.inner)
        sBT = storage.sBT.get_tensor(sBT_layout.outer, swizzle=sBT_layout.inner)

        gA = cute.local_tile(mA_tma, tiler=(self.BM, self.BK), coord=(bidy, None))
        gBT = cute.local_tile(mBT_tma, tiler=(self.BN, self.BK), coord=(bidx, None))
        gC = cute.local_tile(mC, tiler=(self.BM, self.BN), coord=(bidy, bidx))

        tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
            atom=tma_atom_a,
            cta_coord=(0, 0),
            cta_layout=cute.make_layout((1, 1)),
            smem_tensor=cute.group_modes(sA, 0, 2),
            gmem_tensor=cute.group_modes(gA, 0, 2)
        )
        tBsBT, tBgBT = cute.nvgpu.cpasync.tma_partition(
            atom=tma_atom_bT,
            cta_coord=(0, 0),
            cta_layout=cute.make_layout((1, 1)),
            smem_tensor=cute.group_modes(sBT, 0, 2),
            gmem_tensor=cute.group_modes(gBT, 0, 2)
        )

        k_tile_cnt = cute.size(gA, mode=[2])

        if warp_idx == 8:
            producer_state = utils.make_pipeline_state(
                utils.PipelineUserType.Producer, self.nstages
            )

            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):

                pipeline.producer_acquire(producer_state)

                tAgA_k = tAgA[(None, producer_state.count)]
                tAsA_pipe =  tAsA[(None, producer_state.index)]

                tBgBT_k = tBgBT[(None, producer_state.count)]
                tBsBT_pipe = tBsBT[(None, producer_state.index)]

                cute.copy(
                    tma_atom_a,
                    tAgA_k,
                    tAsA_pipe,
                    tma_bar_ptr=pipeline.producer_get_barrier(producer_state)
                )
                cute.copy(
                    tma_atom_bT,
                    tBgBT_k,
                    tBsBT_pipe,
                    tma_bar_ptr=pipeline.producer_get_barrier(producer_state)
                )

                pipeline.producer_commit(producer_state)
                producer_state.advance()

        elif warp_idx < 8: # consumer
            thr_mma = mma.get_slice(tidx)
            tCsA = thr_mma.partition_A(sA)
            tCsB = thr_mma.partition_B(sBT)
            tCgC = thr_mma.partition_C(gC)

            tCrA = thr_mma.make_fragment_A(tCsA)
            tCrB = thr_mma.make_fragment_B(tCsB)
            tCrC = cute.make_fragment(tCgC.shape, cutlass.Float32)
            mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
            tCrC.fill(0.0)

            cute.nvgpu.warpgroup.fence()

            read_state = utils.make_pipeline_state(
                utils.PipelineUserType.Consumer, self.nstages
            )

            for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                pipeline.consumer_wait(read_state)

                tCrA_k = tCrA[(None, None, None, read_state.index)]
                tCrBT_k = tCrB[(None, None, None, read_state.index)]

                cute.gemm(mma, tCrC, tCrA_k, tCrBT_k, tCrC)
                cute.nvgpu.warpgroup.commit_group()

                pipeline.consumer_release(read_state)
                read_state.advance()

            cute.nvgpu.warpgroup.wait_group(0)

            C_atom_r2g = cute.make_copy_atom(
                cute.nvgpu.CopyUniversalOp(),
                tCrC.element_type,
                num_bits_per_copy=32
            )
            cute.copy(C_atom_r2g, tCrC, tCgC)

    @cute.jit
    def __call__(self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):

        M, K = mA.shape
        _, N = mB.shape

        mBT_layout = cute.make_layout((N, K), stride=(1, N))
        mBT = cute.make_tensor(mB.iterator, layout=mBT_layout)

        sw128_k_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
            kind=cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128,
            element_type=cutlass.BFloat16
        )
        sw128_mn_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
            kind=cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128,
            element_type=cutlass.BFloat16
        )

        sA_layout = cute.tile_to_shape(
            sw128_k_atom,
            (self.BM, self.BK, self.nstages),
            order=(0, 1, 2)
        )
        sBT_layout = cute.tile_to_shape(
            sw128_mn_atom,
            (self.BN, self.BK, self.nstages),
            order=(0, 1, 2)
        )

        basic_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
        tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tma_tile_atom(
            op=basic_tma_op,
            gmem_tensor=mA,
            smem_layout=cute.slice_(sA_layout, (None, None, 0)),
            cta_tiler=(self.BM, self.BK)
        )
        tma_atom_bT, tma_tensor_bT = cute.nvgpu.cpasync.make_tma_tile_atom(
            op=basic_tma_op,
            gmem_tensor=mBT,
            smem_layout=cute.slice_(sBT_layout, (None, None, 0)),
            cta_tiler=(self.BN, self.BK)
        )

        @cute.struct
        class SharedStorage:
            pipeline_barriers: cute.struct.MemRange[cutlass.Int64, self.nstages * 2]
            sA: cute.struct.Align[
                cute.struct.MemRange[cutlass.BFloat16, cute.cosize(sA_layout)],
                1024
            ]
            sBT: cute.struct.Align[
                cute.struct.MemRange[cutlass.BFloat16, cute.cosize(sBT_layout)],
                1024
            ]
        self.shared_storage = SharedStorage

        op = cute.nvgpu.warpgroup.MmaF16BF16Op(
            ab_dtype=cutlass.BFloat16,
            acc_dtype=cutlass.Float32,
            instruction_shape=(64, 128, 16),
            a_src=cute.nvgpu.warpgroup.OperandSource.SMEM,
            a_major_mode=cute.nvgpu.warpgroup.OperandMajorMode.K,
            b_major_mode=cute.nvgpu.warpgroup.OperandMajorMode.MN
        )
        tC = cute.make_layout((2, 1, 1))
        tiled_mma = cute.make_tiled_mma(op, tC)

        self.kernel(
            tma_atom_a,
            tma_tensor_a,
            tma_atom_bT,
            tma_tensor_bT,
            sA_layout,
            sBT_layout,
            mC,
            tiled_mma
        ).launch(
            grid=((M // self.BM) * (N // self.BN), 1, 1),
            block=(128 + 256, 1, 1),
            smem=self.shared_storage.size_in_bytes()
        )

M, N, K = 16384, 16384, 16384

A = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
B = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
C = torch.zeros(M, N, dtype=torch.float32, device='cuda')

A_tensor = from_dlpack(A, assumed_align=16)
B_tensor = from_dlpack(B, assumed_align=16)
C_tensor = from_dlpack(C, assumed_align=16)

cutlass.cuda.initialize_cuda_context()
gemm = cute.compile(GemmKernel(BM=128, BN=256, BK=64, nstages=3), A_tensor, B_tensor, C_tensor)
gemm(A, B, C)

torch.cuda.synchronize()

err = (C.to(A.dtype) - A @ B).abs()
assert err.max() < 1e-3, f"{err.max() = }"

The code seems to have a synchronization bug/race condition. It fails the assertion but it fails with different numeric values, suggesting something nondeterministic and synchronization related.

My code does work if I change my mma atom to only have 1 warpgroup (and then change the producer to warp_idx == 4 and the consumer to warp_idx < 4.

I think this may come from using the pipeline API incorrectly with 2 consumer warpgroups? However I'm not sure how to use it correctly and fix the bug. Any ideas would be really helpful :)

axelfeldmann avatar Jun 21 '25 03:06 axelfeldmann

Just as we have discussed before, the public codes only allow one thread(tidx=0) to issue the consumer_release, so

consumer_group = utils.CooperativeGroup(utils.Agent.Thread)

is correct.

The arrive_cnt of consumer_release is 1 even if the number of consumer warpgroup is 2.

But we need to remember there's only one thread of a warp can issue the consumer release, we need extra work to guarantee that the two consumer warp groups complete their work. Or you can hack the pipeline.py to let all threads in the consumer warps can issue the consumer release.

In your codes, you didn't write such logic. There're two ways to make your codes work:

  1. each consumer warp group has its own branch
elif warp_idx < 4:
      for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
                pipeline.consumer_wait(read_state)

                cute.nvgpu.warpgroup.fence()

                tCrA_k = tCrA[(None, None, None, read_state.index)]
                tCrBT_k = tCrB[(None, None, None, read_state.index)]

                cute.gemm(mma, tCrC, tCrA_k, tCrBT_k, tCrC)

                cute.nvgpu.warpgroup.commit_group()
                
                cute.nvgpu.warpgroup.wait_group(0)

                pipeline.consumer_release(read_state)
                read_state.advance()
elif warp_idx >= 4 and warp_idx < 8:
 for k in cutlass.range_dynamic(k_tile_cnt, unroll=1):
        pipeline.consumer_wait(read_state)

                cute.nvgpu.warpgroup.fence()

                tCrA_k = tCrA[(None, None, None, read_state.index)]
                tCrBT_k = tCrB[(None, None, None, read_state.index)]

                cute.gemm(mma, tCrC, tCrA_k, tCrBT_k, tCrC)

                cute.nvgpu.warpgroup.commit_group()
                
                cute.nvgpu.warpgroup.wait_group(1)  <---- ensure another warp group is completed

                pipeline.consumer_release(read_state)
                read_state.advance()
  1. hack the pipeline.py

in the init_empty_barrier_arrive_signal function, add one line of codes:

tidx = tidx % 32 <--- new line
is_signalling_thread = tidx < cute.size(cluster_shape_mnk)
dst_rank = tidx % cute.size(cluster_shape_mnk)
m = cluster_shape_mnk[0]

And update the consumer_group:

consumer_group = utils.CooperativeGroup(utils.Agent.Thread, 8, 8)

Jie-Fang avatar Jun 21 '25 06:06 Jie-Fang

Thank you so much! Question: does the tidx = tidx %32 work well with the uniform registers?

axelfeldmann avatar Jun 21 '25 18:06 axelfeldmann

tidx isn't stored in the uniform registers, it's different across all the lanes in a warp.

Jie-Fang avatar Jun 22 '25 00:06 Jie-Fang

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jul 22 '25 01:07 github-actions[bot]

@axelfeldmann, Please reopen the issue if you still have further question.

jwu1980 avatar Aug 07 '25 01:08 jwu1980