migrate qwen3omni to MLX
Add Support for Qwen/Qwen3-Omni-30B-A3B-Instruct
This PR adds support for the Qwen3-Omni-30B-A3B-Instruct model, enabling multi-modal understanding and audio generation capabilities. The model supports text, image, and audio inputs, and can generate both text responses and audio output.
Weight Conversion
Convert HuggingFace weights to MLX format with quantization:
python -m mlx_vlm convert \
--hf-path Qwen/Qwen3-Omni-30B-A3B-Instruct \
--mlx-path ./mlx_qwen3_omni_4bit \
--quantize \
--q-bits 4 \
--q-group-size 64
Usage Example
See the demo script for an example demonstrating audio input and audio generation.
Performance Notes
During audio generation inference (without visual inputs), memory usage is approximately 20-30GB. Testing was performed on an Apple M3 Max with 36GB. When visual inputs are included, inference speed decreases and memory usage increases accordingly.
~~### Known Issues~~
~~Long audio generation may produce low-quality output; short audio works fine :)~~
Update: skipping quantization for code_predictor fixes long audio quality degradation issues.
Hey @hellopahe
Awesome contribution!
Cannot wait for the merge, ran a test on the hellopahe:qwenomni-dev branch. Ran into some issues:
mlx-vlm/mlx_vlm/models/qwen3_omni_moe/vision.py", line 158, in __call__
output = output.reshape(seq_length, -1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: [reshape] Cannot reshape array of size 22090752 into shape (16184,1364).
This is a possible fix. Please feel free to incorporate if it helps:
--- a/mlx_vlm/models/qwen3_omni_moe/vision.py
+++ b/mlx_vlm/models/qwen3_omni_moe/vision.py
@@ -144,8 +144,10 @@ class Attention(nn.Module):
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
- splits = [mx.split(tensor, lengths.tolist(), axis=2) for tensor in (q, k, v)]
+ splits = [
+ mx.split(tensor, cu_seqlens[1:-1].tolist(), axis=2) for tensor in (q, k, v)
+ ]
+
attn_outputs = []
for q, k, v in zip(*splits):
output = mx.fast.scaled_dot_product_attention(
@@ -153,9 +155,8 @@ class Attention(nn.Module):
)
attn_outputs.append(output)
- output = mx.concat(attn_outputs, axis=2)
- output = output.transpose(0, 2, 1, 3)
- output = output.reshape(seq_length, -1)
+ output = mx.concatenate(attn_outputs, axis=2)
+ output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
return self.proj(output)
Thanks @ziya32 !
It's gonna take a bit of time because I making omni a first class citizen in mlx-vlm.