MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

fix transformerblock

Open yang-ze-kang opened this issue 3 months ago • 3 comments

Fixes # monai/networks/blocks/transformerblock.py

Description

When "with_cross_attention==False", there is no need to initialize "CrossAttentionBlock" in "init"; otherwise, it will introduce unnecessary parameters to the model and may potentially cause some errors.

Types of changes

  • [x] Non-breaking change (fix or new feature that would not break existing functionality).
  • [ ] Breaking change (fix or new feature that would cause existing functionality to change).
  • [ ] New tests added to cover the changes.
  • [ ] Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • [ ] Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • [ ] In-line docstrings updated.
  • [ ] Documentation updated, tested make html command in the docs/ folder.

yang-ze-kang avatar Sep 25 '25 04:09 yang-ze-kang

Walkthrough

Instantiation of cross-attention components in monai/networks/blocks/transformerblock.py is now conditional: norm_cross_attn and cross_attn are created only when with_cross_attention is True. When False, those attributes are not instantiated and a pre-load hook is registered to drop any state_dict entries related to cross_attn and norm_cross_attn during loading. The forward path remains gated by with_cross_attention. No exported or public entity declarations were altered.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title is too generic and does not clearly convey the main change of conditional cross-attention initialization in TransformerBlock, so it fails to inform readers what is being fixed. Please update the title to something more descriptive, for example “Conditionally initialize cross-attention in TransformerBlock when with_cross_attention=False,” so it clearly reflects the key change.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed The description follows the repository template by including a summary of the fix, a clear description of the change, and a types of changes checklist, making it largely complete and aligned with expectations.
✨ Finishing touches
  • [ ] 📝 Generate Docstrings
🧪 Generate unit tests
  • [ ] Create PR with unit tests
  • [ ] Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

coderabbitai[bot] avatar Sep 25 '25 04:09 coderabbitai[bot]

Hi @yang-ze-kang thanks for the contribution. We have another PR attempting to address the same issue here but doesn't include your approach to dealing with loading weights. Would it be possible to comment on your changes there so we can consolidate the fixes into that PR? There was a few points of discussion there as well that would be relevant for your PR so it would be sensible to combine efforts I feel. Thanks!

ericspod avatar Oct 29 '25 17:10 ericspod

Hi, @yang-ze-kang, sorry for the delay in us reviewing this PR. Since this PR is identical to https://github.com/Project-MONAI/MONAI/pull/8545, but adds the state_dict logic to ensure that the weights are compatible with old ones, we will incorporate the changes of this PR. However, a few changes are still required:

  1. Make sure that the CI tests run. Currently, there is a listing issue which will get sorted with: ./runtests.sh --autofix
  2. The diffusion model code https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/diffusion_model_unet.py also includes a Transformer Block which presents the same problem. Could you incorporate the same logic here?
  3. Last, but not least, in order to verify that things work properly, could you create a test (in tests/networks/blocks/test_transformerblock.py) showcasing that you can load the weights of a model with cross attention layers to a model without them?

virginiafdez avatar Oct 30 '25 08:10 virginiafdez