cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] TF32 tensor core bank conflicts

Open capybara-club opened this issue 10 months ago • 4 comments

Hello, I'm making some progress running some of these different layout. I copied a setup for SM80_16x8x8_F32TF32TF32F32_TN from the default_gemm_configuration.hpp file and I now have the following:

template <
	class TA, class TB, class TC,
	class Alpha, class Beta
>
void
gemm_nt(
	int m, int n, int k,
	Alpha alpha,
	TA const* A, int ldA,
	TB const* B, int ldB,
	Beta beta,
	TC      * C, int ldC,
	cudaStream_t stream = 0
) {
	using namespace cute;

	// Define shapes (dynamic)
	auto M = int(m);
	auto N = int(n);
	auto K = int(k);
	auto prob_shape = make_shape(M, N, K);

	// Define NT strides (mixed)
	auto dA = make_stride(Int<1>{}, ldA);
	auto dB = make_stride(Int<1>{}, ldB);
	auto dC = make_stride(Int<1>{}, ldC);

	// Define CTA tile sizes (static)
	auto bM = Int<128>{};
	auto bN = Int<128>{};
	auto bK = Int< 32>{};
	auto cta_tiler = make_shape(bM, bN, bK);
	auto bP = Int<3>{};

	auto swizzle_atom = composition(
		Swizzle<3,2,3>{},
		Layout<Shape <_32, _8>, Stride< _1,_32>>{}
	);

	auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
	auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
	auto sC = make_layout(make_shape(bM, bN));

	TiledCopy copyA = make_tiled_copy(
		Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, tfloat32_t>{},
		Layout<Shape<_16,_8>, Stride<_1,_16>>{},
		Layout<Shape< _4,_1>>{}
	);

	TiledCopy copyB = make_tiled_copy(
		Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, tfloat32_t>{},
		Layout<Shape<_16,_8>, Stride<_1,_16>>{},
		Layout<Shape< _4,_1>>{}
	);

	TiledMMA mmaC = 
		make_tiled_mma(
			SM80_16x8x8_F32TF32TF32F32_TN{},
			Layout<Shape<_2,_2,_1>, Stride<_2, _1, _1>>{},
			Tile<_32,_32,_8>{}
		);

#if 0
  print(copyA);
  print(copyB);
  print(mmaC);
#endif

#if 0
  print_latex(copyA);
  print_latex(copyB);
  print_latex(mmaC);
#endif
	
	int smem_size = int(sizeof(SharedStorage<tfloat32_t, tfloat32_t, decltype(sA), decltype(sB)>));
	dim3 dimBlock(size(mmaC));
	dim3 dimGrid(
		size(ceil_div(M, bM)),
		size(ceil_div(N, bN))
	);

	auto kernel_fptr = 
		gemm_device<
			decltype(prob_shape), decltype(cta_tiler),
			float, decltype(dA), decltype(sA), decltype(copyA),
			float, decltype(dB), decltype(sB), decltype(copyB),
			float, decltype(dC), decltype(sC), decltype(mmaC),
			decltype(alpha), decltype(beta)
		>;

	cudaFuncSetAttribute(
		kernel_fptr,
		cudaFuncAttributeMaxDynamicSharedMemorySize, 
		smem_size
	);
	
	cudaFuncSetAttribute(
		kernel_fptr,
		cudaFuncAttributePreferredSharedMemoryCarveout, 
		100
	);

	kernel_fptr<<<dimGrid, dimBlock, smem_size, stream>>>(
		prob_shape, cta_tiler,
		A, dA, sA, copyA,
		B, dB, sB, copyB,
		C, dC, sC, mmaC,
		alpha, beta
	);
}

I ran both cublas and this for the same M=N=K=4096 and I got the following in profiler metrics:

