cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] Missing copy_if Implementation for AutoVectorizingCopyWithAssumedAlignment

Open kitecats opened this issue 7 months ago • 2 comments

Describe the bug When using Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<sizeof(uint128_t)*8>, half with copy_if function for data loading, the program fails to execute properly. The issue occurs specifically when trying to perform conditional vectorized copies with half precision data using 128-bit alignment.

Steps/Code to reproduce bug

template <class TiledCopy_, int NumElemPerBlcok, int kNumElemPerThread = 8>
__global__ void vector_add_local_tile_multi_elem_per_thread_half_prefill_128bit(
    half *z, int num, const half *x, const half *y, const half a, const half b,
    const half c) {
  using namespace cute;

  int bidx = blockIdx.x;
  int thridx = threadIdx.x;

  Tensor Pre = make_identity_tensor(shape(num));
  Tensor Z   = make_tensor(make_gmem_ptr(z), make_shape(num));
  Tensor X   = make_tensor(make_gmem_ptr(x), make_shape(num));
  Tensor Y   = make_tensor(make_gmem_ptr(y), make_shape(num));

  
  Tensor BlcokPre =  local_tile(Pre, make_shape(Int<NumElemPerBlcok>{}), make_coord(bidx));
  Tensor BlcokZ =  local_tile(Z, make_shape(Int<NumElemPerBlcok>{}), make_coord(bidx));
  Tensor BlcokX =  local_tile(X, make_shape(Int<NumElemPerBlcok>{}), make_coord(bidx));
  Tensor BlcokY =  local_tile(Y, make_shape(Int<NumElemPerBlcok>{}), make_coord(bidx));


  Tensor tzR = make_tensor<half>(make_shape(Int<1>{},Int<kNumElemPerThread>{}));
  Tensor txR = make_tensor<half>(make_shape(Int<1>{},Int<kNumElemPerThread>{}));
  Tensor tyR = make_tensor<half>(make_shape(Int<1>{},Int<kNumElemPerThread>{}));

  clear(tzR);

  TiledCopy_ tiled_copy;
  auto thr_copy = tiled_copy.get_slice(thridx);
  
  auto pre_r = thr_copy.partition_S(BlcokPre);

  auto tzr = thr_copy.partition_D(BlcokZ);
  auto tzR_view = thr_copy.retile_S(tzR); 

  auto txr = thr_copy.partition_S(BlcokX);
  auto txR_view = thr_copy.retile_D(txR);    

  auto tyr = thr_copy.partition_S(BlcokY);
  auto tyR_view = thr_copy.retile_D(tyR);   



  auto pre_ = [&](auto... coords) { return cute::elem_less(pre_r(Int<kNumElemPerThread-1>{}), shape(num)); }; 
  copy_if(tiled_copy, pre_, txr, txR_view);
  copy_if(tiled_copy, pre_, tyr, tyR_view);


  half2 a2 = {a, a};
  half2 b2 = {b, b};
  half2 c2 = {c, c};

  auto tzR2 = recast<half2>(tzR);
  auto txR2 = recast<half2>(txR);
  auto tyR2 = recast<half2>(tyR);

#pragma unroll
  for (int i = 0; i < size(tzR2); ++i) {
    // two hfma2 instruction
    tzR2(i) = txR2(i) * a2 + (tyR2(i) * b2 + c2);
  }

  auto tzRx = recast<half>(tzR2);

  // STG.128
  copy(tiled_copy, tzR_view, tzr);
};


