cutlass
cutlass copied to clipboard
[QST]The best way to get origin coord from a fragment by cute?
How can get the origin coord of register framgment for a shared mem matrix. I want use the cute::gemm to calculate matrix multiply and mask Like:
C = Mask(A x B)
So I use the cute api like below:
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tSrB = thr_mma.partition_fragment_B(aB); // (MMA,MMA_N,MMA_K)
Tensor tSrC = partition_fragment_C(tiled_mma, Shape<Int<TileM>, Int<TileN>>{}); // MMA, MMA_M, MMA_K
For now I got the tmp result for C, before write back to global mem, I want do some other calculate for C, like Mask below:
for (int i = 0; i < TileM; ++i) {
for (int j = 0; j < TileN; ++j) {
if (i < j) {
// It's not right, just a demo.
tSrC(i, j) = 0;
}
}
}
Or
for (int i = 0; i < size(tSrC); ++i) {
auto crd = GetOriginCoordOfTileMAndTileN(i);
if (crd.x < crd.y) {
tSrC[i] = 0;
}
}
// Write to global mem.
The cute is less of doc, I reading the source code, but can't find a way to do like this...
https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0y_predication.md
This documentation should let you do what you want to do
https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0y_predication.m
This documentation should let you do what you want to do
The link is 404, can u send a available one?
comment updated inline
comment updated inline
Hi thakkarV thanks for your kind reply. I try use this method but got wired result, my code like below:
// MNK = [64, 32, 64]
using _SmemLayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{},
Layout<Shape<_8, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutA =
decltype(tile_to_shape(_SmemLayoutAtom{}, Shape<Int<M>, Int<K>>{}));
using _MMAAtom = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using TiledMma =
TiledMMA<_MMAAtom, Layout<Shape<_2, _2, _1>>, Tile<_32, _32, _16>>;
using SmemLayoutB =
decltype(tile_to_shape(_SmemLayoutAtom{}, Shape<Int<N>, Int<K>>{}));
// Define shared tensor.
Tensor sA = make_tensor(make_smem_ptr(shared_storage.q.data()),
SmemLayoutA{}); // [M, K]
Tensor sB = make_tensor(make_smem_ptr(shared_storage.k.data()),
SmemLayoutB{}); // [N, K]
// Define tileMMA
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tid/*thread id*/);
// Define the register tensor for sQ/sK by tiled_mma.
Tensor rA = thr_mma.partition_fragment_A(sA); // (MMA, MMA_M, MMA_K)
Tensor rB = thr_mma.partition_fragment_B(sB); // (MMA, MMA_N, MMA_K)
// Define the result of rC = rA X rB
Tensor rC = partition_fragment_C(
thr_mma, make_shape(Int<M>{}, Int<N>{})); // (MMA, MMA_M, MMA_N)
// Define identity_tensor
Tensor m_x_n_identity =
make_identity_tensor(make_shape(Int<M>{}, Int<N>{}));
Tensor _origin_coord = thr_mma.partition_C(m_x_n_identity);
for (int i = 0; i < cute::size(m_x_n_identity); ++i) {
cute::print(m_x_n_identity(i));
cute::print("\n");
}
for (int l = 0; l < size(rC); ++l) {
cute::print(_origin_coord(l));
cute::print("\n");
}
The wired thing is:
cute::print(m_x_n_identity(i)); always got (40,44)
cute::print(_origin_coord(l)); when l is 0 got (40,44) but when l > 0 looks like got random number like:(139650861629480,139650861629484)
Does I wrong with this? If yes how can I fix it?
Switch to cutlass 3.4.1 it will work, 3.5 is not right. Wired...