For cublas (really cool the kernel just comes straight from cutlass)

  void cutlass::Kernel2<cutlass_80_tensorop_s1688gemm_256x128_16x3_nt_align4>(T1::Params) (256, 2, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum                        0
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                        0
    -------------------------------------------------------- ----------- ------------

And for mine borrowed from the default_gemm_configuration:

void gemm_device<**A lot of templates**>(T1, T2, const T3 *, T4, T5, T6, const T7 *, T8, T9, T10, T11 *, T12, T13, T14, T15, T16) (32, 32, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
    Section: Command line profiler metrics
    -------------------------------------------------------- ----------- ------------
    Metric Name                                              Metric Unit Metric Value
    -------------------------------------------------------- ----------- ------------
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum               67,243,740
    l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum                        0
    -------------------------------------------------------- ----------- ------------

I'm trying to find a good layout as a base to write something more complex inside the device function. This kernel runs pretty fast but I'm wondering if there is a known layout like the one that cublas is pulling (or others) that avoid these conflicts. For one I read Lei Mao's blog post on cute swizzles (https://leimao.github.io/blog/CuTe-Swizzle/) and I was surprised the MBase of the swizzle is 3 where the vector size of the f32 is 4, so log_base_2(4) = 2, maybe there are many solutions?

I've also tried to get chatGPT to tell me something about ldmatrix and its pretty awful. Is there an inherent reason why tf32 tensor core examples seem to avoid SM75_U32x4_LDSM_N and SM75_U16x8_LDSM_T? I've looked at default_gemm_configuration.h (the one inside the cutlass include) and it seems like only the kernels with sm_90 even use LDSM (I only see it in the sm90_common.inl) which would imply that cublas doesn't use it either?

Is there an easy way to pull this data out of those cutlass kernels? Or is there a known adjustment to the above that is just better for avoiding conflicts? I recommend a cute "model zoo" file for all the known good ones to live.

Thanks for the help!

capybara-club avatar Feb 15 '25 20:02 capybara-club

You are writing a kernel for MN major A and B inputs, meanwhile the partitioning pattern required by the MMA sources A and B inputs in a K major format. This implies that a transpose is required at some point between gmem -> tensor core.

2 places to do this transpose: during gmem->smem or during smem->rmem. LDSM_T destination partitioner is not compatible with the MMA source partitioner. This means for tf32, you either have to give up vectorization in GMEM and load to SMEM to a layout that is compatible with LDSM OR give up vectorization in smem and use 32b LDS to load them instead. Usually, the latter is the preferred strategy.

I was surprised the MBase of the swizzle is 3 where the vector size of the f32 is 4, so log_base_2(4) = 2, maybe there are many solutions?

that's right, you should only beed an MBase of 2 for TF32 inputs for 128b vectors.

I recommend a cute "model zoo" file for all the known good ones to live.

entirety of CUTLASS 3.x is a modle zoo for CuTe :)

Is there an easy way to pull this data out of those cutlass kernels?

Reading the builders and test cases is your best bet here.

thakkarV avatar Feb 16 '25 16:02 thakkarV

I see. Now i'm reading the ptx doc diagrams and it makes sense. I didn't realize the output was N-major for these instruction also. So expecting M and N major inputs and M major outputs requires transposing for input and then also either paying a cost to write uncoalesced or doing a transpose in smem? Does a permutation and/or tiling wider in the column direction when storing N-major help alleviate uncoalesced writes as the threads wrap to different rows?

that's right, you should only beed an MBase of 2 for TF32 inputs for 128b vectors.

The Lie Mao cute swizzle post has him calculating the swizzle parameters based on the row major format. Just to clarify some confusion, the MBase can be 2, the BBits should be log_base_2(32 * 1) - MBase, so BBits should be 3? But since he quotes in row major, what should the SShift be from my layout above? Because of the issue you illustrated with the transpose does this have to stay something different to avoid breaking the gmem vectorization?

entirety of CUTLASS 3.x is a modle zoo for CuTe :)

No no, I respect the template setup being very clever and D.R.Y. but a zoo has separate exhibits each with a species label.

Thank you very much for your time.

capybara-club avatar Feb 16 '25 17:02 capybara-club

The cublas one is essentially ex 14 of cutlass. if you do col x row, it uses 32b lds to transpose when loading from smem to rmem instead of ldsm

hwu36 avatar Feb 19 '25 03:02 hwu36

@hwu36 Thanks that makes sense to me now.

