mlx-vlm icon indicating copy to clipboard operation
mlx-vlm copied to clipboard

overhaul and adding full weight fine-tuning (Quantized training, GRPO, DPO, ORPO)

Open Goekdeniz-Guelmez opened this issue 11 months ago • 35 comments

Goekdeniz-Guelmez avatar Mar 20 '25 21:03 Goekdeniz-Guelmez

is the a good dataset I can use here?

Goekdeniz-Guelmez avatar Mar 20 '25 21:03 Goekdeniz-Guelmez

is the a good dataset I can use here?

You can use this one: https://huggingface.co/datasets/5CD-AI/LLaVA-CoT-o1-Instruct

Blaizzy avatar Mar 20 '25 21:03 Blaizzy

You are fast 🔥 🚀

Blaizzy avatar Mar 20 '25 21:03 Blaizzy

Btw, since you are doing a overhaul. I think we need to get closer to the PEFT package API because more multimodal models are using task loras such as Phi-4-mm:

Features:

  • Fuse lora
  • Disable lora
  • Add named loras

Check #225

Blaizzy avatar Mar 20 '25 21:03 Blaizzy

you guys are FAST!

lin72h avatar Mar 20 '25 23:03 lin72h

(mlx-dev) (base)  ~/Desktop/mlx-vlm/ [adding-full-finetuning*] python -m mlx_vlm.lora --full-weight-training --dataset 5CD-AI/LLaVA-CoT-o1-Instruct INFO:main:Loading model from mlx-community/Qwen2-VL-2B-Instruct-bf16 Fetching 11 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 41828.96it/s] Using a slow image processor as use_fast is unset and a slow processor was saved with this model. use_fast=True will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with use_fast=False. Fetching 11 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 16789.43it/s] INFO:mlx_vlm.trainers.datasets:Loading dataset from 5CD-AI/LLaVA-CoT-o1-Instruct Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 41.03it/s] Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 258774.54it/s] INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset INFO:main:Using full weight training (all parameters will be trained) #trainable params: 2208.9856 M || all params: 2208.9856 M || trainable%: 100.000% INFO:main:Setting up optimizer INFO:main:Setting up trainer INFO:main:Training model {'Epoch': 0, 'Step': 0, 'Loss': '8.1304'}
{'Epoch': 0, 'Step': 10, 'Loss': '6.1346'}

Goekdeniz-Guelmez avatar Mar 21 '25 21:03 Goekdeniz-Guelmez

@Blaizzy would you mind trying it for yourself as well? Seems to be working for me :D

Goekdeniz-Guelmez avatar Mar 24 '25 10:03 Goekdeniz-Guelmez

Thanks @Goekdeniz-Guelmez, you rock and ship crazy fast!

I'm testing it and will give you feedback. Feel free to ping me if you don't hear anything in the next 24h.

Blaizzy avatar Mar 24 '25 22:03 Blaizzy

Hey @Goekdeniz-Guelmez This is what I got: Screenshot 2025-03-24 at 11 53 25 PM

I noticed the loss went up and ran out of memory (96GB M3 max), did this happend to you?

Blaizzy avatar Mar 25 '25 20:03 Blaizzy

yea, I just got the same error, I changed it to be like MLX-LM, which I believe is more scalable, especially when we want to add other training algorithms, this new one gives me this:

INFO:mlx_vlm.trainers.datasets:Loading dataset from 5CD-AI/LLaVA-CoT-o1-Instruct
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 96.24it/s]
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 192105.53it/s]
INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset
INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset
INFO:__main__:Using full weight training (all parameters will be trained)
#trainable params: 2208.9856 M || all params: 2208.9856 M || trainable%: 100.000%
INFO:__main__:Setting up optimizer
INFO:__main__:Setting up training arguments
INFO:__main__:Training model
  0%|                                                                                                                                        | 0/1000 [00:00<?, ?it/s]Starting training..., iters: 1000
