[QST] CuteDSL in memory pass
Here I write a simple cuteDSL program in order to perform cast from fp32 tensor to bf16 tensor:
import argparse
import math
import torch
import triton
from typing import Tuple, Type, Callable, Optional, Union, Literal
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.cute.testing as testing
import cutlass.pipeline as pipeline
from cutlass.pipeline.helpers import MbarrierArray
import cutlass.torch as cutlass_torch
import cutlass.utils.hopper_helpers as sm90_utils
from cutlass.cute.runtime import from_dlpack
from cutlass import Int32, Boolean, const_expr
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
from typing import Tuple, Type
class S2R:
def __init__(self):
pass
@cute.jit
def make_smem_layouts(
self,
mA_dtype: cutlass.Numeric,
mA_layout: utils.LayoutEnum,
mB_dtype: cutlass.Numeric,
mB_layout: utils.LayoutEnum
):
a_is_k_major = mA_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
b_is_k_major = mB_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
# m = const_expr()
# n = const_expr()
a_smem_shape = (m, n)
a_major_mode_size = const_expr(n if a_is_k_major else m)
print(f"w_major_mode_size: {a_major_mode_size}")
smem_layout_atom_a = sm90_utils.get_smem_layout_atom(mA_layout, mA_dtype, a_major_mode_size)
print(f"smem_layout_atom_a {smem_layout_atom_a}")
a_smem_layout_atom = warpgroup.make_smem_layout_atom(
smem_layout_atom_a,
mA_dtype,
)
print(f"a_smem_layout_atom: {a_smem_layout_atom}")
a_smem_layout = cute.tile_to_shape(
a_smem_layout_atom,
a_smem_shape,
order=(0, 1) if a_is_k_major else (1, 0),
)
b_smem_shape = (m, n)
b_major_mode_size = const_expr(n if a_is_k_major else m)
smem_layout_atom_b = sm90_utils.get_smem_layout_atom(mB_layout, mB_dtype, b_major_mode_size)
print(f"smem_layout_atom_b: {smem_layout_atom_b}")
b_smem_layout_atom = warpgroup.make_smem_layout_atom(
smem_layout_atom_b,
mB_dtype,
)
print(f"b_smem_layout_atom: {b_smem_layout_atom}")
b_smem_layout = cute.tile_to_shape(
b_smem_layout_atom,
b_smem_shape,
order=(0, 1) if b_is_k_major else (1, 0),
)
print(f"bsmem_layout_staged: {b_smem_layout}")
return (
a_smem_layout,
b_smem_layout,
)
def make_tiled_copy(self, dtype: cute.Numeric, major_mode_size: int, num_copy_bits: int=128):
assert dtype in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]
threads_per_cta = 128
copy_op = cute.nvgpu.CopyUniversalOp()
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
# value layout:
elems = num_copy_bits // dtype.width
val_layout = cute.make_layout((elems, 1), stride=(1, 0))
# thread layout:
assert major_mode_size == 64
threads_per_col = const_expr(major_mode_size // elems)
rows_per_block = threads_per_cta // threads_per_col
thr_layout = cute.make_layout((threads_per_col, rows_per_block), stride=(1, threads_per_col))
tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
print(tiled_copy)
return tiled_copy
@cute.jit
def __call__(self, mA: cute.Tensor, mB: cute.Tensor, m: int, n: int):
threads_per_cta = 128
a_dtype = mA.element_type
b_dtype = mB.element_type
assert a_dtype == cutlass.Float32
assert b_dtype == cutlass.BFloat16
tiled_copy_a = self.make_tiled_copy(a_dtype, 64, 128)
tiled_copy_b = self.make_tiled_copy(b_dtype, 64, 64)
a_smem_layout, b_smem_layout = self.make_smem_layouts(mA.element_type, utils.LayoutEnum.from_tensor(mA), mB.element_type, utils.LayoutEnum.from_tensor(mB))
a_smem_size = cute.cosize(a_smem_layout)
b_smem_size = cute.cosize(b_smem_layout)
@cute.struct
class SharedStorage:
sA: cute.struct.Align[
cute.struct.MemRange[a_dtype, a_smem_size],
1024,
]
sB: cute.struct.Align[
cute.struct.MemRange[b_dtype, b_smem_size],
1024,
]
self.shared_storage = SharedStorage
self.kernel(
tiled_copy_a,
tiled_copy_b,
mA, mB,
a_smem_layout, b_smem_layout
).launch(
block=(threads_per_cta, 1, 1),
grid=(1,1,1)
)
@cute.kernel
def kernel(self, tiled_copy_a: cute.TiledCopy, tiled_copy_b: cute.TiledCopy, mA: cute.Tensor, mB: cute.Tensor, a_smem_layout: cute.ComposedLayout, b_smem_layout: cute.ComposedLayout):
tidx, _, _ = cute.arch.thread_idx()
thr_copy_a = tiled_copy_a.get_slice(tidx)
thr_copy_b = tiled_copy_b.get_slice(tidx)
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
sA = storage.sA.get_tensor(
a_smem_layout.outer, swizzle=a_smem_layout.inner
)
sB = storage.sB.get_tensor(
b_smem_layout.outer, swizzle=b_smem_layout.inner
)
tXgA = thr_copy_a.partition_S(mA)
tXgB = thr_copy_b.partition_D(mB)
print(f"tXgA: {tXgA}")
print(f"tXgB: {tXgB}")
tXsA = thr_copy_a.partition_S(sA)
tXsB = thr_copy_b.partition_D(sB)
print(f"sA layout: {sA.layout}")
print(f"sB layout: {sB.layout}")
print(f"mA: {mA}")
print(f"mB: {mB}")
print(f"tXsA: {tXsA}")
print(f"tXsB: {tXsB}")
# cute.autovec_copy(tXgA, tXsA)
cute.copy(thr_copy_a, tXgA, tXsA)
tXrA = cute.make_fragment_like(tXsA)
tXrB = cute.make_fragment_like(tXsB)
print(f"tXrA: {tXrA}")
tXrA = thr_copy_a.retile(tXrA)
print(f"tXrA: {tXrA}")
cute.autovec_copy(tXsA, tXrA)
tXrB.store(tXrA.load().to(cutlass.BFloat16))
cute.autovec_copy(tXrB, tXsB)
cute.autovec_copy(tXsB, tXgB)
m = 64
n = 16
dtype = torch.float32
x = torch.randn((n, m), dtype=dtype, device="cuda").transpose_(0, 1)
print(x.shape)
print(x.is_contiguous())
y = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda").transpose_(0,1)
print(y.shape)
print(y.is_contiguous())
keep = x.clone()
print(keep)
# print(2 * keep)
kernel = S2R()
kernel(from_dlpack(x, assumed_align=16), from_dlpack(y, assumed_align=16), m, n)
from cutlass.cute import KeepPTX, KeepCUBIN
hello_world_compiled_ptx_on = cute.compile[KeepPTX, KeepCUBIN](kernel,from_dlpack(x, assumed_align=16), from_dlpack(y, assumed_align=16), m, n)
# print(x)
print(y)
But the ptx seems confusing:
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-36006120
// Cuda compilation tools, release 12.9, V12.9.83
// Based on NVVM 20.0.0
//
.version 8.8
.target sm_90a
.address_size 64
// .globl kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0
.extern .shared .align 1024 .b8 __dynamic_shmem__0[];
.visible .entry kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0(
.param .align 8 .b8 kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_0[8],
.param .align 8 .b8 kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_1[8]
)
.reqntid 128, 1, 1
{
.reg .b16 %rs<13>;
.reg .b32 %r<29>;
.reg .f32 %f<13>;
.reg .b64 %rd<7>;
ld.param.u64 %rd1, [kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_1];
ld.param.u64 %rd2, [kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_0];
mov.u32 %r1, %tid.x;
shl.b32 %r2, %r1, 2;
shl.b32 %r3, %r1, 3;
and.b32 %r4, %r3, 896;
shl.b32 %r5, %r1, 4;
and.b32 %r6, %r5, 112;
or.b32 %r7, %r6, %r4;
shl.b32 %r8, %r1, 8;
and.b32 %r9, %r8, 2048;
or.b32 %r10, %r9, %r7;
mov.u32 %r11, __dynamic_shmem__0;
add.s32 %r12, %r11, %r10;
add.s32 %r13, %r11, %r3;
add.s32 %r14, %r13, 4096;
mul.wide.u32 %rd3, %r2, 4;
add.s64 %rd4, %rd2, %rd3;
shr.u32 %r15, %r12, 3;
and.b32 %r16, %r15, 112;
xor.b32 %r17, %r16, %r12;
ld.global.v4.f32 {%f1, %f2, %f3, %f4}, [%rd4];
st.shared.v4.f32 [%r17], {%f1, %f2, %f3, %f4};
add.s32 %r18, %r12, 1024;
shr.u32 %r19, %r18, 3;
and.b32 %r20, %r19, 112;
xor.b32 %r21, %r20, %r18;
ld.global.v4.f32 {%f5, %f6, %f7, %f8}, [%rd4+2048];
st.shared.v4.f32 [%r21], {%f5, %f6, %f7, %f8};
ld.shared.v4.f32 {%f9, %f10, %f11, %f12}, [%r17];
cvt.rn.bf16.f32 %rs1, %f12;
cvt.rn.bf16.f32 %rs2, %f11;
cvt.rn.bf16.f32 %rs3, %f10;
cvt.rn.bf16.f32 %rs4, %f9;
cvt.rn.bf16.f32 %rs5, %f8;
cvt.rn.bf16.f32 %rs6, %f7;
cvt.rn.bf16.f32 %rs7, %f6;
cvt.rn.bf16.f32 %rs8, %f5;
shr.u32 %r22, %r14, 3;
and.b32 %r23, %r22, 112;
xor.b32 %r24, %r23, %r14;
st.shared.v4.b16 [%r24], {%rs4, %rs3, %rs2, %rs1};
add.s32 %r25, %r13, 5120;
shr.u32 %r26, %r25, 3;
and.b32 %r27, %r26, 112;
xor.b32 %r28, %r27, %r25;
st.shared.v4.b16 [%r28], {%rs8, %rs7, %rs6, %rs5};
mul.wide.u32 %rd5, %r2, 2;
add.s64 %rd6, %rd1, %rd5;
ld.shared.v4.b16 {%rs9, %rs10, %rs11, %rs12}, [%r24];
st.global.v4.b16 [%rd6], {%rs9, %rs10, %rs11, %rs12};
st.global.v4.b16 [%rd6+1024], {%rs8, %rs7, %rs6, %rs5};
ret;
}
There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register. But My question is why only one ld.shared.v4 is optimized, it is clearly that we can eliminate both ld.shared.v4 in convert since it must go to register first!
There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register.
Do you mind elaborating this a bit more?
There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register.
Do you mind elaborating this a bit more?
I have 128 threads, tile shape: (64, 16, fp32) each threads will do 128bit copy( 4 elements) 64 * 16 / 128 * 4 = 2 -> 2 which means each threads will do 2 ld.v4 so I would expect code like this:
ld.global.v4 # first time load of gmem
st.global.v4 # first time store of gmem
ld.global.v4 # second time load of gmem
st.global.v4 # second time store of gmem
ld.shared.v4 # first time load of smem to register
ld.shared.v4 # second time load of smem to register
However it seems the code has done a fuse optimization(reuse register), so the second time's load from smem is swipped. But in my opinion, suppose we have a memory optimization pass to resuse register, why not reuse both? why only the second load of smem is optimized? Thanks in advance @fengxie
I suspect that for global memory, compiler can't assume no other thread can write to the same location so it shouldn't optimize out global memory load.
For shared memory, it's safe to to optimize it out? ( but I'm not convinced ) @brandon-yujie-sun to double check if it's a bug...
Thanks for reporting.