cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] CuteDSL in memory pass

Open Dingjifeng opened this issue 1 month ago • 1 comments

Here I write a simple cuteDSL program in order to perform cast from fp32 tensor to bf16 tensor:


import argparse
import math
import torch
import triton
from typing import Tuple, Type, Callable, Optional, Union, Literal

import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.cute.testing as testing
import cutlass.pipeline as pipeline

from cutlass.pipeline.helpers import MbarrierArray

import cutlass.torch as cutlass_torch
import cutlass.utils.hopper_helpers as sm90_utils
from cutlass.cute.runtime import from_dlpack
from cutlass import Int32, Boolean, const_expr
from cutlass.cute.nvgpu import cpasync, warp, warpgroup

from typing import Tuple, Type

class S2R:
    def __init__(self):
        pass
    
     
    @cute.jit
    def make_smem_layouts(
        self,
        mA_dtype: cutlass.Numeric,
        mA_layout: utils.LayoutEnum,
        mB_dtype: cutlass.Numeric,
        mB_layout: utils.LayoutEnum
    ):
        a_is_k_major = mA_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
        b_is_k_major = mB_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K
        
        # m = const_expr()
        # n = const_expr()
        a_smem_shape = (m, n)
        
        a_major_mode_size = const_expr(n if a_is_k_major else m)
        print(f"w_major_mode_size: {a_major_mode_size}")
        
        smem_layout_atom_a = sm90_utils.get_smem_layout_atom(mA_layout, mA_dtype, a_major_mode_size)
        print(f"smem_layout_atom_a {smem_layout_atom_a}")
        
        a_smem_layout_atom = warpgroup.make_smem_layout_atom(
            smem_layout_atom_a,
            mA_dtype,
        )
        print(f"a_smem_layout_atom: {a_smem_layout_atom}")
        a_smem_layout = cute.tile_to_shape(
            a_smem_layout_atom,
            a_smem_shape,
            order=(0, 1) if a_is_k_major else (1, 0),
        )
        
        b_smem_shape = (m, n)
        b_major_mode_size = const_expr(n if a_is_k_major else m)
        smem_layout_atom_b = sm90_utils.get_smem_layout_atom(mB_layout, mB_dtype, b_major_mode_size)
        
        print(f"smem_layout_atom_b: {smem_layout_atom_b}")
        b_smem_layout_atom = warpgroup.make_smem_layout_atom(
            smem_layout_atom_b,
            mB_dtype,
        )
        print(f"b_smem_layout_atom: {b_smem_layout_atom}")
        
        b_smem_layout = cute.tile_to_shape(
            b_smem_layout_atom,
            b_smem_shape,
            order=(0, 1) if b_is_k_major else (1, 0),
        )

        print(f"bsmem_layout_staged: {b_smem_layout}")
        
        
        return (
            a_smem_layout,
            b_smem_layout,
        )
        
    def make_tiled_copy(self, dtype: cute.Numeric, major_mode_size: int, num_copy_bits: int=128):
        assert dtype in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]
        threads_per_cta = 128

        copy_op = cute.nvgpu.CopyUniversalOp()
        copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
        
        # value layout:
        elems = num_copy_bits // dtype.width
        val_layout = cute.make_layout((elems, 1), stride=(1, 0))
        
        # thread layout: 
        assert major_mode_size == 64
        threads_per_col = const_expr(major_mode_size // elems)
        rows_per_block = threads_per_cta // threads_per_col
        thr_layout = cute.make_layout((threads_per_col, rows_per_block), stride=(1, threads_per_col))
        
        tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
        
        print(tiled_copy)
        return tiled_copy
        
    
    @cute.jit
    def __call__(self, mA: cute.Tensor, mB: cute.Tensor, m: int, n: int):
        threads_per_cta = 128
        a_dtype = mA.element_type
        b_dtype = mB.element_type
        assert a_dtype == cutlass.Float32
        assert b_dtype == cutlass.BFloat16
        tiled_copy_a = self.make_tiled_copy(a_dtype, 64, 128)
        tiled_copy_b = self.make_tiled_copy(b_dtype, 64, 64)
        a_smem_layout, b_smem_layout = self.make_smem_layouts(mA.element_type, utils.LayoutEnum.from_tensor(mA), mB.element_type, utils.LayoutEnum.from_tensor(mB))
        a_smem_size = cute.cosize(a_smem_layout)
        b_smem_size = cute.cosize(b_smem_layout)
        @cute.struct
        class SharedStorage:
            sA: cute.struct.Align[
                cute.struct.MemRange[a_dtype, a_smem_size],
                1024,
            ]
            sB: cute.struct.Align[
                cute.struct.MemRange[b_dtype, b_smem_size],
                1024,
            ]
            
        self.shared_storage = SharedStorage
        
        self.kernel(
            tiled_copy_a,
            tiled_copy_b,    
            mA, mB, 
            a_smem_layout, b_smem_layout
        ).launch(
            block=(threads_per_cta, 1, 1),
            grid=(1,1,1)
        )
    
    @cute.kernel
    def kernel(self, tiled_copy_a: cute.TiledCopy, tiled_copy_b: cute.TiledCopy, mA: cute.Tensor, mB: cute.Tensor, a_smem_layout: cute.ComposedLayout, b_smem_layout: cute.ComposedLayout):
        tidx, _, _ = cute.arch.thread_idx()
        thr_copy_a = tiled_copy_a.get_slice(tidx)
        thr_copy_b = tiled_copy_b.get_slice(tidx)
        
        smem = cutlass.utils.SmemAllocator()
        
        storage = smem.allocate(self.shared_storage)
        
        sA = storage.sA.get_tensor(
            a_smem_layout.outer, swizzle=a_smem_layout.inner
        )
        
        sB = storage.sB.get_tensor(
            b_smem_layout.outer, swizzle=b_smem_layout.inner
        )
        
        tXgA = thr_copy_a.partition_S(mA)
        tXgB = thr_copy_b.partition_D(mB)
        
        print(f"tXgA: {tXgA}")
        print(f"tXgB: {tXgB}")
        
        tXsA = thr_copy_a.partition_S(sA)
        tXsB = thr_copy_b.partition_D(sB)
    
        print(f"sA layout: {sA.layout}")
        print(f"sB layout: {sB.layout}")
        print(f"mA: {mA}")
        print(f"mB: {mB}")
        print(f"tXsA: {tXsA}")
        print(f"tXsB: {tXsB}")
        
        # cute.autovec_copy(tXgA, tXsA)
        cute.copy(thr_copy_a, tXgA, tXsA)
        
        tXrA = cute.make_fragment_like(tXsA)
        tXrB = cute.make_fragment_like(tXsB)
        
        print(f"tXrA: {tXrA}")
        tXrA = thr_copy_a.retile(tXrA)
        print(f"tXrA: {tXrA}")
        cute.autovec_copy(tXsA, tXrA)
    
        tXrB.store(tXrA.load().to(cutlass.BFloat16))
        
        cute.autovec_copy(tXrB, tXsB)
        cute.autovec_copy(tXsB, tXgB)
        
m = 64
n = 16
dtype = torch.float32
x = torch.randn((n, m), dtype=dtype, device="cuda").transpose_(0, 1)

print(x.shape)
print(x.is_contiguous())

y = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda").transpose_(0,1)
print(y.shape)
print(y.is_contiguous())

keep = x.clone()
print(keep)
# print(2 * keep)
kernel = S2R()
kernel(from_dlpack(x, assumed_align=16), from_dlpack(y, assumed_align=16), m, n)

from cutlass.cute import KeepPTX, KeepCUBIN
hello_world_compiled_ptx_on = cute.compile[KeepPTX, KeepCUBIN](kernel,from_dlpack(x, assumed_align=16), from_dlpack(y, assumed_align=16), m, n)
# print(x)
print(y)

But the ptx seems confusing:

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-36006120
// Cuda compilation tools, release 12.9, V12.9.83
// Based on NVVM 20.0.0
//

.version 8.8
.target sm_90a
.address_size 64

	// .globl	kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0
.extern .shared .align 1024 .b8 __dynamic_shmem__0[];

.visible .entry kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0(
	.param .align 8 .b8 kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_0[8],
	.param .align 8 .b8 kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_1[8]
)
.reqntid 128, 1, 1
{
	.reg .b16 	%rs<13>;
	.reg .b32 	%r<29>;
	.reg .f32 	%f<13>;
	.reg .b64 	%rd<7>;

	ld.param.u64 	%rd1, [kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_1];
	ld.param.u64 	%rd2, [kernel_cutlass_kernel___main__S2R_object_at__TiledCopy_TilerMN64181_TVLayouttiled128441_CopyAtom_ThrID10_TVLayoutSrc1401_TVLayoutDst1401_Valuetypef32_TiledCopy_TilerMN64181_TVLayouttiled1_0_param_0];
	mov.u32 	%r1, %tid.x;
	shl.b32 	%r2, %r1, 2;
	shl.b32 	%r3, %r1, 3;
	and.b32  	%r4, %r3, 896;
	shl.b32 	%r5, %r1, 4;
	and.b32  	%r6, %r5, 112;
	or.b32  	%r7, %r6, %r4;
	shl.b32 	%r8, %r1, 8;
	and.b32  	%r9, %r8, 2048;
	or.b32  	%r10, %r9, %r7;
	mov.u32 	%r11, __dynamic_shmem__0;
	add.s32 	%r12, %r11, %r10;
	add.s32 	%r13, %r11, %r3;
	add.s32 	%r14, %r13, 4096;
	mul.wide.u32 	%rd3, %r2, 4;
	add.s64 	%rd4, %rd2, %rd3;
	shr.u32 	%r15, %r12, 3;
	and.b32  	%r16, %r15, 112;
	xor.b32  	%r17, %r16, %r12;
	ld.global.v4.f32 	{%f1, %f2, %f3, %f4}, [%rd4];
	st.shared.v4.f32 	[%r17], {%f1, %f2, %f3, %f4};
	add.s32 	%r18, %r12, 1024;
	shr.u32 	%r19, %r18, 3;
	and.b32  	%r20, %r19, 112;
	xor.b32  	%r21, %r20, %r18;
	ld.global.v4.f32 	{%f5, %f6, %f7, %f8}, [%rd4+2048];
	st.shared.v4.f32 	[%r21], {%f5, %f6, %f7, %f8};
	ld.shared.v4.f32 	{%f9, %f10, %f11, %f12}, [%r17];
	cvt.rn.bf16.f32 	%rs1, %f12;
	cvt.rn.bf16.f32 	%rs2, %f11;
	cvt.rn.bf16.f32 	%rs3, %f10;
	cvt.rn.bf16.f32 	%rs4, %f9;
	cvt.rn.bf16.f32 	%rs5, %f8;
	cvt.rn.bf16.f32 	%rs6, %f7;
	cvt.rn.bf16.f32 	%rs7, %f6;
	cvt.rn.bf16.f32 	%rs8, %f5;
	shr.u32 	%r22, %r14, 3;
	and.b32  	%r23, %r22, 112;
	xor.b32  	%r24, %r23, %r14;
	st.shared.v4.b16 	[%r24], {%rs4, %rs3, %rs2, %rs1};
	add.s32 	%r25, %r13, 5120;
	shr.u32 	%r26, %r25, 3;
	and.b32  	%r27, %r26, 112;
	xor.b32  	%r28, %r27, %r25;
	st.shared.v4.b16 	[%r28], {%rs8, %rs7, %rs6, %rs5};
	mul.wide.u32 	%rd5, %r2, 2;
	add.s64 	%rd6, %rd1, %rd5;
	ld.shared.v4.b16 	{%rs9, %rs10, %rs11, %rs12}, [%r24];
	st.global.v4.b16 	[%rd6], {%rs9, %rs10, %rs11, %rs12};
	st.global.v4.b16 	[%rd6+1024], {%rs8, %rs7, %rs6, %rs5};
	ret;

}

There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register. But My question is why only one ld.shared.v4 is optimized, it is clearly that we can eliminate both ld.shared.v4 in convert since it must go to register first!

Dingjifeng avatar Dec 05 '25 08:12 Dingjifeng

There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register.

Do you mind elaborating this a bit more?

fengxie avatar Dec 08 '25 16:12 fengxie

There is only one ld.shared.v4, it seems to do a memory optimize pass to reuse the register.

Do you mind elaborating this a bit more?

I have 128 threads, tile shape: (64, 16, fp32) each threads will do 128bit copy( 4 elements) 64 * 16 / 128 * 4 = 2 -> 2 which means each threads will do 2 ld.v4 so I would expect code like this:

ld.global.v4 # first time load of gmem

st.global.v4 # first time store of gmem

ld.global.v4 # second time load of gmem

st.global.v4 # second time store of gmem

ld.shared.v4 # first time load of smem to register

ld.shared.v4 # second time load of smem to register

However it seems the code has done a fuse optimization(reuse register), so the second time's load from smem is swipped. But in my opinion, suppose we have a memory optimization pass to resuse register, why not reuse both? why only the second load of smem is optimized? Thanks in advance @fengxie

Dingjifeng avatar Dec 18 '25 15:12 Dingjifeng

I suspect that for global memory, compiler can't assume no other thread can write to the same location so it shouldn't optimize out global memory load.

For shared memory, it's safe to to optimize it out? ( but I'm not convinced ) @brandon-yujie-sun to double check if it's a bug...

Thanks for reporting.

fengxie avatar Dec 19 '25 07:12 fengxie