vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel

Open LucasWilkinson opened this issue 1 year ago β€’ 1 comments

Description

This PR introduces a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based off of cutlass. This PR just adds the kernel implementation, unit tests and benchmarking scripts. End2end integration will come in a future PR.

Motivation

The motivation for this kernel is multifold:

  1. Marlin (v1) uses mma instructions, which are fastest tensor core instructions available on Ampere but with Hopper Nvidia release a set of new wgmma instructions which are required to hit the peak FLOPs reported by Nvidia, without them i.e. using mma instructions you can expect to achieve at best ~75% of peak [1, 2]
  2. Marlin (v1) uses a specific weight storage layout that is specialized for the mma instructions, we want to adopt a more flexible/dynamic way of defining these layouts so we can accommodate new instructions more rapidly, i.e. wgmma and new instructions Blackwell introduces if any
    • MarlinV2 achieves this by describing the weight storage scheme using cutlass and CUTE
  3. Marlin (v1) does not support cutlass epilogues, we eventually plan to investigate subbyte weight quantization + activation quantization, for activation quantization we'd like to leverage the great work done by @tlrmchlsmth @varun-sundar-rabindranath and @ProExpertProg to write custom cutlass epilogues for fp8 and int8

TODO:

  • [x] Chose a new name (candidates: wahoo, swordfish (kinda cutlass + marlin), non-fish names ...): edit: chose machete
  • [x] Improve heuristic namely for 4096x4096: resolved by moving heuristic into the C++ code

Future PRs

  • [ ] Improve BFloat16 performance (via bit shift or interleaving)
  • [ ] E2E integration
  • [ ] Improve batch size < 32 performance (potentially a future PR, likely through improving the stream-k scheduler)
  • [ ] Investigate fp8 activation support

Current Performance

Float16

graph_machete_bench_float16

BFloat16

graph_machete_bench_bfloat16

LucasWilkinson avatar Aug 05 '24 23:08 LucasWilkinson

πŸ‘‹ Hi! Thank you for contributing to the vLLM project. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

πŸš€

github-actions[bot] avatar Aug 05 '24 23:08 github-actions[bot]

/ready

LucasWilkinson avatar Aug 16 '24 15:08 LucasWilkinson

Should there be a comment/README somewhere that briefly describes what steps are needed if a new type/shape combination needs to be added?

bnellnm avatar Aug 19 '24 21:08 bnellnm

Should there be a comment/README somewhere that briefly describes what steps are needed if a new type/shape combination needs to be added?

Added an initial readme, will likely be able to improve it as work on the w4a8 support in machete since that the first new non-w4a16 type pair added

LucasWilkinson avatar Aug 20 '24 03:08 LucasWilkinson

Oops I did not realize this was merged. Feel free to ignore or address in the next PR

ProExpertProg avatar Aug 20 '24 17:08 ProExpertProg

Hi @LucasWilkinson , with this PR merged, now I could not build vLLM from source any more with the following errors, could you please help to look into that? The errors:

Building wheels for collected packages: vllm
  Building editable for vllm (pyproject.toml) ... error
  error: subprocess-exited-with-error

  Γ— Building editable for vllm (pyproject.toml) did not run successfully.
  β”‚ exit code: 1
  ╰─> [192 lines of output]
      /tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
        cpu = _conversion_method_template(device=torch.device("cpu"))
      running editable_wheel
      creating /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info
      writing /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/PKG-INFO
      writing dependency_links to /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/dependency_links.txt
      writing entry points to /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/entry_points.txt
      writing requirements to /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/requires.txt
      writing top-level names to /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/top_level.txt
      writing manifest file '/tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/SOURCES.txt'
      reading manifest file '/tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/SOURCES.txt'
      reading manifest template 'MANIFEST.in'
      adding license file 'LICENSE'
      writing manifest file '/tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm.egg-info/SOURCES.txt'
      creating '/tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm-0.5.4+cu118.dist-info'
      creating /tmp/pip-wheel-khv6cnct/.tmp-v8tzgcjn/vllm-0.5.4+cu118.dist-info/WHEEL
      running build_py
      running build_ext
      Using MAX_JOBS=256 as the number of jobs.
      Using NVCC_THREADS=16 as the number of nvcc threads.
      -- The CXX compiler identification is GNU 9.4.0
      -- Detecting CXX compiler ABI info
      -- Detecting CXX compiler ABI info - done
      -- Check for working CXX compiler: /usr/bin/c++ - skipped
      -- Detecting CXX compile features
      -- Detecting CXX compile features - done
      -- Build type: RelWithDebInfo
      -- Target device: cuda
      -- Found Python: /home/aiscuser/.conda/envs/myenv/bin/python (found version "3.10.14") found components: Interpreter Development.Module Development.SABIModule
      -- Found python matching: /home/aiscuser/.conda/envs/myenv/bin/python.
      -- Found CUDA: /usr/local/cuda (found version "11.8")
      -- The CUDA compiler identification is NVIDIA 11.8.89
      -- Detecting CUDA compiler ABI info
      -- Detecting CUDA compiler ABI info - done
      -- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc - skipped
      -- Detecting CUDA compile features
      -- Detecting CUDA compile features - done
      -- Found CUDAToolkit: /usr/local/cuda/include (found version "11.8.89")
      -- Performing Test CMAKE_HAVE_LIBC_PTHREAD
      -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
      -- Looking for pthread_create in pthreads
      -- Looking for pthread_create in pthreads - not found
      -- Looking for pthread_create in pthread
      -- Looking for pthread_create in pthread - found
      -- Found Threads: TRUE
      -- Caffe2: CUDA detected: 11.8
      -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
      -- Caffe2: CUDA toolkit directory: /usr/local/cuda
      -- Caffe2: Header version is: 11.8
      -- /usr/local/cuda/lib64/libnvrtc.so shorthash is 672ee683
      -- USE_CUDNN is set to 0. Compiling without cuDNN support
      -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support
      -- Autodetected CUDA architecture(s):  8.0 8.0 8.0 8.0
      -- Added CUDA NVCC flags for: -gencode;arch=compute_80,code=sm_80
      CMake Warning at /tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message):
        static library kineto_LIBRARY-NOTFOUND not found.
      Call Stack (most recent call first):
        /tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:120 (append_torchlib_if_found)
        CMakeLists.txt:67 (find_package)


      -- Found Torch: /tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/torch/lib/libtorch.so
      -- Enabling core extension.
      -- CUDA supported arches: 7.0;7.5;8.0;8.6;8.9;9.0
      -- CUDA target arches: 80-real
      -- CMake Version: 3.30.2
      -- CUTLASS 3.5.1
      -- CUDART: /usr/local/cuda/lib64/libcudart.so
      -- CUDA Driver: /usr/local/cuda/lib64/stubs/libcuda.so
      -- NVRTC: /usr/local/cuda/lib64/libnvrtc.so
      -- Default Install Location: install
      -- Found Python3: /home/aiscuser/.conda/envs/myenv/bin/python3.10 (found suitable version "3.10.14", minimum required is "3.5") found components: Interpreter
      -- Make cute::tuple be the new standard-layout tuple type
      -- CUDA Compilation Architectures: 70;72;75;80;86;87;89;90
      -- Enable caching of reference results in conv unit tests
      -- Enable rigorous conv problem sizes in conv unit tests
      -- Using NVCC flags: --expt-relaxed-constexpr;-DCUTE_USE_PACKED_TUPLE=1;-DCUTLASS_TEST_LEVEL=0;-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1;-DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1;-DCUTLASS_DEBUG_TRACE_LEVEL=0;-Xcompiler=-Wconversion;-Xcompiler=-fno-strict-aliasing;-lineinfo
      fatal: not a git repository (or any of the parent directories): .git
      -- CUTLASS Revision: Unable to detect, Git returned code 128.
      -- Configuring cublas ...
      -- cuBLAS Disabled.
      -- Configuring cuBLAS ... done.
      CMake Error at CMakeLists.txt:245 (message):
        Machete generation failed.  Result: "1"

        Check the log for details:
        /tmp/tmp7h5m3uxb.build-temp/machete_generation.log


      -- Configuring incomplete, errors occurred!
      Traceback (most recent call last):
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 135, in run
          self._create_wheel_file(bdist_wheel)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 338, in _create_wheel_file
          files, mapping = self._run_build_commands(dist_name, unpacked, lib, tmp)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 261, in _run_build_commands
          self._run_build_subcommands()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 288, in _run_build_subcommands
          self.run_command(name)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 316, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
          _build_ext.run(self)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 359, in run
          self.build_extensions()
        File "<string>", line 222, in build_extensions
        File "<string>", line 204, in configure
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/subprocess.py", line 369, in check_call
          raise CalledProcessError(retcode, cmd)
      subprocess.CalledProcessError: Command '['cmake', '/home/aiscuser/vllm', '-G', 'Ninja', '-DCMAKE_BUILD_TYPE=RelWithDebInfo', '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=/tmp/tmp7v3db2j6.build-lib/vllm', '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=/tmp/tmp7h5m3uxb.build-temp', '-DVLLM_TARGET_DEVICE=cuda', '-DVLLM_PYTHON_EXECUTABLE=/home/aiscuser/.conda/envs/myenv/bin/python', '-DNVCC_THREADS=16', '-DCMAKE_JOB_POOL_COMPILE:STRING=compile', '-DCMAKE_JOB_POOLS:STRING=compile=16']' returned non-zero exit status 1.
      /tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py:983: _DebuggingTips: Problem in editable installation.
      !!

              ********************************************************************************
              An error happened while installing `vllm` in editable mode.

              The following steps are recommended to help debug this problem:

              - Try to install the project normally, without using the editable mode.
                Does the error still persist?
                (If it does, try fixing the problem before attempting the editable mode).
              - If you are using binary extensions, make sure you have all OS-level
                dependencies installed (e.g. compilers, toolchains, binary libraries, ...).
              - Try the latest version of setuptools (maybe the error was already fixed).
              - If you (or your project dependencies) are using any setuptools extension
                or customization, make sure they support the editable mode.

              After following the steps above, if the problem still persists and
              you think this is related to how setuptools handles editable installations,
              please submit a reproducible example
              (see https://stackoverflow.com/help/minimal-reproducible-example) to:

                  https://github.com/pypa/setuptools/issues

              See https://setuptools.pypa.io/en/latest/userguide/development_mode.html for details.
              ********************************************************************************

      !!
        cmd_obj.run()
      Traceback (most recent call last):
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
          main()
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 273, in build_editable
          return hook(wheel_directory, config_settings, metadata_directory)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 458, in build_editable
          return self._build_with_temp_dir(
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 402, in _build_with_temp_dir
          self.run_setup()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 318, in run_setup
          exec(code, locals())
        File "<string>", line 470, in <module>
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/__init__.py", line 111, in setup
          return distutils.core.setup(**attrs)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 184, in setup
          return run_commands(dist)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 200, in run_commands
          dist.run_commands()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 964, in run_commands
          self.run_command(cmd)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 135, in run
          self._create_wheel_file(bdist_wheel)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 338, in _create_wheel_file
          files, mapping = self._run_build_commands(dist_name, unpacked, lib, tmp)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 261, in _run_build_commands
          self._run_build_subcommands()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 288, in _run_build_subcommands
          self.run_command(name)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 316, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
          _build_ext.run(self)
        File "/tmp/pip-build-env-k21yn32m/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 359, in run
          self.build_extensions()
        File "<string>", line 222, in build_extensions
        File "<string>", line 204, in configure
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/subprocess.py", line 369, in check_call
          raise CalledProcessError(retcode, cmd)
      subprocess.CalledProcessError: Command '['cmake', '/home/aiscuser/vllm', '-G', 'Ninja', '-DCMAKE_BUILD_TYPE=RelWithDebInfo', '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=/tmp/tmp7v3db2j6.build-lib/vllm', '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=/tmp/tmp7h5m3uxb.build-temp', '-DVLLM_TARGET_DEVICE=cuda', '-DVLLM_PYTHON_EXECUTABLE=/home/aiscuser/.conda/envs/myenv/bin/python', '-DNVCC_THREADS=16', '-DCMAKE_JOB_POOL_COMPILE:STRING=compile', '-DCMAKE_JOB_POOLS:STRING=compile=16']' returned non-zero exit status 1.
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building editable for vllm
Failed to build vllm
ERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (vllm)

congcongchen123 avatar Aug 21 '24 16:08 congcongchen123

@congcongchen123 I think your issue is related to this PR https://github.com/vllm-project/vllm/pull/7730, would you mind giving it a try to see if it resolves your issue?

mgoin avatar Aug 21 '24 16:08 mgoin

@mgoin , PR https://github.com/vllm-project/vllm/pull/7730 is not working for me unfortunately.

congcongchen123 avatar Aug 21 '24 17:08 congcongchen123

@congcongchen123 what are the contents of /tmp/tmp7h5m3uxb.build-temp/machete_generation.log?

ProExpertProg avatar Aug 21 '24 18:08 ProExpertProg

Sorry the temporary log was deleted automatically. And looks like after I apply PR https://github.com/vllm-project/vllm/pull/7730, the error is different. See

ets->outlines<0.1,>=0.0.43->vllm==0.5.4+cu118) (2024.1)
Requirement already satisfied: six>=1.5 in /home/aiscuser/.local/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets->outlines<0.1,>=0.0.43->vllm==0.5.4+cu118) (1.16.0)
Building wheels for collected packages: vllm
  Building editable for vllm (pyproject.toml) ... error
  error: subprocess-exited-with-error

  Γ— Building editable for vllm (pyproject.toml) did not run successfully.
  β”‚ exit code: 1
  ╰─> [318 lines of output]
      /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
        cpu = _conversion_method_template(device=torch.device("cpu"))
      running editable_wheel
      creating /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info
      writing /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/PKG-INFO
      writing dependency_links to /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/dependency_links.txt
      writing entry points to /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/entry_points.txt
      writing requirements to /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/requires.txt
      writing top-level names to /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/top_level.txt
      writing manifest file '/tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/SOURCES.txt'
      reading manifest file '/tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/SOURCES.txt'
      reading manifest template 'MANIFEST.in'
      adding license file 'LICENSE'
      writing manifest file '/tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm.egg-info/SOURCES.txt'
      creating '/tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm-0.5.4+cu118.dist-info'
      creating /tmp/pip-wheel-25fkij1m/.tmp-5dsix36x/vllm-0.5.4+cu118.dist-info/WHEEL
      running build_py
      running build_ext
      Using MAX_JOBS=256 as the number of jobs.
      Using NVCC_THREADS=16 as the number of nvcc threads.
      -- The CXX compiler identification is GNU 9.4.0
      -- Detecting CXX compiler ABI info
      -- Detecting CXX compiler ABI info - done
      -- Check for working CXX compiler: /usr/bin/c++ - skipped
      -- Detecting CXX compile features
      -- Detecting CXX compile features - done
      -- Build type: RelWithDebInfo
      -- Target device: cuda
      -- Found Python: /home/aiscuser/.conda/envs/myenv/bin/python (found version "3.10.14") found components: Interpreter Development.Module Development.SABIModule
      -- Found python matching: /home/aiscuser/.conda/envs/myenv/bin/python.
      -- Found CUDA: /usr/local/cuda (found version "11.8")
      -- The CUDA compiler identification is NVIDIA 11.8.89
      -- Detecting CUDA compiler ABI info
      -- Detecting CUDA compiler ABI info - done
      -- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc - skipped
      -- Detecting CUDA compile features
      -- Detecting CUDA compile features - done
      -- Found CUDAToolkit: /usr/local/cuda/include (found version "11.8.89")
      -- Performing Test CMAKE_HAVE_LIBC_PTHREAD
      -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
      -- Looking for pthread_create in pthreads
      -- Looking for pthread_create in pthreads - not found
      -- Looking for pthread_create in pthread
      -- Looking for pthread_create in pthread - found
      -- Found Threads: TRUE
      -- Caffe2: CUDA detected: 11.8
      -- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
      -- Caffe2: CUDA toolkit directory: /usr/local/cuda
      -- Caffe2: Header version is: 11.8
      -- /usr/local/cuda/lib64/libnvrtc.so shorthash is 672ee683
      -- USE_CUDNN is set to 0. Compiling without cuDNN support
      -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support
      -- Autodetected CUDA architecture(s):  8.0 8.0 8.0 8.0
      -- Added CUDA NVCC flags for: -gencode;arch=compute_80,code=sm_80
      CMake Warning at /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message):
        static library kineto_LIBRARY-NOTFOUND not found.
      Call Stack (most recent call first):
        /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:120 (append_torchlib_if_found)
        CMakeLists.txt:67 (find_package)


      -- Found Torch: /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/lib/libtorch.so
      -- Enabling core extension.
      -- CUDA supported arches: 7.0;7.5;8.0;8.6;8.9;9.0
      -- CUDA target arches: 80-real
      -- CMake Version: 3.30.2
      -- CUTLASS 3.5.1
      -- CUDART: /usr/local/cuda/lib64/libcudart.so
      -- CUDA Driver: /usr/local/cuda/lib64/stubs/libcuda.so
      -- NVRTC: /usr/local/cuda/lib64/libnvrtc.so
      -- Default Install Location: install
      -- Found Python3: /home/aiscuser/.conda/envs/myenv/bin/python3.10 (found suitable version "3.10.14", minimum required is "3.5") found components: Interpreter
      -- Make cute::tuple be the new standard-layout tuple type
      -- CUDA Compilation Architectures: 70;72;75;80;86;87;89;90
      -- Enable caching of reference results in conv unit tests
      -- Enable rigorous conv problem sizes in conv unit tests
      -- Using NVCC flags: --expt-relaxed-constexpr;-DCUTE_USE_PACKED_TUPLE=1;-DCUTLASS_TEST_LEVEL=0;-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1;-DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1;-DCUTLASS_DEBUG_TRACE_LEVEL=0;-Xcompiler=-Wconversion;-Xcompiler=-fno-strict-aliasing;-lineinfo
      fatal: not a git repository (or any of the parent directories): .git
      -- CUTLASS Revision: Unable to detect, Git returned code 128.
      -- Configuring cublas ...
      -- cuBLAS Disabled.
      -- Configuring cuBLAS ... done.
      -- Machete generation completed successfully.
      -- Machete generated sources: /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4b8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4b8_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4b8_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8b128.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8b128_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u8b128_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4b8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4b8_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u4b8_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8b128.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8b128_impl_part0.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_f16u8b128_impl_part1.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_bf16u4.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_bf16u4b8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_bf16u8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_bf16u8b128.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_f16u4.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_f16u4b8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_f16u8.cu;/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_prepack_f16u8b128.cu
      -- Enabling C extension.
      -- Enabling moe extension.
      -- Configuring done (11.2s)
      -- Generating done (0.1s)
      -- Build files have been written to: /tmp/tmpqtey188_.build-temp
      Using MAX_JOBS=256 as the number of jobs.
      Using NVCC_THREADS=16 as the number of nvcc threads.
      [1/66] Building CUDA object CMakeFiles/_C.dir/csrc/cuda_utils_kernels.cu.o
      [2/66] Building CXX object CMakeFiles/_core_C.dir/csrc/core/torch_bindings.cpp.o
      [3/66] Linking CXX shared module /tmp/tmp1rxbsk_2.build-lib/vllm/_core_C.abi3.so
      [4/66] Building CXX object CMakeFiles/_moe_C.dir/csrc/moe/torch_bindings.cpp.o
      [5/66] Building CXX object CMakeFiles/_C.dir/csrc/torch_bindings.cpp.o
      [6/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/compressed_tensors/int8_quant_kernels.cu.o
      [7/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/squeezellm/quant_cuda_kernel.cu.o
      [8/66] Building CUDA object CMakeFiles/_C.dir/csrc/prepare_inputs/advance_step.cu.o
      [9/66] Building CUDA object CMakeFiles/_C.dir/csrc/moe_align_block_size_kernels.cu.o
      [10/66] Building CUDA object CMakeFiles/_C.dir/csrc/layernorm_kernels.cu.o
      [11/66] Building CUDA object CMakeFiles/_C.dir/csrc/activation_kernels.cu.o
      [12/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/fp8/common.cu.o
      /home/aiscuser/vllm/csrc/quantization/fp8/common.cu:14:1: warning: β€˜host’ attribute directive ignored [-Wattributes]
         14 | C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
            | ^~~~~~~~~~~~
      [13/66] Building CUDA object CMakeFiles/_C.dir/csrc/pos_encoding_kernels.cu.o
      [14/66] Building CUDA object CMakeFiles/_moe_C.dir/csrc/moe/topk_softmax_kernels.cu.o
      [15/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/aqlm/gemm_kernels.cu.o
      [16/66] Building CUDA object CMakeFiles/_C.dir/csrc/cache_kernels.cu.o
      [17/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu.o
      [18/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/awq/gemm_kernels.cu.o
      [19/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/gptq/q_gemm.cu.o
      [20/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu.o
      FAILED: CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu.o
      /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -I/home/aiscuser/vllm/csrc -I/tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include -isystem /home/aiscuser/.conda/envs/myenv/include/python3.10 -isystem /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/include -isystem /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O2 -g -DNDEBUG -std=c++17 "--generate-code=arch=compute_80,code=[sm_80]" -Xcompiler=-fPIC -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -DENABLE_FP8 --threads=16 -D_GLIBCXX_USE_CXX11_ABI=0 -MD -MT CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu.o -MF CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu.o.d -x cu -c /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu -o CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu.o
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(199): error: static assertion failed with "Unsupported Toolkit for SM90 Collective Builder
      "
                detected during:
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType> [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, ClusterShape_MNK=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<6144>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu(36): here

      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cute/atom/copy_traits_sm90_tma.hpp(932): warning #177-D: variable "smem_box_stride" was declared but never referenced
                detected during:
                  instantiation of "auto cute::detail::make_tma_copy_atom<TmaInternalType,CopyOp,GEngine,GLayout,SLayout,VShape,VStride>(CopyOp, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const uint32_t &, const cute::Layout<VShape, VStride> &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, VShape=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, VStride=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 0>, 0>>, cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 1>, 0>, 0>, cute::_0>>, cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 1>, 0>>, cute::tuple<cute::_0, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 1>, 1>, 0>>>>, cute::tuple<cute::_0, cute::_0>>]"
      (1129): here
                  instantiation of "auto cute::detail::make_tma_copy_tiled<TmaInternalType,CopyOp,GEngine,GLayout,SLayout,TShape,TStride,VShape,VStride>(const CopyOp &, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const cute::Layout<TShape, TStride> &, const cute::Layout<VShape, VStride> &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, TShape=cute::_1, TStride=cute::C<0>, VShape=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, VStride=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 0>, 0>>, cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 1>, 0>, 0>, cute::_0>>, cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 1>, 0>>, cute::tuple<cute::_0, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 1>, 1>, 0>>>>, cute::tuple<cute::_0, cute::_0>>]"
      (1266): here
                  instantiation of "auto cute::make_tma_copy(const CopyOp &, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const CTA_Tiler &, const Cluster_Size &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, CTA_Tiler=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, Cluster_Size=cute::C<1>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(455): here
                  instantiation of "auto machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::make_tma_copy_A(machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::ATensor) [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, ClusterShape_MNK=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<6144>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(538): here
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::Params [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, ClusterShape_MNK=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<6144>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp(154): here
                  instantiation of class "cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, std::enable_if_t<std::is_base_of_v, void>>::Params [with ProblemShape_=cute::tuple<int, int, int, int>, CollectiveMainloop_=machete::MacheteCollectiveMma<cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, 32, cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, float, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<6144>, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>, CollectiveEpilogue_=cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<1, 1, 8, false, true>, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<16>>, void, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::bfloat16_t, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm90TmaWarpSpecialized<1, 1, 8, false, true>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, void, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<16>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM75_U16x8_LDSM_T, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM90_U16x8_STSM_T, cute::Copy_Atom<cute::SM90_U32x4_STSM_N, cutlass::half_t>>, TileScheduler_=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::TileScheduler]"
      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h(186): here
                  instantiation of class "cutlass::gemm::device::GemmUniversalAdapter<GemmKernel_, std::enable_if_t<cutlass::gemm::detail::IsCutlass3GemmKernel<GemmKernel_, void>::value, void>> [with GemmKernel_=cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, machete::MacheteCollectiveMma<cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, 32, cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, float, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<6144>, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<1, 1, 8, false, true>, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<16>>, void, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::bfloat16_t, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm90TmaWarpSpecialized<1, 1, 8, false, true>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, void, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::_128, cute::C<16>, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<16>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM75_U16x8_LDSM_T, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM90_U16x8_STSM_T, cute::Copy_Atom<cute::SM90_U32x4_STSM_N, cutlass::half_t>>, machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK::TileScheduler, void>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x16_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu(36): here

      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(199): error: static assertion failed with "Unsupported Toolkit for SM90 Collective Builder
      "
                detected during:
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType> [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::C<32>, cute::C<64>>, ClusterShape_MNK=machete::sch_128x32_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<10240>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x32_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x32_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu(66): here

      2 errors detected in the compilation of "/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part0.cu".
      [21/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu.o
      FAILED: CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu.o
      /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -I/home/aiscuser/vllm/csrc -I/tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include -isystem /home/aiscuser/.conda/envs/myenv/include/python3.10 -isystem /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/include -isystem /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -DONNX_NAMESPACE=onnx_c2 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -O2 -g -DNDEBUG -std=c++17 "--generate-code=arch=compute_80,code=[sm_80]" -Xcompiler=-fPIC -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -DENABLE_FP8 --threads=16 -D_GLIBCXX_USE_CXX11_ABI=0 -MD -MT CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu.o -MF CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu.o.d -x cu -c /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu -o CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu.o
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(199): error: static assertion failed with "Unsupported Toolkit for SM90 Collective Builder
      "
                detected during:
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType> [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::_64, cute::C<64>>, ClusterShape_MNK=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<18432>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu(36): here

      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cute/atom/copy_traits_sm90_tma.hpp(932): warning #177-D: variable "smem_box_stride" was declared but never referenced
                detected during:
                  instantiation of "auto cute::detail::make_tma_copy_atom<TmaInternalType,CopyOp,GEngine,GLayout,SLayout,VShape,VStride>(CopyOp, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const uint32_t &, const cute::Layout<VShape, VStride> &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, VShape=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, VStride=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 0>, 0>>, cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 1>, 0>, 0>, cute::_0>>, cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 1>, 0>>, cute::tuple<cute::_0, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 1>, 1>, 0>>>>, cute::tuple<cute::_0, cute::_0>>]"
      (1129): here
                  instantiation of "auto cute::detail::make_tma_copy_tiled<TmaInternalType,CopyOp,GEngine,GLayout,SLayout,TShape,TStride,VShape,VStride>(const CopyOp &, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const cute::Layout<TShape, TStride> &, const cute::Layout<VShape, VStride> &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, TShape=cute::_1, TStride=cute::C<0>, VShape=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, VStride=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 0>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 0>, 0>>, cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 1>, 0>, 0>, cute::_0>>, cute::tuple<cute::tuple<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 0>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 0>, 1>, 0>, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 2>, 0>, 1>, 0>>, cute::tuple<cute::_0, cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::ScaledBasis<cute::C<1>, 1>, 1>, 1>, 0>>>>, cute::tuple<cute::_0, cute::_0>>]"
      (1266): here
                  instantiation of "auto cute::make_tma_copy(const CopyOp &, const cute::Tensor<GEngine, GLayout> &, const SLayout &, const CTA_Tiler &, const Cluster_Size &) [with TmaInternalType=uint8_t, CopyOp=cute::SM90_TMA_LOAD, GEngine=cute::ViewEngine<cute::subbyte_iterator<const cutlass::uint4b_t>>, GLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<int32_t, int32_t>, int32_t>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::C<8192>, int>, int32_t>>, SLayout=cute::Layout<cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_32, cute::_128, cute::C<1024>>, cute::tuple<cute::C<4096>, cute::C<0>>>, cute::tuple<cute::tuple<cute::_1, cute::_2, cute::C<4>>, cute::tuple<cute::C<0>, cute::_8>>>, cute::tuple<cute::_0, cute::_0>>>, CTA_Tiler=cute::tuple<cute::tuple<cute::tuple<cute::tuple<cute::_4, cute::_8, cute::_4>, cute::tuple<cute::_2, cute::_1>>, cute::tuple<cute::tuple<cute::_2, cute::_2, cute::_2>, cute::tuple<cute::_1, cute::_4>>>, cute::tuple<cute::_1, cute::_1>>, Cluster_Size=cute::C<1>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(455): here
                  instantiation of "auto machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::make_tma_copy_A(machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::ATensor) [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::_64, cute::C<64>>, ClusterShape_MNK=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<18432>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(538): here
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>::Params [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::_64, cute::C<64>>, ClusterShape_MNK=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<18432>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp(154): here
                  instantiation of class "cutlass::gemm::kernel::GemmUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, std::enable_if_t<std::is_base_of_v, void>>::Params [with ProblemShape_=cute::tuple<int, int, int, int>, CollectiveMainloop_=machete::MacheteCollectiveMma<cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, 32, cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, float, cute::tuple<cute::_128, cute::_64, cute::C<64>>, machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<18432>, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>, CollectiveEpilogue_=cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<2, 2, 16, false, true>, cute::tuple<cute::_128, cute::_64, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<32>>, void, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::bfloat16_t, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm90TmaWarpSpecialized<2, 2, 16, false, true>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, void, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::_128, cute::_64, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<32>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM75_U16x8_LDSM_T, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM90_U16x8_STSM_T, cute::Copy_Atom<cute::SM90_U32x4_STSM_N, cutlass::half_t>>, TileScheduler_=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::TileScheduler]"
      /tmp/tmpqtey188_.build-temp/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h(186): here
                  instantiation of class "cutlass::gemm::device::GemmUniversalAdapter<GemmKernel_, std::enable_if_t<cutlass::gemm::detail::IsCutlass3GemmKernel<GemmKernel_, void>::value, void>> [with GemmKernel_=cutlass::gemm::kernel::GemmUniversal<cute::tuple<int, int, int, int>, machete::MacheteCollectiveMma<cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, 32, cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, float, cute::tuple<cute::_128, cute::_64, cute::C<64>>, machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout<18432>, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm90TmaWarpSpecialized<2, 2, 16, false, true>, cute::tuple<cute::_128, cute::_64, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<32>>, void, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::bfloat16_t, cute::tuple<cute::C<1>, int64_t, int64_t>, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm90TmaWarpSpecialized<2, 2, 16, false, true>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, void, float, cutlass::FloatRoundStyle::round_to_nearest>, cute::tuple<cute::_128, cute::_64, cute::C<64>>, cute::tuple<cute::C<128>, cute::C<32>>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM75_U16x8_LDSM_T, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<3, 4, 3>, cute::smem_ptr_flag_bits<16>, cute::Layout<cute::tuple<cute::_64, cute::_8>, cute::tuple<cute::_1, cute::C<64>>>>, cute::SM90_U16x8_STSM_T, cute::Copy_Atom<cute::SM90_U32x4_STSM_N, cutlass::half_t>>, machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK::TileScheduler, void>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x64_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu(36): here

      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mainloop.cuh(199): error: static assertion failed with "Unsupported Toolkit for SM90 Collective Builder
      "
                detected during:
                  instantiation of class "machete::MacheteCollectiveMma<ElementATuple_, GmemLayoutA, AlignmentA, ElementB_, GmemLayoutB, AlignmentB, ElementAccumulator_, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType> [with ElementATuple_=cute::tuple<cutlass::uint4b_t, cutlass::bfloat16_t, cutlass::bfloat16_t>, GmemLayoutA=machete::PrepackedLayoutBTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::layout::ColumnMajor, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::IlvBlkLayoutAuto>, AlignmentA=32, ElementB_=cutlass::bfloat16_t, GmemLayoutB=cutlass::layout::ColumnMajor, AlignmentB=8, ElementAccumulator_=float, TileShape_MNK=cute::tuple<cute::_128, cute::_128, cute::_64>, ClusterShape_MNK=machete::sch_128x128_1x1x1_TmaMI_TmaCoop_streamK::ClusterShape, StageCountType=cutlass::gemm::collective::StageCountAutoCarveout<18432>, KernelScheduleType=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_kernel.cuh(130): here
                  instantiation of class "machete::MacheteKernelTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT, ScaleT, ZeroT, KernelSchedule, ScheduleConfig, with_C, with_scales, with_zeropoints> [with ElementA_=cutlass::bfloat16_t, ElementB_=cutlass::uint4b_t, ElementD_=cutlass::bfloat16_t, AccumulatorT=float, ScaleT=cutlass::bfloat16_t, ZeroT=cutlass::bfloat16_t, KernelSchedule=cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, ScheduleConfig=machete::sch_128x128_1x1x1_TmaMI_TmaCoop_streamK, with_C=false, with_scales=true, with_zeropoints=true]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/../machete_mm_launcher.cuh(30): here
                  instantiation of "at::Tensor machete::run_impl<MacheteKernel>(machete::PyTorchArguments) [with MacheteKernel=machete::MacheteKernelTemplate<cutlass::bfloat16_t, cutlass::uint4b_t, cutlass::bfloat16_t, float, cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, machete::sch_128x128_1x1x1_TmaMI_TmaCoop_streamK, false, true, true>]"
      /home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu(66): here

      2 errors detected in the compilation of "/home/aiscuser/vllm/csrc/quantization/machete/generated/machete_mm_bf16u4_impl_part1.cu".
      [22/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu.o
      [23/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu.o
      [24/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu.o
      [25/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu.o
      [26/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/gptq_marlin/awq_marlin_repack.cu.o
      [27/66] Building CUDA object CMakeFiles/_C.dir/csrc/custom_all_reduce.cu.o
      [28/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/gguf/gguf_kernel.cu.o
      [29/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4.cu.o
      [30/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/machete/generated/machete_mm_bf16u4b8.cu.o
      [31/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu.o
      [32/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/fp8/fp8_marlin.cu.o
      [33/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu.o
      [34/66] Building CUDA object CMakeFiles/_C.dir/csrc/attention/attention_kernels.cu.o
      [35/66] Building CUDA object CMakeFiles/_C.dir/csrc/quantization/gptq_marlin/gptq_marlin.cu.o
      ninja: build stopped: subcommand failed.
      Traceback (most recent call last):
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 135, in run
          self._create_wheel_file(bdist_wheel)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 338, in _create_wheel_file
          files, mapping = self._run_build_commands(dist_name, unpacked, lib, tmp)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 261, in _run_build_commands
          self._run_build_subcommands()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 288, in _run_build_subcommands
          self.run_command(name)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 316, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
          _build_ext.run(self)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 359, in run
          self.build_extensions()
        File "<string>", line 238, in build_extensions
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/subprocess.py", line 369, in check_call
          raise CalledProcessError(retcode, cmd)
      subprocess.CalledProcessError: Command '['cmake', '--build', '.', '-j=16', '--target=_core_C', '--target=_moe_C', '--target=_C']' returned non-zero exit status 1.
      /tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py:983: _DebuggingTips: Problem in editable installation.
      !!

              ********************************************************************************
              An error happened while installing `vllm` in editable mode.

              The following steps are recommended to help debug this problem:

              - Try to install the project normally, without using the editable mode.
                Does the error still persist?
                (If it does, try fixing the problem before attempting the editable mode).
              - If you are using binary extensions, make sure you have all OS-level
                dependencies installed (e.g. compilers, toolchains, binary libraries, ...).
              - Try the latest version of setuptools (maybe the error was already fixed).
              - If you (or your project dependencies) are using any setuptools extension
                or customization, make sure they support the editable mode.

              After following the steps above, if the problem still persists and
              you think this is related to how setuptools handles editable installations,
              please submit a reproducible example
              (see https://stackoverflow.com/help/minimal-reproducible-example) to:

                  https://github.com/pypa/setuptools/issues

              See https://setuptools.pypa.io/en/latest/userguide/development_mode.html for details.
              ********************************************************************************

      !!
        cmd_obj.run()
      Traceback (most recent call last):
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
          main()
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 273, in build_editable
          return hook(wheel_directory, config_settings, metadata_directory)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 458, in build_editable
          return self._build_with_temp_dir(
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 402, in _build_with_temp_dir
          self.run_setup()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 318, in run_setup
          exec(code, locals())
        File "<string>", line 474, in <module>
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/__init__.py", line 111, in setup
          return distutils.core.setup(**attrs)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 184, in setup
          return run_commands(dist)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 200, in run_commands
          dist.run_commands()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 964, in run_commands
          self.run_command(cmd)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 135, in run
          self._create_wheel_file(bdist_wheel)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 338, in _create_wheel_file
          files, mapping = self._run_build_commands(dist_name, unpacked, lib, tmp)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 261, in _run_build_commands
          self._run_build_subcommands()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/editable_wheel.py", line 288, in _run_build_subcommands
          self.run_command(name)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 316, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 948, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 983, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
          _build_ext.run(self)
        File "/tmp/pip-build-env-7gd0j8pl/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 359, in run
          self.build_extensions()
        File "<string>", line 238, in build_extensions
        File "/home/aiscuser/.conda/envs/myenv/lib/python3.10/subprocess.py", line 369, in check_call
          raise CalledProcessError(retcode, cmd)
      subprocess.CalledProcessError: Command '['cmake', '--build', '.', '-j=16', '--target=_core_C', '--target=_moe_C', '--target=_C']' returned non-zero exit status 1.
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building editable for vllm
Failed to build vllm
ERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (vllm)

congcongchen123 avatar Aug 21 '24 20:08 congcongchen123

Not sure why this PR triggers build of MacheteKernel for me locally. Looks like a bug. Can we revert this PR since it affects other users. cc @simon-mo

congcongchen123 avatar Aug 21 '24 20:08 congcongchen123

It looks like you are building with CUDA 11.8, which definitely makes sense as an issue for this kernel. I believe we should just not build the kernel in this case

      -- Found CUDA: /usr/local/cuda (found version "11.8")

mgoin avatar Aug 21 '24 20:08 mgoin

@congcongchen123 apologies for this PR causing build issues for you, do you mind trying: https://github.com/vllm-project/vllm/pull/7757 sorry don't have access to a machine with CUDA 11.8 right now (and don't want to mess up a shared machine)

LucasWilkinson avatar Aug 21 '24 21:08 LucasWilkinson

@LucasWilkinson , https://github.com/vllm-project/vllm/pull/7757 doesn't work as well. With the patch, now the build is successfully, but I failed to run vLLM server now, see error below:

(myenv) aiscuser@node-0:~/vllm/benchmarks$ python benchmark_latency.py
WARNING 08-21 15:48:59 _custom_ops.py:17] Failed to import from vllm._C with ImportError('/home/aiscuser/vllm/vllm/_C.abi3.so: undefined symbol: _ZN7machete14GemmDispatcherIN7cutlass6half_tEhS2_fS2_S2_E19supported_schedulesEv')
Namespace(model='facebook/opt-125m', speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, tokenizer=None, quantization=None, tensor_parallel_size=1, input_len=32, output_len=128, batch_size=8, n=1, use_beam_search=False, num_iters_warmup=10, num_iters=30, trust_remote_code=False, max_model_len=None, dtype='auto', enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, profile=False, profile_result_dir=None, device='auto', block_size=16, enable_chunked_prefill=False, enable_prefix_caching=False, use_v2_block_manager=False, ray_workers_use_nsight=False, download_dir=None, output_json=None, gpu_memory_utilization=0.9, load_format='auto', distributed_executor_backend=None, otlp_traces_endpoint=None)
INFO 08-21 15:49:03 llm_engine.py:184] Initializing an LLM engine (v0.5.4) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=facebook/opt-125m, use_v2_block_manager=False, enable_prefix_caching=False)
/home/aiscuser/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
INFO 08-21 15:49:04 model_runner.py:878] Starting to load model facebook/opt-125m...
INFO 08-21 15:49:04 weight_utils.py:236] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
/home/aiscuser/vllm/vllm/model_executor/model_loader/weight_utils.py:416: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.88it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  7.86it/s]

INFO 08-21 15:49:05 model_runner.py:889] Loading model weights took 0.2389 GB
INFO 08-21 15:49:06 gpu_executor.py:121] # GPU blocks: 128013, # CPU blocks: 7281
INFO 08-21 15:49:08 model_runner.py:1180] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-21 15:49:08 model_runner.py:1184] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
ERROR 08-21 15:49:08 _custom_ops.py:36] Error in calling custom op reshape_and_cache_flash: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache_flash'
ERROR 08-21 15:49:08 _custom_ops.py:36] Possibly you have built or installed an obsolete version of vllm.
ERROR 08-21 15:49:08 _custom_ops.py:36] Please try a clean build and install of vllm,or remove old built files such as vllm/*cpython*.so and build/ .
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/aiscuser/vllm/benchmarks/benchmark_latency.py", line 285, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/aiscuser/vllm/benchmarks/benchmark_latency.py", line 24, in main
[rank0]:     llm = LLM(
[rank0]:   File "/home/aiscuser/vllm/vllm/entrypoints/llm.py", line 175, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:   File "/home/aiscuser/vllm/vllm/engine/llm_engine.py", line 473, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/home/aiscuser/vllm/vllm/engine/llm_engine.py", line 284, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/home/aiscuser/vllm/vllm/engine/llm_engine.py", line 403, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/aiscuser/vllm/vllm/executor/gpu_executor.py", line 124, in initialize_cache
[rank0]:     self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/aiscuser/vllm/vllm/worker/worker.py", line 234, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/home/aiscuser/vllm/vllm/worker/worker.py", line 250, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/worker/model_runner.py", line 1274, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "/home/aiscuser/vllm/vllm/worker/model_runner.py", line 1512, in capture
[rank0]:     self.model(
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/model_executor/models/opt.py", line 326, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/model_executor/models/opt.py", line 291, in forward
[rank0]:     return self.decoder(input_ids,
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/model_executor/models/opt.py", line 260, in forward
[rank0]:     hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/model_executor/models/opt.py", line 162, in forward
[rank0]:     hidden_states = self.self_attn(hidden_states=hidden_states,
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/model_executor/models/opt.py", line 105, in forward
[rank0]:     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/attention/layer.py", line 98, in forward
[rank0]:     return self.impl.forward(query,
[rank0]:   File "/home/aiscuser/vllm/vllm/attention/backends/flash_attn.py", line 637, in forward
[rank0]:     ops.reshape_and_cache_flash(
[rank0]:   File "/home/aiscuser/vllm/vllm/_custom_ops.py", line 37, in wrapper
[rank0]:     raise e
[rank0]:   File "/home/aiscuser/vllm/vllm/_custom_ops.py", line 28, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/aiscuser/vllm/vllm/_custom_ops.py", line 522, in reshape_and_cache_flash
[rank0]:     torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
[rank0]:   File "/home/aiscuser/.local/lib/python3.10/site-packages/torch/_ops.py", line 1170, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache_flash'
(myenv) aiscuser@node-0:~/vllm/benchmarks$

congcongchen123 avatar Aug 21 '24 22:08 congcongchen123

@congcongchen123 Apologies that didn't work, I was being a bit overly optimistic. I spun a GCP L4 instance with cuda 11.8 on it and updated https://github.com/vllm-project/vllm/pull/7757 it should be working now, please let me know if you have any issues though

LucasWilkinson avatar Aug 22 '24 05:08 LucasWilkinson