void forward_vector_add(torch::Tensor X, torch::Tensor Y, torch::Tensor Z, const int num) 
{
  using namespace cute;
  constexpr int NumThreadPerBlcok = 32;
  constexpr int kNumElemPerThread = 8;
  constexpr int NumElemPerBlcok= NumThreadPerBlcok * kNumElemPerThread;

  int type_size = sizeof(uint128_t);

  using CopyAtom = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<sizeof(uint128_t)*8>, half>;

  using TiledCopy = decltype(make_tiled_copy(
      CopyAtom{},
      make_layout(
          Shape<Int<NumThreadPerBlcok>>{},
          GenRowMajor{}),
      make_layout(Shape<Int<kNumElemPerThread>>{}, GenRowMajor{})));



  half a = __float2half(1.0f), b = __float2half(1.0f), c = __float2half(0.0f);

  dim3 grid(ceil_div(num, NumElemPerBlcok));
  dim3 block(NumThreadPerBlcok);
  vector_add_local_tile_multi_elem_per_thread_half_prefill_128bit<TiledCopy, NumElemPerBlcok, kNumElemPerThread><<<grid, block>>>(
      reinterpret_cast<half *>(Z.data_ptr()),
      num,
      reinterpret_cast<half *>(X.data_ptr()),
      reinterpret_cast<half *>(Y.data_ptr()),
      a,b,c);
}

The above code can compile successfully, but the following error occurs during runtime:


Thread 1 "vector_add_main" received signal CUDA_EXCEPTION_6, Warp Misaligned Address.
0x00007f623e85a780 in void vector_add_local_tile_multi_elem_per_thread_half_prefill_128bit<cute::TiledCopy<cute::Copy_Atom<cute::AutoVectorizingCopyWithAssumedAlignment<128>, __half>, cute::Layout<cute::tuple<cute::C<32>, cute::C<8> >, cute::tuple<cute::C<8>, cute::C<1> > >, cute::tuple<cute::C<256> > >, 256, 8>(__half*, int, __half const*, __half const*, __half, __half, __half)<<<(1,1,1),(32,1,1)>>> ()
cuda block (0, 0, 0) thread (0, 0, 0)

Additional context When I added the corresponding copy_if function in the cute/algorithm/copy.hpp file, the program executed normally.

template <int MaxVecBits, class... Args,
          class PrdTensor,
          class SrcEngine, class SrcLayout,
          class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
     PrdTensor                                           const& pred,
     Tensor<SrcEngine, SrcLayout>                        const& src,
     Tensor<DstEngine, DstLayout>                             & dst)
{
  constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
  constexpr int align_bits  = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
  static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
  constexpr int vec_bits    = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);

  if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) {
    // If more than one element vectorizes to 8bits or more, then recast and copy
    using VecType = uint_bit_t<vec_bits>;
    // Preserve volatility
    using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
    using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType       volatile, VecType      >;

    // Recast
    Tensor src_v = recast<SrcVecType>(src);
    Tensor dst_v = recast<DstVecType>(dst);
    return copy_if(pred, src_v, dst_v);
  } else {
    return copy_if(pred, src, dst);
  }
}

template <int MaxVecBits, class... Args,
          class PrdTensor,
          class SrcEngine, class SrcLayout,
          class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>, Args...> const&,
     PrdTensor                                                                  const& pred,
     Tensor<SrcEngine, SrcLayout>                                               const& src,
     Tensor<DstEngine, DstLayout>                                                 & dst)
{
  return copy_if(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>{}, pred, src, dst);
}

So, I suspect the issue was caused by the lack of a corresponding copy_if function.

kitecats avatar May 31 '25 09:05 kitecats

This is a known deficiency and will be addressed soon.

I'll note that the above implementation of copy_if() is buggy because PrdTensor (a lambda, not a tensor) is not transformed symmetrically along with SrcTensor and DstTensor. The domains of SrcTensor and DstTensor are transformed without transforming the domain of PrdTensor equivalently so the coordinates of each do not necessarily correspond.

There is an update coming shortly that improves this design (with examples) so copy_if can be written correctly and more reliably.

ccecka avatar May 31 '25 16:05 ccecka

Thank you for explaining the issue and the planned resolution. We understand the current limitations of copy_if() and will await the updated version for further testing.

kitecats avatar Jun 01 '25 01:06 kitecats

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jul 03 '25 18:07 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Oct 11 '25 12:10 github-actions[bot]