overhaul and adding full weight fine-tuning (Quantized training, GRPO, DPO, ORPO)
is the a good dataset I can use here?
is the a good dataset I can use here?
You can use this one: https://huggingface.co/datasets/5CD-AI/LLaVA-CoT-o1-Instruct
You are fast 🔥 🚀
Btw, since you are doing a overhaul. I think we need to get closer to the PEFT package API because more multimodal models are using task loras such as Phi-4-mm:
Features:
- Fuse lora
- Disable lora
- Add named loras
Check #225
you guys are FAST!
(mlx-dev) (base) ~/Desktop/mlx-vlm/ [adding-full-finetuning*] python -m mlx_vlm.lora --full-weight-training --dataset 5CD-AI/LLaVA-CoT-o1-Instruct
INFO:main:Loading model from mlx-community/Qwen2-VL-2B-Instruct-bf16
Fetching 11 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 41828.96it/s]
Using a slow image processor as use_fast is unset and a slow processor was saved with this model. use_fast=True will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with use_fast=False.
Fetching 11 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 16789.43it/s]
INFO:mlx_vlm.trainers.datasets:Loading dataset from 5CD-AI/LLaVA-CoT-o1-Instruct
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 41.03it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 258774.54it/s]
INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset
INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset
INFO:main:Using full weight training (all parameters will be trained)
#trainable params: 2208.9856 M || all params: 2208.9856 M || trainable%: 100.000%
INFO:main:Setting up optimizer
INFO:main:Setting up trainer
INFO:main:Training model
{'Epoch': 0, 'Step': 0, 'Loss': '8.1304'}
{'Epoch': 0, 'Step': 10, 'Loss': '6.1346'}
@Blaizzy would you mind trying it for yourself as well? Seems to be working for me :D
Thanks @Goekdeniz-Guelmez, you rock and ship crazy fast!
I'm testing it and will give you feedback. Feel free to ping me if you don't hear anything in the next 24h.
Hey @Goekdeniz-Guelmez
This is what I got:
I noticed the loss went up and ran out of memory (96GB M3 max), did this happend to you?
yea, I just got the same error, I changed it to be like MLX-LM, which I believe is more scalable, especially when we want to add other training algorithms, this new one gives me this:
INFO:mlx_vlm.trainers.datasets:Loading dataset from 5CD-AI/LLaVA-CoT-o1-Instruct
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 96.24it/s]
Resolving data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 192105.53it/s]
INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset
INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset
INFO:__main__:Using full weight training (all parameters will be trained)
#trainable params: 2208.9856 M || all params: 2208.9856 M || trainable%: 100.000%
INFO:__main__:Setting up optimizer
INFO:__main__:Setting up training arguments
INFO:__main__:Training model
0%| | 0/1000 [00:00<?, ?it/s]Starting training..., iters: 1000
Iter 10: Train loss 14.786, Learning Rate 1.000e-04, It/sec 0.189, Tokens/sec 194.285, Trained Tokens 10274, Peak mem 26.404 GB
{'Step': 10, 'Loss': '14.7856', 'Learning Rate': '1.00e-04', 'Tokens/sec': '194'}
1%|▊ | 10/1000 [00:52<1:27:15, 5.29s/it, Loss=14.7856, LR=1.00e-04, Tokens/sec=194]Iter 20: Train loss 5.563, Learning Rate 1.000e-04, It/sec 0.223, Tokens/sec 204.416, Trained Tokens 19457, Peak mem 26.404 GB
{'Step': 20, 'Loss': '5.5630', 'Learning Rate': '1.00e-04', 'Tokens/sec': '204'}
2%|█▋ | 20/1000 [01:37<1:18:43, 4.82s/it, Loss=5.5630, LR=1.00e-04, Tokens/sec=204]Iter 30: Train loss 4.231, Learning Rate 1.000e-04, It/sec 0.168, Tokens/sec 204.660, Trained Tokens 31607, Peak mem 26.404 GB
{'Step': 30, 'Loss': '4.2310', 'Learning Rate': '1.00e-04', 'Tokens/sec': '205'}
3%|██▍ | 30/1000 [02:37<1:26:10, 5.33s/it, Loss=4.2310, LR=1.00e-04, Tokens/sec=205]Iter 40: Train loss 3.811, Learning Rate 1.000e-04, It/sec 0.172, Tokens/sec 220.804, Trained Tokens 44454, Peak mem 26.404 GB
{'Step': 40, 'Loss': '3.8108', 'Learning Rate': '1.00e-04', 'Tokens/sec': '221'}
4%|███▎ | 40/1000 [03:35<1:28:21, 5.52s/it, Loss=3.8108, LR=1.00e-04, Tokens/sec=221]
LoRA s well:
INFO:mlx_vlm.trainers.datasets:Preparing and maping dataset
INFO:mlx_vlm.trainers.datasets:Applying chat template to the dataset
INFO:__main__:Setting up LoRA
#trainable params: 11.54048 M || all params: 2208.9856 M || trainable%: 0.522%
INFO:__main__:Setting up optimizer
INFO:__main__:Setting up training arguments
INFO:__main__:Training model
0%| | 0/1000 [00:00<?, ?it/s]Starting training..., iters: 1000
Iter 10: Train loss 3.438, Learning Rate 1.000e-04, It/sec 0.307, Tokens/sec 301.622, Trained Tokens 9817, Peak mem 22.427 GB
{'Step': 10, 'Loss': '3.4376', 'Learning Rate': '1.00e-04', 'Tokens/sec': '302'}
1%|▊ | 10/1000 [00:32<53:42, 3.25s/it, Loss=3.4376, LR=1.00e-04, Tokens/sec=302]Iter 20: Train loss 3.142, Learning Rate 1.000e-04, It/sec 0.283, Tokens/sec 302.828, Trained Tokens 20518, Peak mem 22.427 GB
{'Step': 20, 'Loss': '3.1416', 'Learning Rate': '1.00e-04', 'Tokens/sec': '303'}
2%|█▋ | 20/1000 [01:07<55:50, 3.42s/it, Loss=3.1416, LR=1.00e-04, Tokens/sec=303]Iter 30: Train loss 2.803, Learning Rate 1.000e-04, It/sec 0.453, Tokens/sec 319.994, Trained Tokens 27586, Peak mem 22.427 GB
{'Step': 30, 'Loss': '2.8035', 'Learning Rate': '1.00e-04', 'Tokens/sec': '320'}
3%|██▌ | 30/1000 [01:29<46:20, 2.87s/it, Loss=2.8035, LR=1.00e-04, Tokens/sec=320]Iter 40: Train loss 3.789, Learning Rate 1.000e-04, It/sec 0.331, Tokens/sec 302.422, Trained Tokens 36736, Peak mem 22.541 GB
{'Step': 40, 'Loss': '3.7891', 'Learning Rate': '1.00e-04', 'Tokens/sec': '302'}
4%|███▍ | 40/1000 [02:00<46:52, 2.93s/it, Loss=3.7891, LR=1.00e-04, Tokens/sec=302]Iter 50: Train loss 3.981, Learning Rate 1.000e-04, It/sec 0.254, Tokens/sec 306.186, Trained Tokens 48780, Peak mem 22.541 GB
{'Step': 50, 'Loss': '3.9812', 'Learning Rate': '1.00e-04', 'Tokens/sec': '306'}
5%|████▎ | 50/1000 [02:39<52:06, 3.29s/it, Loss=3.9812, LR=1.00e-04, Tokens/sec=306]
@Blaizzy Can you try that again? I made some major updates, that should work really well :D
Hey @Goekdeniz-Guelmez
Just took it for a spin again and it works really well!
I have been running it for good solid 30 min, the loss has converged to arround ~2.4 and decreasing.
https://x.com/Prince_Canuma/status/1905049248476827696
@Goekdeniz-Guelmez I got this error towards the end.
The error is because a particular sample is too small.
Can we add a skip such samples if they show up instead of throwing an error.
{'Step': 960, 'Loss': '2.3301', 'Learning Rate': '1.00e-04', 'Tokens/sec': '421'}
96%|▉| 960/1000 [44:26<01:44, 2.62s/it, Loss=2.3301, LR=1.00e-04, Tokens/sec=421]Warning: Failed to process inputs with error: height:24 or width:198 must be larger than factor:28 Trying to process inputs with return_tensors='pt'
Traceback (most recent call last):
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 780, in process_inputs_with_fallback
inputs = process_inputs(
^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 772, in process_inputs
inputs = processor(
^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/processing_qwen2_vl.py", line 126, in __call__
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/image_processing_utils.py", line 42, in __call__
return self.preprocess(images, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 431, in preprocess
patches, image_grid_thw = self._preprocess(
^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 252, in _preprocess
resized_height, resized_width = smart_resize(
^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 69, in smart_resize
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
ValueError: height:24 or width:198 must be larger than factor:28
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 789, in process_inputs_with_fallback
inputs = process_inputs(processor, images, prompts, return_tensors="pt")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 772, in process_inputs
inputs = processor(
^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/processing_qwen2_vl.py", line 126, in __call__
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/image_processing_utils.py", line 42, in __call__
return self.preprocess(images, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 431, in preprocess
patches, image_grid_thw = self._preprocess(
^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 252, in _preprocess
resized_height, resized_width = smart_resize(
^^^^^^^^^^^^^
File "/Users/prince_canuma/transformers/transformers/transformers/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py", line 69, in smart_resize
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
ValueError: height:24 or width:198 must be larger than factor:28
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/lora.py", line 182, in <module>
main(args)
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/lora.py", line 96, in main
train(
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 347, in train
batch = next(dataset_iterator)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 130, in iterate_batches
batch_samples = [dataset[idx] for idx in batch_indices]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/sft_trainer.py", line 130, in <listcomp>
batch_samples = [dataset[idx] for idx in batch_indices]
~~~~~~~^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/trainers/datasets.py", line 99, in __getitem__
inputs = prepare_inputs(
^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 845, in prepare_inputs
inputs = process_inputs_with_fallback(processor, images, prompts)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/prince_canuma/Documents/Projects/LLMs/mlx-vlm-prs/mlx-vlm/mlx_vlm/utils.py", line 791, in process_inputs_with_fallback
raise ValueError(
ValueError: Failed to process inputs with error: height:24 or width:198 must be larger than factor:28. Please install PyTorch and try again.
96%|▉| 960/1000 [44:44<01:51, 2.80s/it, Loss=2.3301, LR=1.00e-04, Tokens/sec=421]
Amazing work! You two are the Dream Team!
@Goekdeniz-Guelmez I got this error towards the end.
The error is because a particular sample is too small.
Can we add a skip such samples if they show up instead of throwing an error.
Definitely!! This is a good idea. I'll push the update later today :D
So I added a try except block of the prepare_inputs in the SFTDataset class, if it was unsuccessful, then skip and go to the next sample.
@Goekdeniz-Guelmez awesome, this looks good to me.
Thank you very very much, for the amazing work and speed!❤️
What's missing? and what's next here? (GRPO maybe?)
Can't wait for GRPO integration!
@Blaizzy Always a pleasure! I think this should get merged first, and then I'll start adding GRPO. Do you think DPO or ORPO should go in as well? As well as reporting to WandB?
Yes, absolutely!
All of those are much needed and welcome ❤️
Of course, just move it to ready and we can get started with the review.
@Blaizzy Great!!! I'll start adding the rest, and this PR should be ready to review :D.
@Goekdeniz-Guelmez : The issue#284 , is happening with your branch. Is there something i need to do separately to get it working with your branch
I think the Datasets: ydeng9/llavaone_grpo_v2, ydeng9/llavaone_grpo_v1 are well suited, so I'll use them, if someone find better ones or different ones, feel free to put them into here too.
OK you can now report your training to WandB via --wandb-project project-name
So I got pretty far with the GRPO implementation but still get no generated outputs, It's pretty late in Germany now, so I'll go to sleep and continue working on it later, gn.
Edit: For the DPO and ORPO implementation, I will be using ucsahin/TR-VLM-DPO-Dataset
Hey Goekdeniz
Hope you got your deserved rest,
Any updates?
Hey Prince,
sorry for the infrequent updates, yes I’m gettin closer. I worked on it this morning too, but I’m still getting an error in generate for GRPO function. I’m off with the family and start working on it on Monday when I’m back.
Awesome, enjoy the holiday and time with family this can wait 💪🏾
I wish I was in your shoes