cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Getting a template error trying to use cutlass's depthwise 2D convolution with pytorch

Open ahmadsharif1 opened this issue 1 year ago • 5 comments

My high level goal is to use one of cutlass' 2D depthwise convolution kernels with pytorch's tensors.

I am starting off with the SIMT kernel because that can work on older devices. So I am basically copying code from this example:

https://github.com/NVIDIA/cutlass/blob/affd1b693dfc121c51118cbc8583dfd308227ca6/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu#L80

The tensors that I have have layout NCHW and use float so I used this snippet of code in the example:

// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float;      // Data type of accumulator
using ElementComputeEpilogue = float;  // Data type of epilogue computation (alpha, beta)
using ElementInputA = float;           // Data type of elements in input tensor
using ElementInputB = float;           // Data type of elements in input tensor
using ElementOutput = float;           // Data type of elements in output tensor

using LayoutInputA = cutlass::layout::TensorNCHW;
using LayoutInputB = cutlass::layout::TensorNCHW;
using LayoutOutput = cutlass::layout::TensorNCHW;

This throws an error with nvcc:

Building CUDA object examples/46_depthwise_simt_conv2dfprop/CMakeFiles/46_depthwise_simt_conv2dfprop.dir/depthwise_simt_conv2dfprop.cu.o
/home/ahmads/personal/cutlass/include/cutlass/conv/kernel/direct_convolution.h(95): error: incomplete type is not allowed
          detected during:
            instantiation of class "cutlass::conv::kernel::DirectConvolutionParams<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, Arguments_, ConvOutputIteratorParameter_, ConvProblemSize_, GroupMode_, ThreadBlockOutputShape_> [with Mma_=cutlass::conv::threadblock::DepthwiseFpropDirectConvMultipleStage<ThreadblockShape, cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<cutlass::MatrixShape<64, 64>, ThreadBlockOutputShape, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, ElementInputA, LayoutInputA, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<100, 64>, ElementInputA, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<cutlass::MatrixShape<64, 9>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<9, 64>, ElementInputB, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, 1>, 4, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, bool>, Epilogue_=cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, ThreadblockSwizzle_=SwizzleThreadBlock, ConvOperator=cutlass::conv::Operator::kFprop, Arguments_=cutlass::conv::kernel::DirectConvolution<cutlass::conv::threadblock::DepthwiseFpropDirectConvMultipleStage<ThreadblockShape, cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<cutlass::MatrixShape<64, 64>, ThreadBlockOutputShape, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, ElementInputA, LayoutInputA, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<100, 64>, ElementInputA, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<cutlass::MatrixShape<64, 9>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<9, 64>, ElementInputB, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, 1>, 4, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, bool>, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, SwizzleThreadBlock, cutlass::conv::Operator::kFprop, cutlass::conv::Conv2dProblemSize, cutlass::conv::GroupMode::kDepthwise, ThreadBlockOutputShape>::Arguments, ConvOutputIteratorParameter_=cutlass::epilogue::threadblock::ConvOutputIteratorParameter<LayoutInputA, cutlass::layout::RowMajor, cutlass::TensorRef<ElementInputA, LayoutInputA>, cutlass::conv::Operator::kFprop, cutlass::conv::Conv2dProblemSize>, ConvProblemSize_=cutlass::conv::Conv2dProblemSize, GroupMode_=cutlass::conv::GroupMode::kDepthwise, ThreadBlockOutputShape_=ThreadBlockOutputShape]" 
/home/ahmads/personal/cutlass/include/cutlass/conv/device/direct_convolution.h(96): here
            instantiation of class "cutlass::conv::device::DirectConvolution<DirectConvolutionKernel_> [with DirectConvolutionKernel_=DepthwiseDirect2dConv]" 
/home/ahmads/personal/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu(481): here

