cutlass
cutlass copied to clipboard
[QST] About cast in cuteDSL
I want to implement a GemmSM90 with A bf16 & B fp32. i want to cast B from fp32 to bf16 in smem as steps: 1. load fp32 smem to register 2. do cast using .load().to(cutlass.BFloat16) 3. store bf16 register back to bf16 smem
wh_pipeline.consumer_wait(
wh_mainloop_consumer_read_state, peek_wh_full_status
)
cute.printf("sH layout: {}", sH.layout)
cute.printf("sH0 layout: {}", sH0.layout)
# convert h0 to bf16
tidx = warp_group_thread_layout(warp_group_idx)
thr_copy_h0_s2r = tiled_copy_h0_s2r.get_slice(tidx)
tXsH0 = thr_copy_h0_s2r.partition_S(sH0)
tXrH0 = cute.make_fragment_like(tXsH0)
cute.autovec_copy(tXsH0, tXrH0)
print(f"tXsH0: {tXsH0}")
print(f"tXrH0: {tXrH0}")
thr_copy_h_r2s = tiled_copy_h_r2s.get_slice(tidx)
tXsH = thr_copy_h_r2s.partition_D(sH)
tXrH = cute.make_fragment_like(tXsH)
print(f"tXsH: {tXsH}")
print(f"tXrH: {tXrH}")
tXrH.store(tXrH0.load().to(cutlass.BFloat16))
cute.autovec_copy(tXrH, tXsH)
cute.arch.sync_threads()
cute.printf("sH: {}", sH)
# why need fence? make sure acc visible to wgmma?
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
wh_mainloop_consumer_read_state.index,
)
tCrW_1phase = tCrW[k_block_coord]
tCrH_1phase = tCrH[k_block_coord]
cute.gemm(
tiled_mma_wh,
acc,
tCrW_1phase,
tCrH_1phase,
acc,
)
but it doesn't to seem works for me.
Code:
import argparse
import math
import torch
import triton
from typing import Tuple, Type, Callable, Optional, Union, Literal
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.cute.testing as testing
import cutlass.pipeline as pipeline
from cutlass.pipeline.helpers import MbarrierArray
import cutlass.torch as cutlass_torch
import cutlass.utils.hopper_helpers as sm90_utils
from cutlass.cute.runtime import from_dlpack
from cutlass import Int32, Boolean, const_expr
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
from typing import Tuple, Type
class ChunkGatedDeltaFwdHKernel:
def __init__(
self,
tile_shape_mn: Tuple[int, int],
cluster_shape_mnk: Tuple[int, int, int],
chunk_size: int
):
self.buffer_align_bytes = 1024
# K is deferred in _setup_attributes
self.cta_tile_shape_mnk = (*tile_shape_mn, -1)
self.cluster_shape_mnk = cluster_shape_mnk
# For large tile size, using two warp groups is preferred because using only one warp
# group may result in register spill
# ??? How to tune the settings?
self.atom_layout_mnk = (
(2, 1, 1)
if self.cta_tile_shape_mnk[0] > 64 and self.cta_tile_shape_mnk[1] > 128
else (1, 1, 1)
)
self.mma_warp_groups = math.prod(self.atom_layout_mnk)
self.num_threads_per_warp_group = 128
self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group
print(f"threads_per_cta: {self.threads_per_cta}")
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90")
# self.num_ab_load_warps = 1
# self.ab_load_warp_id = self.mma_warp_groups * 4
# print(f"ab_load_warp_id: {self.ab_load_warp_id}")
# ? what is the acc dtype
self.acc_dtype = cute.Float32
self.chunk_size = chunk_size
@cute.jit
def __call__(
self,
w: cute.Tensor,
k: cute.Tensor,
u: cute.Tensor,
h0: cute.Tensor,
h: cute.Tensor,
o: cute.Tensor,
cu_seqlens: cute.Tensor | None = None,
store_final_state: bool = False,
save_new_value: bool = False
):
self.k_dtype = k.element_type
self.w_dtype = w.element_type
self.u_dtype = u.element_type
self.o_dtype = o.element_type
self.h0_dtype = h0.element_type if h0 is not None else cute.Float32 # ?? still have question here
# h dtype default to bf16
self.h_dtype = cutlass.BFloat16
# get row/col major
self.k_layout = utils.LayoutEnum.from_tensor(k)
self.w_layout = utils.LayoutEnum.from_tensor(w)
self.u_layout = utils.LayoutEnum.from_tensor(u)
self.h0_layout = utils.LayoutEnum.from_tensor(h0) if h0 is not None else None
# h default to col-major?
self.h_layout = self.h0_layout if h0 is not None else utils.LayoutEnum.COL_MAJOR
self.o_layout = utils.LayoutEnum.from_tensor(o) if o is not None else None
# calculate stage ?
self.wh_stage = 2
# ?? why we need to assume the stride??
# Assume all strides are divisible by 128 bits except the last stride
# new_stride = lambda t: tuple(
# cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s
# for s in t.stride
# )
# w = [
# cute.make_tensor(w.iterator, cute.make_layout(w.shape, stride=new_stride(w)))
# if w is not None
# else None
# ]
# print(f"new w: {w}")
self._setup_attributes()
# non-staged smem layout
w_smem_layout = cute.slice_(self.w_smem_layout_staged, (None, None, 0))
h0_smem_layout = cute.slice_(self.h0_smem_layout_staged, (None, None, 0))
print(w_smem_layout)
print(h0_smem_layout)
tma_atom_w, tma_tensor_w = self._make_tma_atoms_and_tensors(
w,
w_smem_layout,
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]),
)
tma_atom_h0, tma_tensor_h0 = self._make_tma_atoms_and_tensors(
h0,
h0_smem_layout,
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2])
)
cute.printf(f"original_w: {w}")
cute.printf(f"original_h: {h0}")
cute.printf(f"tma_atom_w: {tma_atom_w} \ntma_tensor_w: {tma_tensor_w}")
cute.printf(f"tma_atom_h: {tma_atom_h0} \ntma_tensor_h: {tma_tensor_h0}")
self.num_tma_load_bytes = cute.size_in_bytes(self.w_dtype, w_smem_layout)
print(f"load_bytes: {self.num_tma_load_bytes}")
# why we need this?
if const_expr(h0 is not None):
self.num_tma_load_bytes += cute.size_in_bytes(self.h0_dtype, h0_smem_layout)
if const_expr(self.h0_layout is not None):
self.tiled_copy_h0_s2r = self._make_tiled_copy(self.h0_layout,self.h0_dtype,False,128)
self.tiled_copy_h_r2s = self._make_tiled_copy(self.h_layout,self.h_dtype,False,64)
print(f"load_bytes: {self.num_tma_load_bytes}")
# compute the
grid = self._compute_grid(w,h0,o)
print(f"grid: {grid}")
h0_smem_size = (
cute.cosize(self.h0_smem_layout_staged) if h0 is not None else 0
)
@cute.struct
class SharedStorage:
wh_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.wh_stage * 2]
tile_count: cute.struct.MemRange[cutlass.Int32, self.wh_stage]
sH0: cute.struct.Align[
cute.struct.MemRange[self.h0_dtype if self.h0_dtype is not None else cutlass.Int32, h0_smem_size],
self.buffer_align_bytes,
]
sW: cute.struct.Align[
cute.struct.MemRange[self.w_dtype, cute.cosize(self.w_smem_layout_staged)],
self.buffer_align_bytes,
]
sH: cute.struct.Align[
cute.struct.MemRange[self.h_dtype, cute.cosize(self.h_smem_layout_staged)],
self.buffer_align_bytes
]
self.shared_storage = SharedStorage
self.kernel(
self.tiled_mma_wh,
tma_atom_w,
tma_tensor_w,
self.w_smem_layout_staged,
tma_atom_h0,
tma_tensor_h0,
self.tiled_copy_h0_s2r,
self.tiled_copy_h_r2s,
self.h0_smem_layout_staged,
self.h_smem_layout_staged,
o,
self.cluster_layout_mnk
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
)
return
def _setup_attributes(self):
if self.cta_tile_shape_mnk[0] not in [64, 128]:
raise ValueError("CTA tile shape M must be 64/128")
if self.cta_tile_shape_mnk[1] not in [64, 128, 256]:
raise ValueError("CTA tile shape N must be 64/128/256")
self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk)
self.tiled_mma_wh = sm90_utils.make_trivial_tiled_mma(
self.w_dtype,
self.h_dtype,
self.w_layout.sm90_mma_major_mode(),
self.h_layout.sm90_mma_major_mode(),
self.acc_dtype,
self.atom_layout_mnk,
tiler_mn=(64, self.cta_tile_shape_mnk[1]), # maybe we should use chunk_size instead?
)
print(f"tiled_mma_wh: {self.tiled_mma_wh}")
mma_inst_shape_k = cute.size(self.tiled_mma_wh.shape_mnk, mode=[2])
mma_inst_tile_k = 1
self.cta_tile_shape_mnk = (
self.cta_tile_shape_mnk[0],
self.cta_tile_shape_mnk[1],
mma_inst_shape_k * mma_inst_tile_k,
)
print(f"cta_tile_shape: {self.cta_tile_shape_mnk}")
(
self.w_smem_layout_staged,
self.h0_smem_layout_staged,
self.h_smem_layout_staged
) = self._make_smem_layouts(
self.cta_tile_shape_mnk,
self.w_dtype,
self.w_layout,
self.h0_dtype,
self.h0_layout,
self.h_dtype,
self.h_layout,
self.wh_stage
)
def _make_smem_layouts(
self,
cta_tile_shape_mnk: tuple[int, int, int],
w_dtype: type[cutlass.Numeric],
w_layout: utils.LayoutEnum,
h0_dtype: type[cutlass.Numeric],
h0_layout: utils.LayoutEnum,
h_dtype: type[cutlass.Numeric],
h_layout: utils.LayoutEnum,
ab_stage: int = 1,
):
# not consider fused smem first!
w_is_k_major = w_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
h0_is_k_major = h0_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
h_is_k_major = h_layout.sm90_mma_major_mode == warpgroup.OperandMajorMode.K
print(f"w_is_k_major: {w_is_k_major}")
print(f"h0_is_k_major: {h0_is_k_major}")
print(f"h_is_k_major: {h_is_k_major}")
w_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
w_major_mode_size = cta_tile_shape_mnk[2 if w_is_k_major else 0]
print(f"w_major_mode_size: {w_major_mode_size}")
smem_layout_atom_w = sm90_utils.get_smem_layout_atom(w_layout, w_dtype, w_major_mode_size)
print(f"smem_layout_atom_w {smem_layout_atom_w}")
w_smem_layout_atom = warpgroup.make_smem_layout_atom(
smem_layout_atom_w,
w_dtype,
)
print(f"w_smem_layout_atom: {w_smem_layout_atom}")
w_smem_layout_staged = cute.tile_to_shape(
w_smem_layout_atom,
cute.append(w_smem_shape, ab_stage),
order=(0, 1, 2) if w_is_k_major else (1, 0, 2),
)
print(f"w_smem_layout_staged: {w_smem_layout_staged}")
h0_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
h0_major_mode_size = cta_tile_shape_mnk[2 if h0_is_k_major else 1]
smem_layout_atom_h0 = sm90_utils.get_smem_layout_atom(h0_layout, h0_dtype, h0_major_mode_size)
print(f"smem_layout_atom_h0: {smem_layout_atom_h0}")
h0_smem_layout_atom = warpgroup.make_smem_layout_atom(
smem_layout_atom_h0,
h0_dtype,
)
print(f"h0_smem_layout_atom: {h0_smem_layout_atom}")
h0_smem_layout_staged = cute.tile_to_shape(
h0_smem_layout_atom,
cute.append(h0_smem_shape, ab_stage),
order=(0, 1, 2) if h0_is_k_major else (1, 0, 2),
)
print(f"h0_smem_layout_staged: {h0_smem_layout_staged}")
h_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
h_major_mode_size = cta_tile_shape_mnk[2 if h_is_k_major else 1]
smem_layout_atom_h = sm90_utils.get_smem_layout_atom(h_layout, h_dtype, h_major_mode_size)
print(f"smem_layout_atom_h: {smem_layout_atom_h}")
h_smem_layout_atom = warpgroup.make_smem_layout_atom(
smem_layout_atom_h,
h_dtype,
)
print(f"h_smem_layout_atom: {h_smem_layout_atom}")
h_smem_layout_staged = cute.tile_to_shape(
h_smem_layout_atom,
cute.append(h_smem_shape, ab_stage),
order=(0, 1, 2) if h_is_k_major else (1, 0, 2),
)
print(f"h_smem_layout_staged: {h_smem_layout_staged}")
# v_new_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
# # default to n
# v_new_major_mode_size = cta_tile_shape_mnk[1]
# # smem_layout_atom_v_new = sm90_utils.get_smem_layout_atom()
# print(f"v_new_smem_shape: {v_new_smem_shape}")
return (
w_smem_layout_staged,
h0_smem_layout_staged,
h_smem_layout_staged,
)
def _make_tma_atoms_and_tensors(
self,
tensor: cute.Tensor,
smem_layout: cute.ComposedLayout,
smem_tile: Tuple[int, int],
mcast_dim: int = -1,
) -> Tuple[cute.CopyAtom, cute.Tensor]:
# do not support mcast_dim
op = (
cpasync.CopyBulkTensorTileG2SOp()
)
tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
op,
tensor,
smem_layout,
smem_tile
)
return tma_atom, tma_tensor
def _compute_grid(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mD: cute.Tensor,
) -> tuple[int, int, int]:
assert len(mA.shape) == 4 and len(mB.shape) == 4 and len(mD.shape) == 4
num_problems = mD.shape[2] * mD.shape[3]
self.B = mD.shape[2]
self.H = mD.shape[3]
# BT % tile_shape_mnk[0] and only the T dimension is serial
assert self.chunk_size % self.cta_tile_shape_mnk[0] == 0
problem_shape_ntile_mnl = (
cute.ceil_div(self.chunk_size, self.cta_tile_shape_mnk[0]),
cute.ceil_div(mB.shape[1], self.cta_tile_shape_mnk[1]),
num_problems
)
return problem_shape_ntile_mnl
def _make_ab_pipeline(
self,
tiled_mma: cute.TiledMma,
cluster_layout_vmnk: cute.Layout,
ab_pipeline_mbar_ptr: cute.Pointer
):
producer_cnt = 1
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
mcast_size = 1
consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt
)
mainloop_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=ab_pipeline_mbar_ptr,
num_stages=self.wh_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
)
return mainloop_pipeline
@cute.kernel
def kernel(
self,
tiled_mma_wh: cute.TiledMma,
tma_atom_w: cute.CopyAtom,
mW_mkbh: cute.Tensor,
w_smem_layout_staged: cute.ComposedLayout,
tma_atom_h0: Optional[cute.CopyAtom],
mH0_nkbh: Optional[cute.Tensor],
tiled_copy_h0_s2r: Optional[cute.TiledCopy],
tiled_copy_h_r2s: Optional[cute.TiledCopy],
h0_smem_layout_staged: cute.ComposedLayout,
h_smem_layout_staged: cute.ComposedLayout,
mO_mnbh: cute.Tensor, # for now temporary used for acc check
cta_layout_mnk: cute.Layout,
):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# prefetch Tma desc
if warp_idx == 0:
for tma_atom_desc in (tma_atom_w, tma_atom_h0):
if const_expr(tma_atom_desc is not None):
cpasync.prefetch_descriptor(tma_atom_desc)
bidx, bidy, bidz = cute.arch.block_idx()
tidx, _, _ = cute.arch.thread_idx()
b = bidz // self.H
h = bidz % H
tile_coord_mnkbh = (bidx, bidy, None, b, h)
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster)
print(f'cta_layout_mnk: {cta_layout_mnk}')
print(f"cluster_coord_mnk: {cluster_coord_mnk}")
# Alloc and init
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
wh_pipeline = self._make_ab_pipeline(
tiled_mma=tiled_mma_wh,
cluster_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)),
ab_pipeline_mbar_ptr=storage.wh_pipeline_array_ptr.data_ptr(),
)
if warp_idx == 0:
full : MbarrierArray = wh_pipeline.sync_object_full
empty : MbarrierArray = wh_pipeline.sync_object_empty
print(f"full: {full.arrive_count}")
print(f"full.tx_count: {full.tx_count}")
print(f"empty.arrive_count: {empty.arrive_count}")
print(f"empty.tx_count: {empty.tx_count}")
sW = storage.sW.get_tensor(
w_smem_layout_staged.outer, swizzle=w_smem_layout_staged.inner
)
sH0 = storage.sH0.get_tensor(
h0_smem_layout_staged.outer, swizzle=h0_smem_layout_staged.inner
)
sH = storage.sH.get_tensor(
h_smem_layout_staged.outer, swizzle=h_smem_layout_staged.inner
)
gW_mk = cute.local_tile(
mW_mkbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(1, None, 1)
)
cute.printf("gW_mk: {}", gW_mk)
if const_expr(mH0_nkbh is not None):
print(mH0_nkbh)
gH0_nk = cute.local_tile(
mH0_nkbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(None, 1, 1)
)
print(gH0_nk)
gO_mn = cute.local_tile(
mO_mnbh, self.cta_tile_shape_mnk, tile_coord_mnkbh, proj=(1, 1, None)
)
# partition global tensor for TiledMMA
warp_group_idx = cute.arch.make_warp_uniform(
tidx // self.num_threads_per_warp_group
)
warp_group_thread_layout = cute.make_layout(
self.mma_warp_groups, stride=self.num_threads_per_warp_group
)
# make fragments
thr_mma_wh = tiled_mma_wh.get_slice(warp_group_thread_layout(warp_group_idx))
tCsW = thr_mma_wh.partition_A(sW)
tCsH = thr_mma_wh.partition_B(sH)
tCgO = thr_mma_wh.partition_C(gO_mn)
# mma, mma_m, mma_k, pipe
tCrW = tiled_mma_wh.make_fragment_A(tCsW)
tCrH = tiled_mma_wh.make_fragment_B(tCsH)
print(f"tCrH: {tCrH}")
print(f"tCrW: {tCrW}")
acc_shape = tiled_mma_wh.partition_shape_C(
cute.select(self.cta_tile_shape_mnk, mode=[0,1])
)
acc = cute.make_fragment(acc_shape, self.acc_dtype)
w_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape)
w_cta_crd = cluster_coord_mnk[1]
# shared memory partition needs to group from [0, 2)
sw_for_tma_partition = cute.group_modes(sW, 0, 2)
gW_for_tma_partition = cute.group_modes(gW_mk, 0, 2)
tAsW, tAgW_mk = cute.nvgpu.cpasync.tma_partition(
tma_atom_w,
w_cta_crd,
w_cta_layout,
sw_for_tma_partition,
gW_for_tma_partition
)
if const_expr(mH0_nkbh is not None):
h0_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape)
h0_cta_crd = cluster_coord_mnk[0]
sH0_for_tma_partition = cute.group_modes(sH0, 0, 2)
gH0_for_tma_partition = cute.group_modes(gH0_nk, 0, 2)
tBsH0, tBgH0_nk = cute.nvgpu.cpasync.tma_partition(
tma_atom_h0,
h0_cta_crd,
h0_cta_layout,
sH0_for_tma_partition,
gH0_for_tma_partition
)
if cute.size(self.cluster_shape_mnk) > 1:
# NOTE: wait for all CTAs in the cluster
cute.arch.cluster_wait()
else:
# NOTE: only wait for the current CTA
cute.arch.sync_threads()
k_tile_cnt = cute.size(gW_mk, mode=[2])
prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.wh_stage, k_tile_cnt), 0)
wh_mainloop_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.wh_stage
)
if warp_idx == 0:
# prefetch TMA load
for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1):
wh_pipeline.producer_acquire(wh_mainloop_producer_state)
# count is used for gmem / index is used for smem
tAgW_k = tAgW_mk[(None, wh_mainloop_producer_state.count)]
tAsW_pipe = tAsW[(None, wh_mainloop_producer_state.index)]
tBgH0_k = tBgH0_nk[(None, wh_mainloop_producer_state.count)]
tBsH0_pipe = tBsH0[(None, wh_mainloop_producer_state.index)]
cute.copy(
tma_atom_w,
tAgW_k,
tAsW_pipe,
tma_bar_ptr=wh_pipeline.producer_get_barrier(
wh_mainloop_producer_state
),
mcast_mask=0,
)
cute.copy(
tma_atom_h0,
tBgH0_k,
tBsH0_pipe,
tma_bar_ptr=wh_pipeline.producer_get_barrier(
wh_mainloop_producer_state
),
mcast_mask=0,
)
wh_pipeline.producer_commit(wh_mainloop_producer_state)
wh_mainloop_producer_state.advance()
# Prologue MMAs
k_pipe_mmas = 1
wh_mainloop_consumer_read_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.wh_stage
)
wh_mainloop_consumer_release_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.wh_stage
)
peek_wh_full_status = cutlass.Boolean(1)
if wh_mainloop_consumer_read_state.count < k_tile_cnt:
peek_wh_full_status = wh_pipeline.consumer_try_wait(
wh_mainloop_consumer_read_state
)
tiled_mma_wh.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
num_k_blocks = cute.size(tCrW, mode=[2])
for k_tile in cutlass.range_constexpr(k_pipe_mmas):
wh_pipeline.consumer_wait(
wh_mainloop_consumer_read_state, peek_wh_full_status
)
cute.printf("sH layout: {}", sH.layout)
cute.printf("sH0 layout: {}", sH0.layout)
# convert h0 to bf16
tidx = warp_group_thread_layout(warp_group_idx)
thr_copy_h0_s2r = tiled_copy_h0_s2r.get_slice(tidx)
tXsH0 = thr_copy_h0_s2r.partition_S(sH0)
tXrH0 = cute.make_fragment_like(tXsH0)
cute.autovec_copy(tXsH0, tXrH0)
print(f"tXsH0: {tXsH0}")
print(f"tXrH0: {tXrH0}")
thr_copy_h_r2s = tiled_copy_h_r2s.get_slice(tidx)
tXsH = thr_copy_h_r2s.partition_D(sH)
tXrH = cute.make_fragment_like(tXsH)
print(f"tXsH: {tXsH}")
print(f"tXrH: {tXrH}")
tXrH.store(tXrH0.load().to(cutlass.BFloat16))
cute.autovec_copy(tXrH, tXsH)
cute.arch.sync_threads()
cute.printf("sH: {}", sH)
# why need fence? make sure acc visible to wgmma?
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
wh_mainloop_consumer_read_state.index,
)
tCrW_1phase = tCrW[k_block_coord]
tCrH_1phase = tCrH[k_block_coord]
cute.gemm(
tiled_mma_wh,
acc,
tCrW_1phase,
tCrH_1phase,
acc,
)
tiled_mma_wh.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
wh_mainloop_consumer_read_state.advance()
peek_wh_full_status = cutlass.Boolean(1)
if wh_mainloop_consumer_read_state.count < k_tile_cnt:
peek_wh_full_status = wh_pipeline.consumer_try_wait(
wh_mainloop_consumer_read_state
)
cute.printf("acc0 {}: {}", k_block_idx, acc)
cute.nvgpu.warpgroup.wait_group(0)
cute.printf("k_tile_cnt: {}",k_tile_cnt)
# mainloop
for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1):
# wait for TMA copies to complete
wh_pipeline.consumer_wait(wh_mainloop_consumer_read_state, peek_wh_full_status)
cute.nvgpu.warpgroup.fence()
for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_block_coord = (
None,
None,
k_block_idx,
wh_mainloop_consumer_read_state.index
)
tCrW_1phase = tCrW[k_block_coord]
tCrH_1phase = tCrH[k_block_coord]
cute.gemm(
tiled_mma_wh,
acc,
tCrW_1phase,
tCrH_1phase,
acc,
)
cute.nvgpu.warpgroup.commit_group()
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
cute.nvgpu.warpgroup.wait_group(k_pipe_mmas)
wh_pipeline.consumer_release(wh_mainloop_consumer_release_state)
cute.printf("ktile: {}", k_tile)
cute.printf("acc100{}: {}", k_tile, acc)
wh_mainloop_consumer_read_state.advance()
wh_mainloop_consumer_release_state.advance()
peek_wh_full_status = cutlass.Boolean(1)
if wh_mainloop_consumer_read_state.count < k_tile_cnt:
peek_wh_full_status = wh_pipeline.consumer_try_wait(
wh_mainloop_consumer_read_state
)
if warp_idx == 0 and wh_mainloop_producer_state.count < k_tile_cnt:
wh_pipeline.producer_acquire(wh_mainloop_producer_state)
tAgW_k = tAgW_mk[(None, wh_mainloop_producer_state.count)]
tAsW_pipe = tAsW[(None, wh_mainloop_producer_state.index)]
tBgH0_k = tBgH0_nk[(None, wh_mainloop_producer_state.count)]
tBsH0_pipe = tBsH0[(None, wh_mainloop_producer_state.index)]
cute.copy(
tma_atom_w,
tAgW_k,
tAsW_pipe,
tma_bar_ptr=wh_pipeline.producer_get_barrier(
wh_mainloop_producer_state
),
mcast_mask=0,
)
cute.copy(
tma_atom_h0,
tBgH0_k,
tBsH0_pipe,
tma_bar_ptr=wh_pipeline.producer_get_barrier(
wh_mainloop_producer_state
),
mcast_mask=0,
)
wh_pipeline.producer_commit(wh_mainloop_producer_state)
wh_mainloop_producer_state.advance()
cute.nvgpu.warpgroup.wait_group(0)
cute.printf("acc final: {}", acc)
cute.printf("tCgO: {}", tCgO)
cute.autovec_copy(acc, tCgO)
# epi
cute.printf("acc: {}", acc)
def _make_tiled_copy(self, layout : utils.LayoutEnum, dtype: cute.Numeric, is_async: bool = False, num_copy_bits=128):
# only support copy H0 which is col-major
assert layout.is_m_major_c()
assert is_async == False
assert dtype in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
# value layout:
elems = num_copy_bits // dtype.width
val_layout = cute.make_layout((elems, 1), stride=(1,0))
# thread layout:
# threads_per_rows = const_expr(major_mode_size // elems)
threads_per_col = const_expr(self.cta_tile_shape_mnk[1] // elems)
assert self.cta_tile_shape_mnk[1] == 64
assert self.threads_per_cta == 128
rows_per_block = self.threads_per_cta // threads_per_col
thr_layout = cute.make_layout((threads_per_col, rows_per_block), stride=(1, threads_per_col))
tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
print("construct:")
print(tiled_copy)
return tiled_copy
def run(w, k, u, h0, h, o, tile_shape_mn, cluster_shape_mnk, chunk_size):
kernel = ChunkGatedDeltaFwdHKernel(tile_shape_mn, cluster_shape_mnk, chunk_size)
w = from_dlpack(w, assumed_align=16)
k = from_dlpack(k, assumed_align=16)
u = from_dlpack(u, assumed_align=16)
h0 = from_dlpack(h0, assumed_align=16)
h = from_dlpack(h, assumed_align=16)
o = from_dlpack(o, assumed_align=16)
kernel(w, k, u, h0, h, o)
B, T, H, K, V = 1, 256, 1, 64, 64
chunk_size = 64
BT = chunk_size
dtype = torch.bfloat16
acc_dtype = torch.float32
device = 'cuda'
output_final_state = True
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
w = torch.randn((B, T, H, K), dtype=dtype, device=device)
w = w.new_ones((B, T, H, K), dtype=dtype)
k = torch.randn((B, T, H, K), dtype=dtype, device=device)
u = torch.randn((B, T, H, V), dtype=dtype, device=device)
h0 = torch.ones((B, H, K, V), dtype=torch.float32, device=device)
h = k.new_empty(B, NT, H, K, V)
o = torch.zeros((B, chunk_size, H, V), dtype=acc_dtype, device=device)
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
tile_shape_mn = (chunk_size, 64)
cluster_shape_mnk = (1, 1, 1)
assert chunk_size == tile_shape_mn[0]
# need to permute first ??
w = w.permute(1, 3, 0, 2) # T, K, B, H
k = k.permute(1, 3, 0, 2) # ???
h0 = h0.permute(3, 2, 0, 1) #V, K, B, H
u = u.permute(2, 3, 0, 1) #T, V, B, H
o = o.permute(1, 3, 0, 2) #chunk_size, V, B, H
# by cute convention default is mk, nk, mn
run(w, k, u, h0, h, o, tile_shape_mn, cluster_shape_mnk, chunk_size)
print(f"------")
print(o)
print(f"------")
o_ref = torch.einsum('tkbh,vkbh->tvbh', w[:chunk_size].to(acc_dtype), h0.to(acc_dtype))
print(o_ref)