cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] calling cast_smem_ptr_to_uint(device fn) from make_gmma_desc(host device fn) is not allowed

Open lygztq opened this issue 1 year ago • 4 comments

Describe the bug nvcc will report an error like

error: calling a __device__ function("cute::cast_smem_ptr_to_uint(const void *)") from a __host__ __device__ function("make_gmma_desc") is not allowed
    uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data()));
                             ^
          detected during instantiation of "cute::GmmaDescriptor cute::GMMA::make_gmma_desc<MajorMode,TEngine,TLayout>(const cute::Tensor<TEngine, TLayout> &) [with MajorMode=cute::GMMA::Major::K, TEngine=cute::ViewEngine<cute::swizzle_ptr<cute::Swizzle<3, 4, 3>, cute::smem_ptr<cutlass::float_e4m3_t *>>>, TLayout=cute::Layout<cute::tuple<cute::C<64>, cute::C<32>>, cute::tuple<cute::_128, cute::_1>>]"

when I try to make a partition fragment on smem tensor with wgmma tiled mma.

Steps/Code to reproduce bug Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.

#include "cutlass/float8.h"

#include "cute/layout.hpp"
#include "cute/pointer.hpp"
#include "cute/tensor.hpp"

#include "cute/swizzle_layout.hpp"
#include "cute/underscore.hpp"

#include "cute/pointer_flagged.hpp"

#include "cute/arch/copy.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/atom/copy_traits_sm90_tma_swizzle.hpp"

#include "cute/arch/mma_sm90_gmma.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/mma_traits_sm90_gmma.hpp"

#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/gemm.h"

#include "cute/arch/mma_sm90.hpp"

#include "cutlass/numeric_conversion.h"
#include <cstdint>

__global__ void test_kernel() {
  using namespace cute;
  constexpr int M = 256;
  constexpr int N = 16;
  constexpr int K = 128;

  using Element = cutlass::float_e4m3_t;
  using AccumElement = float;
  using TileShape_MNK = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>;
  using GmmaTileShape = cute::Layout<cute::Shape<cute::Int<M / 64>, cute::_1, cute::_1>>;
  using TiledGmma0 = decltype(cute::make_tiled_mma(
      cute::GMMA::ss_op_selector<Element, Element, AccumElement,
                                 cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>>(),
      GmmaTileShape{}));
  using SmemLayoutAtomA =
      decltype(cutlass::gemm::collective::detail::ss_smem_selector<
               cute::GMMA::Major::K, Element, decltype(cute::get<0>(TileShape_MNK{})),
               decltype(cute::get<2>(TileShape_MNK{}))>());
  using SmemLayoutA =
      decltype(cute::tile_to_shape(SmemLayoutAtomA{}, cute::select<0, 2>(TileShape_MNK{})));

  __shared__ uint8_t smem_a_bytes[size(select<0, 2>(TileShape_MNK{}))];

  auto tiled_mma0 = TiledGmma0{};
  auto thread_mma0 = tiled_mma0.get_thread_slice(threadIdx.x);

  auto sA = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_a_bytes)), SmemLayoutA{});
  auto tOrA = thread_mma0.partition_fragment_A(sA);
}

int main() {
  dim3 grid(1);
  dim3 block(4 * 32);
  test_kernel<<<grid, block>>>();
  cudaDeviceSynchronize();
  return 0;
}

Expected behavior I don't think it is a proper invocation (calling device in host device) that is acceptable to nvcc. However I do notice that in some cases the same invocation can be accepted by nvcc, why?

Environment details (please complete the following information): cuda 12.4, cutlass 3.4

lygztq avatar Dec 18 '24 09:12 lygztq

You're most likely using your own command line. Use the command line flags generated by our cmake.

thakkarV avatar Dec 18 '24 09:12 thakkarV

Here is my compile command

/usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler  -I/soft/3rdparty/cutlass/include -I/soft/3rdparty/cutlass/tools/util/include -I/soft/3rdparty/cutlass/examples/common --generate-code=arch=compute_90a,code=[compute_90a,sm_90a] --use_fast_math --forward-unknown-to-host-compiler --expt-extended-lambda --expt-relaxed-constexpr --generate-line-info -Xcompiler=-fPIE -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing -O3 -std=c++17 -MD -MT case_study/CMakeFiles/foo.dir/foo.cu.o -MF CMakeFiles/foo.dir/foo.cu.o.d -x cu -c /soft/case_study/foo.cu -o CMakeFiles/foo.dir/foo.cu.o

If you mean the missing --expt-relaxed-constexpr when calling __host__ in __host__ __device__, I have it in my command. Could you tell me which flag I can use to avoid such error?

lygztq avatar Dec 18 '24 10:12 lygztq

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jan 17 '25 11:01 github-actions[bot]

@Junkai-Wu

