cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

Add support for mixed 4-bit/8-bit data types GEMM

Open alexsamardzic opened this issue 11 months ago • 35 comments

alexsamardzic avatar Mar 19 '24 14:03 alexsamardzic

More to come here: support for U4, support for generator in the CUTLASS library, etc. Still, opening PR to solicit feedback for S8/S4 and S4/S8 GEMMs that are now available; in particular, I'm interested in eventual suggestions for a faster approach to S4->S8 conversion.

alexsamardzic avatar Mar 19 '24 14:03 alexsamardzic

Added more tests.

alexsamardzic avatar Mar 20 '24 10:03 alexsamardzic

Added generator support for S8/S4 and S4/S8.


AFAIK, implementing generator support for given operation is not specifically documented, so I want to clarify the steps I've taken here. Basically, I've copied code from GenerateSM80_TensorOp_16832_TN method into GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b), and then made some changes:

  • Obviously, I've changed math_instructions assignments according to data types actually used for mixed input data types.
  • I'm not sure from where smem_usage = 164 in GenerateSM80_TensorOp_16832_TN comes, and this variable is not further used anyway, so I skipped it in new methods.
  • I've used two sets of alignment constraints. The alignments for operands A and B are the same, but operand C (and thus the result too) could be either 32-bit or 8-bit. The code at the end of the mixed input methods, within the last if statement is handling the later case, and alignments are changed here accordingly. (Note that for GenerateSM80_TensorOp_16816_mixed_input_upcast_(a|b) there are snippets of code at the end of methods doing alike thing, but they're slightly different from each other, and also from what I did here.)
  • The tile_descriptions were initially copied from GenerateSM80_TensorOp_16832_TN, and then I would make sure that all the relevant kernels would be compiled (through adding CUTLASS_LIBRARY_KERNELS="*i16832gemm*" to the CMake command line), and would remove tiles that would fail to compile.

I did the verification as @manishucsd suggested here: As mentioned above, I did the build with all the relevant kernels included, and then I verified that cutlass_profiler would run all the tile variations that are specified in GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b). Note that the profiler would produce Disposition: Incorrect for all the kernels with 8-bit output; I suppose it's related to saturation - I'm not sure if I should actually come up with applying saturation somehow for this combination of input data types?


Overall, this PR now contains everything that I intended to do for S4/S8 and S8/S4 GEMM, and it's ready for review. It has grown somewhat large, so I'd suggest to have it reviewed and eventually merged, and then I can add U4/U8 and U8/U4, and maybe U4/S8 and S8/U4, support in follow-up PR(s).

alexsamardzic avatar Mar 22 '24 15:03 alexsamardzic

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

andrewor14 avatar Mar 22 '24 19:03 andrewor14

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.

alexsamardzic avatar Mar 22 '24 19:03 alexsamardzic

