[BUG] Modifying the block/warptile shapes and the output datatype in the unit test causes the tests to fail.
Describe the bug I modified the block/warptile shapes and the output datatype in https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu, and found some shapes cause the tests to fail. I modified the ElementOutput to cutlass::half_t and tested various block/warptile shapes. While some shapes passed, others failed.
Passed block/warptile shapes, for example, are <32,128,64>/<16,64,64>, <32,128,64>/<32,32,64>, <32,128,64>/<32,64,64>, <32,256,64>/<32,64,64>, <64,32,64>/<32,32,64>, <64,32,64>/<64,32,64>, etc.
Failed block/warptile shapes, I found, are <16,16,64>/<16,16,64>, <16,128,64>/<16,128,64>, <16,256,64>/<16,128,64>, <32,16,64>/<32,16,64>, <32,128,64>/<16,128,64>, <32,128,64>/<32,128,64>, <32,256,64>/<16,128,64>, <32,256,64>/<32,128,64>, <64,16,64>/<64,16,64>, <64,128,64>/<16,128,64>, <64,128,64>/<32,128,64>, <64,128,64>/<64,128,64>, <64,256,64>/<16,128,64>, <64,256,64>/<32,128,64>, <64,256,64>/<64,128,64>, <128,16,64>/<128,16,64>, <128,32,64>/<128,32,64>, <128,64,64>/<128,64,64>, <128,128,64>/<16,128,64>, <128,128,64>/<32,128,64>, <128,128,64>/<64,128,64>, <128,256,64>/<128,128,64>, <256,32,64>/<128,32,64>, <256,64,64>/<128,64,64>, <256,128,64>/<32,128,64>, <256,128,64>/<64,128,64>, <256,128,64>/<128,128,64>.
Steps/Code to reproduce bug
CUTLASS_TEST_L1(SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x64_32x128x64, {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
} )
Your observation is correct. Not all combinations are supported. Those listed in the unit tests and profiler generator.py are common ones. Some combinations that are not listed there are also functional. But lots are not supported, especially some very small or very big m or n dimensions are not supported.
Usually, we want to have 4 or 8 warps which means threadblock.m/warp.m * threadblock.n/warp.n = 4 or 8. threadblock.k is usually the same as warp.k. To get good performance, threadblock tile and warp tile are preferred to be squareish.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.