cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] About cast in cuteDSL

Open Dingjifeng opened this issue 1 month ago • 0 comments

I want to implement a GemmSM90 with A bf16 & B fp32. i want to cast B from fp32 to bf16 in smem as steps: 1. load fp32 smem to register 2. do cast using .load().to(cutlass.BFloat16) 3. store bf16 register back to bf16 smem

 wh_pipeline.consumer_wait(
                wh_mainloop_consumer_read_state, peek_wh_full_status
            )
                
            cute.printf("sH layout: {}", sH.layout)
            cute.printf("sH0 layout: {}", sH0.layout)
            # convert h0 to bf16
            
            tidx = warp_group_thread_layout(warp_group_idx)
            thr_copy_h0_s2r = tiled_copy_h0_s2r.get_slice(tidx)
            tXsH0 = thr_copy_h0_s2r.partition_S(sH0)
            tXrH0 = cute.make_fragment_like(tXsH0)
            cute.autovec_copy(tXsH0, tXrH0)
            print(f"tXsH0: {tXsH0}")
            print(f"tXrH0: {tXrH0}")

            thr_copy_h_r2s = tiled_copy_h_r2s.get_slice(tidx)
            tXsH = thr_copy_h_r2s.partition_D(sH)
            tXrH = cute.make_fragment_like(tXsH)
            print(f"tXsH: {tXsH}")
            print(f"tXrH: {tXrH}")
            
            tXrH.store(tXrH0.load().to(cutlass.BFloat16))
        
            cute.autovec_copy(tXrH, tXsH)
            
            cute.arch.sync_threads()
            
            cute.printf("sH: {}", sH)
           
            # why need fence? make sure acc visible to wgmma?
            cute.nvgpu.warpgroup.fence()
            
            for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
                k_block_coord = (
                    None,
                    None,
                    k_block_idx,
                    wh_mainloop_consumer_read_state.index,
                )
                tCrW_1phase = tCrW[k_block_coord]
                tCrH_1phase = tCrH[k_block_coord]
                
                cute.gemm(
                    tiled_mma_wh,
                    acc,
                    tCrW_1phase,
                    tCrH_1phase,
                    acc,
                )

but it doesn't to seem works for me.

Code:


import argparse
import math
import torch
import triton
from typing import Tuple, Type, Callable, Optional, Union, Literal

import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.cute.testing as testing
import cutlass.pipeline as pipeline

from cutlass.pipeline.helpers import MbarrierArray

import cutlass.torch as cutlass_torch
import cutlass.utils.hopper_helpers as sm90_utils
from cutlass.cute.runtime import from_dlpack
from cutlass import Int32, Boolean, const_expr
from cutlass.cute.nvgpu import cpasync, warp, warpgroup

from typing import Tuple, Type


