cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] Accuracy Error in CUTLASS GEMM operations.

Open yoon5862 opened this issue 7 months ago • 2 comments

Describe the bug The result of accuracy of GEMM operation in CUTLASS (TensorOp, Simt) does not fully match accuracy of cuBLAS GEMM result.

Steps/Code to reproduce bug

using GEMM = cutlass::gemm::device::Gemm<
                            /*ElementA_ = */ float,
                            /*LayoutA_ = */ cutlass::layout::RowMajor,
                            /*ElementB_ = */ float,
                            /*LayoutB_ = */ cutlass::layout::ColumnMajor,
                            /*ElementC_ = */ float,
                            /*LayoutC_ = */ cutlass::layout::RowMajor,
                            /*ElementAccumulator_ = */ float,
                            /*OperatorClass_ = */ cutlass::arch::OpClassSimt,
                            /*ArchTag_ = */ cutlass::arch::Sm50,
                            /*ThreadblockShape_ = */ cutlass::gemm::GemmShape<128, 128, 8>,
                            /*WarpShape_ = */ cutlass::gemm::GemmShape<32, 64, 8>,
                            /*InstructionShape_ = */ cutlass::gemm::GemmShape<1, 1, 1>,
                            /*EpilogueOutputOp_ = */ cutlass::epilogue::thread::LinearCombination<float, 1, float, float, cutlass::epilogue::thread::ScaleType::Nothing, cutlass::FloatRoundStyle::round_to_nearest>,
                            /*ThreadblockSwizzle_ = */ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>,
                            /*Stages = */ 2,
                            /*AlignmentA = */ 1,
                            /*AlignmentB = */ 1,
                            /*SplitKSerial = */ false,
                            /* operator = */ cutlass::arch::OpMultiplyAdd
                            >;
    GEMM gemm_operator;
    GEMM::Arguments args({m, n, k},
                        {reinterpret_cast<float *>(A.data_ptr()), k},
                        {reinterpret_cast<float *>(B.data_ptr()), k},
                        {reinterpret_cast<float *>(C.data_ptr()), n},
                        {reinterpret_cast<float *>(C.data_ptr()), n},
                        {float(1.0f), float(0.0f)});
    gemm_operator(args);

Expected behavior

I compare cuBLAS, and CUTLASS GEMM result using both bfloat16 and float32 precision to check CUTLASS get accurate result. The code shown above is a portion of CUTLASS implementation for float32 data type. I integrated the implementation into PyTorch via Pybind11 and verified its correctness by comparing the outputs with nn.Linear using absolute tolerance (atol) of 1e-6 and relative tolerance (rtol) of 1e-5. Typically, the absolute difference is around 1e-3. How should I improve this?

Environment details (please complete the following information):

  • CUDA 12.4
  • HW: RTX4070TI

Thank you.

yoon5862 avatar May 26 '25 11:05 yoon5862

When comparing the accuracy of different kernels, there are several potential implementation differences that can lead to variations in outputs. These include compiling with fast math, differences in how reductions occur with split-K or stream-K, implicitly using TF32, etc.

The most straightforward way to debug this is probably the following:

  • Use the CUTLASS profiler's reference checking or an example with a CUTLASS reference kernel to compare to the GEMM you instantiated. This will help you understand what to expect.

  • In the integration test for PyTorch, include a test that uses integers within an expected representable range. Differences in rounding, order of operations, etc., within the kernels should not result in different outcomes when using these integers. If this test produces incorrect output but the CUTLASS profiler does not, it would imply either the PyTorch reference is doing something unexpected, or the integration does not match the CUTLASS profiler.

  • Consider using CUBLAS_PEDANTIC_MATH when doing reference checking. It is designed for debugging these kinds of issues. Also, ensure PyTorch is not calling cuBLAS with an unexpected configuration (e.g., using TF32 instead).

depaulmillz avatar May 26 '25 20:05 depaulmillz

Thank you for reply. I test GEMM dimension (8192, 8192, 28672). When I test integers within an expected represntable range, accuracy of CUTLASS, and cuBLAS GEMM is matched. Also when I test floating point which range is (-0.0001, 0.0001), their accuracy is also matched!

yoon5862 avatar May 27 '25 01:05 yoon5862

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jul 04 '25 19:07 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Oct 02 '25 20:10 github-actions[bot]