mlx-vlm icon indicating copy to clipboard operation
mlx-vlm copied to clipboard

KeyError: 'image_token_index' when training with LoRA on Qwen2.5-VL-3B-Instruct-bf16

Open mathav95raj opened this issue 9 months ago • 6 comments

Description

When attempting to fine-tune the Qwen2.5-VL-3B-Instruct-bf16 model using LoRA, I'm encountering a KeyError for 'image_token_index' during training.

Environment

  • mlx-vlm version: 0.1.26
  • Model: mlx-community/Qwen2.5-VL-3B-Instruct-bf16
  • OS: macOS

Dataset Structure

The dataset is in HuggingFace format with the following structure:

  • Each split (train.json, validation.json, test.json) contains entries with:
    • messages: A list of message objects in the chat format
    • images: A list of image paths relative to the dataset directory

Example of a dataset entry:

{
  "messages": [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What's in this image? <image>"},
    {"role": "assistant", "content": "The image shows a cat sitting on a windowsill."}
  ],
  "images": ["images/cat.jpg"]
}

Code Modifications

To accommodate loading the dataset from a local directory, I modified the load_dataset call in lora.py to:

dataset = load_dataset(
    'json',
    data_files={
        'train': f'{args.dataset}/train.json',
        'validation': f'{args.dataset}/validation.json',
        'test': f'{args.dataset}/test.json'
    }
)

This allows loading the dataset from local JSON files instead of using a dataset from the Hugging Face Hub.

Steps to Reproduce

Created a dataset in HF format with train.json, validation.json, and test.json

Ran the following command:

python -m mlx_vlm.lora \
  --model-path "mlx-community/Qwen2.5-VL-3B-Instruct-bf16" \
  --dataset "/path/to/hf_dataset" \
  --output-path "/path/to/output" \
  --batch-size 4 \
  --epochs 3 \
  --learning-rate 1e-4 \
  --lora-rank 16 \
  --lora-alpha 32

Error Message

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/path/to/mlx-vlm-0.1.26/mlx_vlm/lora.py", line 203, in <module>
    main(args)
  File "/path/to/mlx-vlm-0.1.26/mlx_vlm/lora.py", line 118, in main
    dataset[i * args.batch_size : (i + 1) * args.batch_size]
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/mlx-vlm-0.1.26/mlx_vlm/trainer/trainer.py", line 89, in __getitem__
    image_token_index = self.config["image_token_index"]
                        ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
KeyError: 'image_token_index'

The error occurs when the trainer tries to access the 'image_token_index' key in the model config, but this key doesn't exist in the Qwen2.5-VL-3B-Instruct-bf16 model configuration.

It seems that the trainer.py code expects this key to be present in the model config, but different vision-language models might use different keys or approaches for handling image tokens.

mathav95raj avatar May 16 '25 14:05 mathav95raj

Hey @mathav95raj

Could you change the image token from to the one Qwen 2.5 VL uses?

You can find examples here: https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct

Let me know if that fixes the issue. Otherwise you could try our Dev branco for trainer v2.0 (here #261)

Blaizzy avatar May 16 '25 21:05 Blaizzy

Hi there,

I'm getting the same error trying to fine-tune with 0.1.27 on Mac M4 using both Qwen2-VL-2B-Instruct-4bit and mlx-community/Qwen2-VL-2B-Instruct-bf16. I also tried meta-llama/Llama-3.2-11B-Vision-Instruct, and am getting a different error there (see below), but I'm not sure that's related.

I'm using the following data format for the training and validation data sets

{
    "images": img_path,
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "type": "image",
                    "image": f"file://{img_path}"
                }
            ]
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": "..."
                }
            ]
        }
    ]
}