class ChunkGatedDeltaFwdHKernel:
    def __init__(
        self,
        tile_shape_mn: Tuple[int, int],
        cluster_shape_mnk: Tuple[int, int, int],
        chunk_size: int
    ):
        self.buffer_align_bytes = 1024
        
        # K is deferred in _setup_attributes
        self.cta_tile_shape_mnk = (*tile_shape_mn, -1)
        
        self.cluster_shape_mnk = cluster_shape_mnk
        
        # For large tile size, using two warp groups is preferred because using only one warp
        # group may result in register spill
        
        # ??? How to tune the settings?
        self.atom_layout_mnk = (
            (2, 1, 1)
            if self.cta_tile_shape_mnk[0] > 64 and self.cta_tile_shape_mnk[1] > 128
            else (1, 1, 1)
        )
        
        
        self.mma_warp_groups = math.prod(self.atom_layout_mnk)
        self.num_threads_per_warp_group = 128
        self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
        print(f"threads_per_cta: {self.threads_per_cta}")
        
        self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
        # self.num_ab_load_warps = 1
        # self.ab_load_warp_id = self.mma_warp_groups * 4
        # print(f"ab_load_warp_id: {self.ab_load_warp_id}")
        
        # ? what is the acc dtype
        self.acc_dtype = cute.Float32
        self.chunk_size = chunk_size
        
    
    @cute.jit
    def __call__(
        self,
        w: cute.Tensor,
        k: cute.Tensor,
        u: cute.Tensor,
        h0: cute.Tensor,
        h: cute.Tensor,
        o: cute.Tensor,
        cu_seqlens: cute.Tensor | None = None,
        store_final_state: bool = False,
        save_new_value: bool = False
    ):
        self.k_dtype = k.element_type
        self.w_dtype = w.element_type
        self.u_dtype = u.element_type
        self.o_dtype = o.element_type
        self.h0_dtype = h0.element_type if h0 is not None else cute.Float32 # ?? still have question here
        
        # h dtype default to bf16
        self.h_dtype = cutlass.BFloat16
        
        # get row/col major
        self.k_layout = utils.LayoutEnum.from_tensor(k)
        self.w_layout = utils.LayoutEnum.from_tensor(w)
        self.u_layout = utils.LayoutEnum.from_tensor(u)
        self.h0_layout = utils.LayoutEnum.from_tensor(h0) if h0 is not None else None
        
        # h default to col-major?
        self.h_layout = self.h0_layout if h0 is not None else utils.LayoutEnum.COL_MAJOR
        
        self.o_layout = utils.LayoutEnum.from_tensor(o) if o is not None else None
        
        
        # calculate stage ?
        self.wh_stage = 2
        
        # ?? why we need to assume the stride??
        # Assume all strides are divisible by 128 bits except the last stride
        # new_stride = lambda t: tuple(
        #     cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
        #     for s in t.stride
        # )
        
        # w = [
        #     cute.make_tensor(w.iterator, cute.make_layout(w.shape, stride=new_stride(w)))
        #     if w is not None
        #     else None
        # ]
        
        # print(f"new w: {w}")
        
        self._setup_attributes()
        
        # non-staged smem layout
        w_smem_layout = cute.slice_(self.w_smem_layout_staged, (None, None, 0))
        h0_smem_layout = cute.slice_(self.h0_smem_layout_staged, (None, None, 0))
        
        print(w_smem_layout)
        print(h0_smem_layout)
        
        tma_atom_w, tma_tensor_w = self._make_tma_atoms_and_tensors(
            w,
            w_smem_layout,
            (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
        )
        
        tma_atom_h0, tma_tensor_h0 = self._make_tma_atoms_and_tensors(
            h0,
            h0_smem_layout,
            (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2])
        )
        
        cute.printf(f"original_w: {w}")
        cute.printf(f"original_h: {h0}")
        cute.printf(f"tma_atom_w: {tma_atom_w} \ntma_tensor_w: {tma_tensor_w}")
        cute.printf(f"tma_atom_h: {tma_atom_h0} \ntma_tensor_h: {tma_tensor_h0}")
        
        self.num_tma_load_bytes = cute.size_in_bytes(self.w_dtype, w_smem_layout)
        
        print(f"load_bytes: {self.num_tma_load_bytes}")
        # why we need this?
        if const_expr(h0 is not None):
            self.num_tma_load_bytes += cute.size_in_bytes(self.h0_dtype, h0_smem_layout)

        if const_expr(self.h0_layout is not None):
            self.tiled_copy_h0_s2r = self._make_tiled_copy(self.h0_layout,self.h0_dtype,False,128) 
            self.tiled_copy_h_r2s = self._make_tiled_copy(self.h_layout,self.h_dtype,False,64)
        
        print(f"load_bytes: {self.num_tma_load_bytes}")
        
        # compute the 
        grid = self._compute_grid(w,h0,o)
        print(f"grid: {grid}")
        
        h0_smem_size = (
            cute.cosize(self.h0_smem_layout_staged) if h0 is not None else 0
        )
        
         
        @cute.struct
        class SharedStorage:
            wh_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.wh_stage * 2]
            tile_count: cute.struct.MemRange[cutlass.Int32, self.wh_stage]
            sH0: cute.struct.Align[
                cute.struct.MemRange[self.h0_dtype if self.h0_dtype is not None else cutlass.Int32, h0_smem_size],
                self.buffer_align_bytes,
            ]
            sW: cute.struct.Align[
                cute.struct.MemRange[self.w_dtype, cute.cosize(self.w_smem_layout_staged)],
                self.buffer_align_bytes,
            ]
            sH: cute.struct.Align[
                cute.struct.MemRange[self.h_dtype, cute.cosize(self.h_smem_layout_staged)],
                self.buffer_align_bytes
            ]
            
        self.shared_storage = SharedStorage

        self.kernel(
            self.tiled_mma_wh,
            tma_atom_w,
            tma_tensor_w,
            self.w_smem_layout_staged,
            tma_atom_h0,
            tma_tensor_h0,
            self.tiled_copy_h0_s2r,
            self.tiled_copy_h_r2s,
            self.h0_smem_layout_staged,
            self.h_smem_layout_staged,
            o,
            self.cluster_layout_mnk
        ).launch(
            grid=grid,
            block=[self.threads_per_cta, 1, 1],
            cluster=self.cluster_shape_mnk,
        )

        return
        
        
    def _setup_attributes(self):
        if self.cta_tile_shape_mnk[0] not in [64, 128]:
            raise ValueError("CTA tile shape M must be 64/128")
        if self.cta_tile_shape_mnk[1] not in [64, 128, 256]:
            raise ValueError("CTA tile shape N must be 64/128/256")
        
        self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
        self.tiled_mma_wh = sm90_utils.make_trivial_tiled_mma(
            self.w_dtype,
            self.h_dtype,
            self.w_layout.sm90_mma_major_mode(),
            self.h_layout.sm90_mma_major_mode(),
            self.acc_dtype,
            self.atom_layout_mnk,
            tiler_mn=(64, self.cta_tile_shape_mnk[1]), # maybe we should use chunk_size instead?
        )
        
        print(f"tiled_mma_wh: {self.tiled_mma_wh}")
        
        mma_inst_shape_k = cute.size(self.tiled_mma_wh.shape_mnk, mode=[2])
        mma_inst_tile_k = 1
        self.cta_tile_shape_mnk = (
            self.cta_tile_shape_mnk[0],
            self.cta_tile_shape_mnk[1],
            mma_inst_shape_k * mma_inst_tile_k,
        )
        
        print(f"cta_tile_shape: {self.cta_tile_shape_mnk}")
        
        (
            self.w_smem_layout_staged,
            self.h0_smem_layout_staged,
            self.h_smem_layout_staged
        ) = self._make_smem_layouts(
            self.cta_tile_shape_mnk,
            self.w_dtype,
            self.w_layout,
            self.h0_dtype,
            self.h0_layout,
            self.h_dtype,
            self.h_layout,
            self.wh_stage
        )
        
    
    def _make_smem_layouts(
        self,
        cta_tile_shape_mnk: tuple[int, int, int],
        w_dtype: type[cutlass.Numeric],
        w_layout: utils.LayoutEnum,
        h0_dtype: type[cutlass.Numeric],
        h0_layout: utils.LayoutEnum,
        h_dtype: type[cutlass.Numeric],
        h_layout: utils.LayoutEnum,
        ab_stage: int = 1,
    ):
        # not consider fused smem first!
        
        w_is_k_major = w_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
        h0_is_k_major = h0_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
        h_is_k_major = h_layout.sm90_mma_major_mode == warpgroup.OperandMajorMode.K
        
        print(f"w_is_k_major: {w_is_k_major}")
        print(f"h0_is_k_major: {h0_is_k_major}")
        print(f"h_is_k_major: {h_is_k_major}")
        
        w_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
        
        w_major_mode_size = cta_tile_shape_mnk[2 if w_is_k_major else 0]
        print(f"w_major_mode_size: {w_major_mode_size}")
        
        smem_layout_atom_w = sm90_utils.get_smem_layout_atom(w_layout, w_dtype, w_major_mode_size)
        print(f"smem_layout_atom_w {smem_layout_atom_w}")
        
        w_smem_layout_atom = warpgroup.make_smem_layout_atom(
            smem_layout_atom_w,
            w_dtype,
        )
        print(f"w_smem_layout_atom: {w_smem_layout_atom}")
        
        w_smem_layout_staged = cute.tile_to_shape(
            w_smem_layout_atom,
            cute.append(w_smem_shape, ab_stage),
            order=(0, 1, 2) if w_is_k_major else (1, 0, 2),
        )
        
        print(f"w_smem_layout_staged: {w_smem_layout_staged}")
        
        h0_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
        h0_major_mode_size = cta_tile_shape_mnk[2 if h0_is_k_major else 1]
        smem_layout_atom_h0 = sm90_utils.get_smem_layout_atom(h0_layout, h0_dtype, h0_major_mode_size)
        
        print(f"smem_layout_atom_h0: {smem_layout_atom_h0}")
        h0_smem_layout_atom = warpgroup.make_smem_layout_atom(
            smem_layout_atom_h0,
            h0_dtype,
        )
        print(f"h0_smem_layout_atom: {h0_smem_layout_atom}")
        
        h0_smem_layout_staged = cute.tile_to_shape(
            h0_smem_layout_atom,
            cute.append(h0_smem_shape, ab_stage),
            order=(0, 1, 2) if h0_is_k_major else (1, 0, 2),
        )

        print(f"h0_smem_layout_staged: {h0_smem_layout_staged}")
        
        h_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
        h_major_mode_size = cta_tile_shape_mnk[2 if h_is_k_major else 1]
        smem_layout_atom_h = sm90_utils.get_smem_layout_atom(h_layout, h_dtype, h_major_mode_size)
        print(f"smem_layout_atom_h: {smem_layout_atom_h}")
        
        h_smem_layout_atom = warpgroup.make_smem_layout_atom(
            smem_layout_atom_h,
            h_dtype,
        )
        print(f"h_smem_layout_atom: {h_smem_layout_atom}")
        
        h_smem_layout_staged = cute.tile_to_shape(
            h_smem_layout_atom,
            cute.append(h_smem_shape, ab_stage),
            order=(0, 1, 2) if h_is_k_major else (1, 0, 2),
        )
        print(f"h_smem_layout_staged: {h_smem_layout_staged}")
        
        
        # v_new_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
        # # default to n
        # v_new_major_mode_size = cta_tile_shape_mnk[1]
        
        # # smem_layout_atom_v_new = sm90_utils.get_smem_layout_atom()
        # print(f"v_new_smem_shape: {v_new_smem_shape}")
    
        return (
            w_smem_layout_staged,
            h0_smem_layout_staged,
            h_smem_layout_staged,
        )
    
    def _make_tma_atoms_and_tensors(
        self,
        tensor: cute.Tensor,
        smem_layout: cute.ComposedLayout,
        smem_tile: Tuple[int, int],
        mcast_dim: int = -1,
    ) -> Tuple[cute.CopyAtom, cute.Tensor]:
        
        # do not support mcast_dim
        op = (
            cpasync.CopyBulkTensorTileG2SOp()
        )
        
        tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
            op,
            tensor,
            smem_layout,
            smem_tile
        )
        
        return tma_atom, tma_tensor

    def _compute_grid(
        self,
        mA: cute.Tensor,
        mB: cute.Tensor,
        mD: cute.Tensor,
    ) -> tuple[int, int, int]:
        assert len(mA.shape) == 4 and len(mB.shape) == 4 and len(mD.shape) == 4
         
        num_problems = mD.shape[2] * mD.shape[3]

        self.B = mD.shape[2]
        self.H = mD.shape[3]
        
        # BT % tile_shape_mnk[0] and only the T dimension is serial
        assert self.chunk_size % self.cta_tile_shape_mnk[0] == 0 
        
        problem_shape_ntile_mnl = (
            cute.ceil_div(self.chunk_size, self.cta_tile_shape_mnk[0]),
            cute.ceil_div(mB.shape[1], self.cta_tile_shape_mnk[1]),
            num_problems
        )
        
        
        return problem_shape_ntile_mnl
    
    def _make_ab_pipeline(
        self,
        tiled_mma: cute.TiledMma,
        cluster_layout_vmnk: cute.Layout,
        ab_pipeline_mbar_ptr: cute.Pointer
    ):
        producer_cnt = 1
        
        ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)

        mcast_size = 1
        consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
        
        ab_pipeline_consumer_group = pipeline.CooperativeGroup(
            pipeline.Agent.Thread, consumer_arrive_cnt
        )
        
        mainloop_pipeline = pipeline.PipelineTmaAsync.create(
            barrier_storage=ab_pipeline_mbar_ptr,
            num_stages=self.wh_stage,
            producer_group=ab_pipeline_producer_group,
            consumer_group=ab_pipeline_consumer_group,
            tx_count=self.num_tma_load_bytes,
            cta_layout_vmnk=cluster_layout_vmnk,
        )
        
        return mainloop_pipeline
        
        
    @cute.kernel
    def kernel(
        self,
        tiled_mma_wh: cute.TiledMma,
        tma_atom_w: cute.CopyAtom,
        mW_mkbh: cute.Tensor,
        w_smem_layout_staged: cute.ComposedLayout,
        tma_atom_h0: Optional[cute.CopyAtom],
        mH0_nkbh: Optional[cute.Tensor],
        tiled_copy_h0_s2r: Optional[cute.TiledCopy],
        tiled_copy_h_r2s: Optional[cute.TiledCopy],
        h0_smem_layout_staged: cute.ComposedLayout,
        h_smem_layout_staged: cute.ComposedLayout,
        mO_mnbh: cute.Tensor, # for now temporary used for acc check
        cta_layout_mnk: cute.Layout,
    ):
        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
        
        # prefetch Tma desc
        if warp_idx == 0:
            for tma_atom_desc in (tma_atom_w, tma_atom_h0):
                if const_expr(tma_atom_desc is not None):
                    cpasync.prefetch_descriptor(tma_atom_desc)
        
        bidx, bidy, bidz = cute.arch.block_idx()
        tidx, _, _ = cute.arch.thread_idx()
        
        b = bidz // self.H
        h = bidz % H
        
        tile_coord_mnkbh = (bidx, bidy, None, b, h)
        
        cta_rank_in_cluster = cute.arch.make_warp_uniform(
            cute.arch.block_idx_in_cluster()
        )

        cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
        
        print(f'cta_layout_mnk: {cta_layout_mnk}')
        print(f"cluster_coord_mnk: {cluster_coord_mnk}")
        
        # Alloc and init
        smem = cutlass.utils.SmemAllocator()
        storage = smem.allocate(self.shared_storage)
 
        
        wh_pipeline = self._make_ab_pipeline(
            tiled_mma=tiled_mma_wh,
            cluster_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)),
            ab_pipeline_mbar_ptr=storage.wh_pipeline_array_ptr.data_ptr(),
        )
        
        if warp_idx == 0:
            full : MbarrierArray = wh_pipeline.sync_object_full
            empty : MbarrierArray = wh_pipeline.sync_object_empty
            print(f"full: {full.arrive_count}")
            print(f"full.tx_count: {full.tx_count}")
            print(f"empty.arrive_count: {empty.arrive_count}")
            print(f"empty.tx_count: {empty.tx_count}")
            
        sW = storage.sW.get_tensor(
            w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner
        )
        
        sH0 = storage.sH0.get_tensor(
            h0_smem_layout_staged.outer, swizzle=h0_smem_layout_staged.inner
        )
        
        sH = storage.sH.get_tensor(
            h_smem_layout_staged.outer, swizzle=h_smem_layout_staged.inner
        )
    
        gW_mk = cute.local_tile(
            mW_mkbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(1, None, 1)
        )
        cute.printf("gW_mk: {}", gW_mk)
        
        if const_expr(mH0_nkbh is not None):
            print(mH0_nkbh)
            gH0_nk = cute.local_tile(
                mH0_nkbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(None, 1, 1)
            )
            print(gH0_nk)
            
        gO_mn = cute.local_tile(
            mO_mnbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(1, 1, None)
        )

        # partition global tensor for TiledMMA
        warp_group_idx = cute.arch.make_warp_uniform(
            tidx // self.num_threads_per_warp_group
        )
        
        warp_group_thread_layout = cute.make_layout(
            self.mma_warp_groups, stride=self.num_threads_per_warp_group
        )
        
        
        # make fragments 
        thr_mma_wh = tiled_mma_wh.get_slice(warp_group_thread_layout(warp_group_idx))
        tCsW = thr_mma_wh.partition_A(sW)
        tCsH = thr_mma_wh.partition_B(sH)
        
        tCgO = thr_mma_wh.partition_C(gO_mn)
        
        # mma, mma_m, mma_k, pipe
        tCrW = tiled_mma_wh.make_fragment_A(tCsW)
        tCrH = tiled_mma_wh.make_fragment_B(tCsH)
        
        print(f"tCrH: {tCrH}")
        print(f"tCrW: {tCrW}")
        acc_shape = tiled_mma_wh.partition_shape_C(
            cute.select(self.cta_tile_shape_mnk, mode=[0,1])
        )
        acc = cute.make_fragment(acc_shape, self.acc_dtype)
        
        w_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
        w_cta_crd = cluster_coord_mnk[1]
        # shared memory partition needs to group from [0, 2)
        sw_for_tma_partition = cute.group_modes(sW, 0, 2)
        gW_for_tma_partition = cute.group_modes(gW_mk, 0, 2)
        tAsW, tAgW_mk = cute.nvgpu.cpasync.tma_partition(
            tma_atom_w,
            w_cta_crd,
            w_cta_layout,
            sw_for_tma_partition,
            gW_for_tma_partition
        )
        
        if const_expr(mH0_nkbh is not None):
            h0_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
            h0_cta_crd = cluster_coord_mnk[0] 
            sH0_for_tma_partition = cute.group_modes(sH0, 0, 2)
            gH0_for_tma_partition = cute.group_modes(gH0_nk, 0, 2)
            tBsH0, tBgH0_nk = cute.nvgpu.cpasync.tma_partition(
                tma_atom_h0,
                h0_cta_crd,
                h0_cta_layout,
                sH0_for_tma_partition,
                gH0_for_tma_partition
            )
            
        if cute.size(self.cluster_shape_mnk) > 1:
            # NOTE: wait for all CTAs in the cluster
            cute.arch.cluster_wait()
        else:
            # NOTE: only wait for the current CTA
            cute.arch.sync_threads()
        
        k_tile_cnt = cute.size(gW_mk, mode=[2])
        prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.wh_stage, k_tile_cnt), 0)
        
        wh_mainloop_producer_state = pipeline.make_pipeline_state(
            pipeline.PipelineUserType.Producer, self.wh_stage
        )
        
        
        if warp_idx == 0:
            # prefetch TMA load
            for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
                wh_pipeline.producer_acquire(wh_mainloop_producer_state)
                
                # count is used for gmem / index is used for smem 
                tAgW_k = tAgW_mk[(None, wh_mainloop_producer_state.count)]
                tAsW_pipe = tAsW[(None, wh_mainloop_producer_state.index)]
                
                tBgH0_k = tBgH0_nk[(None, wh_mainloop_producer_state.count)]
                tBsH0_pipe = tBsH0[(None, wh_mainloop_producer_state.index)]
                cute.copy(
                    tma_atom_w,
                    tAgW_k,
                    tAsW_pipe,
                    tma_bar_ptr=wh_pipeline.producer_get_barrier(
                        wh_mainloop_producer_state
                    ),
                    mcast_mask=0,
                )
                cute.copy(
                    tma_atom_h0,
                    tBgH0_k,
                    tBsH0_pipe,
                    tma_bar_ptr=wh_pipeline.producer_get_barrier(
                        wh_mainloop_producer_state
                    ),
                    mcast_mask=0,
                )
                
                wh_pipeline.producer_commit(wh_mainloop_producer_state)
                wh_mainloop_producer_state.advance()
                
                
        # Prologue MMAs
        k_pipe_mmas = 1

        wh_mainloop_consumer_read_state = pipeline.make_pipeline_state(
            pipeline.PipelineUserType.Consumer, self.wh_stage
        )
        
        wh_mainloop_consumer_release_state = pipeline.make_pipeline_state(
            pipeline.PipelineUserType.Consumer, self.wh_stage
        )
        
        peek_wh_full_status = cutlass.Boolean(1)
        if wh_mainloop_consumer_read_state.count < k_tile_cnt:
            peek_wh_full_status = wh_pipeline.consumer_try_wait(
                wh_mainloop_consumer_read_state
            )
        
        tiled_mma_wh.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
        num_k_blocks = cute.size(tCrW, mode=[2])
        
        for k_tile in cutlass.range_constexpr(k_pipe_mmas):
            wh_pipeline.consumer_wait(
                wh_mainloop_consumer_read_state, peek_wh_full_status
            )
                
            cute.printf("sH layout: {}", sH.layout)
            cute.printf("sH0 layout: {}", sH0.layout)
            # convert h0 to bf16
            
            tidx = warp_group_thread_layout(warp_group_idx)
            thr_copy_h0_s2r = tiled_copy_h0_s2r.get_slice(tidx)
            tXsH0 = thr_copy_h0_s2r.partition_S(sH0)
            tXrH0 = cute.make_fragment_like(tXsH0)
            cute.autovec_copy(tXsH0, tXrH0)
            print(f"tXsH0: {tXsH0}")
            print(f"tXrH0: {tXrH0}")

            thr_copy_h_r2s = tiled_copy_h_r2s.get_slice(tidx)
            tXsH = thr_copy_h_r2s.partition_D(sH)
            tXrH = cute.make_fragment_like(tXsH)
            print(f"tXsH: {tXsH}")
            print(f"tXrH: {tXrH}")
            
            tXrH.store(tXrH0.load().to(cutlass.BFloat16))
        
            cute.autovec_copy(tXrH, tXsH)
            
            cute.arch.sync_threads()
            
            cute.printf("sH: {}", sH)
           
            # why need fence? make sure acc visible to wgmma?
            cute.nvgpu.warpgroup.fence()
            
            for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
                k_block_coord = (
                    None,
                    None,
                    k_block_idx,
                    wh_mainloop_consumer_read_state.index,
                )
                tCrW_1phase = tCrW[k_block_coord]
                tCrH_1phase = tCrH[k_block_coord]
                
                cute.gemm(
                    tiled_mma_wh,
                    acc,
                    tCrW_1phase,
                    tCrH_1phase,
                    acc,
                )
                
                tiled_mma_wh.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
                cute.nvgpu.warpgroup.commit_group()
                wh_mainloop_consumer_read_state.advance()
                peek_wh_full_status = cutlass.Boolean(1)
                if wh_mainloop_consumer_read_state.count < k_tile_cnt:
                    peek_wh_full_status = wh_pipeline.consumer_try_wait(
                        wh_mainloop_consumer_read_state
                    )
                cute.printf("acc0 {}: {}", k_block_idx, acc)
                cute.nvgpu.warpgroup.wait_group(0)
                
        
        cute.printf("k_tile_cnt: {}",k_tile_cnt)
        # mainloop
        for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
            
            # wait for TMA copies to complete
            wh_pipeline.consumer_wait(wh_mainloop_consumer_read_state, peek_wh_full_status)
            
            cute.nvgpu.warpgroup.fence()
            for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
                k_block_coord = (
                    None,
                    None,
                    k_block_idx,
                    wh_mainloop_consumer_read_state.index
                )
                tCrW_1phase = tCrW[k_block_coord]
                tCrH_1phase = tCrH[k_block_coord]
                
                cute.gemm(
                    tiled_mma_wh,
                    acc,
                    tCrW_1phase,
                    tCrH_1phase,
                    acc,
                )
                
                cute.nvgpu.warpgroup.commit_group()
                # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
                cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
                
                wh_pipeline.consumer_release(wh_mainloop_consumer_release_state)
                
                cute.printf("ktile: {}", k_tile)
                cute.printf("acc100{}: {}", k_tile, acc)

                wh_mainloop_consumer_read_state.advance()
                wh_mainloop_consumer_release_state.advance()

                peek_wh_full_status = cutlass.Boolean(1)
                if wh_mainloop_consumer_read_state.count < k_tile_cnt:
                    peek_wh_full_status = wh_pipeline.consumer_try_wait(
                        wh_mainloop_consumer_read_state
                    )
                    
                if warp_idx == 0 and wh_mainloop_producer_state.count < k_tile_cnt:
                    wh_pipeline.producer_acquire(wh_mainloop_producer_state)
                    
                    tAgW_k = tAgW_mk[(None, wh_mainloop_producer_state.count)]
                    tAsW_pipe = tAsW[(None, wh_mainloop_producer_state.index)]
                    
                    tBgH0_k = tBgH0_nk[(None, wh_mainloop_producer_state.count)]
                    tBsH0_pipe = tBsH0[(None, wh_mainloop_producer_state.index)]
                    cute.copy(
                        tma_atom_w,
                        tAgW_k,
                        tAsW_pipe,
                        tma_bar_ptr=wh_pipeline.producer_get_barrier(
                            wh_mainloop_producer_state
                        ),
                        mcast_mask=0,
                    )
                    cute.copy(
                        tma_atom_h0,
                        tBgH0_k,
                        tBsH0_pipe,
                        tma_bar_ptr=wh_pipeline.producer_get_barrier(
                            wh_mainloop_producer_state
                        ),
                        mcast_mask=0,
                    )
                    
                    wh_pipeline.producer_commit(wh_mainloop_producer_state)
                    wh_mainloop_producer_state.advance()
        
        cute.nvgpu.warpgroup.wait_group(0)
        cute.printf("acc final: {}", acc)
        cute.printf("tCgO: {}", tCgO)
        cute.autovec_copy(acc, tCgO)
        # epi      
        cute.printf("acc: {}", acc)      
        
    def _make_tiled_copy(self, layout : utils.LayoutEnum, dtype: cute.Numeric, is_async: bool = False, num_copy_bits=128):
        # only support copy H0 which is col-major
        assert layout.is_m_major_c()
        assert is_async == False
        assert dtype in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]
        copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
        copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
        
        # value layout:
        elems = num_copy_bits // dtype.width
        val_layout = cute.make_layout((elems, 1), stride=(1,0))
          
        # thread layout: 
        # threads_per_rows = const_expr(major_mode_size // elems)
        threads_per_col = const_expr(self.cta_tile_shape_mnk[1] // elems)
        assert self.cta_tile_shape_mnk[1] == 64
        assert self.threads_per_cta == 128
        rows_per_block = self.threads_per_cta // threads_per_col
        thr_layout = cute.make_layout((threads_per_col, rows_per_block), stride=(1, threads_per_col))
        
        tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
        print("construct:")
        print(tiled_copy)
        return tiled_copy

        
def run(w, k, u, h0, h, o, tile_shape_mn, cluster_shape_mnk, chunk_size):
    kernel = ChunkGatedDeltaFwdHKernel(tile_shape_mn, cluster_shape_mnk, chunk_size) 
    w = from_dlpack(w, assumed_align=16)
    k = from_dlpack(k, assumed_align=16)
    u = from_dlpack(u, assumed_align=16)
    h0 = from_dlpack(h0, assumed_align=16)
    h = from_dlpack(h, assumed_align=16)
    o = from_dlpack(o, assumed_align=16)
    kernel(w, k, u, h0, h, o)

B, T, H, K, V = 1, 256, 1, 64, 64
chunk_size = 64
BT = chunk_size
dtype = torch.bfloat16
acc_dtype = torch.float32
device = 'cuda'
output_final_state = True

N, NT, chunk_offsets = B, triton.cdiv(T, BT), None


w = torch.randn((B, T, H, K), dtype=dtype, device=device)
w = w.new_ones((B, T, H, K), dtype=dtype)
k = torch.randn((B, T, H, K), dtype=dtype, device=device)
u = torch.randn((B, T, H, V), dtype=dtype, device=device)
h0 = torch.ones((B, H, K, V), dtype=torch.float32, device=device)
h = k.new_empty(B, NT, H, K, V)

o = torch.zeros((B, chunk_size, H, V), dtype=acc_dtype, device=device)
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
tile_shape_mn = (chunk_size, 64)
cluster_shape_mnk = (1, 1, 1)

assert chunk_size == tile_shape_mn[0]

# need to permute first ??
w = w.permute(1, 3, 0, 2) # T, K, B, H
k = k.permute(1, 3, 0, 2) # ???
h0 = h0.permute(3, 2, 0, 1) #V, K, B, H
u = u.permute(2, 3, 0, 1) #T, V, B, H
o = o.permute(1, 3, 0, 2) #chunk_size, V, B, H


# by cute convention default is mk, nk, mn
run(w, k, u, h0, h, o, tile_shape_mn, cluster_shape_mnk, chunk_size)

print(f"------")

print(o)

print(f"------")

o_ref = torch.einsum('tkbh,vkbh->tvbh', w[:chunk_size].to(acc_dtype), h0.to(acc_dtype))

print(o_ref)

Dingjifeng avatar Dec 04 '25 13:12 Dingjifeng