Iter 10: Train loss 14.786, Learning Rate 1.000e-04, It/sec 0.189, Tokens/sec 194.285, Trained Tokens 10274, Peak mem 26.404 GB
{'Step': 10, 'Loss': '14.7856', 'Learning Rate': '1.00e-04', 'Tokens/sec': '194'}                                                                                     
  1%|▊                                                                                 | 10/1000 [00:52<1:27:15,  5.29s/it, Loss=14.7856, LR=1.00e-04, Tokens/sec=194]Iter 20: Train loss 5.563, Learning Rate 1.000e-04, It/sec 0.223, Tokens/sec 204.416, Trained Tokens 19457, Peak mem 26.404 GB
{'Step': 20, 'Loss': '5.5630', 'Learning Rate': '1.00e-04', 'Tokens/sec': '204'}                                                                                      
  2%|█▋                                                                                 | 20/1000 [01:37<1:18:43,  4.82s/it, Loss=5.5630, LR=1.00e-04, Tokens/sec=204]Iter 30: Train loss 4.231, Learning Rate 1.000e-04, It/sec 0.168, Tokens/sec 204.660, Trained Tokens 31607, Peak mem 26.404 GB
{'Step': 30, 'Loss': '4.2310', 'Learning Rate': '1.00e-04', 'Tokens/sec': '205'}                                                                                      
  3%|██▍                                                                                | 30/1000 [02:37<1:26:10,  5.33s/it, Loss=4.2310, LR=1.00e-04, Tokens/sec=205]Iter 40: Train loss 3.811, Learning Rate 1.000e-04, It/sec 0.172, Tokens/sec 220.804, Trained Tokens 44454, Peak mem 26.404 GB
{'Step': 40, 'Loss': '3.8108', 'Learning Rate': '1.00e-04', 'Tokens/sec': '221'}                                                                                      
  4%|███▎                                                                               | 40/1000 [03:35<1:28:21,  5.52s/it, Loss=3.8108, LR=1.00e-04, Tokens/sec=221]

Goekdeniz-Guelmez avatar Mar 25 '25 21:03 Goekdeniz-Guelmez

LoRA s well:

INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset
INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset
INFO:__main__:Setting up LoRA
#trainable params: 11.54048 M || all params: 2208.9856 M || trainable%: 0.522%
INFO:__main__:Setting up optimizer
INFO:__main__:Setting up training arguments
INFO:__main__:Training model
  0%|                                                                                                                                        | 0/1000 [00:00<?, ?it/s]Starting training..., iters: 1000
Iter 10: Train loss 3.438, Learning Rate 1.000e-04, It/sec 0.307, Tokens/sec 301.622, Trained Tokens 9817, Peak mem 22.427 GB
{'Step': 10, 'Loss': '3.4376', 'Learning Rate': '1.00e-04', 'Tokens/sec': '302'}                                                                                      
  1%|▊                                                                                    | 10/1000 [00:32<53:42,  3.25s/it, Loss=3.4376, LR=1.00e-04, Tokens/sec=302]Iter 20: Train loss 3.142, Learning Rate 1.000e-04, It/sec 0.283, Tokens/sec 302.828, Trained Tokens 20518, Peak mem 22.427 GB
{'Step': 20, 'Loss': '3.1416', 'Learning Rate': '1.00e-04', 'Tokens/sec': '303'}                                                                                      
  2%|█▋                                                                                   | 20/1000 [01:07<55:50,  3.42s/it, Loss=3.1416, LR=1.00e-04, Tokens/sec=303]Iter 30: Train loss 2.803, Learning Rate 1.000e-04, It/sec 0.453, Tokens/sec 319.994, Trained Tokens 27586, Peak mem 22.427 GB
{'Step': 30, 'Loss': '2.8035', 'Learning Rate': '1.00e-04', 'Tokens/sec': '320'}                                                                                      
  3%|██▌                                                                                  | 30/1000 [01:29<46:20,  2.87s/it, Loss=2.8035, LR=1.00e-04, Tokens/sec=320]Iter 40: Train loss 3.789, Learning Rate 1.000e-04, It/sec 0.331, Tokens/sec 302.422, Trained Tokens 36736, Peak mem 22.541 GB
{'Step': 40, 'Loss': '3.7891', 'Learning Rate': '1.00e-04', 'Tokens/sec': '302'}                                                                                      
  4%|███▍                                                                                 | 40/1000 [02:00<46:52,  2.93s/it, Loss=3.7891, LR=1.00e-04, Tokens/sec=302]Iter 50: Train loss 3.981, Learning Rate 1.000e-04, It/sec 0.254, Tokens/sec 306.186, Trained Tokens 48780, Peak mem 22.541 GB
{'Step': 50, 'Loss': '3.9812', 'Learning Rate': '1.00e-04', 'Tokens/sec': '306'}                                                                                      
  5%|████▎                                                                                | 50/1000 [02:39<52:06,  3.29s/it, Loss=3.9812, LR=1.00e-04, Tokens/sec=306]