@manishucsd, @hwu36: Would it be possible for someone to review this PR (and eventually #1350 too)? These should not be controversial, are needed by PyTorch, and for this one I'd like to proceed with another PR to add other 4-bit/8-bit integer combinations that make sense.

alexsamardzic avatar Apr 02 '24 06:04 alexsamardzic

working on it now.

hwu36 avatar Apr 18 '24 18:04 hwu36

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.

Great job! How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

Hongbosherlock avatar May 06 '24 09:05 Hongbosherlock

How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.

alexsamardzic avatar May 06 '24 15:05 alexsamardzic

How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

Hongbosherlock avatar May 08 '24 09:05 Hongbosherlock

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see using Gemm = cutlass::gemm::device::GemmUniversal... template instantiations in the test cases added by this PR into test/unit/gemm/device directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.

alexsamardzic avatar May 08 '24 13:05 alexsamardzic

On a quick look, your strides may be wrong.

alexsamardzic avatar May 13 '24 10:05 alexsamardzic

On a quick look, your strides may be wrong.

Thank you for your prompt reply. I don't know much about this parameter, and I can't find many references. Could you give me some more details? Thank you very much.

zkf331 avatar May 13 '24 12:05 zkf331

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see using Gemm = cutlass::gemm::device::GemmUniversal... template instantiations in the test cases added by this PR into test/unit/gemm/device directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

Hongbosherlock avatar May 16 '24 03:05 Hongbosherlock

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

No, s4 values should be packed, two values per byte.

alexsamardzic avatar May 16 '24 10:05 alexsamardzic

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

No, s4 values should be packed, two values per byte.

Thanks for your help ! I can get correct result now. but I have another question: Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments? I am not sure how to do this. Maybe here is an example what I want to do :https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131

Hongbosherlock avatar May 23 '24 09:05 Hongbosherlock

Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments?

If matrix E is really MxN (i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there is examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example demonstrating how to use EVT epilogues, you'd have to remove everything related to Bias/C1 matrices in this example, to use C2 as your matrix E, and then to replace cutlass::plus with cutlass::multiplies in using Compute2 = ... (also, you should take care that all of the data types in the template instantiations are correctly specified).

alexsamardzic avatar May 23 '24 10:05 alexsamardzic

Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments?

If matrix E is really MxN (i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there is examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example demonstrating how to use EVT epilogues, you'd have to remove everything related to Bias/C1 matrices in this example, to use C2 as your matrix E, and then to replace cutlass::plus with cutlass::multiplies in using Compute2 = ... (also, you should take care that all of the data types in the template instantiations are correctly specified).

Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int4
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32
    // outputs
    //     mat [M, N]            fp32

That is: (alphaCol x alphaRow) * (A x B) I think here is a s8/s8 example(A and B are all int8):https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131, which also uses EVT, and the inputs are passed from here I wonder if I could use the same EVT code and using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel> with this s4/s8 GEMM.

Hongbosherlock avatar May 28 '24 07:05 Hongbosherlock

Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int4
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32
    // outputs
    //     mat [M, N]            fp32

Well, that's not element-wise multiplication with MxN tensor, as stated initially... The example from TensorRT-LLM that you're linking to is not using EVT, but some kind of their CUTLASS extension instead (probably because CUTLASS had no EVT support at that time). I don't have time to look into these, so I can only point you again to examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example, but now you should take a look into how Bias is applied, and ignore everything C1/C2 related. With cutlass::plus replaced with cutlass::multiplies, your alphaCol will be applied the same way, and once you understand how this works, it should be easy to interchange offsets and apply alphaRow accordingly too.

alexsamardzic avatar May 28 '24 13:05 alexsamardzic

Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int4
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32
    // outputs
    //     mat [M, N]            fp32

Well, that's not element-wise multiplication with MxN tensor, as stated initially... The example from TensorRT-LLM that you're linking to is not using EVT, but some kind of their CUTLASS extension instead (probably because CUTLASS had no EVT support at that time). I don't have time to look into these, so I can only point you again to examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example, but now you should take a look into how Bias is applied, and ignore everything C1/C2 related. With cutlass::plus replaced with cutlass::multiplies, your alphaCol will be applied the same way, and once you understand how this works, it should be easy to interchange offsets and apply alphaRow accordingly too.

I got errors with ElementB = cutlass::int4b_t when I tried to follow this example:

    using EVTKernelStreamK =
        typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
        ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
        ElementC, LayoutC, AlignmentC,
        ElementAccumulator,
        ElementCompute,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages
    >::GemmKernel;

error message:

cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h(58): error: incomplete type is not allowed detected during: instantiation of class "cutlass::gemm::warp::MmaTensorOpPolicy<Operator_, OpDelta_> [with Operator_=cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, int32_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, OpDelta_=cutlass::MatrixShape<1, 1>]"

The complete log is here. When I set ElementB = int8_t , It's OK. Is it because DefaultGemmWithVisitor doesn't support s8/s4? If it's possible pointers in the correct direction would be greatly appreciated, thanks!

Hongbosherlock avatar Jun 03 '24 13:06 Hongbosherlock

I got errors with ElementB = cutlass::int4b_t when I tried to follow this example:

Are you using CUTLASS main, or the branch from this PR?

alexsamardzic avatar Jun 03 '24 13:06 alexsamardzic

I got errors with ElementB = cutlass::int4b_t when I tried to follow this example:

Are you using CUTLASS main, or the branch from this PR?

I'm using this PR branch: alexsamardzic:add-mixed-4bit-8bit-gemm, I can get the right s4/s8 GEMM result with using Gemm = cutlass::gemm::device::GemmUniversal as added in the test/unit/gemm/device.

// ok
using Gemm = cutlass::gemm::device::GemmUniversal<
  ElementA,                // ElementA
  cutlass::layout::RowMajor,       // LayoutA
  ElementB,                // ElementB
  cutlass::layout::ColumnMajor,    // LayoutB
  ElementOutput,                         // ElementOutput
  cutlass::layout::RowMajor,       // LayoutOutput
  ElementAccumulator,                         // ElementAccumulator
  cutlass::arch::OpClassTensorOp,  // tag indicating Tensor Cores
  cutlass::arch::Sm80,  // tag indicating target GPU compute architecture
  cutlass::gemm::GemmShape<128, 128, 64>,
  cutlass::gemm::GemmShape<64, 64, 64>,
  cutlass::gemm::GemmShape<16, 8, 32>,
  cutlass::epilogue::thread::LinearCombination<
  ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
  ElementAccumulator, ElementAccumulator>,
  cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  4,   // Stages
  16,  // AlignmentA
  32,  // AlignmentB
  cutlass::arch::OpMultiplyAddMixedInputUpcast,
  cutlass::ComplexTransform::kNone,
  cutlass::ComplexTransform::kNone
   >;


But when I try to useGemmUniversalBase or GemmUniversalAdapter which need to specify a GemmKernel, It couldn't work with cutlass::int4b_t, while int8_t could.

// get errors
    using EVTKernelStreamK =
        typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
        ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
        ElementC, LayoutC, AlignmentC,
        ElementAccumulator,
        ElementCompute,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages
    >::GemmKernel;   //  where is the key I think

    using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;

I don’t know much about warp-level computation. This PR modifies file include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h, but the errors is related to include/cutlass/gemm/warp/mma_tensor_op_policy.h(58)and include/cutlass/gemm/warp/mma_tensor_op.h(108) as you can see in the log.

Hongbosherlock avatar Jun 04 '24 02:06 Hongbosherlock

But when I try to useGemmUniversalBase or GemmUniversalAdapter which need to specify a GemmKernel, It couldn't work with cutlass::int4b_t, while int8_t could.

Can you post your full code here?

alexsamardzic avatar Jun 04 '24 13:06 alexsamardzic

But when I try to useGemmUniversalBase or GemmUniversalAdapter which need to specify a GemmKernel, It couldn't work with cutlass::int4b_t, while int8_t could.

Can you post your full code here?

#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "gemm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>

#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/util/reference/host/gemm.h>
#include <cutlass/util/reference/host/tensor_compare.h>
#include <cutlass/util/reference/host/tensor_copy.h>
#include <cutlass/util/reference/host/tensor_fill.h>
#include <cutlass/util/tensor_view_io.h>

#include <cutlass/gemm/device/gemm_universal_with_broadcast.h>
#include <cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h>

#include <cutlass/util/reference/host/error_metrics.h>
#include <cutlass/util/reference/host/tensor_foreach.h>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>


torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
    torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
    auto M = A.size(0);
    auto N = B.size(0);
    auto K = A.size(1);  // 4bit packing is on the columns
    auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device())); 

    // A matrix configuration
    using         ElementA         = int8_t;                                    // Element type for A matrix operand
    using         LayoutA          = cutlass::layout::RowMajor;                        // Layout type for A matrix operand
    constexpr int AlignmentA       = 128 / cutlass::sizeof_bits<ElementA>::value;      // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

    // B matrix configuration
    using         ElementB         = cutlass::int4b_t;                                  // Element type for B matrix operand
    using         LayoutB          = cutlass::layout::ColumnMajor;                        // Layout type for B matrix operand
    constexpr int AlignmentB       = 128 / cutlass::sizeof_bits<ElementB>::value;      // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

    // C1/C2/D matrix configuration
    using         ElementC         = float;        //cutlass::half_t;                   // Element type for C matrix operands
    using         LayoutC          = cutlass::layout::RowMajor;                        // Layout type for C matrix operands
    constexpr int AlignmentC       = 128 / cutlass::sizeof_bits<ElementC>::value;      // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)

    // Output matrix configuration
    using         ElementOutput    = float;                                          // Element type for output matrix operands
    using         LayoutOutput     = cutlass::layout::RowMajor;                        // Layout type for output matrix operands
    // constexpr int AlignmentOutput  = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)

    // Multiply-accumulate blocking/pipelining details
    using ElementAccumulator  = int32_t;                                 // Element type for internal accumulation
    using ElementCompute      = float;  //cutlass::half_t;                          // Element type for compute
    using ArchTag             = cutlass::arch::Sm80;                      // Tag indicating the minimum SM that supports the intended feature
    using OperatorClass       = cutlass::arch::OpClassTensorOp;           // Operator class tag
    using ThreadblockShape    = cutlass::gemm::GemmShape<128, 128, 64>;   // Threadblock-level tile size (concept: GemmShape)
    using WarpShape           = cutlass::gemm::GemmShape<64, 64, 64>;     // Warp-level tile size (concept: GemmShape)
    using InstructionShape    = cutlass::gemm::GemmShape<16, 8, 32>;      // Instruction-level tile size (concept: GemmShape)
    constexpr int NumStages   = 4;                                        // Number of global->shared pipeline stages used in the GEMM mainloop
    constexpr int EVTEpilogueStages = 1;   

    // StreamK device GEMM implementation type with EVT
    using namespace cute;

    using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    ThreadblockShape, 
    WarpShape, 
    ElementC,  
    AlignmentC,  //4
    EVTEpilogueStages
    >;

    using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

    // alphaCol    [M, 1]    fp32
    using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<int32_t, _1, _0>  // StrideMNL
    >;

    // alphaRow    [1, N]    fp32
    using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<_0, _1, int32_t>  // StrideMNL
    >;

    // mul
    using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementCompute, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaCol * accumulator
    using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute0,
        V1Broadcast,
        Accum>;

    // mul
    using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementOutput, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaRow * alphaCol * accumulator
    using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute1,
        V2Broadcast,
        EVTCompute0>;

    using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
        OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
        cute::Stride<int64_t, _1, int64_t> // StrideMNL
    >;

    using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
        StoreD,
        EVTCompute1>;

    using EVTKernelStreamK =
        typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
        ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
        ElementC, LayoutC, AlignmentC,
        ElementAccumulator,
        ElementCompute,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
        NumStages,
        cutlass::arch::OpMultiplyAddSaturate,
        EVTEpilogueStages
    >::GemmKernel;

    using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;

    // Populates a DeviceGemmStreamK::Arguments structure from the given commandline options

    // Ensure the input tensors are in the correct device and layout
    auto tensor_a = A.contiguous();
    auto tensor_b = B.contiguous();
    auto tensor_v1 = alphaCol.contiguous();
    auto tensor_v2 = alphaRow.contiguous();
    auto tensor_d = D.contiguous();                                                                                              // EVTD

    typename EVTD::Arguments callback_args{
        {
            {
                {
                    {},                                                                                                          // Accum
                    {tensor_v1.data_ptr<ElementC>(), ElementC(0), {int32_t(M), _1{}, _0{}}},                                    // V1 Broadcast
                    {}                                                                                                           // Compute0
                },                                                                                                             // EVTCompute0
                {tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}},                                      // V2 Broadcast
                {}                                                                                                             // Compute1
            },                                                                                                               // EVTCompute1
            {}                                                                                                               // Compute2
        },                                                                                                                // EVTCompute2
        {tensor_d.data_ptr<ElementC>(), {int32_t{N}, _1{}, int32_t{M*N}}}                                                                  // D
    };                                                                                                                   // EVTD
    
                  
    using GemmCoord = cutlass::gemm::GemmCoord;
    cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
    int batch_count = 1;
    // Construct Gemm ProblemSize with user defined output size
    // cutlass::gemm::GemmCoord problem_size = {M, N, K};
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    int64_t stride_A = M * K;
    int64_t stride_B = N * K;
    // int64_t stride_C = M * N;
    // int64_t stride_D = M * N;                                                                                               // EVTD
    int     avail_sms = -1;
    typename DeviceGemmStreamK::Arguments arguments(
        mode,                                     // universal mode
        problem_size,                             // problem_size
        batch_count,                              // batch count / splitk slices
        callback_args,                            // argument of EVT callbacks
        tensor_a.data_ptr<ElementA>(),            // ptr_A
        (cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(),            // ptr_B
        nullptr,                                  // ptr_C (unused)
        nullptr,                                  // ptr_D (unused)
        stride_A,                                 // batch_stride_A
        stride_B,                                 // batch_stride_B
        0,                                        // batch_stride_C (unused)
        0,                                        // batch_stride_D (unused)
        tensor_a.stride(0),                       // stride_a
        tensor_b.stride(0),                       // stride_b
        0,                                        // stride_c (unused)
        0);                                      // stride_d (unused)
                            

    DeviceGemmStreamK gemm_op;
    
    auto stream = at::cuda::getCurrentCUDAStream(A.get_device());

    // Using the arguments, query for extra workspace required for matrix
    // multiplication computation
    size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);

    // Allocate workspace memory
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    // Check the problem size is supported or not
    cutlass::Status status = gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot implement");
    }

    // Initialize CUTLASS kernel with arguments and workspace pointer
    status = gemm_op.initialize(arguments, workspace.get(), stream);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot initialize");
    }

    status = gemm_op(stream);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot run");
    }

    return tensor_d;
}

