cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] [CuTeDSL] Unexpected behavior with async copy on ampere

Open simveit opened this issue 6 months ago • 5 comments

Hello,

I am trying to implement a transpose kernel and face the problem that cute.copy seems to only copy one element per row to shared memory.

# BEFORE TRANSFER
thread_global_tile_src = raw_ptr(0x0000759d4e800000: f32, gmem, align<8>) o (4,2):(8,1) = 
  ( 0.659961, -0.130726, -1.193286, -0.525738, -0.023081, 0.912704, -1.268611, -0.718438 )
thread_shared_tile_src = raw_ptr(0x0000000000000000: f32, smem, align<1024>) o (4,2):(1,4) = 
  ( 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000 )
# AFTER TRANSFER
thread_global_tile_src = raw_ptr(0x0000759d4e800000: f32, gmem, align<8>) o (4,2):(8,1) = 
  ( 0.659961, -0.130726, -1.193286, -0.525738, -0.023081, 0.912704, -1.268611, -0.718438 )
thread_shared_tile_src = raw_ptr(0x0000000000000000: f32, smem, align<1024>) o (4,2):(1,4) = 
  ( 0.659961, 0.000000, 0.000000, 0.000000, -0.023081, 0.000000, 0.000000, 0.000000 )
thread_shared_tile_dst = raw_ptr(0x0000000000000000: f32, smem, align<1024>) o (4,2):(1,4) = 
  ( 0.659961, 0.000000, 0.000000, 0.000000, -0.023081, 0.000000, 0.000000, 0.000000 )
thread_global_tile_dst_transposed = raw_ptr(0x0000759d4e800200: f32, gmem, align<16>) o (4,2):(1,8) = 
  ( 0.659961, 0.000000, 0.000000, 0.000000, -0.023081, 0.000000, 0.000000, 0.000000 )

Could somebody help on this?

Minimal example to reproduce behavior: https://gist.github.com/simveit/66ff1ee4ee5a6b174e446d92d6fa40ef

simveit avatar Jun 11 '25 19:06 simveit

I think its because a (1,1) thread layout local_partition is only giving each thread access to one element per dimension of the tensor rather than the full tensor data you expect

the issue is probably with the thread partitioning setup

Could you share the actual thread layout dimensions and thread block size you're using in your full implementation?

prateekshukla1108 avatar Jun 12 '25 09:06 prateekshukla1108

Hello @prateekshukla1108 you can see the whole setup in the above gist.

simveit avatar Jun 12 '25 09:06 simveit

When you call local_partition it divides your tensor according to the thread layout you provide.

When your thread layout was configured as (1,1) or improperly sized local_partition would give each thread access to a narrow slice of the tensor then each thread would indeed only copy one element per row of its assigned slice

I think you need to use proper thread layouts with multiple threads and ensure the shared memory layouts were correctly configured for the transpose operation

prateekshukla1108 avatar Jun 12 '25 09:06 prateekshukla1108

I don't think it's related to that. A naive version of the kernel using same way of assigning thread layout but no SMEM works as expected

See here

simveit avatar Jun 12 '25 10:06 simveit

Looking at local_partition function definition in here, you will find index used to produce a coord into tile "tile.get_flat_coord(index)", in your case tile is (1,1):(0,0) layout, which means you can only get coord (0,0). This coord is later used in outter_partition to slice/index the partitioned tensor, meaning you will first partition your tensor to repetition of 1x1 tiles, then you slice/index to get the first 1x1 tile, so it looks reasonable that you only get 1st element copied.

// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index.
// This is typical at the Thread level where data is partitioned across repeated patterns of threads:
//   Tensor data = ...                                                            // (_16,_64)
//   Tensor thr_data = local_partition(data, Layout<Shape<_2,_16>>{}, thr_idx);   // ( _8, _4)
template <class Tensor, class LShape, class LStride, class Index,
          __CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor                     && tensor,
                Layout<LShape,LStride> const& tile,    // coord -> index
                Index                  const& index)   // index to slice for
{
  static_assert(is_integral<Index>::value);
  return outer_partition(static_cast<Tensor&&>(tensor),
                         product_each(shape(tile)),
                         tile.get_flat_coord(index));
}

I tried to change TileSizeX=TileSizeY=ThreadBlockX=ThreadBlockY=8, your program works correctly.

lijingticy22 avatar Jun 16 '25 04:06 lijingticy22

i will take a look. i need to study cute layouts closer. thanks @lijingticy22

simveit avatar Jun 17 '25 21:06 simveit