MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Restormer Implementation for MONAI: High-Resolution Image Restoration

Open phisanti opened this issue 1 year ago • 15 comments

Is your feature request related to a problem? Please describe.

I've noticed that MONAI currently lacks dedicated models for image denoising and restoration tasks. While MONAI provides excellent tools for medical image analysis, having specialized architectures for improving image quality would be valuable for preprocessing pipelines and enhancing low-quality medical images (microscopy, X-ray, scans...).

Describe the solution you'd like

I have implemented a well-documented version of the Restormer model (https://arxiv.org/abs/2111.09881) that could be contributed to MONAI. The implementation includes key components like:

  • Multi-DConv Head Transposed Self-Attention (MDTA) for efficient attention computation
  • Gated-DConv Feed-Forward Network (GDFN) for refined feature selection
  • Modular architecture allowing easy extension and modification
  • Support for flash attention when available
  • Comprehensive documentation of components and architecture

Describe alternatives you've considered

The implementation is already structured in a modular way with clear separation of components. I'm willing to:

  • Refactor the code to meet MONAI coding standards
  • Add appropriate type hints and docstrings
  • Include unit tests
  • Provide example notebooks demonstrating usage
  • Add benchmarks comparing performance

Additional context

The code is currently functional and tested. It supports both standard and dual-pixel tasks, with configurable parameters for network depth, attention heads, and feature dimensions. The implementation prioritizes efficiency through features like flash attention support while maintaining flexibility for different use cases.

phisanti avatar Dec 09 '24 10:12 phisanti

@Nic-Ma and @ericspod and @KumoLiu - this seems like an outstanding addition to MONAI - agreed?

@phisanti - if all approve, please look at our contribution guidelines. You are already doing the exact right thing by having a modular design. Whenever appropriate, please support the exploration of alternative components in this framework via that modular design and appropriate class abstractions. Please also include multiple tutorials and unit tests with your work.

Does your code currently exist in another repo that we could preliminarily review?

Thanks!

aylward avatar Dec 11 '24 17:12 aylward

You can take a look at the modular implementation of the Restormer architecture here. Also copied the code below. As you can see, I maintain many of the key blocks intact and focus on expanding functionality (Flash att), and adding modularity on the enc/dec blocks. I am happy to implement extra changes if a good suggestion is made.

"""
Restormer: Efficient Transformer for High-Resolution Image Restoration
Implementation based on: https://arxiv.org/abs/2111.09881
"""

from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.layers import Norm
from einops import rearrange

class FeedForward(nn.Module):
    """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
    Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection."""
    def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool):
        super().__init__()
        hidden_features = int(dim * ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, 
                               stride=1, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        return self.project_out(F.gelu(x1) * x2)

class Attention(nn.Module):
    """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
    by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
    convolutions for local mixing before attention, achieving linear complexity vs quadratic
    in vanilla attention."""
    def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
        super().__init__()
        if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
            raise ValueError("Flash attention not available")
            
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.flash_attention = flash_attention
        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, 
                                   padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self._attention_fn = self._get_attention_fn()

    def _get_attention_fn(self):
        if self.flash_attention:
            return self._flash_attention
        return self._normal_attention
    def _flash_attention(self, q, k, v):
        """Flash attention implementation using scaled dot-product attention."""
        scale = float(self.temperature.mean())  
        out = F.scaled_dot_product_attention(
            q,
            k, 
            v,
            scale=scale,
            dropout_p=0.0,
            is_causal=False
        )
        return out

    def _normal_attention(self, q, k, v):
        """Attention matrix multiplication with depth-wise convolutions."""
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        return attn @ v
    def forward(self, x):
        """Forward pass for MDTA attention. 
        1. Apply depth-wise convolutions to Q, K, V
        2. Reshape Q, K, V for multi-head attention
        3. Compute attention matrix using flash or normal attention
        4. Reshape and project out attention output"""
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        
        out = self._attention_fn(q, k, v)        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out


class TransformerBlock(nn.Module):
    """Basic transformer unit combining MDTA and GDFN with skip connections.
    Unlike standard transformers that use LayerNorm, this block uses Instance Norm
    for better adaptation to image restoration tasks."""
    
    def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float,
                 bias: bool, LayerNorm_type: str, flash_attention: bool = False):
        super().__init__()
        use_bias = LayerNorm_type != 'BiasFree'        
        self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
        self.attn = Attention(dim, num_heads, bias, flash_attention)
        self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #print(f'x shape in transformer block: {x.shape}')

        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class OverlapPatchEmbed(nn.Module):
    """Initial feature extraction using overlapped convolutions.
    Unlike standard patch embeddings that use non-overlapping patches,
    this approach maintains spatial continuity through 3x3 convolutions."""
    
    def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
        super().__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, 
                             stride=1, padding=1, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class Downsample(nn.Module):
    """Downsampling module that halves spatial dimensions while doubling channels.
    Uses PixelUnshuffle for efficient feature map manipulation."""
    
    def __init__(self, n_feat: int):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat//2, kernel_size=3, 
                     stride=1, padding=1, bias=False),
            nn.PixelUnshuffle(2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.body(x)

class Upsample(nn.Module):
    """Upsampling module that doubles spatial dimensions while halving channels.
    Combines convolution with PixelShuffle for efficient feature expansion."""
    
    def __init__(self, in_channels: int) -> None:
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 2, kernel_size=3, 
                     stride=1, padding=1, bias=False),
            nn.PixelShuffle(2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.body(x)


##---------- Restormer -----------------------
class Restormer(nn.Module):
    """Restormer: Efficient Transformer for High-Resolution Image Restoration.
    
    Implements a U-Net style architecture with transformer blocks, combining:
    - Multi-scale feature processing through progressive down/upsampling
    - Efficient attention via MDTA blocks
    - Local feature mixing through GDFN
    - Skip connections for preserving spatial details
    
    Architecture:
        - Encoder: Progressive feature downsampling with increasing channels
        - Latent: Deep feature processing at lowest resolution
        - Decoder: Progressive upsampling with skip connections
        - Refinement: Final feature enhancement
    """
    def __init__(self, 
                 inp_channels=3,
                 out_channels=3,
                 dim=48,
                 num_blocks=[1, 1, 1, 1],  
                 heads=[1, 1, 1, 1],  
                 num_refinement_blocks=4,    
                 ffn_expansion_factor=2.66,
                 bias=False,
                 LayerNorm_type='WithBias',
                 dual_pixel_task=False,
                 flash_attention=False):
        super().__init__()
        """Initialize Restormer model.
        
        Args:
            inp_channels: Number of input image channels
            out_channels: Number of output image channels
            dim: Base feature dimension
            num_blocks: Number of transformer blocks at each scale
            num_refinement_blocks: Number of final refinement blocks
            heads: Number of attention heads at each scale
            ffn_expansion_factor: Expansion factor for feed-forward network
            bias: Whether to use bias in convolutions
            LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree')
            dual_pixel_task: Enable dual-pixel specific processing
            flash_attention: Use flash attention if available
        """
        # Check input parameters
        assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
        assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
        assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0"
        
        # Initial feature extraction
        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        self.encoder_levels = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        self.decoder_levels = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.reduce_channels = nn.ModuleList()
        num_steps = len(num_blocks) - 1 
        self.num_steps = num_steps

        # Define encoder levels
        for n in range(num_steps):
            current_dim = dim * 2**n
            self.encoder_levels.append(
                nn.Sequential(*[
                    TransformerBlock(
                        dim=current_dim,
                        num_heads=heads[n],
                        ffn_expansion_factor=ffn_expansion_factor,
                        bias=bias,
                        LayerNorm_type=LayerNorm_type,
                        flash_attention=flash_attention
                    ) for _ in range(num_blocks[n])
                ])
            )
            self.downsamples.append(Downsample(current_dim))

        # Define latent space
        latent_dim = dim * 2**num_steps
        self.latent = nn.Sequential(*[
            TransformerBlock(
                dim=latent_dim,
                num_heads=heads[num_steps],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                flash_attention=flash_attention
            ) for _ in range(num_blocks[num_steps])
        ])

        # Define decoder levels
        for n in reversed(range(num_steps)):
            current_dim = dim * 2**n
            next_dim = dim * 2**(n+1)
            self.upsamples.append(Upsample(next_dim))
            
            # Reduce channel layers to deal with skip connections
            if n != 0:
                self.reduce_channels.append(
                    nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias)
                    )
                decoder_dim = current_dim
            else:
                decoder_dim = next_dim
            
            self.decoder_levels.append(
                nn.Sequential(*[
                    TransformerBlock(
                        dim=decoder_dim,
                        num_heads=heads[n],
                        ffn_expansion_factor=ffn_expansion_factor,
                        bias=bias,
                        LayerNorm_type=LayerNorm_type,
                        flash_attention=flash_attention
                    ) for _ in range(num_blocks[n])
                ])
            )

        # Final refinement and output
        self.refinement = nn.Sequential(*[
            TransformerBlock(
                dim=decoder_dim,
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                flash_attention=flash_attention
            ) for _ in range(num_refinement_blocks)
        ])
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
            
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        """Forward pass of Restormer.
        Processes input through encoder-decoder architecture with skip connections.
        Args:
            inp_img: Input image tensor of shape (B, C, H, W)
            
        Returns:
            Restored image tensor of shape (B, C, H, W)
        """
        assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step"

        # Patch embedding
        x = self.patch_embed(x)
        skip_connections = []

        # Encoding path
        for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
            x = encoder(x)
            skip_connections.append(x)
            x = downsample(x)

        # Latent space
        x = self.latent(x)
        
        # Decoding path
        for idx in range(len(self.decoder_levels)):
            x = self.upsamples[idx](x)           
            x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
            if idx < len(self.decoder_levels) - 1:
                x = self.reduce_channels[idx](x)                
            x = self.decoder_levels[idx](x)
                
        # Final refinement
        x = self.refinement(x)

        if self.dual_pixel_task:
            x = x + self.skip_conv(skip_connections[0])
            x = self.output(x)
        else:
            x = self.output(x)

        return x