hwu36 avatar Mar 04 '25 03:03 hwu36

Any update on this? I got the same error and I cannot find such a special flag that can fix it. Thanks for any help!

monellz avatar Mar 13 '25 07:03 monellz

@thakkarV I follow the suggestion and copy the code provided by @lygztq into example directory to try compiling it but still get errors.

Instructions I did are:

# copy code into example/00_basic_gemm, name it 'mma_test.cu', and then modify the target in cmakelist:
#
# cutlass_example_add_executable(
#  00_basic_gemm
#  mma_test.cu
#)

mkdir build && cd build
cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCMAKE_BUILD_TYPE=Release
cmake --build . --target test_examples_00_basic_gemm --verbose

The error is:

Building CUDA object examples/00_basic_gemm/CMakeFiles/00_basic_gemm.dir/mma_test.cu.o
cd /home/zhongrx/gpu/nvgpu_microbench/3rd/cutlass/build/examples/00_basic_gemm && /home/spack/spack/opt/spack/linux-debian12-sapphirerapids/gcc-12.2.0/cuda-12.4.1-f3kmmeb5h7wldnd25td433vh2bg765wp/bin/nvcc -forward-unknown-to-host-compiler  --options-file CMakeFiles/00_basic_gemm.dir/includes_CUDA.rsp -DCUTLASS_VERSIONS_GENERATED -O3 -DNDEBUG -std=c++17 "--generate-code=arch=compute_90a,code=[sm_90a]" "--generate-code=arch=compute_90a,code=[compute_90a]" -Xcompiler=-fPIE -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 -DCUTLASS_ENABLE_GDC_FOR_SM100=1 --expt-relaxed-constexpr -DCUTLASS_TEST_LEVEL=0 -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1 -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1 -DCUTLASS_DEBUG_TRACE_LEVEL=0 -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing -MD -MT examples/00_basic_gemm/CMakeFiles/00_basic_gemm.dir/mma_test.cu.o -MF CMakeFiles/00_basic_gemm.dir/mma_test.cu.o.d -x cu -c /home/zhongrx/gpu/nvgpu_microbench/3rd/cutlass/examples/00_basic_gemm/mma_test.cu -o CMakeFiles/00_basic_gemm.dir/mma_test.cu.o
/home/zhongrx/gpu/nvgpu_microbench/3rd/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp(215): error: calling a __device__ function("cute::cast_smem_ptr_to_uint(const void *)") from a __host__ __device__ function("make_gmma_desc") is not allowed
    uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data()));
                             ^
          detected during instantiation of "cute::GmmaDescriptor cute::SM90::GMMA::make_gmma_desc<MajorMode,TEngine,TLayout>(const cute::Tensor<TEngine, TLayout> &) [with MajorMode=cute::SM90::GMMA::Major::K, TEngine=cute::ViewEngine<cute::swizzle_ptr<cute::Swizzle<3, 4, 3>, cute::smem_ptr<cutlass::float_e4m3_t *>>>, TLayout=cute::Layout<cute::tuple<cute::C<64>, cute::C<32>>, cute::tuple<cute::_128, cute::_1>>]" at line 56 of /home/zhongrx/gpu/nvgpu_microbench/3rd/cutlass/examples/00_basic_gemm/mma_test.cu

1 error detected in the compilation of "/home/zhongrx/gpu/nvgpu_microbench/3rd/cutlass/examples/00_basic_gemm/mma_test.cu".

Version infomation:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0


cutlass commit: 06e560d98a5fe8acb975db2c4c26817b6c90acb1

monellz avatar Mar 13 '25 10:03 monellz

We should add a #if defined(__CUDA_ARCH__) guard around cast_smem_ptr_to_uint function to prevent it called from host. I'll file a PR for this.

Junkai-Wu avatar Mar 14 '25 01:03 Junkai-Wu

I disagree with guarding functions with __CUDA_ARCH__, there is a lot of utility in being able to prototype on host and preventing a function like this from being used at all would also prevent the creation of a GMMA Desc on host in this case and a whole tons of functionality in other cases.

Let's make it HOST_DEVICE as suggested and guard the implementation with __CUDA_ARCH__ instead (the implementation already is guarded).

ccecka avatar Mar 14 '25 01:03 ccecka

Got it. I'll make all functions calling cast_smem_ptr_to_uint to be HOST_DEVICE

Junkai-Wu avatar Mar 14 '25 01:03 Junkai-Wu

I tried to add CUTE_HOST_DEVICE to cast_smem_ptr_to_unit and it works properly (see pr #2171).

monellz avatar Mar 14 '25 04:03 monellz

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Apr 13 '25 05:04 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Jul 12 '25 06:07 github-actions[bot]

This issue is fixed. Close it.

Junkai-Wu avatar Jul 24 '25 02:07 Junkai-Wu