cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Running CUTLASS kernels from device __global__ functions

Open cydoroga opened this issue 2 years ago • 3 comments

Hi! I'm pretty new to CUTLASS (and CUDA, to be honest). I have a two-fold question:

  1. I'm trying to apply dynamic parallelism with the launch of cutlass::gemm::device::Gemm under hood. I want to run a global function from my code (this line). For that I need to initialise params_ variable myself (since it is private in your realisation). I've tried to run the constructor in the same way you are running it here. However, the compiler says, the argument list is wrong. My take to initialise parameters:
    using TCutlassGemm = cutlass::gemm::device::Gemm<
        float,
        cutlass::layout::ColumnMajor,
        float,
        cutlass::layout::ColumnMajor,
        float,
        cutlass::layout::ColumnMajor
    >;

    TCutlassGemm::Arguments args {
        {m, n, k},
        {chunkA, m},
        {chunkB, k},
        {chunkC, m},
        {chunkD, m},
        {alpha, beta},
    };

    TCutlassGemm::ThreadblockSwizzle threadblock_swizzle;

    const cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
      args.problem_size, 
      {
        TCutlassGemm::ThreadblockShape::kM, 
        TCutlassGemm::ThreadblockShape::kN, 
        TCutlassGemm::ThreadblockShape::kK
      },
      args.split_k_slices
    );

    const cutlass::gemm::GemmCoord problem_size = args.problem_size;

    typename TCutlassGemm::GemmKernel::Params params;
    params = typename TCutlassGemm::GemmKernel::Params{
      problem_size,
      grid_shape,
      args.ref_A.non_const_ref(),
      args.ref_B.non_const_ref(),
      args.ref_C.non_const_ref(),
      args.ref_D,
      args.epilogue,
      static_cast<int *>(nullptr),
      const_cast<const int*>(static_cast<int *>(nullptr)),
      const_cast<const int*>(static_cast<int *>(nullptr)),
      const_cast<const int*>(static_cast<int *>(nullptr))
    };

Compiler says, the arguments passed have the following types:

const cutlass::gemm::GemmCoord, 
const cutlass::gemm::GemmCoord, 
cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, 
cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, 
cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, 
cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, 
cutlass::epilogue::thread::LinearCombination<float, 1, float, float, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest>::Params, 
int *, 
const int *, 
const int *, 
const int *

And I have the following error during compilation:

