Restormer Implementation for MONAI: High-Resolution Image Restoration
Is your feature request related to a problem? Please describe.
I've noticed that MONAI currently lacks dedicated models for image denoising and restoration tasks. While MONAI provides excellent tools for medical image analysis, having specialized architectures for improving image quality would be valuable for preprocessing pipelines and enhancing low-quality medical images (microscopy, X-ray, scans...).
Describe the solution you'd like
I have implemented a well-documented version of the Restormer model (https://arxiv.org/abs/2111.09881) that could be contributed to MONAI. The implementation includes key components like:
- Multi-DConv Head Transposed Self-Attention (MDTA) for efficient attention computation
- Gated-DConv Feed-Forward Network (GDFN) for refined feature selection
- Modular architecture allowing easy extension and modification
- Support for flash attention when available
- Comprehensive documentation of components and architecture
Describe alternatives you've considered
The implementation is already structured in a modular way with clear separation of components. I'm willing to:
- Refactor the code to meet MONAI coding standards
- Add appropriate type hints and docstrings
- Include unit tests
- Provide example notebooks demonstrating usage
- Add benchmarks comparing performance
Additional context
The code is currently functional and tested. It supports both standard and dual-pixel tasks, with configurable parameters for network depth, attention heads, and feature dimensions. The implementation prioritizes efficiency through features like flash attention support while maintaining flexibility for different use cases.
@Nic-Ma and @ericspod and @KumoLiu - this seems like an outstanding addition to MONAI - agreed?
@phisanti - if all approve, please look at our contribution guidelines. You are already doing the exact right thing by having a modular design. Whenever appropriate, please support the exploration of alternative components in this framework via that modular design and appropriate class abstractions. Please also include multiple tutorials and unit tests with your work.
Does your code currently exist in another repo that we could preliminarily review?
Thanks!
You can take a look at the modular implementation of the Restormer architecture here. Also copied the code below. As you can see, I maintain many of the key blocks intact and focus on expanding functionality (Flash att), and adding modularity on the enc/dec blocks. I am happy to implement extra changes if a good suggestion is made.
"""
Restormer: Efficient Transformer for High-Resolution Image Restoration
Implementation based on: https://arxiv.org/abs/2111.09881
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers import Norm
from einops import rearrange
class FeedForward(nn.Module):
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool):
super().__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
return self.project_out(F.gelu(x1) * x2)
class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
class TransformerBlock(nn.Module):
"""Basic transformer unit combining MDTA and GDFN with skip connections.
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
for better adaptation to image restoration tasks."""
def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float,
bias: bool, LayerNorm_type: str, flash_attention: bool = False):
super().__init__()
use_bias = LayerNorm_type != 'BiasFree'
self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.attn = Attention(dim, num_heads, bias, flash_attention)
self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
#print(f'x shape in transformer block: {x.shape}')
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""
def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class Downsample(nn.Module):
"""Downsampling module that halves spatial dimensions while doubling channels.
Uses PixelUnshuffle for efficient feature map manipulation."""
def __init__(self, n_feat: int):
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(n_feat, n_feat//2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelUnshuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
class Upsample(nn.Module):
"""Upsampling module that doubles spatial dimensions while halving channels.
Combines convolution with PixelShuffle for efficient feature expansion."""
def __init__(self, in_channels: int) -> None:
super().__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, kernel_size=3,
stride=1, padding=1, bias=False),
nn.PixelShuffle(2)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.body(x)
##---------- Restormer -----------------------
class Restormer(nn.Module):
"""Restormer: Efficient Transformer for High-Resolution Image Restoration.
Implements a U-Net style architecture with transformer blocks, combining:
- Multi-scale feature processing through progressive down/upsampling
- Efficient attention via MDTA blocks
- Local feature mixing through GDFN
- Skip connections for preserving spatial details
Architecture:
- Encoder: Progressive feature downsampling with increasing channels
- Latent: Deep feature processing at lowest resolution
- Decoder: Progressive upsampling with skip connections
- Refinement: Final feature enhancement
"""
def __init__(self,
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[1, 1, 1, 1],
heads=[1, 1, 1, 1],
num_refinement_blocks=4,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=False,
flash_attention=False):
super().__init__()
"""Initialize Restormer model.
Args:
inp_channels: Number of input image channels
out_channels: Number of output image channels
dim: Base feature dimension
num_blocks: Number of transformer blocks at each scale
num_refinement_blocks: Number of final refinement blocks
heads: Number of attention heads at each scale
ffn_expansion_factor: Expansion factor for feed-forward network
bias: Whether to use bias in convolutions
LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
dual_pixel_task: Enable dual-pixel specific processing
flash_attention: Use flash attention if available
"""
# Check input parameters
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
# Initial feature extraction
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_levels = nn.ModuleList()
self.downsamples = nn.ModuleList()
self.decoder_levels = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.reduce_channels = nn.ModuleList()
num_steps = len(num_blocks) - 1
self.num_steps = num_steps
# Define encoder levels
for n in range(num_steps):
current_dim = dim * 2**n
self.encoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=current_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
self.downsamples.append(Downsample(current_dim))
# Define latent space
latent_dim = dim * 2**num_steps
self.latent = nn.Sequential(*[
TransformerBlock(
dim=latent_dim,
num_heads=heads[num_steps],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[num_steps])
])
# Define decoder levels
for n in reversed(range(num_steps)):
current_dim = dim * 2**n
next_dim = dim * 2**(n+1)
self.upsamples.append(Upsample(next_dim))
# Reduce channel layers to deal with skip connections
if n != 0:
self.reduce_channels.append(
nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias)
)
decoder_dim = current_dim
else:
decoder_dim = next_dim
self.decoder_levels.append(
nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[n],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_blocks[n])
])
)
# Final refinement and output
self.refinement = nn.Sequential(*[
TransformerBlock(
dim=decoder_dim,
num_heads=heads[0],
ffn_expansion_factor=ffn_expansion_factor,
bias=bias,
LayerNorm_type=LayerNorm_type,
flash_attention=flash_attention
) for _ in range(num_refinement_blocks)
])
self.dual_pixel_task = dual_pixel_task
if self.dual_pixel_task:
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
def forward(self, x):
"""Forward pass of Restormer.
Processes input through encoder-decoder architecture with skip connections.
Args:
inp_img: Input image tensor of shape (B, C, H, W)
Returns:
Restored image tensor of shape (B, C, H, W)
"""
assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step"
# Patch embedding
x = self.patch_embed(x)
skip_connections = []
# Encoding path
for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
x = encoder(x)
skip_connections.append(x)
x = downsample(x)
# Latent space
x = self.latent(x)
# Decoding path
for idx in range(len(self.decoder_levels)):
x = self.upsamples[idx](x)
x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
if idx < len(self.decoder_levels) - 1:
x = self.reduce_channels[idx](x)
x = self.decoder_levels[idx](x)
# Final refinement
x = self.refinement(x)
if self.dual_pixel_task:
x = x + self.skip_conv(skip_connections[0])
x = self.output(x)
else:
x = self.output(x)
return x
if __name__ == "__main__":
flash_att = True
test_model = Restormer(
inp_channels=2,
out_channels=2,
dim=16,
num_blocks=[1,1,1,1],
heads=[1,1,1,1],
num_refinement_blocks=2,
ffn_expansion_factor=1.5,
bias=False,
LayerNorm_type='WithBias',
dual_pixel_task=True,
flash_attention=flash_att
)
print(f'flash attention set to {flash_att}')
input_tensor = torch.randn(8, 2, 256, 256)
print(f"Input shape: {input_tensor.shape}")
output = test_model(input_tensor)
print(f"Output shape: {output.shape}")
print(f'printing final model')
from torchsummary import summary
summary(test_model, input_size=input_tensor)
```
@aylward @Nic-Ma @ericspod and @KumoLiu, if you all agree and there is no comments on extra modules to be added, I will implement the class as it is. For that, I will:
- Fork and create branch '8261-restormer-implementation'
- Place architecture in MONAI/monai/networks/nets folder
- Add extensive documentation (docstring + docs) following UNet class style as template
- Write unit tests following existing test patterns
- Create tutorial notebook with example dataset for the Project-MONAI/tutorials
- Submit PRs for both code and tutorial
What aspects of this approach would you modify to fully align with MONAI's contribution standards?
Hi @phisanti, thank you for sharing the comprehensive plan!
I’d recommend dividing the implementation into several PRs to simplify the review process. Additionally, I highly suggest checking if there are existing blocks in MONAI that can be reused in your network, such as upsample, downsample, attention mechanisms, etc.
https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/downsample.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/selfattention.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/spatialattention.py
Also, consider using Convolution, which could make your network support both 2D and 3D implementations seamlessly. https://github.com/Project-MONAI/MONAI/blob/e1e3d8ebc1c7247aad9f1bffc649c5a20084340f/monai/networks/blocks/convolutions.py#L25
Hi @KumoLiu and @aylward
I have just done an in depth review of the Upsample/Downsample, SABlock and Transformer blocks present in MONAI. From what I can see, using the local version of the Upsample/Downsample classes is trivial. I think, somthing as:
UpSample(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels//2,
mode=UpsampleMode.PIXELSHUFFLE,
scale_factor=2,
bias=False,
apply_pad_pool=False
)
should mirror the current behaviour in the current restormer. However using the MONAI classes for the SAB block would not work. The SABlock is a spatial attention mechanism based on Dosovitskiy paper. However, the restormer is a channel attention mechanism. See code below:
class Attention(nn.Module):
"""Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention."""
def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
super().__init__()
if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
raise ValueError("Flash attention not available")
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.flash_attention = flash_attention
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self._attention_fn = self._get_attention_fn()
def _get_attention_fn(self):
if self.flash_attention:
return self._flash_attention
return self._normal_attention
def _flash_attention(self, q, k, v):
"""Flash attention implementation using scaled dot-product attention."""
scale = float(self.temperature.mean())
out = F.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
dropout_p=0.0,
is_causal=False
)
return out
def _normal_attention(self, q, k, v):
"""Attention matrix multiplication with depth-wise convolutions."""
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
return attn @ v
def forward(self, x):
"""Forward pass for MDTA attention.
1. Apply depth-wise convolutions to Q, K, V
2. Reshape Q, K, V for multi-head attention
3. Compute attention matrix using flash or normal attention
4. Reshape and project out attention output"""
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
out = self._attention_fn(q, k, v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
My suggestion is to be implemented as separate class (CABlock, Chanel Attention Block). Then, in the native Transformer block from Monai, add two Attention mechanism and give the user the argument attention_type="spatial"/"channel". If so, the new class could be included in the Blocks segment.
Let me know what do you think.
A CABlock class seems appropriate, but it would be good if we could avoid having an argument to toggle between two different variables / use-cases.
Would it be possible for CABlock and SABlock to have a member var that defines its type: spatial or channel? Then the Transformer class would query that member var of the block to determine if it is spatial- or channel-based block and adjust accordingly (assuming transformer's logic would change if self.attn points to CABlock)? Not certain if API differences could be resolved - what do you think?
If so, then perhaps the Transformer class would default (as-is) to self.attn being an SABlock, but after init, a user/caller could overwrite self.attn with a CABlock?
s
On Fri, Jan 10, 2025 at 10:44 AM Cano-Muniz, Santiago < @.***> wrote:
Hi @KumoLiu https://github.com/KumoLiu and @aylward https://github.com/aylward
I have just done an in depth review of the Upsample/Downsample, SABlock and Transformer blocks present in MONAI. From what I can see, using the local version of the Upsample/Downsample classes is trivial. I think, somthing as:
UpSample( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels//2, mode=UpsampleMode.PIXELSHUFFLE, scale_factor=2, bias=False, apply_pad_pool=False )
should mirror the current behaviour in the current restormer. However using the MONAI classes for the SAB block would not work. The SABlock is a spatial attention mechanism based on Dosovitskiy paper. However, the restormer is a channel attention mechanism. See code below:
class Attention(nn.Module): """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic in vanilla attention.""" def init(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): super().init() if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): raise ValueError("Flash attention not available")
self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.flash_attention = flash_attention self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) self._attention_fn = self._get_attention_fn() def _get_attention_fn(self): if self.flash_attention: return self._flash_attention return self._normal_attention def _flash_attention(self, q, k, v): """Flash attention implementation using scaled dot-product attention.""" scale = float(self.temperature.mean()) out = F.scaled_dot_product_attention( q, k, v, scale=scale, dropout_p=0.0, is_causal=False ) return out def _normal_attention(self, q, k, v): """Attention matrix multiplication with depth-wise convolutions.""" attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) return attn @ v def forward(self, x): """Forward pass for MDTA attention. 1. Apply depth-wise convolutions to Q, K, V 2. Reshape Q, K, V for multi-head attention 3. Compute attention matrix using flash or normal attention 4. Reshape and project out attention output""" b,c,h,w = x.shape qkv = self.qkv_dwconv(self.qkv(x)) q,k,v = qkv.chunk(3, dim=1) q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) out = self._attention_fn(q, k, v) out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return outMy suggestion is to be implemented as separate class (CABlock, Chanel Attention Block). Then, in the native Transformer block from Monai, add two Attention mechanism and give the user the argument attention_type="spatial"/"channel". If so, the new class could be included in the Blocks segment.
Let me know what do you think.
— Reply to this email directly, view it on GitHub https://github.com/Project-MONAI/MONAI/issues/8261#issuecomment-2582995109, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACEJLYDT6IMDYSL5OXW53D2J7TFDAVCNFSM6AAAAABTIP5HM2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOBSHE4TKMJQHE . You are receiving this because you were mentioned.Message ID: @.***>
Hi @phisanti, I agree with @KumoLiu that we could possibly reduce some code duplication here a little. I would also make sure there isn't a duplication of class names for the sake of readability. If you want to propose a PR or multiples we can discuss further there. If you want to Thanks!
Dear @aylward and @KumoLiu, thanks for your patience. Here is an overview of the progress I have made so far:
The goal was to implement the Restormer class with the MONAI convolution as the base so that the model can operate with both 2D and 3D images. During this process, I realized that the project lacked a Downsample class and the pixel_unshuffle operation. Thus, I created the necessary classes/functions.
- I placed the
Downsampleclass inmonai/networks/blocks/downsample.py. - The
pixel_unshuffleoperation was placed inmonai/networks/utils.py, just below thepixel_shuffleoperation.
When writing the pixel_unshuffle, I tried to follow your coding pattern as closely as possible to ensure everything looks integrated. I also performed some performance tests with the native torch.pixel_unshuffle, and it looks good.
Next, I proceeded to implement the Channel Attention Block (CABlock) along with the FeedForward layer. Both classes were placed in monai/networks/blocks/cablock.py.
Regarding the Multi-DConv Head Transposed Self-Attention (MDTA), I was not very clear on where you wanted to place it. Given its many particularities, I did not have a clear idea of how to integrate it with the existing transformer block so that the user could flip between the classic spatial transformer and the MDTA. Therefore, I put the MDTA in monai/networks/nets/restormer.py. I also included the OverlapPatchEmbed class along with the restormer in monai/networks/nets/restormer.py.
Finally, for every new class and function, I wrote extensive unit tests trying to cover as many edge cases as I could come up with. However, any double-check is always welcome to raise the standards.
You can check all the progress in my forked MOANI repo
Please, let me know what you think. If you do a quick overview, then I can proceed with the following steps:
- [ ] Implement any feedback that you give me
- [ ] Additional documentation
- [ ] Example notebooks
Once this is done, the implementation will be ready for review and integration into the main MONAI repository.
Hi @phisanti all that looks good. I think it's time to put this work into a PR where it would be easier to review. We'll consider the core code itself which should be documented as well (docstrings and in the .rst document files where appropriate). We'd then look at example notebooks for the tutorials repo. Please open a PR from fork when you can and link it to this issue, and we'll review shortly. Thanks!
Hi @ericspod, thanks for the feedback. I have just opened the pull request.
@phisanti Thank you very much for the proposed content. I would like to know if there is a 3D implementation version available?
@cyl0000 I included the 3D image restoration in the implementation that I developed. It is currently in pull request waiting for review to merge, but if you need it now, you can use the fork here. I wrote extensive unit tests to validate that it worked in 3D. However, I didn't have 3D images to do the testing with real data. If you have some data to test, I would happily run some tests.
Dear @ericspod, I am really glad the PR finally went through. If you are ok with it, I can create and include the tutorial for the tutorials repo. By eyeballing the folders, I guess I should create a 2D_regression folder. There, I can include some example datasets from my own work or get some online datasets (there are many for general camera images, but not so sure about medical image denoising). My idea is that the tutorial should explain a bit on:
- Denoising
- Showing noisy/clean images
- Explain the Restormer architecture
- Train the model for few epochs
- Show output
Please, let me know what you think and I will go ahead.
Dear @ericspod, I am really glad the PR finally went through. If you are ok with it, I can create and include the tutorial for the tutorials repo. By eyeballing the folders, I guess I should create a 2D_regression folder. There, I can include some example datasets from my own work or get some online datasets (there are many for general camera images, but not so sure about medical image denoising). My idea is that the tutorial should explain a bit on:
* Denoising * Showing noisy/clean images * Explain the Restormer architecture * Train the model for few epochs * Show outputPlease, let me know what you think and I will go ahead.
Hi @phisanti this sounds good to me. I would try to choose small datasets available to anyone online just for ease of use, you don't have to train the model to do something useful but just demonstrate the process. You can propose the tutorial on the Tutorials repo and we can review from there. Thanks!
Hi @ericspod I have just opened a PR with the tutorial #1987. If all is good, this should close this issue.