Add LWDetr model
What does this PR do?
Adds LWDetr model. In #36895 I started working on adding RFDetr, but after putting some work I realized that it relies a LOT on LWDetr. Adding RFDetr will essentially replace the ViT encoder by Dino, so the biggest part of the work is on the implementation of LWDetr, which could also be a good alternative for people to use for their use cases.
Who can review?
Still work in progress but since @yonigozlan asked for an update here it is. All the inference code is implemented. A lot of refactoring/renaming is still needed and I'm writing the tests to be able to do that safely. In the meantime you can check the code and let me know if you have comments.
@qubvel
@yonigozlan @qubvel ready for a first review
hey @sbucaille ! Just checking in to see if I should make another pass at this. Don't hesitate if you need any help!
Hey @yonigozlan, thanks for the follow up ! Finally got some time again to put into the PR, I've addressed most of your comments but I'd like to reply to some here regarding the attention implementations that you suggest to rewrite. I thought about this when dealing with ViT as well as DeformableDetr (or other Detr models) (indeed that's a lot of work around in this PR for now) but thought that rewriting these models attentions would not belong to this PR. Although I'd be happy to contribute to updating these models attention implementations to follow Llama standards but I think it would belong to another PR, let me know what you think. Regarding the tests, I have some issue with the following :
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_00_fp16_pad_left_sdpa_kernels PASSED [ 11%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_01_fp16_pad_left PASSED [ 11%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_02_fp16_pad_left_no_attn_mask_sdpa_kernels PASSED [ 12%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_03_fp16_pad_left_no_attn_mask PASSED [ 12%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_04_fp16_pad_right_sdpa_kernels PASSED [ 13%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_05_fp16_pad_right PASSED [ 13%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_06_fp16_pad_right_no_attn_mask_sdpa_kernels PASSED [ 14%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_07_fp16_pad_right_no_attn_mask PASSED [ 15%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_08_fp32_pad_left_sdpa_kernels FAILED [ 15%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_09_fp32_pad_left FAILED [ 16%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_10_fp32_pad_left_no_attn_mask_sdpa_kernels FAILED [ 16%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_11_fp32_pad_left_no_attn_mask FAILED [ 17%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_12_fp32_pad_right_sdpa_kernels FAILED [ 17%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_13_fp32_pad_right FAILED [ 18%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_14_fp32_pad_right_no_attn_mask_sdpa_kernels FAILED [ 18%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_15_fp32_pad_right_no_attn_mask FAILED [ 19%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_16_bf16_pad_left_sdpa_kernels PASSED [ 20%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_17_bf16_pad_left PASSED [ 20%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_18_bf16_pad_left_no_attn_mask_sdpa_kernels PASSED [ 21%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_19_bf16_pad_left_no_attn_mask PASSED [ 21%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_20_bf16_pad_right_sdpa_kernels PASSED [ 22%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_21_bf16_pad_right PASSED [ 22%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_22_bf16_pad_right_no_attn_mask_sdpa_kernels PASSED [ 23%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_23_bf16_pad_right_no_attn_mask PASSED [ 24%]
tests/models/lw_detr/test_modeling_lw_detr.py::LwDetrModelTest::test_eager_matches_sdpa_inference_24_fp32_pad_left_output_attentions FAILED [ 24%]
Notice how it fails only on fp32 tests, have you had a similar problem in the past, what do you think would be a potential solution ?
That being said, you can have another pass on the PR if you have some time ! 😃
Hey @yonigozlan thanks for the comments, I addressed them. Regarding the vision models refactor, do you have a branch somewhere so that I can anticipate the changes on my side ? The branch is rebased on main and I dealt with the latest changes ! For the flash attention tests I'll try to get access to a CUDA machine to investigate it
Managed to fix the test_sdpa_can_compile_dynamic and test_sdpa_can_dispatch_on_flash tests, it was a matter of dtype in the loss function, please let me know if the solution I found is correct in d2fed66105f2ade8251a1fb73eb00edae456187b. I still have flash attention failures on my own, I don't know if it is setup related but I have this kind of error for test_flash_attn_2_inference_equivalence_right_padding and test_flash_attn_2_inference_equivalence:
E RuntimeError: schema_.has_value() INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.h":80, please report a bug to PyTorch. Tried to access the schema for which doesn't have a schema registered yet
@yonigozlan addressed the few comments, also added an additional test to cover large models which have differences in the MultiScaleProjector module
@yonigozlan @Rocketknight1 @ArthurZucker congrats on the v5 ! Bumping LWDetr if you have some bandwidth now to review it. I've polished small documentation details and updated all model cards on the Hub, maybe @stevhliu would like to have a look !
Hey @yonigozlan @Cyrilvallez @stevhliu, thanks for your reviews, I addressed most of your comments and commented others, check it out and let me know what you think
Just answered you for the config thing! Let me know when you made the change, and I'll review again!
[For maintainers] Suggested jobs to run (before merge)
run-slow: auto, lw_detr
@Cyrilvallez done !