Goekdeniz-Guelmez avatar Mar 25 '25 21:03 Goekdeniz-Guelmez

@Blaizzy Can you try that again? I made some major updates, that should work really well :D

Goekdeniz-Guelmez avatar Mar 26 '25 07:03 Goekdeniz-Guelmez

Hey @Goekdeniz-Guelmez

Just took it for a spin again and it works really well!

I have been running it for good solid 30 min, the loss has converged to arround ~2.4 and decreasing.

https://x.com/Prince_Canuma/status/1905049248476827696

Screenshot 2025-03-27 at 1 26 46 AM

Blaizzy avatar Mar 27 '25 00:03 Blaizzy

@Goekdeniz-Guelmez I got this error towards the end.

The error is because a particular sample is too small.

Can we add a skip such samples if they show up instead of throwing an error.

{'Step': 960, 'Loss': '2.3301', 'Learning Rate': '1.00e-04', 'Tokens/sec': '421'}  
 96%|▉| 960/1000 [44:26<01:44,  2.62s/it, Loss=2.3301, LR=1.00e-04, Tokens/sec=421]Warning: Failed to process inputs with error: height:24 or width:198 must be larger than factor:28 Trying to process inputs with return_tensors='pt'
Traceback (most recent call last):
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 780, in process_inputs_with_fallback
    inputs = process_inputs(
             ^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 772, in process_inputs
    inputs = processor(
             ^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/processing_qwen2_vl.py", line 126, in __call__
    image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/image_processing_utils.py", line 42, in __call__
    return self.preprocess(images, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 431, in preprocess
    patches, image_grid_thw = self._preprocess(
                              ^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 252, in _preprocess
    resized_height, resized_width = smart_resize(
                                    ^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 69, in smart_resize
    raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
ValueError: height:24 or width:198 must be larger than factor:28

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 789, in process_inputs_with_fallback
    inputs = process_inputs(processor, images, prompts, return_tensors="pt")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 772, in process_inputs
    inputs = processor(
             ^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/processing_qwen2_vl.py", line 126, in __call__
    image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/image_processing_utils.py", line 42, in __call__
    return self.preprocess(images, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 431, in preprocess
    patches, image_grid_thw = self._preprocess(
                              ^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 252, in _preprocess
    resized_height, resized_width = smart_resize(
                                    ^^^^^^^^^^^^^
  File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 69, in smart_resize
    raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
ValueError: height:24 or width:198 must be larger than factor:28

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/lora.py", line 182, in <module>
    main(args)
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/lora.py", line 96, in main
    train(
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 347, in train
    batch = next(dataset_iterator)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 130, in iterate_batches
    batch_samples = [dataset[idx] for idx in batch_indices]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 130, in <listcomp>
    batch_samples = [dataset[idx] for idx in batch_indices]
                     ~~~~~~~^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/datasets.py", line 99, in __getitem__
    inputs = prepare_inputs(
             ^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 845, in prepare_inputs
    inputs = process_inputs_with_fallback(processor, images, prompts)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 791, in process_inputs_with_fallback
    raise ValueError(
ValueError: Failed to process inputs with error: height:24 or width:198 must be larger than factor:28. Please install PyTorch and try again.
 96%|▉| 960/1000 [44:44<01:51,  2.80s/it, Loss=2.3301, LR=1.00e-04, Tokens/sec=421]

Blaizzy avatar Mar 27 '25 01:03 Blaizzy

Amazing work! You two are the Dream Team!

ivanfioravanti avatar Mar 27 '25 06:03 ivanfioravanti

@Goekdeniz-Guelmez I got this error towards the end.

The error is because a particular sample is too small.

Can we add a skip such samples if they show up instead of throwing an error.

Definitely!! This is a good idea. I'll push the update later today :D

Goekdeniz-Guelmez avatar Mar 27 '25 07:03 Goekdeniz-Guelmez

So I added a try except block of the prepare_inputs in the SFTDataset class, if it was unsuccessful, then skip and go to the next sample.

Goekdeniz-Guelmez avatar Mar 27 '25 17:03 Goekdeniz-Guelmez

@Goekdeniz-Guelmez awesome, this looks good to me.

Thank you very very much, for the amazing work and speed!❤️

What's missing? and what's next here? (GRPO maybe?)

Blaizzy avatar Mar 27 '25 23:03 Blaizzy

Can't wait for GRPO integration!

lin72h avatar Mar 27 '25 23:03 lin72h

@Blaizzy Always a pleasure! I think this should get merged first, and then I'll start adding GRPO. Do you think DPO or ORPO should go in as well? As well as reporting to WandB?

Goekdeniz-Guelmez avatar Mar 28 '25 14:03 Goekdeniz-Guelmez

Yes, absolutely!

All of those are much needed and welcome ❤️

Blaizzy avatar Mar 28 '25 15:03 Blaizzy

Of course, just move it to ready and we can get started with the review.

Blaizzy avatar Mar 28 '25 15:03 Blaizzy

@Blaizzy Great!!! I'll start adding the rest, and this PR should be ready to review :D.

Goekdeniz-Guelmez avatar Mar 28 '25 15:03 Goekdeniz-Guelmez

@Goekdeniz-Guelmez : The issue#284 , is happening with your branch. Is there something i need to do separately to get it working with your branch

keshavpeswani avatar Apr 01 '25 07:04 keshavpeswani

I think the Datasets: ydeng9/llavaone_grpo_v2, ydeng9/llavaone_grpo_v1 are well suited, so I'll use them, if someone find better ones or different ones, feel free to put them into here too.

Goekdeniz-Guelmez avatar Apr 10 '25 18:04 Goekdeniz-Guelmez

OK you can now report your training to WandB via --wandb-project project-name

Goekdeniz-Guelmez avatar Apr 12 '25 11:04 Goekdeniz-Guelmez

So I got pretty far with the GRPO implementation but still get no generated outputs, It's pretty late in Germany now, so I'll go to sleep and continue working on it later, gn.

Edit: For the DPO and ORPO implementation, I will be using ucsahin/TR-VLM-DPO-Dataset

Goekdeniz-Guelmez avatar Apr 12 '25 20:04 Goekdeniz-Guelmez

Hey Goekdeniz

Hope you got your deserved rest,

Any updates?

Blaizzy avatar Apr 19 '25 12:04 Blaizzy

Hey Prince,

sorry for the infrequent updates, yes I’m gettin closer. I worked on it this morning too, but I’m still getting an error in generate for GRPO function. I’m off with the family and start working on it on Monday when I’m back.

Goekdeniz-Guelmez avatar Apr 19 '25 16:04 Goekdeniz-Guelmez

Awesome, enjoy the holiday and time with family this can wait 💪🏾

I wish I was in your shoes

Blaizzy avatar Apr 19 '25 17:04 Blaizzy