Fix potential out-of-bounds access in 3.x grouped gemm kernel
I have identified two potential out-of-bounds access issues in the 3.x version of the grouped gemm kernel: (specifically this line of code problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); )
- Initialization of problem_shape_MNKL before entering the tile loop: If the total number of tiles is very low, some threadblocks may start with an invalid initial tile, leading to out-of-bounds access.
- This can be reproduced by
compute-sanitizer examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=1024 --n=1024 --k=640 --groups=3 - The error message showed by compute-sanitizer is as followes. It says an invalid read at 12 bytes (which is th size of UnderlyingProblemShape) before the nearest allocation, indicating that it is caused by a index of -1 access to the problem shape array
========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 4 bytes
========= at void cutlass::device_kernel<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm90ArrayTmaGmmaWarpSpecialized<(int)9, cute::tuple<cute::C<(int)2>, cute::C<(int)2>, cute::C<(int)1>>, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>, cute::tuple<cute::C<(int)256>, cute::C<(int)128>, cute::C<(int)64>>, cutlass::float_e4m3_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::float_e5m2_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM90_64x128x32_F32E4M3E5M2_SS_TN<(cute::GMMA::ScaleIn)1, (cute::GMMA::ScaleIn)1>>, cute::Layout<cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)1>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)2, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)8>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity>, cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<cutlass::epilogue::collective::DefaultEpilogueArray<cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cutlass::epilogue::thread::LinearCombination<cutlass::half_t, (int)1, float, float, (cutlass::epilogue::thread::ScaleType::Kind)0, (cutlass::FloatRoundStyle)2, cutlass::half_t>, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized>>, void, void>>(T1::Params)+0x1bf0
========= by thread (64,0,0) in block (48,0,0)
========= Address 0x7fccd83e03f4 is out of bounds
========= and is 12 bytes before the nearest allocation at 0x7fccd83e0400 of size 36 bytes
========= Saved host backtrace up to driver entry point at kernel launch time
- The solution is to early exit if the initial tile is invalid
- Updating problem_shape_MNKL at the end of the tile loop. The last iteration always involves an invalid tile, which might lead to out-of-bounds access. This issue may be unnoticed as it could be optimized away during compilation, but adding some logic between the updating and the end of the loop body can trigger the out-of-bounds access, which is hard to debug.
- This can also be reproduced by running example 57 with compute-sanitizer, if you add a line of
__syncwarp();after the updating of problem_shape_MNKL as followes:
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{});
}
__syncwarp(); // adding this line triggers out-of-bounds access of the previous line
- The solution is to check tile is valid before updating problem_shape_MNKL
Can you guys take a look maybe? @hwu36 @yzhaiustc Thanks.
We actually need to make some change to your code. The problem you pointed out is correct.
Sounds good. cc @kongroo
We actually need to make some change to your code. The problem you pointed out is correct.
@hwu36 Hello, I was wondering if there might be any updates on this issue?
We will upstream 3.5.1 first