if __name__ == "__main__":
    flash_att = True
    test_model = Restormer(
        inp_channels=2,
        out_channels=2,
        dim=16,
        num_blocks=[1,1,1,1],
        heads=[1,1,1,1],
        num_refinement_blocks=2,
        ffn_expansion_factor=1.5,
        bias=False,
        LayerNorm_type='WithBias',
        dual_pixel_task=True,
        flash_attention=flash_att
    )
    print(f'flash attention set to {flash_att}')
    input_tensor = torch.randn(8, 2, 256, 256)
    print(f"Input shape: {input_tensor.shape}")
    output = test_model(input_tensor)
    print(f"Output shape: {output.shape}")

    print(f'printing final model')
    from torchsummary import summary
    
    summary(test_model, input_size=input_tensor)
    
    ```

phisanti avatar Dec 16 '24 07:12 phisanti

@aylward @Nic-Ma @ericspod and @KumoLiu, if you all agree and there is no comments on extra modules to be added, I will implement the class as it is. For that, I will:

  1. Fork and create branch '8261-restormer-implementation'
  2. Place architecture in MONAI/monai/networks/nets folder
  3. Add extensive documentation (docstring + docs) following UNet class style as template
  4. Write unit tests following existing test patterns
  5. Create tutorial notebook with example dataset for the Project-MONAI/tutorials
  6. Submit PRs for both code and tutorial

What aspects of this approach would you modify to fully align with MONAI's contribution standards?

phisanti avatar Dec 18 '24 13:12 phisanti

Hi @phisanti, thank you for sharing the comprehensive plan!

I’d recommend dividing the implementation into several PRs to simplify the review process. Additionally, I highly suggest checking if there are existing blocks in MONAI that can be reused in your network, such as upsample, downsample, attention mechanisms, etc.

https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/downsample.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/selfattention.py https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/spatialattention.py

Also, consider using Convolution, which could make your network support both 2D and 3D implementations seamlessly. https://github.com/Project-MONAI/MONAI/blob/e1e3d8ebc1c7247aad9f1bffc649c5a20084340f/monai/networks/blocks/convolutions.py#L25

KumoLiu avatar Dec 19 '24 07:12 KumoLiu

Hi @KumoLiu and @aylward

I have just done an in depth review of the Upsample/Downsample, SABlock and Transformer blocks present in MONAI. From what I can see, using the local version of the Upsample/Downsample classes is trivial. I think, somthing as:

 UpSample(
                 spatial_dims=spatial_dims,
                 in_channels=in_channels,
                 out_channels=in_channels//2,
                 mode=UpsampleMode.PIXELSHUFFLE,
                 scale_factor=2,
                 bias=False,
                 apply_pad_pool=False
             )

should mirror the current behaviour in the current restormer. However using the MONAI classes for the SAB block would not work. The SABlock is a spatial attention mechanism based on Dosovitskiy paper. However, the restormer is a channel attention mechanism. See code below:

class Attention(nn.Module):
    """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention
    by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
    convolutions for local mixing before attention, achieving linear complexity vs quadratic
    in vanilla attention."""
    def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
        super().__init__()
        if flash_attention and not hasattr(F, 'scaled_dot_product_attention'):
            raise ValueError("Flash attention not available")
            
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.flash_attention = flash_attention
        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, 
                                   padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self._attention_fn = self._get_attention_fn()

    def _get_attention_fn(self):
        if self.flash_attention:
            return self._flash_attention
        return self._normal_attention
    def _flash_attention(self, q, k, v):
        """Flash attention implementation using scaled dot-product attention."""
        scale = float(self.temperature.mean())  
        out = F.scaled_dot_product_attention(
            q,
            k, 
            v,
            scale=scale,
            dropout_p=0.0,
            is_causal=False
        )
        return out

    def _normal_attention(self, q, k, v):
        """Attention matrix multiplication with depth-wise convolutions."""
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)
        return attn @ v
    def forward(self, x):
        """Forward pass for MDTA attention. 
        1. Apply depth-wise convolutions to Q, K, V
        2. Reshape Q, K, V for multi-head attention
        3. Compute attention matrix using flash or normal attention
        4. Reshape and project out attention output"""
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        
        out = self._attention_fn(q, k, v)        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out

My suggestion is to be implemented as separate class (CABlock, Chanel Attention Block). Then, in the native Transformer block from Monai, add two Attention mechanism and give the user the argument attention_type="spatial"/"channel". If so, the new class could be included in the Blocks segment.

Let me know what do you think.

phisanti avatar Jan 10 '25 15:01 phisanti

A CABlock class seems appropriate, but it would be good if we could avoid having an argument to toggle between two different variables / use-cases.

Would it be possible for CABlock and SABlock to have a member var that defines its type: spatial or channel? Then the Transformer class would query that member var of the block to determine if it is spatial- or channel-based block and adjust accordingly (assuming transformer's logic would change if self.attn points to CABlock)? Not certain if API differences could be resolved - what do you think?

If so, then perhaps the Transformer class would default (as-is) to self.attn being an SABlock, but after init, a user/caller could overwrite self.attn with a CABlock?

s

On Fri, Jan 10, 2025 at 10:44 AM Cano-Muniz, Santiago < @.***> wrote:

Hi @KumoLiu https://github.com/KumoLiu and @aylward https://github.com/aylward

I have just done an in depth review of the Upsample/Downsample, SABlock and Transformer blocks present in MONAI. From what I can see, using the local version of the Upsample/Downsample classes is trivial. I think, somthing as:

UpSample( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels//2, mode=UpsampleMode.PIXELSHUFFLE, scale_factor=2, bias=False, apply_pad_pool=False )

should mirror the current behaviour in the current restormer. However using the MONAI classes for the SAB block would not work. The SABlock is a spatial attention mechanism based on Dosovitskiy paper. However, the restormer is a channel attention mechanism. See code below:

class Attention(nn.Module): """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic in vanilla attention.""" def init(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): super().init() if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): raise ValueError("Flash attention not available")

    self.num_heads = num_heads
    self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
    self.flash_attention = flash_attention
    self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
    self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1,
                               padding=1, groups=dim*3, bias=bias)
    self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
    self._attention_fn = self._get_attention_fn()

def _get_attention_fn(self):
    if self.flash_attention:
        return self._flash_attention
    return self._normal_attention
def _flash_attention(self, q, k, v):
    """Flash attention implementation using scaled dot-product attention."""
    scale = float(self.temperature.mean())
    out = F.scaled_dot_product_attention(
        q,
        k,
        v,
        scale=scale,
        dropout_p=0.0,
        is_causal=False
    )
    return out

def _normal_attention(self, q, k, v):
    """Attention matrix multiplication with depth-wise convolutions."""
    attn = (q @ k.transpose(-2, -1)) * self.temperature
    attn = attn.softmax(dim=-1)
    return attn @ v
def forward(self, x):
    """Forward pass for MDTA attention.         1. Apply depth-wise convolutions to Q, K, V        2. Reshape Q, K, V for multi-head attention        3. Compute attention matrix using flash or normal attention        4. Reshape and project out attention output"""
    b,c,h,w = x.shape

    qkv = self.qkv_dwconv(self.qkv(x))
    q,k,v = qkv.chunk(3, dim=1)

    q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
    k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
    v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
    q = torch.nn.functional.normalize(q, dim=-1)
    k = torch.nn.functional.normalize(k, dim=-1)

    out = self._attention_fn(q, k, v)
    out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
    out = self.project_out(out)
    return out

My suggestion is to be implemented as separate class (CABlock, Chanel Attention Block). Then, in the native Transformer block from Monai, add two Attention mechanism and give the user the argument attention_type="spatial"/"channel". If so, the new class could be included in the Blocks segment.

Let me know what do you think.

— Reply to this email directly, view it on GitHub https://github.com/Project-MONAI/MONAI/issues/8261#issuecomment-2582995109, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACEJLYDT6IMDYSL5OXW53D2J7TFDAVCNFSM6AAAAABTIP5HM2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOBSHE4TKMJQHE . You are receiving this because you were mentioned.Message ID: @.***>

aylward avatar Jan 11 '25 01:01 aylward

Hi @phisanti, I agree with @KumoLiu that we could possibly reduce some code duplication here a little. I would also make sure there isn't a duplication of class names for the sake of readability. If you want to propose a PR or multiples we can discuss further there. If you want to Thanks!

ericspod avatar Jan 13 '25 00:01 ericspod

Dear @aylward and @KumoLiu, thanks for your patience. Here is an overview of the progress I have made so far:

The goal was to implement the Restormer class with the MONAI convolution as the base so that the model can operate with both 2D and 3D images. During this process, I realized that the project lacked a Downsample class and the pixel_unshuffle operation. Thus, I created the necessary classes/functions.

  • I placed the Downsample class in monai/networks/blocks/downsample.py.
  • The pixel_unshuffle operation was placed in monai/networks/utils.py, just below the pixel_shuffle operation.

When writing the pixel_unshuffle, I tried to follow your coding pattern as closely as possible to ensure everything looks integrated. I also performed some performance tests with the native torch.pixel_unshuffle, and it looks good.

Next, I proceeded to implement the Channel Attention Block (CABlock) along with the FeedForward layer. Both classes were placed in monai/networks/blocks/cablock.py.

Regarding the Multi-DConv Head Transposed Self-Attention (MDTA), I was not very clear on where you wanted to place it. Given its many particularities, I did not have a clear idea of how to integrate it with the existing transformer block so that the user could flip between the classic spatial transformer and the MDTA. Therefore, I put the MDTA in monai/networks/nets/restormer.py. I also included the OverlapPatchEmbed class along with the restormer in monai/networks/nets/restormer.py.

Finally, for every new class and function, I wrote extensive unit tests trying to cover as many edge cases as I could come up with. However, any double-check is always welcome to raise the standards.

You can check all the progress in my forked MOANI repo

Please, let me know what you think. If you do a quick overview, then I can proceed with the following steps:

  • [ ] Implement any feedback that you give me
  • [ ] Additional documentation
  • [ ] Example notebooks

Once this is done, the implementation will be ready for review and integration into the main MONAI repository.

phisanti avatar Jan 20 '25 16:01 phisanti

Hi @phisanti all that looks good. I think it's time to put this work into a PR where it would be easier to review. We'll consider the core code itself which should be documented as well (docstrings and in the .rst document files where appropriate). We'd then look at example notebooks for the tutorials repo. Please open a PR from fork when you can and link it to this issue, and we'll review shortly. Thanks!

ericspod avatar Jan 21 '25 16:01 ericspod

Hi @ericspod, thanks for the feedback. I have just opened the pull request.

phisanti avatar Jan 23 '25 15:01 phisanti

@phisanti Thank you very much for the proposed content. I would like to know if there is a 3D implementation version available?

cyl0000 avatar Feb 11 '25 12:02 cyl0000

@cyl0000 I included the 3D image restoration in the implementation that I developed. It is currently in pull request waiting for review to merge, but if you need it now, you can use the fork here. I wrote extensive unit tests to validate that it worked in 3D. However, I didn't have 3D images to do the testing with real data. If you have some data to test, I would happily run some tests.

phisanti avatar Feb 11 '25 15:02 phisanti

Dear @ericspod, I am really glad the PR finally went through. If you are ok with it, I can create and include the tutorial for the tutorials repo. By eyeballing the folders, I guess I should create a 2D_regression folder. There, I can include some example datasets from my own work or get some online datasets (there are many for general camera images, but not so sure about medical image denoising). My idea is that the tutorial should explain a bit on:

  • Denoising
  • Showing noisy/clean images
  • Explain the Restormer architecture
  • Train the model for few epochs
  • Show output

Please, let me know what you think and I will go ahead.

phisanti avatar Apr 01 '25 14:04 phisanti

Dear @ericspod, I am really glad the PR finally went through. If you are ok with it, I can create and include the tutorial for the tutorials repo. By eyeballing the folders, I guess I should create a 2D_regression folder. There, I can include some example datasets from my own work or get some online datasets (there are many for general camera images, but not so sure about medical image denoising). My idea is that the tutorial should explain a bit on:

* Denoising

* Showing noisy/clean images

* Explain the Restormer architecture

* Train the model for few epochs

* Show output

Please, let me know what you think and I will go ahead.

Hi @phisanti this sounds good to me. I would try to choose small datasets available to anyone online just for ease of use, you don't have to train the model to do something useful but just demonstrate the process. You can propose the tutorial on the Tutorials repo and we can review from there. Thanks!

ericspod avatar Apr 01 '25 14:04 ericspod

Hi @ericspod I have just opened a PR with the tutorial #1987. If all is good, this should close this issue.

phisanti avatar May 04 '25 13:05 phisanti