Llama 3.2 error:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/lora.py", line 194, in <module>
    main(args)
    ~~~~^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/lora.py", line 108, in main
    loss = trainer.train_step(
        dataset[i * args.batch_size : (i + 1) * args.batch_size]
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/trainer/trainer.py", line 265, in train_step
    loss, grads = loss_and_grad_fn(self.model, batch)
                  ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
                  ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx/nn/utils.py", line 29, in inner_fn
    return fn(*args, **kwargs)
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/trainer/trainer.py", line 230, in loss_fn
    outputs = model(input_ids, pixel_values, attention_mask, **kwargs)
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/mllama.py", line 116, in __call__
    outputs = self.language_model(
        input_ids=input_ids,
    ...<5 lines>...
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 378, in __call__
    hidden_states = self.model(
        input_ids=input_ids,
    ...<5 lines>...
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 348, in __call__
    layer_outputs = decoder_layer(
        hidden_states,
        mask=mask,
        cache=c,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 229, in __call__
    hidden_states = self.self_attn(
        x=hidden_states,
        mask=mask,
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 183, in __call__
    attn_output = mx.fast.scaled_dot_product_attention(
        query_states, key_states, value_states, scale=self.scale, mask=mask
    )
ValueError: [broadcast_shapes] Shapes (1,344) and (1,32,343,343) cannot be broadcast.

Any suggestions would be greatly appreciated.

vvvlad-com avatar Jun 22 '25 03:06 vvvlad-com

你好呀,

我在 Mac M4 上尝试使用 0.1.27 进行微调时遇到了同样的错误,同时使用了Qwen2-VL-2B-Instruct-4bitmlx-community/Qwen2-VL-2B-Instruct-bf16。我也尝试了meta-llama/Llama-3.2-11B-Vision-Instruct,但出现了不同的错误(见下文),但我不确定这是否相关。

我对训练和验证数据集使用以下数据格式

{
    "images": img_path,
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "type": "image",
                    "image": f"file://{img_path}"
                }
            ]
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": "..."
                }
            ]
        }
    ]
}

Llama 3.2 错误:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/lora.py", line 194, in <module>
    main(args)
    ~~~~^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/lora.py", line 108, in main
    loss = trainer.train_step(
        dataset[i * args.batch_size : (i + 1) * args.batch_size]
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/trainer/trainer.py", line 265, in train_step
    loss, grads = loss_and_grad_fn(self.model, batch)
                  ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
                  ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.13/site-packages/mlx/nn/utils.py", line 29, in inner_fn
    return fn(*args, **kwargs)
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/trainer/trainer.py", line 230, in loss_fn
    outputs = model(input_ids, pixel_values, attention_mask, **kwargs)
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/mllama.py", line 116, in __call__
    outputs = self.language_model(
        input_ids=input_ids,
    ...<5 lines>...
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 378, in __call__
    hidden_states = self.model(
        input_ids=input_ids,
    ...<5 lines>...
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 348, in __call__
    layer_outputs = decoder_layer(
        hidden_states,
        mask=mask,
        cache=c,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 229, in __call__
    hidden_states = self.self_attn(
        x=hidden_states,
        mask=mask,
        cache=cache,
    )
  File "/.../.venv/lib/python3.13/site-packages/mlx_vlm/models/mllama/language.py", line 183, in __call__
    attn_output = mx.fast.scaled_dot_product_attention(
        query_states, key_states, value_states, scale=self.scale, mask=mask
    )
ValueError: [broadcast_shapes] Shapes (1,344) and (1,32,343,343) cannot be broadcast.

如有任何建议,我们将不胜感激。

I also encountered this problem, have you solved it?

2657666247 avatar Jun 30 '25 01:06 2657666247

Not yet, unfortunately. How about you?

vvvlad-com avatar Jul 04 '25 01:07 vvvlad-com

In fact, I did not encounter this problem in versions 0.1.26 and below, but it occurred in versions 0.1.27 and above. Maybe you can try another version.

2657666247 avatar Jul 04 '25 01:07 2657666247

Hey guys,

We are refactoring the entire training pipeline.

Check out #261 by @Goekdeniz-Guelmez

Blaizzy avatar Jul 05 '25 17:07 Blaizzy