trl icon indicating copy to clipboard operation
trl copied to clipboard

[SFT VLM] Add support for Molmo models

Open lewtun opened this issue 1 year ago • 15 comments

Feature request

Extend the sft_vlm.py script to support the new Molmo models from AllenAI: https://huggingface.co/collections/allenai/molmo-66f379e6fe3b8ef090a8ca19

Paper: https://arxiv.org/abs/2409.17146

Motivation

The Molmo models are super strong VLMs across all model scales, in some cases matching or exceeding the performance of GPT-4V:

Screenshot 2024-09-27 at 09 43 26

Having the ability to tune these models on custom datasets would be quite exciting for many vision-language applications (e.g. agents)

Your contribution

Open to the community!

lewtun avatar Sep 27 '24 07:09 lewtun

I'd like to contribute to it if you give me some guidance about the requirements! 😄

sergiopaniego avatar Sep 27 '24 16:09 sergiopaniego

I'd like to contribute to it if you give me some guidance about the requirements! 😄

Great! I would start by looking at the inference code from one of the models (example) and seeing how the inputs need to be provided to the model. Once you've understood that, it should be reasonably straightforward to extend the training script to include these models with trust_remote_code=True

@edbeeching can also provide some guidance as he made the original implementation :)

lewtun avatar Sep 27 '24 16:09 lewtun

Hi @sergiopaniego, I had a look at the modelling code of Molmo and the precessor is not quite the same as llama-vision and llava. So you may find it challenging to have a script that works for all these models.

If you would like to make a standalone script that works just for Molmo, adapted from our sft_vlm script, that would be a great first step, we can then iterate together to see if we can generalize the scripts.

edbeeching avatar Sep 30 '24 07:09 edbeeching

It might also be good to track the transformers integration which will presumably standardise the preprocessing: https://github.com/huggingface/transformers/issues/33710

lewtun avatar Sep 30 '24 07:09 lewtun

Thanks a lot for the details!

I'm currently running the script as it is while trying to understand the differences compared to Molmo. To clarify, @edbeeching the processor that you're talking about is the one in https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py? I'll try to generate first a standalone script for Molmo as you suggest 😄

sergiopaniego avatar Sep 30 '24 16:09 sergiopaniego

@sergiopaniego

To clarify, @edbeeching the processor that you're talking about is the one in https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py?

Yes that is the one.

edbeeching avatar Oct 03 '24 09:10 edbeeching

Thanks for the reaffirmation, @edbeeching!

I've created a reproducible example on Google Colab to share the code:

Colab Notebook

Currently, I'm encountering a RuntimeError: CUDA error: device-side assert triggered.

Some details:

  • I've set batch_size=1 because the processor.process function expects only one example.
  • I’ve made some modifications to the collate_fn to accommodate the processor.
  • I've also upgraded the transformers library to the latest version.

I’m actively investigating the issue. Do you have any suggestions on how to resolve it?

sergiopaniego avatar Oct 03 '24 16:10 sergiopaniego

Thanks for the reaffirmation, @edbeeching!

I've created a reproducible example on Google Colab to share the code:

Colab Notebook

Currently, I'm encountering a RuntimeError: CUDA error: device-side assert triggered.

I could load the model without the error by first importing BitsAndBytesConfig from transformers in the 5th cell before adding the config

quantization_config = BitsAndBytesConfig(
        load_in_8bit=False, load_in_4bit=True
        )

smellslikeml avatar Oct 03 '24 19:10 smellslikeml

Thanks for the reaffirmation, @edbeeching! I've created a reproducible example on Google Colab to share the code: Colab Notebook Currently, I'm encountering a RuntimeError: CUDA error: device-side assert triggered.

I could load the model without the error by first importing BitsAndBytesConfig from transformers in the 5th cell before adding the config

quantization_config = BitsAndBytesConfig(
        load_in_8bit=False, load_in_4bit=True
        )

Could you share your reproducible example?

sergiopaniego avatar Oct 04 '24 09:10 sergiopaniego

Could you share your reproducible example?

Sure, I've added those changes to your colab here and the rest should be the same.

smellslikeml avatar Oct 04 '24 12:10 smellslikeml

Hello, do you have a timeline for this?

