mlx-audio
mlx-audio copied to clipboard
Fix voxtral segments
closes #215
This may be unrelated, but I still see
❯ uvx --from 'git+https://github.com/Blaizzy/mlx-audio@pc/fix-voxtral-segments' python -m mlx_audio.stt.generate \
--model "mlx-community/Voxtral-Mini-3B-2507-bf16" \
--audio ~/Downloads/audio.mp3 \
--output ~/Downloads/test-transcript.json \
∙ --format json
Fetching 8 files: 100%|██████████████████████████████████████████| 8/8 [00:00<00:00, 101680.10it/s]
==========
Audio path: /Users/drewbitt/Downloads/audio.mp3
Output path: /Users/drewbitt/Downloads/test-transcript.json
Format: json
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/generate.py", line 258, in <module>
main()
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/generate.py", line 247, in main
generate(
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/generate.py", line 219, in generate
segments = model.generate(
^^^^^^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/models/voxtral/voxtral.py", line 405, in generate
for token, _ in self.stream_generate(
^^^^^^^^^^^^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/models/voxtral/voxtral.py", line 342, in stream_generate
for n, (token, logprobs) in enumerate(
^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_lm/generate.py", line 422, in generate_step
_model_call(
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_lm/generate.py", line 380, in _model_call
return model(
^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_audio/stt/models/voxtral/voxtral.py", line 191, in __call__
out = self.model(inputs, cache=cache, input_embeddings=input_embeddings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_lm/models/llama.py", line 193, in __call__
h = layer(h, mask, cache=cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_lm/models/llama.py", line 142, in __call__
r = self.self_attn(self.input_layernorm(x), mask, cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/drewbitt/.cache/uv/archive-v0/swwNn0ZkRE6fXlFZQdcb1/lib/python3.12/site-packages/mlx_lm/models/llama.py", line 83, in __call__
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: [reshape] Cannot infer the shape of an empty array
@drewbitt This is related with #271