transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[WIP] VMamba implementation

Open dmus opened this issue 1 year ago • 20 comments

What does this PR do?

Fixes #28606

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

dmus avatar Jan 22 '24 20:01 dmus

@dmus Very exciting! Let us know when the PR is ready for review.

cc @ArthurZucker as I believe there's an on-going mamba implementation that we might want to coordinate with here

amyeroberts avatar Jan 22 '24 20:01 amyeroberts

I don't have bandwidth yet so nice if you want to do ti!

ArthurZucker avatar Jan 23 '24 15:01 ArthurZucker

I have a question about this test in test_modeling_common.py

def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.forward)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
                expected_arg_names.extend(
                    ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
                    if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
                    else ["encoder_outputs"]
                )
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
            elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
                expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
                self.assertListEqual(arg_names, expected_arg_names)
            elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
                expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
                self.assertListEqual(arg_names, expected_arg_names)
            else:
                expected_arg_names = [model.main_input_name]
                self.assertListEqual(arg_names[:1], expected_arg_names)

it fails because model.main_input_name is equal to 'input_ids'. I don't know why this is the case and where I can set this model.main_input_name. This is a pure vision model, no NLP

dmus avatar Jan 29 '24 11:01 dmus

@dmus For certain tests like these, where the default implementation doesn't apply, we override the test in the model's testing module e.g. like here for DETR.

amyeroberts avatar Jan 29 '24 12:01 amyeroberts

And for my understanding, is it expected that the forward method of the VMambaForImageClassification returns a ImageClassifierOutputWithNoAttention object?

And the VMamba model should return a BaseModelOutput?

dmus avatar Jan 29 '24 14:01 dmus

And for my understanding, is it expected that the forward method of the VMambaForImageClassification returns a ImageClassifierOutputWithNoAttention object?

And the VMamba model should return a BaseModelOutput?

Yep!

amyeroberts avatar Jan 29 '24 14:01 amyeroberts

Two tests that are still failing are the VMambaModelTest::test_torch_fx and VMambaModelTest::test_torch_fx_output_loss because AssertionError: Couldn't trace module: Model VMambaModel is not supported yet, supported models: AlbertForMaskedLM, AlbertForMultipleChoice, AlbertForPreTraining, AlbertForQuestionAnswering, AlbertForSequenceClassificat...

I am not sure what to do here, suggestions?

dmus avatar Jan 30 '24 07:01 dmus

Also the test_initialization is failing. This check that fails is

if param.requires_grad: self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", )

it fails with

AssertionError: 0.0031509040854871273 not found in [0.0, 1.0] : Parameter patch_embed.proj.weight of model <class 'transformers.models.vmamba.modeling_vmamba.VMambaModel'> seems not properly initialized

I am not sure what is exactly expected here. Also if I copy the _init_weights from modeling_vit.py the test fails.

dmus avatar Jan 30 '24 09:01 dmus

Two tests that are still failing are the VMambaModelTest::test_torch_fx and VMambaModelTest::test_torch_fx_output_loss because AssertionError: Couldn't trace module: Model VMambaModel is not supported yet, supported models: AlbertForMaskedLM, AlbertForMultipleChoice, AlbertForPreTraining, AlbertForQuestionAnswering, AlbertForSequenceClassificat...

I am not sure what to do here, suggestions?

In the testing suite, you can force fx test not to run by setting fx_compatible = False e.g. like here

amyeroberts avatar Jan 31 '24 17:01 amyeroberts

Also the test_initialization is failing. This check that fails is

if param.requires_grad: self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", )

it fails with

AssertionError: 0.0031509040854871273 not found in [0.0, 1.0] : Parameter patch_embed.proj.weight of model <class 'transformers.models.vmamba.modeling_vmamba.VMambaModel'> seems not properly initialized

I am not sure what is exactly expected here. Also if I copy the _init_weights from modeling_vit.py the test fails.

The initialization of the weights should follow what's been done for the original model implementation, not other models in the library e.g. vit.

If the weight initialization in _init_weights is different from the default behaviour, then you need to override the test in test_modeling_vmamba.py e.g. like here for BLIP

amyeroberts avatar Jan 31 '24 17:01 amyeroberts

All test are passing now locally. (I still use a local checkpoint because the vmamba checkpoint is not uploaded to the hub yet). So I think this PR is ready for review.

I did implement the VMambaForImageClassification.

Not done yet is implementing the VMambaForSegmentation and VMambaForObjectDetection because that requires a bit more work, but that could follow a similar design as the VMambaForImageClassification

dmus avatar Jan 31 '24 19:01 dmus

@dmus Great - glad to hear it's in a good stage!

Regarding the other classes e.g. VMambaForObjectDetection, they can always be added in follow-up PRs.