I changed the swizzle to <2,3,2> and I'm down to ~6 conflicts per block now. What looks wrong about the layout I have? I've played with a ton of parameters and I pulled up that cute swizzle preview tool. Based on the layout from the ptx isa for A if I was hand coding it I would load 4 columns of 32 rows contiguously (M major). Then I would have threads 0,4,8..28 read from banks 0 to 7, threads 1,5,9..29 read from banks 8-15 etc in one go and then swizzle those chunks so that threads 1,5,9..29 read banks 0 to 7, threads 0,4,8..28 read banks 8-15. I thought that was what I was doing with this swizzle pattern (2,3,2) that looks like:

Image

but over 8 columns instead of 4. I also tried changing the tiled copy from 16x8 to 32x4 and the bank conflicts increased modestly and the kernel got pretty slow. Any idea what i'm missing?

Thanks!

capybara-club avatar Feb 20 '25 21:02 capybara-club

Ok, I roughly followed an example from f16 from you guys that had a nested shape to compose the swizzle with and I did

  auto swizzle_atom = composition(
		Swizzle<3,2,4>{},
		Layout<
      Shape<_8, Shape<_4,_8>>,
      Stride<_1, Stride<_8,_32>>
    >{}
	);

It's now reporting zero conflicts and is decently fast.

@ccecka In case you publish more cute examples (NT TF32). This was hard fought for me. It's probably embarrassingly easy for you guys, but just in case.

Full set up:

	auto M = int(m);
	auto N = int(n);
	auto K = int(k);
	auto prob_shape = make_shape(M, N, K);

	// Define NT strides (mixed)
	auto dA = make_stride(Int<1>{}, ldA);
	auto dB = make_stride(Int<1>{}, ldB);
	auto dC = make_stride(Int<1>{}, ldC);

	// Define CTA tile sizes (static)
	auto bM = Int<128>{};
	auto bN = Int<128>{};
	auto bK = Int< 32>{};
	auto cta_tiler = make_shape(bM, bN, bK);
	auto bP = Int<3>{};

  // Best I found with simple stride
	// auto swizzle_atom = composition(
	// 	Swizzle<2,3,2>{},
	// 	Layout<
  //     Shape<_32,_8>,
  //     Stride<_1,_32>
  //   >{}
	// );

  // Best I found with complex stride, no conflicts
  auto swizzle_atom = composition(
		Swizzle<3,2,4>{},
		Layout<
      Shape<_8, Shape<_4,_8>>,
      Stride<_1, Stride<_8,_32>>
    >{}
	);

	auto sA = tile_to_shape(swizzle_atom, make_shape(bM,bK,bP));
	auto sB = tile_to_shape(swizzle_atom, make_shape(bN,bK,bP));
	auto sC = make_layout(make_shape(bM, bN));

	TiledCopy copyA = make_tiled_copy(
		Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, tfloat32_t>{},
		Layout<Shape<_16,_8>, Stride<_1,_16>>{},
		Layout<Shape< _4,_1>>{}
	);

	TiledCopy copyB = make_tiled_copy(
		Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, tfloat32_t>{},
		Layout<Shape<_16, _8>, Stride<_1,_16>>{},
		Layout<Shape< _4,_1>>{}
	);

	TiledMMA mmaC = 
		make_tiled_mma(
			SM80_16x8x8_F32TF32TF32F32_TN{},
			Layout<Shape<_2,_2,_1>, Stride<_2, _1, _1>>{},
			Tile<_32,_32,_16>{}
		);

capybara-club avatar Feb 21 '25 22:02 capybara-club

That's awesome! Nice work.

I agree with you about the zoo comments and that was the original design behind DefaultGemmConfiguration -- to record layout engineering efforts. Unfortunately, that pattern was sacrificed for the "builders" with Hopper/Blackwell simply having a reduced parameter space of configurations.

ccecka avatar Feb 21 '25 22:02 ccecka

@cceka thanks for the kind words!

I see, thats a shame. A full worked cute example is so nice and hackable off the shelf. Maybe the rtx 5 series going to TMA as well might make a lot of this swizzardry disappear.

capybara-club avatar Feb 21 '25 23:02 capybara-club

Yes RTX 5000 does support TMA (non multicast)

thakkarV avatar Feb 21 '25 23:02 thakkarV