aleSuglia avatar Oct 07 '24 16:10 aleSuglia

Could you share your reproducible example?

Sure, I've added those changes to your colab here and the rest should be the same.

I attempted to extend the notebook, but I encountered the same exception. I’m continuing to investigate the root cause.

sergiopaniego avatar Oct 09 '24 17:10 sergiopaniego

Could you share your reproducible example?

Sure, I've added those changes to your colab here and the rest should be the same.

I attempted to extend the notebook, but I encountered the same exception. I’m continuing to investigate the root cause.

try this colab: https://colab.research.google.com/drive/1RICZvuxLJ0g6dCIkOIf0HC5J9fJGqNTU?usp=sharing it get past the CUDA error and begins training before OOM

smellslikeml avatar Oct 09 '24 17:10 smellslikeml

Hi, let me know if you would like me to take a look?

edbeeching avatar Oct 13 '24 19:10 edbeeching

Hi @edbeeching!

Sorry for the delay. I was busy last week, but I have some additional time to dedicate this week. I've reproduced @smellslikeml's idea (https://colab.research.google.com/drive/1doT9u811J-WNCnsT6-rP9-OxnDv52M6W?usp=sharing), and I'll try to generate the PR this week. Should we wait until https://github.com/huggingface/transformers/pull/33962 is completed?

sergiopaniego avatar Oct 14 '24 16:10 sergiopaniego

Hi @sergiopaniego, I need to fine tune a Molmo model for my current project, and I just stumbled upon your related pull request, thank you for putting it together! May I ask you if you were able to successfully run a fine tune of Molmo? How was your experience?

I just tried to run your example script as it is, but I get an error when putting together a batch in batch["input_ids"] = torch.stack(batch["input_ids"]), since the elements in that list have different dimensions. It seems that some padding is missing? Do you have any quick pointer on how to work around this?

Thank you in advance!

chisarie avatar Jan 13 '25 13:01 chisarie

Hi @chisarie!

Thanks for the comment! I've been waiting for the model integration with transformers. It looks like it's getting close to being merged, so it might be a good time to reactivate this PR.

Have you checked using this branch (https://github.com/huggingface/transformers/pull/33962)?

sergiopaniego avatar Jan 13 '25 18:01 sergiopaniego

Hi Sergio, I tried today to use that branch, but I get an other error:

ValueError: '<class 'transformers_modules.allenai.Molmo-7B-D-0924.1721478b71306fb7dc671176d5c204dc7a4d27d7.config_molmo.MolmoConfig'>' is already used by a Transformers model.

I guess it's because that branch has the -hf model integrated, with the config also called MolmoConfig. I also tried to load the -hf model from the branch directly, but it seems not compatible with your script.

Do you know how to work around this? Thanks!

chisarie avatar Jan 15 '25 16:01 chisarie

Thanks for the update @sergiopaniego!

I tested the changes from huggingface/transformers#33962 in this colab notebook and successfully trained with trl version 0.13.0.dev0.

smellslikeml avatar Jan 16 '25 18:01 smellslikeml

Thanks for the updated notebook, @smellslikeml! 🙌

I've tested it on an updated Colab, which allowed me to generate a fine-tuned version. Everything seems to be working correctly!

I've also created a gist for the script.

Once the Transformers PR is merged, we’ll be all set! 😄

sergiopaniego avatar Jan 21 '25 20:01 sergiopaniego

Hi @smellslikeml and @sergiopaniego. Thank you again for the effort on making Molmo finetuning work!

I have a question regarding your collate_fn, especially the way you handle the labels. From what I know, the labels are usually just a copy of the input_ids, plus some -100 masking for things like padding and image tokens, for example as done here: https://github.com/huggingface/trl/blob/55a329e9f0636a2ad6522caa4a601326def44545/examples/scripts/sft_vlm.py#L102

Why is it done differently in your script? I tried to understand the way you construct the labels in your proposed collate_fn, but couldn't wrap my head around it.

chisarie avatar Jan 26 '25 12:01 chisarie

@chisarie This was mostly inspired by the collate_fn() function described in this recipe fine tuning Qwen-VL-7B

smellslikeml avatar Jan 26 '25 16:01 smellslikeml