I got errors about incomplete type is not allowed.

Hongbosherlock avatar Jun 04 '24 13:06 Hongbosherlock

This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?

alexsamardzic avatar Jun 04 '24 14:06 alexsamardzic

This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?

Hi @alexsamardzic , I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu#L114

you can add this example , then complie and run it:

# cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80

# cutlass/build$ make 61_s4s8_gemm 

#  cutlass/build$ ./examples/61_s4s8_gemm/61_s4s8_gemm

when ElementB = int8_t, it seems ok. you can get the result: image

But when ElementB = cutlass::int4b_t, lots of compilation errors occur. image

I am really at a loss and would greatly appreciate any guidance or help you can provide. Thank you very much in advance for your time and assistance!

Hongbosherlock avatar Jun 05 '24 09:06 Hongbosherlock

I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu

Replace cutlass::arch::OpMultiplyAddSaturate with cutlass::arch::OpMultiplyAddMixedInputUpcast.

alexsamardzic avatar Jun 06 '24 11:06 alexsamardzic

I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu

Replace cutlass::arch::OpMultiplyAddSaturate with cutlass::arch::OpMultiplyAddMixedInputUpcast.

Works for me. Thanks!

Hongbosherlock avatar Jun 06 '24 14:06 Hongbosherlock

Works for me. Thanks!

Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.

alexsamardzic avatar Jun 06 '24 14:06 alexsamardzic

Works for me. Thanks!

Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.

Yea, that was a mistake. OpMultiplyAddMixedInputUpcast did appear in the test folder. I think I was a bit disoriented. There are too many arguments. I think CUTLASS is somewhat challenging for beginners. Do you have any recommended learning paths?

Hongbosherlock avatar Jun 07 '24 03:06 Hongbosherlock