vllm icon indicating copy to clipboard operation
vllm copied to clipboard

AQLM CUDA support

Open jaemzfleming opened this issue 2 years ago • 4 comments

SUMMARY: Supports AQLM compressed inference, see

https://github.com/Vahe1994/AQLM https://arxiv.org/pdf/2401.06118.pdf

Optimized supported formats are 1x16 and 2x8. Tensor parallelism is supported. Only CUDA kernels are provided. Formats other than 1x16 and 2x8 will run but at lower performance.

Also adds underlying support for all quantization schemes that require a separate fixed size codebook per layer.

The only trickiness was that QKVParallelLinear concatenates the Q, K, and V tensors, whose sizes and offsets are determined by by the number of heads, kv heads, and tensor parallelism. The corresponding codebooks all need to be present and concatenated for apply_weights. To support this we add the is_metadata attribute, which if present, will concatenate the Q,K, and V tensors along the zeroth dimension, just using the size of the loaded tensor.

Here's a benchmark server graph comparing 2bit 1x16 and 2x8 compared to FP16, plotting mean TPOT vs queries per second. At low query rates, you can see that the 1x16 is 1.36x faster and the 2x8 is 2.12x faster than FP16. By 15 queries per second, the 1x16 is 1.56x slower and the 2x8 is 1.16 slower. So either format is a good choice if memory is limited, especially if are serving low QPS. But 2x8 is best if you can afford the slightly lower accuracy.

aqlm_benchmark

Tested on several models:

  • ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf
  • ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf
  • ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf
  • ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf
  • BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf

Including with single or multiple GPUS and associated tensor parallelism.

jaemzfleming avatar Mar 08 '24 20:03 jaemzfleming

Great work, @jaemzfleming. It seems the kernels are too inefficient - it takes 10 minutes to load a 1x16 70b on a 3090, and ~4 minutes for 2x8. Have you modified the original kernels at all? If not, I'll take a look and see if any improvements are possible.

AlpinDale avatar Mar 09 '24 23:03 AlpinDale

@AlpinDale we're going to fold in the new dequant kernels so you aren't stuck with gevm all the time during prefill. There does seem to be really long model loading time as you said that I'm not sure what could be the issue, FWIW it is much slower than loading with Transformers. The kernels haven't been modified much from the original ones, but we can dig in further after efficient gemm is added

mgoin avatar Mar 11 '24 15:03 mgoin

Looks good from my testing @jaemzfleming, just some formatting cleanup. Load time is now only a few seconds at most - the issue was the initial model inference for kv cache sizing. I still have CUDA error: out of memory for running without enforce_eager=True, but we can look into that now/shortly after landing.

@AlpinDale please try again now that the dequant gemm is in. Loading and prefill time should be at least an order of magnitude faster now.

I added optimized dequantization kernels that are another factor of 6 and 9 (depending on format) faster than the "stock" one, so load times will now be that much faster, should be no problem now.

jaemzfleming avatar Mar 21 '24 18:03 jaemzfleming

@jaemzfleming can you make yapf pass?

robertgshaw2-redhat avatar Mar 24 '24 17:03 robertgshaw2-redhat

What is missing for this PR to go through? :)

remiconnesson avatar Apr 01 '24 21:04 remiconnesson

@remiconnesson it looks like @mgoin already approved the changes. It just needs to be reviewed before it can be automatically merged.

jacobwarren avatar Apr 01 '24 22:04 jacobwarren

Great :) can't wait to try it !

remiconnesson avatar Apr 01 '24 22:04 remiconnesson

Not lying, I have been checking in on this PR periodically for a week now to see if it has merged :D Looking forward to trying 2bit Mixtral on my 3090. thanks so much :)

andysalerno avatar Apr 01 '24 23:04 andysalerno

@andysalerno haha same 😁😁😁

remiconnesson avatar Apr 01 '24 23:04 remiconnesson

Tried running this on my machine. It installed and compiled ok, but when I tried running the server, I got this CUDA symbol error. I tried deleting my Python environment, creating a new one from scratch, and reinstalling everything for vLLM, but that didn't work. It did work when I ran it inside of the vLLM container, not sure why. Here's the stack trace:

(ve) alyssa@alyssa-ThinkPad-X1-Extreme-Gen-3:~/lm_test/vllm$ python3 -m fastchat.serve.vllm_worker --model-path  alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16 --quantization aqlm
WARNING 04-03 10:12:48 config.py:213] aqlm quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 04-03 10:12:48 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16', tokenizer='alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=aqlm, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
2024-04-03 10:12:51 | ERROR | stderr | Traceback (most recent call last):
2024-04-03 10:12:51 | ERROR | stderr |   File "<frozen runpy>", line 198, in _run_module_as_main
2024-04-03 10:12:51 | ERROR | stderr |   File "<frozen runpy>", line 88, in _run_code
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/ve/lib/python3.11/site-packages/fastchat/serve/vllm_worker.py", line 290, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     engine = AsyncLLMEngine.from_engine_args(engine_args)
2024-04-03 10:12:51 | ERROR | stderr |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 348, in from_engine_args
2024-04-03 10:12:51 | ERROR | stderr |     engine = cls(
2024-04-03 10:12:51 | ERROR | stderr |              ^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 311, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self.engine = self._init_engine(*args, **kwargs)
2024-04-03 10:12:51 | ERROR | stderr |                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 422, in _init_engine
2024-04-03 10:12:51 | ERROR | stderr |     return engine_class(*args, **kwargs)
2024-04-03 10:12:51 | ERROR | stderr |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/llm_engine.py", line 110, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self.model_executor = executor_class(model_config, cache_config,
2024-04-03 10:12:51 | ERROR | stderr |                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/executor/gpu_executor.py", line 37, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self._init_worker()
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/executor/gpu_executor.py", line 45, in _init_worker
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.worker.worker import Worker
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/worker/worker.py", line 21, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.worker.model_runner import ModelRunner
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/worker/model_runner.py", line 17, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.model_loader import get_model
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/model_loader.py", line 10, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.models.llava import LlavaForConditionalGeneration
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/models/llava.py", line 11, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.layers.activation import get_act_fn
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/layers/activation.py", line 9, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm._C import ops
2024-04-03 10:12:51 | ERROR | stderr | ImportError: /home/alyssa/lm_test/vllm/vllm/_[C.cpython-311-x86_64-linux-gnu.so](http://c.cpython-311-x86_64-linux-gnu.so/): undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb

rationalism avatar Apr 05 '24 19:04 rationalism

Tried running this on my machine. It installed and compiled ok, but when I tried running the server, I got this CUDA symbol error. I tried deleting my Python environment, creating a new one from scratch, and reinstalling everything for vLLM, but that didn't work. It did work when I ran it inside of the vLLM container, not sure why. Here's the stack trace:

(ve) alyssa@alyssa-ThinkPad-X1-Extreme-Gen-3:~/lm_test/vllm$ python3 -m fastchat.serve.vllm_worker --model-path  alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16 --quantization aqlm
WARNING 04-03 10:12:48 config.py:213] aqlm quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 04-03 10:12:48 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16', tokenizer='alpindale/Mistral-7B-Instruct-v0.2-AQLM-2Bit-1x16', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=aqlm, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
2024-04-03 10:12:51 | ERROR | stderr | Traceback (most recent call last):
2024-04-03 10:12:51 | ERROR | stderr |   File "<frozen runpy>", line 198, in _run_module_as_main
2024-04-03 10:12:51 | ERROR | stderr |   File "<frozen runpy>", line 88, in _run_code
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/ve/lib/python3.11/site-packages/fastchat/serve/vllm_worker.py", line 290, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     engine = AsyncLLMEngine.from_engine_args(engine_args)
2024-04-03 10:12:51 | ERROR | stderr |              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 348, in from_engine_args
2024-04-03 10:12:51 | ERROR | stderr |     engine = cls(
2024-04-03 10:12:51 | ERROR | stderr |              ^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 311, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self.engine = self._init_engine(*args, **kwargs)
2024-04-03 10:12:51 | ERROR | stderr |                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/async_llm_engine.py", line 422, in _init_engine
2024-04-03 10:12:51 | ERROR | stderr |     return engine_class(*args, **kwargs)
2024-04-03 10:12:51 | ERROR | stderr |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/engine/llm_engine.py", line 110, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self.model_executor = executor_class(model_config, cache_config,
2024-04-03 10:12:51 | ERROR | stderr |                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/executor/gpu_executor.py", line 37, in __init__
2024-04-03 10:12:51 | ERROR | stderr |     self._init_worker()
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/executor/gpu_executor.py", line 45, in _init_worker
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.worker.worker import Worker
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/worker/worker.py", line 21, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.worker.model_runner import ModelRunner
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/worker/model_runner.py", line 17, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.model_loader import get_model
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/model_loader.py", line 10, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.models.llava import LlavaForConditionalGeneration
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/models/llava.py", line 11, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm.model_executor.layers.activation import get_act_fn
2024-04-03 10:12:51 | ERROR | stderr |   File "/home/alyssa/lm_test/vllm/vllm/model_executor/layers/activation.py", line 9, in <module>
2024-04-03 10:12:51 | ERROR | stderr |     from vllm._C import ops
2024-04-03 10:12:51 | ERROR | stderr | ImportError: /home/alyssa/lm_test/vllm/vllm/_[C.cpython-311-x86_64-linux-gnu.so](http://c.cpython-311-x86_64-linux-gnu.so/): undefined symbol: _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb

@rationalism

That's weird their seems to be references to llava and an unknown symbol at the bottom of the traceback

-03 10:12:51 | ERROR | stderr | from vllm.model_executor.models.llava import LlavaForConditionalGeneration 2024-04-03 10:12:51 | ERROR | stderr | File "/home/alyssa/lm_test/vllm/vllm/model_executor/models/llava.py", line 11, in <module> 2024-04-03 10:12:51 | ERROR | stderr | from vllm.model_executor.layers.activation import get_act_fn 2024-04-03 10:12:51 | ERROR | stderr | File "/home/alyssa/lm_test/vllm/vllm/model_executor/layers/activation.py", line 9, in <module> 2024-04-03 10:12:51 | ERROR | stderr | from vllm._C import ops 2024-04-03 10:12:51 | ERROR | stderr | ImportError: /home/alyssa/lm_test/vllm/vllm/_[C.cpython-311-x86_64-linux-gnu.so](http://c.cpython-311-x86_64-linux-gnu.so/): undefined symbol: 

Could it be an interaction with another unfinished unit of work or something?

remiconnesson avatar Apr 07 '24 11:04 remiconnesson

This PR has been freshly merged with main as of today (April 8th) and I tested it with a Mixtral model

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_id = "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf"
model = LLM(model_id, enforce_eager=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampling_params = SamplingParams(max_tokens=300, temperature=0.8, top_p=0.95)

messages = [
    {"role": "user", "content": "What is synthetic data in machine learning?"},
]
formatted_prompt =  tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = model.generate(formatted_prompt, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)

Output:

Synthetic data is data that is artificially generated rather than being collected from real-world sources. In machine learning, synthetic data can be used to augment existing datasets, or to create entirely new datasets, in order to increase the size and diversity of the data available for training models.

Synthetic data can be generated using a variety of techniques, such as:

* Generative models, such as generative adversarial networks (GANs), which can learn to generate new data that is similar to the training data.
* Data augmentation, which involves applying transformations to the existing data, such as rotating images or changing the brightness or contrast, in order to create new data.
* Simulation, which involves creating a model of a system or process and using it to generate data that reflects the behavior of the system.

Synthetic data can be useful in situations where it is difficult or expensive to collect real-world data, or where the data is biased or incomplete. However, it is important to be careful when using synthetic data, as it may not accurately reflect the complexities and variations of the real world. It is also important to validate the synthetic data to ensure that it is representative of the desired population and that it does not introduce any biases or errors.

Confirmed that "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf" loads in just 12.2 GB of memory, so we can see the memory benefits:

INFO 04-08 19:33:44 model_runner.py:104] Loading model weights took 12.2120 GB
INFO 04-08 19:34:12 gpu_executor.py:99] # GPU blocks: 13244, # CPU blocks: 2048

mgoin avatar Apr 08 '24 19:04 mgoin

Hello trying to run it but there seems to be an issue with

File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 63, in get_lock
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
TypeError: BaseFileLock.__init__() got an unexpected keyword argument 'mode'

https://github.com/neuralmagic/nm-vllm/blob/22f7faeee16f63548b33ad6ebcc78e256de93524/vllm/model_executor/weight_utils.py#L62-L64

    # mode 0o666 is required for the filelock to be shared across users
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
                             mode=0o666)

Removing mode=0o666 seems to clear the problem


Traceback

root@86c9ebba321d:/workspace/nm-vllm/examples# python aqlm_example.py 
config.json: 100%|███████████████████████████████████████████████████████████████████| 968/968 [00:00<00:00, 9.93MB/s]
WARNING 04-14 02:07:14 config.py:222] aqlm quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 04-14 02:07:14 llm_engine.py:81] Initializing an LLM engine (v0.4.0.post1) with config: model='ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf', speculative_config=None, tokenizer='ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=aqlm, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, seed=0)
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████| 776/776 [00:00<00:00, 7.82MB/s]
tokenizer.model: 100%|█████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 3.24MB/s]
tokenizer.json: 100%|████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 9.20MB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████| 414/414 [00:00<00:00, 4.37MB/s]
INFO 04-14 02:07:16 pynccl.py:58] Loading nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 04-14 02:07:16 selector.py:77] Cannot use FlashAttention backend because the flash_attn package is not found. Please install it for better performance.
INFO 04-14 02:07:16 selector.py:33] Using XFormers backend.
INFO 04-14 02:07:18 weight_utils.py:194] Using model weights format ['*.safetensors']
Exception ignored in: <function BaseFileLock.__del__ at 0x7efd01ce3f40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/filelock/_api.py", line 240, in __del__
    self.release(force=True)
  File "/usr/local/lib/python3.10/dist-packages/filelock/_api.py", line 201, in release
    with self._thread_lock:
AttributeError: 'UnixFileLock' object has no attribute '_thread_lock'
Traceback (most recent call last):
  File "/workspace/nm-vllm/examples/aqlm_example.py", line 46, in <module>
    main()
  File "/workspace/nm-vllm/examples/aqlm_example.py", line 36, in main
    model = LLM(args.model if args.model is not None else models[args.choice],
  File "/workspace/nm-vllm/vllm/entrypoints/llm.py", line 112, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/workspace/nm-vllm/vllm/engine/llm_engine.py", line 231, in from_engine_args
    engine = cls(
  File "/workspace/nm-vllm/vllm/engine/llm_engine.py", line 119, in __init__
    self.model_executor = executor_class(
  File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 41, in __init__
    self._init_worker()
  File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 67, in _init_worker
    self.driver_worker.load_model()
  File "/workspace/nm-vllm/vllm/worker/worker.py", line 108, in load_model
    self.model_runner.load_model()
  File "/workspace/nm-vllm/vllm/worker/model_runner.py", line 155, in load_model
    self.model = get_model(
  File "/workspace/nm-vllm/vllm/model_executor/model_loader.py", line 101, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 393, in load_weights
    for name, loaded_weight in hf_model_weights_iterator(
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 241, in hf_model_weights_iterator
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 197, in prepare_hf_model_weights
    with get_lock(model_name_or_path, cache_dir):
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 63, in get_lock
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
TypeError: BaseFileLock.__init__() got an unexpected keyword argument 'mode'

remiconnesson avatar Apr 14 '24 02:04 remiconnesson

Hey @remiconnesson that filelock issue seems to be unrelated to this PR and addressed on main. Please try the updated version of this branch or main

mgoin avatar Apr 16 '24 14:04 mgoin

sorry i should have clarified the comment about namespace

the exposed torch binding should be not be in cpp namespace https://github.com/vllm-project/vllm/blob/705578ae14b648782a8a321dd0903c163bd77375/csrc/quantization/awq/gemm_kernels.cu#L389-L394

all other helpers should include in namespace vllm ... https://github.com/vllm-project/vllm/blob/705578ae14b648782a8a321dd0903c163bd77375/csrc/quantization/awq/gemm_kernels.cu#L19-L20

simon-mo avatar Apr 18 '24 18:04 simon-mo

@simon-mo thanks for the clarification, I was thinking about the wrong namespace! I got rid of the cpp file and wrapped everything except the two external functions in vllm::aqlm::

mgoin avatar Apr 18 '24 19:04 mgoin

Thanks for doing all this work @mgoin, much appreciated.

jaemzfleming avatar Apr 22 '24 19:04 jaemzfleming