cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST]The best way to get origin coord from a fragment by cute?

Open amazingyyc opened this issue 9 months ago • 5 comments

How can get the origin coord of register framgment for a shared mem matrix. I want use the cute::gemm to calculate matrix multiply and mask Like: C = Mask(A x B) So I use the cute api like below:

typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrA  = thr_mma.partition_fragment_A(sA);                           // (MMA,MMA_M,MMA_K)
Tensor tSrB  = thr_mma.partition_fragment_B(aB);                           // (MMA,MMA_N,MMA_K)
 
Tensor tSrC = partition_fragment_C(tiled_mma, Shape<Int<TileM>, Int<TileN>>{});  // MMA, MMA_M, MMA_K

For now I got the tmp result for C, before write back to global mem, I want do some other calculate for C, like Mask below:

for (int i = 0; i  < TileM; ++i) {
  for (int j = 0; j < TileN; ++j) {
    if (i < j) {
       // It's not right, just a demo.
       tSrC(i, j) = 0;
    }
  }
}

Or

for (int i = 0; i  < size(tSrC); ++i) {
  auto crd = GetOriginCoordOfTileMAndTileN(i);
  if (crd.x < crd.y) {
    tSrC[i] = 0;
  }
}

// Write to global mem.

The cute is less of doc, I reading the source code, but can't find a way to do like this...

amazingyyc avatar May 20 '24 18:05 amazingyyc

https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0y_predication.md

This documentation should let you do what you want to do

thakkarV avatar May 20 '24 18:05 thakkarV

https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0y_predication.m

This documentation should let you do what you want to do

The link is 404, can u send a available one?

amazingyyc avatar May 20 '24 18:05 amazingyyc

comment updated inline

thakkarV avatar May 20 '24 18:05 thakkarV

comment updated inline

Hi thakkarV thanks for your kind reply. I try use this method but got wired result, my code like below:

// MNK = [64, 32, 64]
using _SmemLayoutAtom = decltype(composition(
      Swizzle<2, 3, 3>{},
      Layout<Shape<_8, Int<32>>, Stride<Int<32>, _1>>{}));

using SmemLayoutA =
      decltype(tile_to_shape(_SmemLayoutAtom{}, Shape<Int<M>, Int<K>>{}));

using _MMAAtom = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using TiledMma =
      TiledMMA<_MMAAtom, Layout<Shape<_2, _2, _1>>, Tile<_32, _32, _16>>;

using SmemLayoutB =
      decltype(tile_to_shape(_SmemLayoutAtom{}, Shape<Int<N>, Int<K>>{}));

// Define shared tensor.
Tensor sA = make_tensor(make_smem_ptr(shared_storage.q.data()),
                        SmemLayoutA{});  // [M, K]
Tensor sB = make_tensor(make_smem_ptr(shared_storage.k.data()),
                        SmemLayoutB{});  // [N, K]

// Define tileMMA
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tid/*thread id*/);

// Define the register tensor for sQ/sK by tiled_mma.
Tensor rA = thr_mma.partition_fragment_A(sA);  // (MMA, MMA_M, MMA_K)
Tensor rB = thr_mma.partition_fragment_B(sB);  // (MMA, MMA_N, MMA_K)

// Define the result of rC = rA X rB
  Tensor rC = partition_fragment_C(
      thr_mma, make_shape(Int<M>{}, Int<N>{}));  // (MMA, MMA_M, MMA_N)

// Define identity_tensor
Tensor m_x_n_identity =
      make_identity_tensor(make_shape(Int<M>{}, Int<N>{}));
Tensor _origin_coord = thr_mma.partition_C(m_x_n_identity);

for (int i = 0; i < cute::size(m_x_n_identity); ++i) {
    cute::print(m_x_n_identity(i));
    cute::print("\n");
  }

for (int l = 0; l < size(rC); ++l) {
    cute::print(_origin_coord(l));
    cute::print("\n");
  }

The wired thing is:

cute::print(m_x_n_identity(i)); always got (40,44)
cute::print(_origin_coord(l)); when l is 0 got (40,44) but when l > 0 looks like got random number like:(139650861629480,139650861629484)

Does I wrong with this? If yes how can I fix it?

amazingyyc avatar May 21 '24 03:05 amazingyyc

Switch to cutlass 3.4.1 it will work, 3.5 is not right. Wired...

amazingyyc avatar May 21 '24 06:05 amazingyyc