cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Is right to read shared mem tensor directly?

Open amazingyyc opened this issue 9 months ago • 4 comments

I has a code like below:

 using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
 using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
 using g2s_copy_atom = Copy_Atom<g2s_copy_traits, T>;

 using G2SCopyA =
                               make_layout(make_shape(Int<16>{}, Int<2>{}),
                                           make_stride(Int<2>{}, Int<1>{})),
                               make_layout(make_shape(Int<1>{}, Int<8>{}))));
Tensor gA = make_tensor(make_gmem_ptr(A), make_layout(make_shape(Int<M>{}, Int<K>{}), make_stride(Int<K>{}, _1)));
Tensor sA = make_tensor(make_smem_ptr(smemA), make_layout(make_shape(Int<M>{}, Int<K>{}), make_stride(Int<K>{}, _1)));

G2SCopyA g2s_tiled_copy_a;
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(threadIdx.x);
auto tAgA_copy = g2s_thr_copy_a.partition_S(gA);
auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); 

cute::copy(g2s_tiled_copy_a, tAgA_copy, tAsA_copy);

// Than read sA directly, looks like the data is random?
for (int i = 0; i < size<0>(sA); ++i) {
  for (int j = 0; j < size<1>(sA); ++j) {
    // The value is not expected
    cute::print(sA[make_coord(i, j)]);

I want read the shared mem tensor in kernel after copy, looks like the data is not right. I want to know does it right to do read? And If not what's the right way read the data directly for shared mem tensor?

amazingyyc avatar May 21 '24 18:05 amazingyyc

Try using Int<1>{} instead of 1 for your strides. For that reason, I'm surprised that this compiles -- I would hope that would be caught by the instruction.

ccecka avatar May 21 '24 18:05 ccecka

Try using Int<1>{} instead of 1 for your strides. For that reason, I'm surprised that this compiles -- I would hope that would be caught by the instruction.

Sorry for the typo error, ignore it, It's _1 in code

amazingyyc avatar May 22 '24 02:05 amazingyyc

Are you printing on a single thread and left that out as well?

if (thread0()) {
  for (int i = 0; i < size<0>(sA); ++i) {
    for (int j = 0; j < size<1>(sA); ++j) {
      cute::print(sA[make_coord(i, j)]);

ccecka avatar May 22 '24 03:05 ccecka

Are you printing on a single thread and left that out as well?

if (thread0()) {
  for (int i = 0; i < size<0>(sA); ++i) {
    for (int j = 0; j < size<1>(sA); ++j) {
      cute::print(sA[make_coord(i, j)]);

May the demo code a little misleading, let's post complete code.

template <typename T, class SmemLayoutMask>
struct _SharedStorage {
  cute::array_aligned<bool, cute::cosize_v<SmemLayoutMask>> mask;

template <typename Trait>
__global__ void Demo(
    const bool* __restrict__ Mask, /*[batch_size, 1, seq_len, seq_len]*/
    const int64_t batch_size,
    const int64_t num_head,
    const int64_t seq_len) {

  using _SmemLayoutAtomMask = decltype(
      composition(Swizzle<2, 3, 3>{},
                  Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, Int<1>>>{}));

  using SmemLayoutMask = decltype(
      tile_to_shape(_SmemLayoutAtomMask{}, Shape<Int<BR>, Int<BC>>{}));

  using GmemTiledCopyMask = decltype(make_tiled_copy(
      Copy_Atom<DefaultCopy, bool>{}, Layout<Shape<_32, _4>, Stride<_4, _1>>{},
      Layout<Shape<_1, _8>>{}));  // Val layout, 16 vals per read

  // Shared memory.
  extern __shared__ char sBuf[];
  auto& shared_storage = *reinterpret_cast<_SharedStorage<SmemLayoutMask>*>(sBuf);

  Tensor _Mask =
      make_tensor(make_gmem_ptr(reinterpret_cast<const bool*>(Mask) +
                                batch * seq_len * seq_len),
                  make_shape(seq_len, seq_len), make_stride(seq_len, Int<1>{}));

  Tensor gMask =
      local_tile(_Mask, make_tile(Int<BR>{}, Int<BC>{}),
                 make_coord(grid_y, _));  // [BR, BC, seq_len / BC]

  Tensor sMask = make_tensor(make_smem_ptr(,
                             SmemLayoutMask{});  // [BR, BC]

  GmemTiledCopyMask gmem_tiled_copy_Mask;

  auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tid);

  Tensor gMask_to_sMask_src =
      gmem_thr_copy_Mask.partition_S(gMask);  // (CPY, CPY_M, CPY_K, _)
  Tensor gMask_to_sMask_dst =
      gmem_thr_copy_Mask.partition_D(sMask);  // (CPY, CPY_M, CPY_K)

  // Copy Mask to smem async.
  cute::copy(gmem_tiled_copy_Mask, gMask_to_sMask_src(_, _, _, _0{}),

  if (batch == 0 && head == 0 && grid_y == 0 && tid == 0) {
    for (int r = 0; r < BR; ++r) {
      for (int c = 0; c < BC; ++c) {
        // At here _Mask is expected but sMask always 0.
        printf("r:%d c:%d mask:%d\n", r, c, (int)(_Mask[make_coord(r, c)]));
        printf("r:%d c:%d mask:%d\n", r, c, (int)(sMask[make_coord(r, c)]));

I try copy a bool mask into shared memory using make_tiled_copy, print global mem tensor looks OK, but shared tensor always 0.

amazingyyc avatar May 22 '24 07:05 amazingyyc

I see, bool is a pretty slippery type. Does it mean a (1) single 8bit value_type or does it mean (2) a 8x 1bit packed subbyte type.

Your gmem tensor is being constructed with a bool*

make_gmem_ptr(reinterpret_cast<const bool*>(Mask) + batch * seq_len * seq_len), ...

which is a pointer with bool value_type and 8bit striding. This is (1).

Your smem tensor is being constructed with a cute::array_aligned<bool, cute::cosize_v<SmemLayoutMask>>::data(), which is also a bool*, but is constructed over a 1b packed subbyte array array_subbyte<bool, N>. This is (2) being viewed as a (1).

Then the copy partitioning is being constructed with a 1b packed subbyte type

Copy_Atom<DefaultCopy, bool>{}

but being applied to these 8bit bool* tensors. This is a (2) being applied to (1)s.

So it's a units consistency problem occurring from using bool* versus higher_level_type<bool>. I suggest using uint8_t as your mask type everywhere and interpreting that as 8 packed boolean bits as that appears to be what you actually intend. If you absolutely want a 1b packed boolean valued tensor, you can recast<bool>(my_uint8_tensor).

Aside: It's my opinion that subbyte_array<T,N>::data() should not exist. This is in analogue to std::vector<bool>::data() not existing. Removing that function would aid in catching this error, I believe.

ccecka avatar May 22 '24 16:05 ccecka

Thanks so much for kind reply. I will modify the mask tobe unit8_t

amazingyyc avatar May 23 '24 05:05 amazingyyc