[FEA] Add INT8 support for sm_120 CollectiveBuilder (TmaWarpSpecialized)
Which component requires the feature?
CUTLASS C++
Feature Request
SM120 TmaWarpSpecialized builder supports for int_8
Hello, Currently, I am trying to write an INT8 GEMM kernel for the Blackwell architecture (targeting sm_120), but I am encountering a compilation error.
TL;DR
how can we write a int8_gemm cuda kernel using cutlass now
Environment
- GPU: NVIDIA RTX 5090 (blackwell arch)
- CUDA: 12.8
- CUTLASS Version: 7817e47154d7869320f3fa6b409ec8c5e5958970
Description
I am trying to compile an INT8 GEMM kernel for the sm_120 architecture using the CUTLASS 3.x CollectiveBuilder.
My policy struct is correctly configured to deduce ElementAccumulator as int32_t when the input type ElementAB is int8_t.
Here is my core policy struct:
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm120 {
using ElementAB = ElementAB_; // int8_t
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
using ElementD = ElementD_; // e.g., bfloat16_t
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD =
128 / cutlass::sizeof_bits<ElementD_>::value;
// Correctly deduce accumulator type
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type; // Becomes int32_t
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
// Use the deduced accumulator type for MMA
using ElementAccumulator = ElementAcc; // Correctly set to int32_t
using ElementCompute = float;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = AlignmentD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<...>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
Compling Args
'nvcc': [
'-O3', # optimization level
'-std=c++17',
# 添加torch的默认NVCC标志
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__',
'-gencode=arch=compute_120a,code=sm_120a', # Blackwell
]
ERROR
When I instantiate this policy with ElementAB = int8_t, ElementD = cutlass::bfloat16_t, and KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto, the compilation fails with the following error:
cutlass/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl(82): error: static assertion failed with "SM120 TmaWarpSpecialized builder currently only supports F8F6F4 MMA."
static_assert(detail::is_sm10x_f8f6f4_element<ElementA>() && detail::is_sm10x_f8f6f4_element<ElementB>(),
^
detected during:
instantiation of class "cutlass::gemm::collective::CollectiveBuilder<cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementA, GmemLayoutATag, AlignmentA, ElementB, GmemLayoutBTag, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, BuilderScheduleTag, ...>> [with ElementA=int8_t, ..., ElementB=int8_t, ..., ElementAccumulator=int32_t, ...]"
Analysis
The static_assert clearly indicates that the sm_120 builder (TmaWarpSpecialized) is currently implemented only for the new F8/F6/F4 data types and does not yet support int8_t.
The instantiation trace confirms my policy is correctly passing ElementA=int8_t, ElementB=int8_t, and ElementAccumulator=int32_t to the CollectiveBuilder.
Question
How can we write a int8_gemm kernel using cutlass3 now for emergency
You are correct that int8_t kernel is not planned/supported in SM120. In order to hack an int8_t SM120 kernel, you could reuse the same mainloop collective and kernel implementations. You would need to write their own atom trait for int8_t kernel that ultimately wraps around the int8_t ptx instructions and modify those static asserts throughout CUTLASS that prevent int8_t kernel from compiling for SM120.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.