cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

Fix potential out-of-bounds access in 3.x grouped gemm kernel

Open kongroo opened this issue 1 year ago • 5 comments

I have identified two potential out-of-bounds access issues in the 3.x version of the grouped gemm kernel: (specifically this line of code problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); )

  1. Initialization of problem_shape_MNKL before entering the tile loop: If the total number of tiles is very low, some threadblocks may start with an invalid initial tile, leading to out-of-bounds access.
  • This can be reproduced by compute-sanitizer examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=1024 --n=1024 --k=640 --groups=3
  • The error message showed by compute-sanitizer is as followes. It says an invalid read at 12 bytes (which is th size of UnderlyingProblemShape) before the nearest allocation, indicating that it is caused by a index of -1 access to the problem shape array
========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 4 bytes
=========     at void cutlass::device_kernel<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90ArrayTmaGmmaWarpSpecialized<(int)9, cute::tuple<cute::C<(int)2>, cute::C<(int)2>, cute::C<(int)1>>, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>, cute::tuple<cute::C<(int)256>, cute::C<(int)128>, cute::C<(int)64>>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::float_e5m2_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x128x32_F32E4M3E5M2_SS_TN<(cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)1>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity>, cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<cutlass::epilogue::collective::DefaultEpilogueArray<cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cutlass::epilogue::thread::LinearCombination<cutlass::half_t, (int)1, float, float, (cutlass::epilogue::thread::ScaleType::Kind)0, (cutlass::FloatRoundStyle)2, cutlass::half_t>, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized>>, void, void>>(T1::Params)+0x1bf0
=========     by thread (64,0,0) in block (48,0,0)
=========     Address 0x7fccd83e03f4 is out of bounds
=========     and is 12 bytes before the nearest allocation at 0x7fccd83e0400 of size 36 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
  • The solution is to early exit if the initial tile is invalid
  1. Updating problem_shape_MNKL at the end of the tile loop. The last iteration always involves an invalid tile, which might lead to out-of-bounds access. This issue may be unnoticed as it could be optimized away during compilation, but adding some logic between the updating and the end of the loop body can trigger the out-of-bounds access, which is hard to debug.
  • This can also be reproduced by running example 57 with compute-sanitizer, if you add a line of __syncwarp(); after the updating of problem_shape_MNKL as followes:
          if constexpr (IsGroupedGemmKernel) {
            problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
          }
          __syncwarp();  // adding this line triggers out-of-bounds access of the previous line
  • The solution is to check tile is valid before updating problem_shape_MNKL

kongroo avatar May 21 '24 04:05 kongroo

Can you guys take a look maybe? @hwu36 @yzhaiustc Thanks.

wenlei-bao avatar May 28 '24 23:05 wenlei-bao

We actually need to make some change to your code. The problem you pointed out is correct.

hwu36 avatar May 28 '24 23:05 hwu36

Sounds good. cc @kongroo

wenlei-bao avatar May 28 '24 23:05 wenlei-bao

We actually need to make some change to your code. The problem you pointed out is correct.

@hwu36 Hello, I was wondering if there might be any updates on this issue?

kongroo avatar Jun 27 '24 02:06 kongroo

We will upstream 3.5.1 first

hwu36 avatar Jun 27 '24 02:06 hwu36