command /root/.ya/tools/v4/2989598506/python /home/cydoroga/arc/arcadia/build/scripts/compile_cuda.py /root/.ya/build/build_root/pi2o/000017/tools/mtime0/mtime0.so /root/.ya/tools/v4/2410761119/bin/nvcc --compiler-bindir=/root/.ya/tools/v4/1886578148/bin/clang -I/root/.ya/tools/v4/1966560555/usr/include/x86_64-linux-gnu -gencode=arch=compute_70,code=sm_70 -g -lineinfo -rdc=true -lcudadevrt -c /home/cydoroga/arc/arcadia/junk/cydoroga/dynamic_parallelism/test.cu -o /root/.ya/build/build_root/pi2o/000017/junk/cydoroga/dynamic_parallelism/test.cu.o -I/root/.ya/build/build_root/pi2o/000017 -I/home/cydoroga/arc/arcadia -I/home/cydoroga/arc/arcadia/contrib/libs/linux-headers -I/home/cydoroga/arc/arcadia/contrib/libs/linux-headers/_nf -I/home/cydoroga/arc/arcadia/contrib/libs/cxxsupp/libcxx/include -I/home/cydoroga/arc/arcadia/contrib/libs/cxxsupp/libcxxrt/include -I/home/cydoroga/arc/arcadia/contrib/libs/zlib/include -I/home/cydoroga/arc/arcadia/contrib/libs/double-conversion -I/home/cydoroga/arc/arcadia/contrib/libs/libc_compat/include/readpassphrase -I/home/cydoroga/arc/arcadia/contrib/libs/libc_compat/include/random -I/home/cydoroga/arc/arcadia/contrib/libs/nvidia/thrust -I/home/cydoroga/arc/arcadia/contrib/libs/nvidia/cub -I/home/cydoroga/arc/arcadia/contrib/libs/nvidia/cutlass/include -I/home/cydoroga/arc/arcadia/contrib/libs/nvidia/cutlass/tools/util/include --cflags --target=x86_64-linux-gnu --sysroot=/root/.ya/tools/v4/1966560555 -B/root/.ya/tools/v4/1966560555/usr/bin -fdebug-prefix-map=/root/.ya/build/build_root/pi2o/000017=/-B -Xclang -fdebug-compilation-dir -Xclang /tmp -pipe -m64 -g -ggnu-pubnames -fexceptions -fno-common -fuse-init-array -fcolor-diagnostics -faligned-allocation -fstack-protector -ffunction-sections -fdata-sections -Wall -Wextra -Wno-parentheses -Wno-implicit-const-int-float-conversion -Wno-unknown-warning-option -Werror -DFAKEID=9517719 -DARCADIA_ROOT=/home/cydoroga/arc/arcadia -DARCADIA_BUILD_ROOT=/root/.ya/build/build_root/pi2o/000017 -D_THREAD_SAFE -D_PTHREADS -D_REENTRANT -D_LIBCPP_ENABLE_CXX17_REMOVED_FEATURES -D_LARGEFILE_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D_FILE_OFFSET_BITS=64 -D_GNU_SOURCE -D_YNDX_LIBUNWIND_ENABLE_EXCEPTION_BACKTRACE -UNDEBUG -D__LONG_LONG_SUPPORTED -DSSE_ENABLED=1 -DSSE3_ENABLED=1 -DSSSE3_ENABLED=1 -DSSE41_ENABLED=1 -DSSE42_ENABLED=1 -DPOPCNT_ENABLED=1 -DCX16_ENABLED=1 -Wno-deprecated-copy -Wno-mismatched-tags -D_libunwind_ -nostdinc++ -DLIBCXX_BUILDING_LIBCXXRT -I/root/.ya/tools/v4/2410761119/include -msse2 -msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -std=c++20 -Woverloaded-virtual -Wimport-preprocessor-directive-pedantic -Wno-ambiguous-reversed-operator -Wno-defaulted-function-deleted -Wno-deprecated-anon-enum-enum-conversion -Wno-deprecated-enum-enum-conversion -Wno-deprecated-enum-float-conversion -Wno-deprecated-volatile -Wno-pessimizing-move -Wno-return-std-move -Wno-undefined-var-template -nostdinc++ -std=c++14 failed with exit code 1 in /root/.ya/build/build_root/pi2o/000017
/home/cydoroga/arc/arcadia/junk/cydoroga/dynamic_parallelism/test.cu(82): error: no instance of constructor "cutlass::gemm::kernel::Gemm<Mma_, Epilogue_, ThreadblockSwizzle_, SplitKSerial>::Params::Params [with Mma_=cutlass::gemm::threadblock::MmaPipelined<cutlass::gemm::GemmShape<128, 128, 8>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<128, 8>, float, cutlass::layout::RowMajor, 1, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::layout::PitchLinearShape<8, 128>, 256, 1>, 1, false>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<128, 8>, float, cutlass::layout::ColumnMajor, 1, cutlass::transform::TransposePitchLinearThreadMapSimt<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::layout::PitchLinearShape<8, 128>, 256, 1>>, 4>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<8, 128>, float, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::layout::PitchLinearShape<128, 8>, 256, 1>, 1, false>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<8, 128>, float, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::layout::PitchLinearShape<128, 8>, 256, 1>, 4>, float, cutlass::layout::RowMajor, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaSimt<cutlass::gemm::GemmShape<32, 64, 8>, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<4, 8>, cutlass::layout::RowMajorInterleaved<2>, cutlass::gemm::GemmShape<4, 4, 1>>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, __nv_bool>, cutlass::MatrixShape<4, 0>, cutlass::MatrixShape<0, 0>, 1>, cutlass::NumericArrayConverter<float, float, 4, cutlass::FloatRoundStyle::round_to_nearest, cutlass::transform::thread::UnaryTransform::Identity>, cutlass::NumericArrayConverter<float, float, 4, cutlass::FloatRoundStyle::round_to_nearest, cutlass::transform::thread::UnaryTransform::Identity>, __nv_bool>, Epilogue_=cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<128, 128, 8>, cutlass::gemm::warp::MmaSimt<cutlass::gemm::GemmShape<32, 64, 8>, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<4, 8>, cutlass::layout::RowMajorInterleaved<2>, cutlass::gemm::GemmShape<4, 4, 1>>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, __nv_bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, 256, 1, 32>, float, false, false>, cutlass::epilogue::warp::FragmentIteratorSimt<cutlass::gemm::GemmShape<32, 64, 8>, cutlass::gemm::thread::Mma<cutlass::gemm::GemmShape<8, 8, 1>, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, __nv_bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<4, 8>, cutlass::layout::RowMajorInterleaved<2>, cutlass::gemm::GemmShape<4, 4, 1>>>, cutlass::epilogue::warp::TileIteratorSimt<cutlass::gemm::GemmShape<32, 64, 8>, cutlass::gemm::thread::Mma<cutlass::gemm::GemmShape<8, 8, 1>, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, __nv_bool>, float, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<4, 8>, cutlass::layout::RowMajorInterleaved<2>, cutlass::gemm::GemmShape<4, 4, 1>>>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 1, 4, 4, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 4, 2, 1, 8>, 256, 1, 32>::CompactedThreadMap, float, 4>, cutlass::epilogue::thread::LinearCombination<float, 1, float, float, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest>, cutlass::MatrixShape<0, 17>, 1, 1>, ThreadblockSwizzle_=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, SplitKSerial=false]" matches the argument list
            argument types are: (const cutlass::gemm::GemmCoord, const cutlass::gemm::GemmCoord, cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, cutlass::epilogue::thread::LinearCombination<float, 1, float, float, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest>::Params, int *, const int *, const int *, const int *)
  1. What I actually want is to implement batched-matmul fused with a fancy gather procedure. I haven't found a batched-gemm fused with gather in your library. But there is a chance I've missed something. May be there is a better way of achieving my aims?

