cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[FEA] Add INT8 support for sm_120 CollectiveBuilder (TmaWarpSpecialized)

Open changdong1687 opened this issue 2 months ago • 2 comments

Which component requires the feature?

CUTLASS C++

Feature Request

SM120 TmaWarpSpecialized builder supports for int_8

Hello, Currently, I am trying to write an INT8 GEMM kernel for the Blackwell architecture (targeting sm_120), but I am encountering a compilation error.

TL;DR

how can we write a int8_gemm cuda kernel using cutlass now

Environment

  • GPU: NVIDIA RTX 5090 (blackwell arch)
  • CUDA: 12.8
  • CUTLASS Version: 7817e47154d7869320f3fa6b409ec8c5e5958970

Description

I am trying to compile an INT8 GEMM kernel for the sm_120 architecture using the CUTLASS 3.x CollectiveBuilder. My policy struct is correctly configured to deduce ElementAccumulator as int32_t when the input type ElementAB is int8_t. Here is my core policy struct:

template <typename ElementAB_, typename ElementD_,
          template <typename, typename, typename> typename Epilogue_,
          typename TileShape, typename ClusterShape, typename KernelSchedule,
          typename EpilogueSchedule>
struct cutlass_3x_gemm_sm120 {
  using ElementAB = ElementAB_; // int8_t
  using LayoutA = cutlass::layout::RowMajor;
  static constexpr int AlignmentA =
      128 / cutlass::sizeof_bits<ElementAB>::value;

  using LayoutB = cutlass::layout::ColumnMajor;
  static constexpr int AlignmentB =
      128 / cutlass::sizeof_bits<ElementAB>::value;

  using ElementD = ElementD_; // e.g., bfloat16_t
  using LayoutD = cutlass::layout::RowMajor;
  static constexpr int AlignmentD = 
      128 / cutlass::sizeof_bits<ElementD_>::value;

  // Correctly deduce accumulator type
  using ElementAcc =
      typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
                                float>::type; // Becomes int32_t
  using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;

  // Use the deduced accumulator type for MMA
  using ElementAccumulator = ElementAcc; // Correctly set to int32_t

  using ElementCompute = float;
  using ElementC = void;
  using LayoutC = cutlass::layout::RowMajor;
  static constexpr int AlignmentC = AlignmentD;
  
  using EVTCompute = typename Epilogue::EVTCompute;

  using CollectiveEpilogue =
      typename cutlass::epilogue::collective::CollectiveBuilder<
          cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
          ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
          ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
          ElementD, LayoutD, AlignmentD, EpilogueSchedule,
          EVTCompute>::CollectiveOp;

  using CollectiveMainloop =
      typename cutlass::gemm::collective::CollectiveBuilder<
          cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
          LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
          ElementAccumulator, TileShape, ClusterShape,
          cutlass::gemm::collective::StageCountAutoCarveout<...>,
          KernelSchedule>::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
Compling Args
'nvcc': [
                    '-O3',                        # optimization level
                    '-std=c++17',
                    # 添加torch的默认NVCC标志
                    '-D__CUDA_NO_HALF_OPERATORS__',
                    '-D__CUDA_NO_HALF_CONVERSIONS__',
                    '-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
                    '-D__CUDA_NO_HALF2_OPERATORS__',
                    '--expt-relaxed-constexpr',
                    '--expt-extended-lambda',      
                    '-U__CUDA_NO_HALF_OPERATORS__',
                    '-U__CUDA_NO_HALF_CONVERSIONS__',
                    '-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
                    '-U__CUDA_NO_HALF2_OPERATORS__',
                    '-gencode=arch=compute_120a,code=sm_120a', # Blackwell
            ]

ERROR

When I instantiate this policy with ElementAB = int8_t, ElementD = cutlass::bfloat16_t, and KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto, the compilation fails with the following error:

cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl(82): error: static assertion failed with "SM120 TmaWarpSpecialized builder currently only supports F8F6F4 MMA."
    static_assert(detail::is_sm10x_f8f6f4_element<ElementA>() && detail::is_sm10x_f8f6f4_element<ElementB>(),
    ^
          detected during:
            instantiation of class "cutlass::gemm::collective::CollectiveBuilder<cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementA, GmemLayoutATag, AlignmentA, ElementB, GmemLayoutBTag, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, BuilderScheduleTag, ...>> [with ElementA=int8_t, ..., ElementB=int8_t, ..., ElementAccumulator=int32_t, ...]"

Analysis

The static_assert clearly indicates that the sm_120 builder (TmaWarpSpecialized) is currently implemented only for the new F8/F6/F4 data types and does not yet support int8_t.

The instantiation trace confirms my policy is correctly passing ElementA=int8_t, ElementB=int8_t, and ElementAccumulator=int32_t to the CollectiveBuilder.

Question

How can we write a int8_gemm kernel using cutlass3 now for emergency

changdong1687 avatar Oct 24 '25 09:10 changdong1687

You are correct that int8_t kernel is not planned/supported in SM120. In order to hack an int8_t SM120 kernel, you could reuse the same mainloop collective and kernel implementations. You would need to write their own atom trait for int8_t kernel that ultimately wraps around the int8_t ptx instructions and modify those static asserts throughout CUTLASS that prevent int8_t kernel from compiling for SM120.

Junkai-Wu avatar Oct 29 '25 00:10 Junkai-Wu

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Nov 28 '25 01:11 github-actions[bot]