cutlass
cutlass copied to clipboard
Add support for mixed 4-bit/8-bit data types GEMM
More to come here: support for U4
, support for generator in the CUTLASS library, etc. Still, opening PR to solicit feedback for S8
/S4
and S4/S8
GEMMs that are now available; in particular, I'm interested in eventual suggestions for a faster approach to S4
->S8
conversion.
Added more tests.
Added generator support for S8/S4 and S4/S8.
AFAIK, implementing generator support for given operation is not specifically documented, so I want to clarify the steps I've taken here. Basically, I've copied code from GenerateSM80_TensorOp_16832_TN
method into GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b)
, and then made some changes:
- Obviously, I've changed
math_instructions
assignments according to data types actually used for mixed input data types. - I'm not sure from where
smem_usage = 164
inGenerateSM80_TensorOp_16832_TN
comes, and this variable is not further used anyway, so I skipped it in new methods. - I've used two sets of alignment constraints. The alignments for operands A and B are the same, but operand C (and thus the result too) could be either 32-bit or 8-bit. The code at the end of the mixed input methods, within the last
if
statement is handling the later case, and alignments are changed here accordingly. (Note that forGenerateSM80_TensorOp_16816_mixed_input_upcast_(a|b)
there are snippets of code at the end of methods doing alike thing, but they're slightly different from each other, and also from what I did here.) - The
tile_descriptions
were initially copied fromGenerateSM80_TensorOp_16832_TN
, and then I would make sure that all the relevant kernels would be compiled (through addingCUTLASS_LIBRARY_KERNELS="*i16832gemm*"
to the CMake command line), and would remove tiles that would fail to compile.
I did the verification as @manishucsd suggested here: As mentioned above, I did the build with all the relevant kernels included, and then I verified that cutlass_profiler
would run all the tile variations that are specified in GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b)
. Note that the profiler would produce Disposition: Incorrect
for all the kernels with 8-bit output; I suppose it's related to saturation - I'm not sure if I should actually come up with applying saturation somehow for this combination of input data types?
Overall, this PR now contains everything that I intended to do for S4
/S8
and S8
/S4
GEMM, and it's ready for review. It has grown somewhat large, so I'd suggest to have it reviewed and eventually merged, and then I can add U4
/U8
and U8
/U4
, and maybe U4
/S8
and S8
/U4
, support in follow-up PR(s).
Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?
Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?
This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.
@manishucsd, @hwu36: Would it be possible for someone to review this PR (and eventually #1350 too)? These should not be controversial, are needed by PyTorch, and for this one I'd like to proceed with another PR to add other 4-bit/8-bit integer combinations that make sense.
working on it now.
Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?
This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.
Great job! How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic
How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic
The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.
How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic
The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.
I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md
I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md
These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see using Gemm = cutlass::gemm::device::GemmUniversal...
template instantiations in the test cases added by this PR into test/unit/gemm/device
directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.
On a quick look, your strides may be wrong.
On a quick look, your strides may be wrong.
Thank you for your prompt reply. I don't know much about this parameter, and I can't find many references. Could you give me some more details? Thank you very much.
I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md
These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see
using Gemm = cutlass::gemm::device::GemmUniversal...
template instantiations in the test cases added by this PR intotest/unit/gemm/device
directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.
I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?
I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?
No, s4 values should be packed, two values per byte.
I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?
No, s4 values should be packed, two values per byte.
Thanks for your help ! I can get correct result now. but I have another question:
Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B
, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C
. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments
? I am not sure how to do this.
Maybe here is an example what I want to do :https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131
Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM:
C = A·B
, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication :E * C
. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments
?
If matrix E
is really MxN
(i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there is examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
example demonstrating how to use EVT epilogues, you'd have to remove everything related to Bias
/C1
matrices in this example, to use C2
as your matrix E
, and then to replace cutlass::plus
with cutlass::multiplies
in using Compute2 = ...
(also, you should take care that all of the data types in the template instantiations are correctly specified).
Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM:
C = A·B
, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication :E * C
. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments
?If matrix
E
is reallyMxN
(i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there isexamples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
example demonstrating how to use EVT epilogues, you'd have to remove everything related toBias
/C1
matrices in this example, to useC2
as your matrixE
, and then to replacecutlass::plus
withcutlass::multiplies
inusing Compute2 = ...
(also, you should take care that all of the data types in the template instantiations are correctly specified).
Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:
// inputs
// A [M, K] int8
// B [N, K] int4
// alphaCol [M, 1] fp32
// alphaRow [1, N] fp32
// outputs
// mat [M, N] fp32
That is: (alphaCol x alphaRow) * (A x B)
I think here is a s8/s8 example(A and B are all int8):https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131, which also uses EVT, and the inputs are passed from here
I wonder if I could use the same EVT code and using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>
with this s4/s8 GEMM.
Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:
// inputs // A [M, K] int8 // B [N, K] int4 // alphaCol [M, 1] fp32 // alphaRow [1, N] fp32 // outputs // mat [M, N] fp32
Well, that's not element-wise multiplication with MxN
tensor, as stated initially... The example from TensorRT-LLM that you're linking to is not using EVT, but some kind of their CUTLASS extension instead (probably because CUTLASS had no EVT support at that time). I don't have time to look into these, so I can only point you again to examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
example, but now you should take a look into how Bias
is applied, and ignore everything C1
/C2
related. With cutlass::plus
replaced with cutlass::multiplies
, your alphaCol
will be applied the same way, and once you understand how this works, it should be easy to interchange offsets and apply alphaRow
accordingly too.
Thanks, I’m trying this, but it’s not going well currently. To make it clearer, what I want to do is exactly the following:
// inputs // A [M, K] int8 // B [N, K] int4 // alphaCol [M, 1] fp32 // alphaRow [1, N] fp32 // outputs // mat [M, N] fp32
Well, that's not element-wise multiplication with
MxN
tensor, as stated initially... The example from TensorRT-LLM that you're linking to is not using EVT, but some kind of their CUTLASS extension instead (probably because CUTLASS had no EVT support at that time). I don't have time to look into these, so I can only point you again toexamples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
example, but now you should take a look into howBias
is applied, and ignore everythingC1
/C2
related. Withcutlass::plus
replaced withcutlass::multiplies
, youralphaCol
will be applied the same way, and once you understand how this works, it should be easy to interchange offsets and applyalphaRow
accordingly too.
I got errors with ElementB = cutlass::int4b_t
when I tried to follow this example:
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel;
error message:
cutlass/include/cutlass/gemm/warp/mma_tensor_op_policy.h(58): error: incomplete type is not allowed detected during: instantiation of class "cutlass::gemm::warp::MmaTensorOpPolicy<Operator_, OpDelta_> [with Operator_=cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, int32_t, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, OpDelta_=cutlass::MatrixShape<1, 1>]"
The complete log is here.
When I set ElementB = int8_t
, It's OK. Is it because DefaultGemmWithVisitor
doesn't support s8/s4
?
If it's possible pointers in the correct direction would be greatly appreciated, thanks!
I got errors with
ElementB = cutlass::int4b_t
when I tried to follow this example:
Are you using CUTLASS main, or the branch from this PR?
I got errors with
ElementB = cutlass::int4b_t
when I tried to follow this example:Are you using CUTLASS main, or the branch from this PR?
I'm using this PR branch: alexsamardzic:add-mixed-4bit-8bit-gemm
, I can get the right s4/s8 GEMM result with using Gemm = cutlass::gemm::device::GemmUniversal
as added in the test/unit/gemm/device
.
// ok
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA, // ElementA
cutlass::layout::RowMajor, // LayoutA
ElementB, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
ElementOutput, // ElementOutput
cutlass::layout::RowMajor, // LayoutOutput
ElementAccumulator, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
cutlass::arch::Sm80, // tag indicating target GPU compute architecture
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
32, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
But when I try to useGemmUniversalBase
or GemmUniversalAdapter
which need to specify a GemmKernel
, It couldn't work with cutlass::int4b_t
, while int8_t
could.
// get errors
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel; // where is the key I think
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
I don’t know much about warp-level
computation. This PR modifies file include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
, but the errors is related to include/cutlass/gemm/warp/mma_tensor_op_policy.h(58)
and include/cutlass/gemm/warp/mma_tensor_op.h(108)
as you can see in the log.
But when I try to use
GemmUniversalBase
orGemmUniversalAdapter
which need to specify aGemmKernel
, It couldn't work withcutlass::int4b_t
, whileint8_t
could.
Can you post your full code here?
But when I try to use
GemmUniversalBase
orGemmUniversalAdapter
which need to specify aGemmKernel
, It couldn't work withcutlass::int4b_t
, whileint8_t
could.Can you post your full code here?
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "gemm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/util/reference/host/gemm.h>
#include <cutlass/util/reference/host/tensor_compare.h>
#include <cutlass/util/reference/host/tensor_copy.h>
#include <cutlass/util/reference/host/tensor_fill.h>
#include <cutlass/util/tensor_view_io.h>
#include <cutlass/gemm/device/gemm_universal_with_broadcast.h>
#include <cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h>
#include <cutlass/util/reference/host/error_metrics.h>
#include <cutlass/util/reference/host/tensor_foreach.h>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1); // 4bit packing is on the columns
auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device()));
// A matrix configuration
using ElementA = int8_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::int4b_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = float; //cutlass::half_t; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = float; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = int32_t; // Element type for internal accumulation
using ElementCompute = float; //cutlass::half_t; // Element type for compute
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1;
// StreamK device GEMM implementation type with EVT
using namespace cute;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC, //4
EVTEpilogueStages
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
// alphaCol [M, 1] fp32
using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<int32_t, _1, _0> // StrideMNL
>;
// alphaRow [1, N] fp32
using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
// mul
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaCol * accumulator
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
V1Broadcast,
Accum>;
// mul
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementOutput, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaRow * alphaCol * accumulator
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
Compute1,
V2Broadcast,
EVTCompute0>;
using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
StoreD,
EVTCompute1>;
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAddSaturate,
EVTEpilogueStages
>::GemmKernel;
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
// Ensure the input tensors are in the correct device and layout
auto tensor_a = A.contiguous();
auto tensor_b = B.contiguous();
auto tensor_v1 = alphaCol.contiguous();
auto tensor_v2 = alphaRow.contiguous();
auto tensor_d = D.contiguous(); // EVTD
typename EVTD::Arguments callback_args{
{
{
{
{}, // Accum
{tensor_v1.data_ptr<ElementC>(), ElementC(0), {int32_t(M), _1{}, _0{}}}, // V1 Broadcast
{} // Compute0
}, // EVTCompute0
{tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}}, // V2 Broadcast
{} // Compute1
}, // EVTCompute1
{} // Compute2
}, // EVTCompute2
{tensor_d.data_ptr<ElementC>(), {int32_t{N}, _1{}, int32_t{M*N}}} // D
}; // EVTD
using GemmCoord = cutlass::gemm::GemmCoord;
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count = 1;
// Construct Gemm ProblemSize with user defined output size
// cutlass::gemm::GemmCoord problem_size = {M, N, K};
cutlass::gemm::GemmCoord problem_size(M, N, K);
int64_t stride_A = M * K;
int64_t stride_B = N * K;
// int64_t stride_C = M * N;
// int64_t stride_D = M * N; // EVTD
int avail_sms = -1;
typename DeviceGemmStreamK::Arguments arguments(
mode, // universal mode
problem_size, // problem_size
batch_count, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_a.data_ptr<ElementA>(), // ptr_A
(cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
stride_A, // batch_stride_A
stride_B, // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_a.stride(0), // stride_a
tensor_b.stride(0), // stride_b
0, // stride_c (unused)
0); // stride_d (unused)
DeviceGemmStreamK gemm_op;
auto stream = at::cuda::getCurrentCUDAStream(A.get_device());
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get(), stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op(stream);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
return tensor_d;
}
I got errors about incomplete type is not allowed
.
This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?
This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?
Hi @alexsamardzic , I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu#L114
you can add this example , then complie and run it:
# cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
# cutlass/build$ make 61_s4s8_gemm
# cutlass/build$ ./examples/61_s4s8_gemm/61_s4s8_gemm
when ElementB = int8_t
, it seems ok. you can get the result:
But when ElementB = cutlass::int4b_t
, lots of compilation errors occur.
I am really at a loss and would greatly appreciate any guidance or help you can provide. Thank you very much in advance for your time and assistance!
I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu
Replace cutlass::arch::OpMultiplyAddSaturate
with cutlass::arch::OpMultiplyAddMixedInputUpcast
.
I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu
Replace
cutlass::arch::OpMultiplyAddSaturate
withcutlass::arch::OpMultiplyAddMixedInputUpcast
.
Works for me. Thanks!
Works for me. Thanks!
Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.
Works for me. Thanks!
Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.
Yea, that was a mistake. OpMultiplyAddMixedInputUpcast
did appear in the test
folder. I think I was a bit disoriented. There are too many arguments.
I think CUTLASS is somewhat challenging for beginners. Do you have any recommended learning paths?