/home/ahmads/personal/cutlass/include/cutlass/conv/kernel/direct_convolution.h(97): error: incomplete type is not allowed
          detected during:
            instantiation of class "cutlass::conv::kernel::DirectConvolutionParams<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, Arguments_, ConvOutputIteratorParameter_, ConvProblemSize_, GroupMode_, ThreadBlockOutputShape_> [with Mma_=cutlass::conv::threadblock::DepthwiseFpropDirectConvMultipleStage<ThreadblockShape, cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<cutlass::MatrixShape<64, 64>, ThreadBlockOutputShape, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, ElementInputA, LayoutInputA, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<100, 64>, ElementInputA, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<cutlass::MatrixShape<64, 9>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<9, 64>, ElementInputB, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, 1>, 4, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, bool>, Epilogue_=cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, ThreadblockSwizzle_=SwizzleThreadBlock, ConvOperator=cutlass::conv::Operator::kFprop, Arguments_=cutlass::conv::kernel::DirectConvolution<cutlass::conv::threadblock::DepthwiseFpropDirectConvMultipleStage<ThreadblockShape, cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation<cutlass::MatrixShape<64, 64>, ThreadBlockOutputShape, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, ElementInputA, LayoutInputA, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<100, 64>, ElementInputA, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized<cutlass::MatrixShape<64, 9>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, cutlass::AlignedArray<ElementInputA, 4, 16>>, cutlass::transform::threadblock::RegularTileAccessIteratorDirectConv<cutlass::MatrixShape<9, 64>, ElementInputB, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, false, 16>, cutlass::arch::CacheOperation::Global, cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy<cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 100>, 128, 4>, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 9>, 128, 4>, 1>, 4, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, bool>, cutlass::epilogue::threadblock::EpilogueDepthwise<ThreadblockShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::warp::MmaDepthwiseDirectConvSimt<WarpShape, FilterShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::RowMajor, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>, cutlass::conv::IteratorAlgorithm::kFixedStrideDilation, StrideShape, DilationShape, cutlass::conv::TensorNHWCShape<1, 10, 10, 64>, 1, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, bool>, cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementOutput, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape>, cutlass::epilogue::warp::FragmentIteratorSimt<WarpShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::warp::TileIteratorSimtDirect2dConv<WarpShape, cutlass::conv::TensorNHWCShape<1, 4, 4, 2>, ThreadBlockOutputShape, cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct<cutlass::gemm::GemmShape<16, 2, 1>, ElementInputA, ElementInputB, ElementAccumulator, cutlass::arch::OpMultiplyAdd, bool>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaSimtPolicy<cutlass::MatrixShape<1, 32>, cutlass::layout::RowMajorInterleaved<1>, cutlass::gemm::GemmShape<2, 2, 1>>>, cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear<cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<64, 64>, 128, 4>, ElementAccumulator, 16>, EpilogueOp, cutlass::MatrixShape<0, 0>>, SwizzleThreadBlock, cutlass::conv::Operator::kFprop, cutlass::conv::Conv2dProblemSize, cutlass::conv::GroupMode::kDepthwise, ThreadBlockOutputShape>::Arguments, ConvOutputIteratorParameter_=cutlass::epilogue::threadblock::ConvOutputIteratorParameter<LayoutInputA, cutlass::layout::RowMajor, cutlass::TensorRef<ElementInputA, LayoutInputA>, cutlass::conv::Operator::kFprop, cutlass::conv::Conv2dProblemSize>, ConvProblemSize_=cutlass::conv::Conv2dProblemSize, GroupMode_=cutlass::conv::GroupMode::kDepthwise, ThreadBlockOutputShape_=ThreadBlockOutputShape]" 
/home/ahmads/personal/cutlass/include/cutlass/conv/device/direct_convolution.h(96): here
            instantiation of class "cutlass::conv::device::DirectConvolution<DirectConvolutionKernel_> [with DirectConvolutionKernel_=DepthwiseDirect2dConv]" 
/home/ahmads/personal/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu(481): here