For the tests:

  • The tests relating to natten aren't releated to this PR, and was an issue we encountered with new releases and library compatibility. A fix has been merged into main. Rebasing to include these commits should resolve those.
  • Make sure to read the errors printed out in the CI runs. For some of the tests, e.g. check_repo, it's telling you that the class needs to be added to the public init. The other faliures e.g. No module named 'torch' are happening because the vmamba classes don't have the import protections in src/transformers/models/vmamba/__init__.py for when packages like torch or timm aren't available. You can look at other model PRs e.g. like this one for Swin which shows you all the places you need to modify in the codebase to fully add a model.

amyeroberts avatar Jan 31 '24 20:01 amyeroberts

For this error:

Checking all objects are properly documented. Traceback (most recent call last): File "/home/derk/transformers/utils/check_repo.py", line 1181, in check_repo_quality() File "/home/derk/transformers/utils/check_repo.py", line 1165, in check_repo_quality check_all_objects_are_documented() File "/home/derk/transformers/utils/check_repo.py", line 1047, in check_all_objects_are_documented raise Exception( Exception: The following objects are in the public init so should be documented:

  • VMambaForImageClassification

What exactly should be documented? It looks like the VMambaForImageClassification in modeling_vmamba.py is documented

dmus avatar Feb 01 '24 15:02 dmus

It should be listed under docs/source/en/model_doc/vmamba.md

amyeroberts avatar Feb 01 '24 16:02 amyeroberts

Thanks. Now I encounter this error when running python utils/check_inits.py:

Traceback (most recent call last): File "/home/derk/transformers/utils/check_inits.py", line 370, in check_all_inits() File "/home/derk/transformers/utils/check_inits.py", line 298, in check_all_inits raise ValueError("\n\n".join(failures)) ValueError: Problem in src/transformers/models/vmamba/init.py, both halves do not define the same objects. Differences for base imports: from .modeling_vmamba import in TYPE_HINT but not in _import_structure. VMAMBA_PRETRAINED_MODEL_ARCHIVE_LIST in TYPE_HINT but not in _import_structure. VMambaForImageClassification in TYPE_HINT but not in _import_structure. VMambaModel in TYPE_HINT but not in _import_structure. VMambaPreTrainedModel in TYPE_HINT but not in _import_structure. in TYPE_HINT but not in _import_structure. VMAMBA_PRETRAINED_MODEL_ARCHIVE_LIST in _import_structure but not in TYPE_HINT. VMambaForImageClassification in _import_structure but not in TYPE_HINT. VMambaModel in _import_structure but not in TYPE_HINT. VMambaPreTrainedModel in _import_structure but not in TYPE_HINT.

what is TYPE_HINT and where should I fix this?

dmus avatar Feb 01 '24 17:02 dmus

This PR is very exciting! Glad to see VMamba making it to the transformers library 🤗 ! @dmus is this still on your radar?

Dexterp37 avatar Mar 23 '24 08:03 Dexterp37

Yes, stil on the radar. Will try to update it soon

Op za 23 mrt 2024 01:48 schreef Alessio Placitelli @.***

:

This PR is very exciting! Glad to see VMamba making it to the transformers library 🤗 ! @dmus https://github.com/dmus is this still on your radar?

— Reply to this email directly, view it on GitHub https://github.com/huggingface/transformers/pull/28652#issuecomment-2016415379, or unsubscribe https://github.com/notifications/unsubscribe-auth/AADRL6SJEEUWBSB2VJ7XEYDYZU6UNAVCNFSM6AAAAABCF2IMQWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMJWGQYTKMZXHE . You are receiving this because you were mentioned.Message ID: @.***>

dmus avatar Mar 23 '24 15:03 dmus

Hi @dmus - any updates on this PR?

amyeroberts avatar Apr 19 '24 09:04 amyeroberts

Hi @dmus - any updates on this PR?

I now encounter the following error while running the tests.

Traceback (most recent call last): File "/home/derk/transformers/utils/check_inits.py", line 370, in check_all_inits() File "/home/derk/transformers/utils/check_inits.py", line 298, in check_all_inits raise ValueError("\n\n".join(failures)) ValueError: Problem in src/transformers/models/vmamba/init.py, both halves do not define the same objects. Differences for base imports: from .modeling_vmamba import in TYPE_HINT but not in _import_structure. VMAMBA_PRETRAINED_MODEL_ARCHIVE_LIST in TYPE_HINT but not in _import_structure. VMambaForImageClassification in TYPE_HINT but not in _import_structure. VMambaModel in TYPE_HINT but not in _import_structure. VMambaPreTrainedModel in TYPE_HINT but not in _import_structure. in TYPE_HINT but not in _import_structure. VMAMBA_PRETRAINED_MODEL_ARCHIVE_LIST in _import_structure but not in TYPE_HINT. VMambaForImageClassification in _import_structure but not in TYPE_HINT. VMambaModel in _import_structure but not in TYPE_HINT. VMambaPreTrainedModel in _import_structure but not in TYPE_HINT.

How could I solve this?

dmus avatar Apr 25 '24 13:04 dmus