cydoroga avatar Jul 04 '22 10:07 cydoroga

/home/cydoroga/arc/arcadia/junk/cydoroga/dynamic_parallelism/test.cu(82): error: no instance of constructor "cutlass::gemm::kernel::Gemm<Mma_, Epilogue_, ThreadblockSwizzle_, SplitKSerial>::Params::Params

Why there are two Params in the end?

hwu36 avatar Jul 05 '22 20:07 hwu36

cutlass GemmUniversal supports both modes of batched gemm. See https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm_universal.h

Set mode here to kBatched or kArray (https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/gemm.h#L407-L408).

This example can gather inputs for a GEMM. It loads the row or column indices from an array and only fetch data from these rows or columns. It only support gathering rows from Row Major matrix or columns from column Major matrix. This gather functionality may not be what you want.

hwu36 avatar Jul 05 '22 20:07 hwu36

I think, this issue can be closed because I found that the Arguments signature for the ColumnMajor matrices is different from the signature for the RowMajors. Here is the change required: https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm.h#L691 Thus, I was able to use your kernel for my needs.

Suddenly, the performance is quite poor. Thats why I've tried to use your GemmArray kernel with which I have other problems. I'll open a new issue for them.

cydoroga avatar Jul 18 '22 18:07 cydoroga

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Aug 17 '22 19:08 github-actions[bot]