fix(transformerblock): conditionally initialize cross attention compo…
Transformerblock
This PR improves the initialization logic of TransformerBlock by adding conditional layer creation for cross-attention. Previously, the cross-attention components (norm_cross_attn and cross_attn) were always initialized, even when with_cross_attention=False. This introduced unnecessary memory overhead and additional parameters that were never used during forward passes.
With this change, the cross-attention layers are initialized only when with_cross_attention=True. This ensures cleaner module definitions, reduced memory usage, and avoids confusion about unused layers in models that do not require cross-attention.
Fixes # .
Description
Before
...
class TransformerBlock(nn.Module):
def __init__(
self,
...
) -> None:
...
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
After
...
class TransformerBlock(nn.Module):
def __init__(
self,
...
) -> None:
...
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention
if with_cross_attention:
self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage. - [ ] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests. - [ ] In-line docstrings updated.
- [ ] Documentation updated, tested
make htmlcommand in thedocs/folder.
Walkthrough
The TransformerBlock now instantiates cross-attention components only when with_cross_attention is True. Specifically, norm\_cross\_attn and cross\_attn are created conditionally; previously they were always initialized. The forward path already checks with_cross_attention, so runtime behavior is unchanged when the flag is False, but those attributes may be absent on instances created with with_cross_attention=False.
Estimated code review effort
🎯 2 (Simple) | ⏱️ ~10 minutes
[!TIP]
🔌 Remote MCP (Model Context Protocol) integration is now available!
Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.
📜 Recent review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📥 Commits
Reviewing files that changed from the base of the PR and between bad0028c3f46ddc937ee3771a396049f188f4e61 and e71c4f9ef69eb045ba2ab49697f32366640ed9a7.
📒 Files selected for processing (1)
monai/networks/blocks/transformerblock.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- monai/networks/blocks/transformerblock.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
✨ Finishing Touches
- [ ] 📝 Generate Docstrings
🧪 Generate unit tests
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
🪧 Tips
Chat
There are 3 ways to chat with CodeRabbit:
- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
I pushed a fix in commit <commit_id>, please review it.Open a follow-up GitHub issue for this discussion.
- Files and specific lines of code (under the "Files changed" tab): Tag
@coderabbitaiin a new review comment at the desired location with your query. - PR comments: Tag
@coderabbitaiin a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
Support
Need help? Create a ticket on our support page for assistance with any issues or questions.
CodeRabbit Commands (Invoked using PR/Issue comments)
Type @coderabbitai help to get the list of available commands.
Other keywords and placeholders
- Add
@coderabbitai ignoreanywhere in the PR description to prevent this PR from being reviewed. - Add
@coderabbitai summaryto generate the high-level summary at a specific location in the PR description. - Add
@coderabbitaianywhere in the PR title to generate the title automatically.
Status, Documentation and Community
- Visit our Status Page to check the current availability of CodeRabbit.
- Visit our Documentation for detailed information on how to use CodeRabbit.
- Join our Discord Community to get help, request features, and share feedback.
- Follow us on X/Twitter for updates and announcements.
Hi @yunhaoli24 thanks for the contribution, but as the Coderabbit comment notes this will break existing stored weight files since the members will not always exist. Your change is more efficient but I think requires more consideration to avoid this issue, and also should be documented in the docstring that this is happening. We are working on a quick release to be done shortly but then we can return to this PR.
Thanks for the feedback! I agree with your point about checkpoint compatibility; that’s indeed an important concern. At the same time, I’d like to note that in a DDP multi-GPU setting, initializing but not using the cross_attn layers can lead to torch errors about unused parameters during training. So this change could also help avoid those issues in addition to saving memory. I think a good compromise might be to add backward-compatibility handling when loading checkpoints, or at least documenting this behavior clearly in the docstring.
Thanks for the feedback! I agree with your point about checkpoint compatibility; that’s indeed an important concern. At the same time, I’d like to note that in a DDP multi-GPU setting, initializing but not using the cross_attn layers can lead to torch errors about unused parameters during training. So this change could also help avoid those issues in addition to saving memory. I think a good compromise might be to add backward-compatibility handling when loading checkpoints, or at least documenting this behavior clearly in the docstring.
Yes I have DDP errors related to this so you're right, this is one advantage to consider. Please do update the docstring to record what's going on here. One other thing to consider is Torchscript compatibility which I can't say is being tested thoroughly enough for this class. I know it has in the past complained about absent members even if the Python code works, you may need to always define the members but use nn.Identity as their values when with_cross_attention is False (in which case the condition check isn't needed in forward either).
Thanks for the suggestion! Regarding defining cross_attn as nn.Identity, I have a few concerns:
- Interface mismatch: The current forward unconditionally calls self.cross_attn(self.norm_cross_attn(x), context=context), whereas nn.Identity.forward only accepts a single input and doesn’t support a context positional or keyword argument. This would raise an “unexpected keyword argument 'context'” error in both Eager and TorchScript.
- Semantic change: Even if we change the call to pass only x to avoid the error, nn.Identity returns the normalized x as-is. That adds a non-zero residual when with_cross_attention=False, which breaks the expectation that disabling cross-attn should not alter the output.
- TorchScript compatibility: Scripting performs static signature checks. Invoking a module with a context argument when the target doesn’t accept it will fail at script/trace time or produce an inconsistent graph.
Thanks for the suggestion! Regarding defining cross_attn as nn.Identity, I have a few concerns:
Sorry for the delay in replying. I see the issue you're mentioning about nn.Identity, one possible solution is to define a new nn.Module subclass whose forward method accepts two arguments but returns 0 always, it's less computationally efficient however. We should double check that Torchscript works with your change as it is first and that the unit tests are thorough enough in that regard, but a second solution is to define self.cross_attn as nn.Identity in the constructor (and always define self.norm_cross_attn) when with_cross_attention is false, and just not use it in the forward method.
@virginiafdez could you please assist here as well? Thanks!
I would like to mention that another PR https://github.com/Project-MONAI/MONAI/pull/8584 is looking at the same problem. I suggested we consolidate efforts here.
I’d be happy to continue working on this PR. Please let me know if there’s anything that needs to be modified or improved. @ericspod
Hi @yunhaoli24, many thanks for spotting this and apologies for the delay. Since this issue is duplicated and also addressed in the PR @ericspod linked, which also addresses the checkpoint problem, we will move the fix to that PR for the time being. Further tests and fixes are required, but I will keep these two PRs open for the time being and make sure I update as soon as the fix's done.