2 errors detected in the compilation of "/home/ahmads/personal/cutlass/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu".
make[3]: *** [examples/46_depthwise_simt_conv2dfprop/CMakeFiles/46_depthwise_simt_conv2dfprop.dir/build.make:77: examples/46_depthwise_simt_conv2dfprop/CMakeFiles/46_depthwise_simt_conv2dfprop.dir/depthwise_simt_conv2dfprop.cu.o] Error 2
make[2]: *** [CMakeFiles/Makefile2:9244: examples/46_depthwise_simt_conv2dfprop/CMakeFiles/46_depthwise_simt_conv2dfprop.dir/all] Error 2
make[1]: *** [CMakeFiles/Makefile2:9251: examples/46_depthwise_simt_conv2dfprop/CMakeFiles/46_depthwise_simt_conv2dfprop.dir/rule] Error 2
make: *** [Makefile:3177: 46_depthwise_simt_conv2dfprop] Error 2

What is your question?

Questions:

  1. How do I fix this error? Are convolutions on NCHW tensors even supported by Cutlass?
  2. Basic question: How do I convert a pytorch tensor into a cutlass::TensorRef? I am currently using this snippet of code:
    auto makeDeviceRef = [](const Tensor &tensor) {
      auto tensorLayout = cutlass::layout::TensorNCHW::packed({tensor.stride(0), tensor.stride(1), tensor.stride(2), tensor.stride(3)});
      auto ret = cutlass::make_TensorRef(tensor.data_ptr<scalar_t*>(), tensorLayout);
      return ret;
    };

Does that look correct? 3. Can cutlass handle convolutions with variable filter sizes that are not known at compile time? It seems like templates need to be instantiated with concrete numbers so cutlass only works for known filter sizes. 4. How can I choose what cuda stream to run the kernel on? 5. Is it possible to build the kernel for all sm arches and dynamically choose the one latest one dependent on the what machine I am running on? It seems like that's not possible because the arch has to be passed down to cutlass at compile time to DefaultDepthwiseDirect2dConvFprop. 6. How do I choose meta variables like pipelines stages, groups per CTA, thread block output shape, etc. in an automated way (I am afraid I may choose variables that could slow down my kernel)? Do I have to tune everything by hand or is there a way to choose these variables based on the GPU I am using more automatically? If these are compile-time constants, I may have to use a look-up table to figure them out per GPU? 7. Depth-wise convolution 2D specific question: what is this tensor_b_transpose variable (https://github.com/NVIDIA/cutlass/blob/affd1b693dfc121c51118cbc8583dfd308227ca6/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu#L424)? It seems like in the code this is not derived from tensor_b at all (I thought it was going to be the transpose of tensor_b, but it seems to not be initialized in the example. Is this just a temporary tensor that is filled by cutlass itself? I am wondering why this is required? Why can't cutlass allocate this tensor by itself?

ahmadsharif1 avatar Feb 04 '25 18:02 ahmadsharif1

@Junkai-Wu

hwu36 avatar Feb 06 '25 03:02 hwu36

@Ethan-Yan27 could you help answer some of the questions here? Thanks!

Junkai-Wu avatar Feb 06 '25 09:02 Junkai-Wu

Q1.

For depthwise conv, it only supports NHWC.

Q3.

Yes. filter size need to be a compile time value as some computations inside kernel rely on the filter size.

Q4.

I think default cuda stream should be okay.

Q5.

Depthwise conv kernel is using cp_async_fill, please build with >=sm80

Q6.

Here is the explanation about depthwise directconv impl and how to set the template: https://github.com/NVIDIA/cutlass/issues/1133#issuecomment-1756668121. several notes:

  1. stages, groups per CTA,thread block output shape would affect the smem usage, we should not exceed the HW SMEM limiation.
  2. please do not forget use splitK feature to max the SM utilization to get a better perf: https://github.com/NVIDIA/cutlass/issues/1213#issuecomment-1833431916

Q7. kernel would use this temp buffer to reorder the filter tensor https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/kernel/direct_convolution.h#L171

Ethan-Yan27 avatar Feb 06 '25 14:02 Ethan-Yan27

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Mar 08 '25 14:03 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Jun 06 '25 14:06 github-actions[bot]