transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add training support for SigLIP

Open aliencaocao opened this issue 1 year ago • 4 comments

What does this PR do?

Add the sigmoid contrastive loss function of SigLIP from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287

This will allow training/finetuning SigLIP models.

Already verified to work on my own dataset.

I saw the note on using torch.distributed for loss function and open_clip's implementation, but I'm not sure why is it needed. I ran my training with both DDP and FDSP with full sharding and it seem to work just fine, also getting the expected speedup and ability to set larger BS. The only issue is https://github.com/huggingface/transformers/issues/31034 when using FDSP but I don't think its SigLIP specific.

Nonetheless, I updated the docs to mention the lack of usage of torch.distributed if that ended up important to some users.

Not sure if a training test is needed.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

@amyeroberts

aliencaocao avatar Jun 19 '24 13:06 aliencaocao

@aliencaocao Could you rebase to include the upstream changes on main? This should fix the failures on the CI runs

amyeroberts avatar Jun 20 '24 16:06 amyeroberts

Added the training tests and also enabled gradient checkpointing tests. I note that CLIP had issues with GC but I have used it with siglip myself and did not find any issue on convergence/accuracy on a single RTX 3080Ti with fp16 training and grad accum=16.

Will let the tests run and see how it goes.

aliencaocao avatar Jun 21 '24 03:06 aliencaocao

@amyeroberts seems to need you to enable slow tests?

aliencaocao avatar Jun 21 '24 03:06 aliencaocao

@amyeroberts now that the GC tests are properly skipped, shall we move forward with this?

aliencaocao avatar Jun 28 '24 05:06 aliencaocao