transformers
transformers copied to clipboard
[GroundingDino] Fix grounding dino loss 🚨
What does this PR do?
Fixes https://github.com/huggingface/transformers/issues/31434
As the original repo doesn't provide the loss implementation I'm using the one implemented here as a baseline since it was mentioned by the original repo, on this issue https://github.com/IDEA-Research/GroundingDINO/issues/241, as a reliable source if one wants to train a GroundingDino model
TODO:
- [x] Test
GroundingDinoMatcherandGroundingDinoLossare working properly
Explanation of the Issue and Solution
So the issue was that GroundingDinoLoss and GroundingDinoHungarianMatcher were just a copy from DeformableDetr which is used for closed-set object detection (i.e. a fixed set of categories). Whereas in GroundingDino there's no limited amount of categories and the output logits are d_model dimensional where the first seq_len elements have a specified value and the subsequent are nan. The main differences are:
class_labelsare associated with the text prompt used- The logits are asscoaited with the tokens of the text so it's not necessarily 1-to-1
For instance if an image with bounding boxes with fishes and jellyfishes using a prompt "fish. jellyfish." fish should have class_label 0 assigned to it and jellyfish should have 1 assigned. If the position of jellyfish and fish in the prompt swapped then the class_labels would swap as well. Moreover, jellyfish is represented by two tokens ([20919, 7529]) and fish by one token ([3869]) therefore we need to select the appropriate logits for each class.
As the original implementation doesn't provide the training loop or the loss implementation, but does recommend other implementations for training GroundingDino on this issue https://github.com/IDEA-Research/GroundingDINO/issues/241, I took as baseline the implementation from Open-GroundingDino as it supports both visual grounding and object detection and they've trained their own GroundingDino using their code base achieving good performance.
Things added in this PR are:
build_label_mapswhich generates a list oftorch.Tensorwith lenghtbatch_sizemapping each category to its corresponding tokens based on theinput_idsbuild_text_maskjust expand theattention_maskto select the appropriate tokens when computingGroundingDino.loss_labels- Added
enc_topk_proposals,encoder_logitsandencoder_pred_boxestoGroundingDinoModelOutputandGroundingDinoObjectDetectionOutputto compute first stage loss - Added
class_loss_coefficient(with correct default value) andclass_loss_reductiontoGroundingDinoConfig.class_loss_reductionwas added because insigmoid_focal_lossfrom the baseline implementation they reducedloss_cewith a simple sum, but that makes the losses imbalanced most of the time and in the original implementation they do have asigmoid_focal_lossimplemented, but usingmeanreduction, therefore I made I decided to make it configurable and use thesumone for testing reasons - Modifications to
GroundingDinoLossandGroundingDinoHungarianMatcher
Also added a new integration test called test_grounding_dino_loss where I compare the loss obtained from 2 sample images with the baseline implementation from Open-GroundingDino.
c.c. @amyeroberts
@amyeroberts FYI for some reason, when testing locally, test_cross_attention_mask is failing on this branch, but when I tested using the main branch it was also failing (locally)
c.c. @amyeroberts
@EduardoPach Thanks for the working this loss. Just sharing more well developed code for finetuning GroundingDINO https://github.com/open-mmlab/mmdetection/blob/main/configs/mm_grounding_dino/grounding_dino_swin-t_pretrain_obj365.py
c.c. @amyeroberts
Maybe @NielsRogge could have a look?
Cough cough, c.c @amyeroberts
This is the result of current commit
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
c.c. @amyeroberts
@EduardoPach One additional comment for the denoising.
OpenGroundingDino : Has configuration file of denoising related variable but not used in training https://github.com/longzw1997/Open-GroundingDino/blob/924bb6c4b93cae2dae582e1afaeccd408c72a31d/models/GroundingDINO/groundingdino.py#L750
MMDetection GroundingDino : Has also configuration and also training pipeline such as auxiliary or encoder. https://github.com/IDEA-Research/GroundingDINO/blob/856dde20aee659246248e20734ef9ba5214f5e44/groundingdino/models/GroundingDINO/groundingdino.py#L384
https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/dense_heads/grounding_dino_head.py#L601
Denoising pipeline is for helping faster convergence so it will be okay to fine-tune without this feature. Current version of this PR is lack of denoising pipeline. We need to also aware that adding denoising pipeline as TO DO 👍🏼
Hey! 🤗 Thanks for your contribution to the transformers library!
Before merging this pull request, slow tests CI should be triggered. To enable this:
- Add the
run-slowlabel to the PR - When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command
[run-slow]followed by a comma separated list of all the models to be tested, i.e.[run_slow] model_to_test_1, model_to_test_2- If the pull request affects a lot of models, put at most 10 models in the commit message
- A
transformersmaintainer will then approve the workflow to start the tests
(For maintainers) The documentation for slow tests CI on PRs is here.
@qubvel Hi Pavel I have got delegated the permission from @EduardoPach So I will be also working this PR to be get merged. There is one issue from the CI that is related to permission otherwise it is good to go. (Do you know about the permission error?)
Hi @SangbumChoi, got it! Can you please drop the two latest commits + rebase the branch to the main locally (instead of merging), then force push changes? Otherwise, I see that 862 files changed in this PR 😄
Meanwhile, I will try to figure out the test fetch error.
Thanks, the diff looks better!
I suppose tests issues are in the debugging stage
- https://github.com/huggingface/transformers/pull/33862
Are we still waiting for the slow tests? cc @SangbumChoi
@ydshieh can you help with the tests? fetch tests is showing an Unauthorized error
Hi @EduardoPach and @SangbumChoi! Can you please rebase PR on the main branch on transformers? This should resolve an issue with tests running
Slow tests are failed 🥲 I suppose most of them in the same state on the main, but the loss test is also failed, can you have a look? At least we need to add a better message in assert to see the diff.
@qubvel Sorry, the problem was not synchronize the device type of model and the input.
Resolved by https://github.com/huggingface/transformers/pull/31828/commits/de356aeacc2cffbd253f1e8cd4e0f62205a25cf2
Strange thing is that local CLI returns follows but the CI pass on this
root@ccb8b39e6c2d:/mnt/nas2/users/sbchoi/transformers# RUN_SLOW=1 pytest tests/models/grounding_dino/test_modeling_grounding_dino.py
========================================= short test summary info =========================================
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_cross_attention_mask - AssertionError: False is not true
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_grounding_dino_loss - AssertionError: False is not true
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_inference_object_detection_head_equivalence_cpu_gpu - AssertionError: False is not true
==================== 3 failed, 46 passed, 109 skipped, 3 warnings in 211.76s (0:03:31) ====================
Can you also confirm about this? -> turns out that CI was having some delay (Local was correct)
I have these ones also failed locally, making a fix for cross_attention + batch equivalence
========================================= short test summary info =========================================
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_cross_attention_mask - AssertionError: False is not true
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_grounding_dino_loss - AssertionError: False is not true
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_inference_object_detection_head_equivalence_cpu_gpu - AssertionError: False is not true
==================== 3 failed, 46 passed, 109 skipped, 3 warnings in 211.76s (0:03:31) ====================
Two tests failed in CI
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_grounding_dino_loss - AssertionError: False is not true
FAILED tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_inference_object_detection_head_equivalence_cpu_gpu - AssertionError: False is not true
For cpu-gpu test I have pretty large diffs locally and even nan (on main)
{
"logits": tensor(0.0),
"pred_boxes": tensor(0.8244),
"last_hidden_state": tensor(0.9422),
"init_reference_points": tensor(0.7221),
"intermediate_hidden_states": tensor(2.5513),
"intermediate_reference_points": tensor(0.8300),
"encoder_last_hidden_state_vision": tensor(0.0001),
"encoder_last_hidden_state_text": tensor(9.5665e-06),
"enc_outputs_class": tensor(1.8090e-05),
"enc_outputs_coord_logits": tensor(nan),
}
@qubvel
tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_inference_object_detection_head_equivalence_cpu_gpu
is failing on our CI (main)
which could be find in #transformers-ci-daily-models channel
@ydshieh @qubvel Currently I'm debugging why this failure happens
@SangbumChoi
tests/models/grounding_dino/test_modeling_grounding_dino.py::GroundingDinoModelIntegrationTests::test_inference_object_detection_head_equivalence_cpu_gpu
is failing for quite long time 😅 but if you are willing to check it we appreciate a lot, but it's not necessary to merge this PR.
(but test_grounding_dino_loss yes)
is failing for quite long time 😅 but if you are willing to check it we appreciate a lot, but it's not necessary to merge this PR.
@ydshieh For the aforemention error of this thing. I am not 100 % sure but I think it is related to somehow seed problem. The reason why I think this way is because even though I did not change any difference, somtimes the test pass.
@SangbumChoi do you mean the loss test is random seed dependant? In that case, we can either slightly increase the tolerance or mask it as @is_flaky()
do you mean the loss test is random seed dependant? In that case, we can either slightly increase the tolerance or mask it as @is_flaky()
@qubvel Still not sure about it. I will debug deeper more tomorrow but the fact is that sometime the CI pass and sometime fails 😭 (Or there might be a computing issue with torch.backends.cudnn.deterministic = True)
Ok, sure!
HI @SangbumChoi , from the test_grounding_dino_loss, you have the "loss_ce_enc": torch.tensor(16226.3145),:
the scale of the loss is way to big from other loss? Isn't the trained network should have smaller scale loss or is there something wrong with the implementation?
Or does the loss actually need to include loss_ce_enc since the first stage is just to used for the regional proposal?
@stevenwudi
loss actually need to include loss_ce_enc since the first stage is just to used for the regional proposal?
Yeah, otherwise can you explain more detail about the reason why "loss_ce_enc" should be low? (I am also open to discuss this circumstances!)