mlc-llm icon indicating copy to clipboard operation
mlc-llm copied to clipboard

[Bug] [OpenCL] Error on OpenCL when head_dim not divisible by 32

Open sbwww opened this issue 1 year ago • 2 comments
trafficstars

🐛 Bug

The issue with paged kvcache under a specific head_dim has been fixed for the cuda target, but there are still problems with the opencl target after after #1889

I'm trying to compile a compressed llama-like model with hidden_size=1536 and head_dim=48 onto opencl. After prefill, the first decoding resulted in an error Cannot sample from the given probability distribution due to unknown reason. The logits after softmax were all around 1e-37. If I use the exact same parameter matrix but change to head_dim=64, there are no issues. Additionally, head_dim=48 works fine on cuda.

To Reproduce

Expected behavior

It is supposed to be fixed after #1889

Environment

  • Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): Android
  • Operating system (e.g. Ubuntu/Windows/MacOS/...):
  • Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...): qualcomm 8gen3
  • How you installed MLC-LLM (conda, source): source
  • How you installed TVM-Unity (pip, source): source
  • Python version (e.g. 3.10): 3.10
  • GPU driver version (if applicable):
  • CUDA/cuDNN version (if applicable):
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models): fbfa92658568428b27c6ee5762ab7fe2f7c0b415
  • Any other relevant information:

Additional context

sbwww avatar Mar 20 '24 08:03 sbwww

Hi @sbwww Thank you for reporting this! Can you please provide more information about the model that you are trying to run?

Kartik14 avatar Mar 25 '24 18:03 Kartik14

Sorry for the late response. Would you please randomly init a Llama MHA model with hidden_size=1536, num_attention_heads=32 (-> head_dim=48)?

I tried to set num_attention_heads=24 (-> head_dim=64) and it works fine. So, I guess head_dim is the primary difference.

sbwww avatar Apr 02 '24 02:04 sbwww