fix transformerblock
Fixes # monai/networks/blocks/transformerblock.py
Description
When "with_cross_attention==False", there is no need to initialize "CrossAttentionBlock" in "init"; otherwise, it will introduce unnecessary parameters to the model and may potentially cause some errors.
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage. - [ ] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests. - [ ] In-line docstrings updated.
- [ ] Documentation updated, tested
make htmlcommand in thedocs/folder.
Walkthrough
Instantiation of cross-attention components in monai/networks/blocks/transformerblock.py is now conditional: norm_cross_attn and cross_attn are created only when with_cross_attention is True. When False, those attributes are not instantiated and a pre-load hook is registered to drop any state_dict entries related to cross_attn and norm_cross_attn during loading. The forward path remains gated by with_cross_attention. No exported or public entity declarations were altered.
Estimated code review effort
🎯 2 (Simple) | ⏱️ ~10 minutes
Pre-merge checks and finishing touches
❌ Failed checks (1 warning, 1 inconclusive)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Docstring Coverage | ⚠️ Warning | Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. | You can run @coderabbitai generate docstrings to improve docstring coverage. |
| Title Check | ❓ Inconclusive | The title is too generic and does not clearly convey the main change of conditional cross-attention initialization in TransformerBlock, so it fails to inform readers what is being fixed. | Please update the title to something more descriptive, for example “Conditionally initialize cross-attention in TransformerBlock when with_cross_attention=False,” so it clearly reflects the key change. |
✅ Passed checks (1 passed)
| Check name | Status | Explanation |
|---|---|---|
| Description Check | ✅ Passed | The description follows the repository template by including a summary of the fix, a clear description of the change, and a types of changes checklist, making it largely complete and aligned with expectations. |
✨ Finishing touches
- [ ] 📝 Generate Docstrings
🧪 Generate unit tests
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.
Hi @yang-ze-kang thanks for the contribution. We have another PR attempting to address the same issue here but doesn't include your approach to dealing with loading weights. Would it be possible to comment on your changes there so we can consolidate the fixes into that PR? There was a few points of discussion there as well that would be relevant for your PR so it would be sensible to combine efforts I feel. Thanks!
Hi, @yang-ze-kang, sorry for the delay in us reviewing this PR. Since this PR is identical to https://github.com/Project-MONAI/MONAI/pull/8545, but adds the state_dict logic to ensure that the weights are compatible with old ones, we will incorporate the changes of this PR. However, a few changes are still required:
- Make sure that the CI tests run. Currently, there is a listing issue which will get sorted with: ./runtests.sh --autofix
- The diffusion model code https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/diffusion_model_unet.py also includes a Transformer Block which presents the same problem. Could you incorporate the same logic here?
- Last, but not least, in order to verify that things work properly, could you create a test (in tests/networks/blocks/test_transformerblock.py) showcasing that you can load the weights of a model with cross attention layers to a model without them?