nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

Using custom architectures

Open noneedanick opened this issue 1 year ago • 16 comments

I am currently trying to train different architectures within nnUNet platform to compare with nnUNet baseline architectures. I believe it should be convenient to use same preprocessing steps before comparison. Monai platform has a lot of capability to create most recent architectures without effort. So following the so called "Quick and dirty" methods I just created a custom trainer class but I am not sure this alone is sufficient to train with nnUNetv2_train function (with -tr UNETRTrainer). I would be grateful for any ideas and help. Here is my class:


from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from monai.networks.nets import UNETR
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Union, Tuple, List


class UNETRTrainer(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.plans_manager = PlansManager(plans)
        self.configuration_manager = self.plans_manager.get_configuration(configuration)
        self.enable_deep_supervision = False
        self.dataset_json = dataset_json
        ### Some hyperparameters for you to fiddle with
        self.initial_lr = 1e-2
        self.weight_decay = 3e-5
        self.oversample_foreground_percent = 0.33
        self.num_iterations_per_epoch = 250
        self.num_val_iterations_per_epoch = 50
        self.num_epochs = 500
        self.current_epoch = 0
        self.enable_deep_supervision = True

        ### Dealing with labels/regions
        self.label_manager = self.plans_manager.get_label_manager(dataset_json)

    def initialize(self):
        if not self.was_initialized:
            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                                   self.dataset_json)

            self.network = self.build_network_architecture(
                self.configuration_manager.network_arch_class_name,
                self.configuration_manager.network_arch_init_kwargs,
                self.configuration_manager.network_arch_init_kwargs_req_import,
                self.num_input_channels,
                self.label_manager.num_segmentation_heads,
                self.enable_deep_supervision
            ).to(self.device)
            # compile network for free speedup
            if self._do_i_compile():
                self.print_to_log_file('Using torch.compile...')
                self.network = torch.compile(self.network)

            self.optimizer, self.lr_scheduler = self.configure_optimizers()
            # if ddp, wrap in DDP wrapper
            if self.is_ddp:
                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
                self.network = DDP(self.network, device_ids=[self.local_rank])

            self.loss = self._build_loss()
            self.was_initialized = True
        else:
            raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                               "That should not happen.")

    @staticmethod
    def build_network_architecture(self,
                                   architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True) -> nn.Module:
        patch_size = self.configuration_manager.patch_size
        model = UNETR(
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        img_size=patch_size,
                        feature_size=16,
                        hidden_size=768,
                        mlp_dim=3072,
                        num_heads=12,
                        pos_embed="perceptron",
                        norm_name="instance",
                        res_block=True,
                        dropout_rate=0.0,
                    )
        return model

noneedanick avatar Apr 08 '24 03:04 noneedanick

Hey, so all you need to test different architectures is to overwrite build_network_architecture plus any additional changes you'd like to make. Basically what you did is exactly correct. The 'quick and dirty' means that your architecture will not be considered during experiment planning. It is expected to be able to deal with whatever patch and batch size nnU-net gives it. Ideally is also used the nnU-net configured downsampling steps but that is optional. Best, Fabian

FabianIsensee avatar Apr 09 '24 14:04 FabianIsensee

It just gave an error about this function so I commented out some parts of it since I won't use deep supervision. Now its working just fine !! Thanks a lot !!.

NOTE: Also I am a little bit inexperienced to use staticmethod, I believe it is not common to use 'self' in staticmethod. I just added it to use incoming patch_size as input image size but those are defined within class using self argument.

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled

noneedanick avatar Apr 10 '24 14:04 noneedanick

Also for researchers who are searching for implementation of SwinUNETR and UNETR segmentation models within nnUNet framework here are my latest classes.

Classes constracted based on following notebooks defining these structures: https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb#scrollTo=xQn18qtvZChG

https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb

UNETR:

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import UNETR
from torch.optim import Adam, AdamW
import torch
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels

class nnUNetTrainerUNETR(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-4
        self.weight_decay = 1e-5
        self.num_epochs = 200
        
    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled
        

    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):
        
        patch_size = (64,96,96)

        model = UNETR(
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        img_size=patch_size,
                        feature_size=16,
                        hidden_size=768,
                        mlp_dim=3072,
                        num_heads=12,
                        pos_embed="perceptron",
                        norm_name="instance",
                        res_block=True,
                        dropout_rate=0.0,
                    )
        return model
        
    def configure_optimizers(self):
        optimizer = AdamW(self.network.parameters(),
                          lr=self.initial_lr,
                          weight_decay=self.weight_decay,
                          amsgrad=True)

        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
        return optimizer, lr_scheduler     

SwinUNETR:

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import SwinUNETR
from torch.optim import Adam, AdamW
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import torch
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1

class nnUNetTrainerSwinUNETR(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-4
        self.weight_decay = 1e-5
        self.num_epochs = 200
        
    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled
        
    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):
        
        patch_size = (64,96,96)
        model = SwinUNETR(
        		img_size=patch_size,
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        feature_size=48,
                        drop_rate=0.0,
                        attn_drop_rate=0.0,
                        dropout_path_rate=0.0,
                        use_checkpoint=True,
                    )
        return model
        
    def _build_loss(self):
        loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
                                    'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
                            apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)

        if self.enable_deep_supervision:
            deep_supervision_scales = self._get_deep_supervision_scales()

            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
            # this gives higher resolution outputs more weight in the loss
            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
            weights[-1] = 0

            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
            weights = weights / weights.sum()
            # now wrap the loss
            loss = DeepSupervisionWrapper(loss, weights)
        return loss
        
    def configure_optimizers(self):
        optimizer = AdamW(self.network.parameters(),
                          lr=self.initial_lr,
                          weight_decay=self.weight_decay,
                          amsgrad=True)

        lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
        return optimizer, lr_scheduler 

noneedanick avatar Apr 10 '24 14:04 noneedanick

Hey, thanks for sharing!

Three things to be aware of:

  • do not hard code the patch size in build_network_architecture! Use the given ConfigurationManager to read the patch size nnU-Net uses
  • do not change the signature of build_network_architecture! Otherwise it cannot be used for inference. I am aware that this is not ideal and will try to change it in the future
  • set_deep_supervision_enabled is specific to our architectures. If your architecture never uses deep supervision please add an assertion that ensures users don't accidentally set this to True.

We have our own implementations for SwinUNetr internally that we were planning to release soon. Will be interesting to compare to yours. How has your experience with these architectures been so far? We haven't found them useful so far. They just can't keep up with a standard UNet or a UNet with residual encoder.

Best, Fabian

FabianIsensee avatar Apr 11 '24 12:04 FabianIsensee

Thanks Fabian for your comments!

I couldnt find a proper way to automatically infer patch size from plans without changing build_network_architecture function signature :(

So far my comparisons were close to each other, very similar performances but it looks like 3DUNET is still going to be winner :laughing: . I am right now trying to deal with hardware issues to speed up the process. It takes over than 300 secs (same for each model) to train an epoch. Also at the validation side both my custom classes and unchanged 3d_fullres training takes over than 20 minutes for each validation case. I am trying to lower these timings because it feels like an eternity for me :laughing:.

Will share my trials and configurations here soon !

Best; Murat

noneedanick avatar Apr 11 '24 13:04 noneedanick

Hah you are right the signature is

def build_network_architecture(architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True) -> nn.Module:

I just read your code and forgot about how it is supposed to be. Easiest way to achieve what you need would be to change the arch_init_kwargs in the plans file so that the patch size is in there as well. Would need to be done as part of the experiment planner.

Those epoch times sound horrible, take a look here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/benchmarking.md There should be information for you how to approach debugging. Best, Fabian

FabianIsensee avatar Apr 12 '24 15:04 FabianIsensee

Hey, how are you progressing? Is everything working now?

FabianIsensee avatar May 28 '24 09:05 FabianIsensee

Hey Fabian ! Lately I am preparing to start a postdoc position in US so I couldn't shared my final results in here. But its in my mind and waiting a proper time (bc currently gathering my stuff for fully migrating to US) to share my results here. Nevertheless, I dealed with epoch time issue by increasing resample pixel values (since it was causing huge output array when automatically defined), my custom approach resulted in higher dice scores with nnUNet (expectedly 😆 ), SwinUNETR showed near similar performance but I need to test it again with more optimized parameters ( such as patch size, resample etc.).

Thanks for your interest by the way, I am honored ❤️

noneedanick avatar May 28 '24 19:05 noneedanick

How to read the patch size by ConfigurationManager? Static methods cannot read self variables. Writing like this will result in an error!

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
from timm.models.layers import DropPath, trunc_normal_

# from torch.nn.init import xavier_uniform_, constant_, normal_
# import copy
# import math

from monai.networks.nets import SwinUNETR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

class nnUNetTrainer_SwinUNETR(nnUNetTrainer):
    def __init__(
            self,
            plans: dict,
            configuration: str,
            fold: int,
            dataset_json: dict,
            unpack_dataset: bool = True,
            device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-2
        # self.weight_decay = 3e-5
        # self.num_epochs = 1000
        self.num_epochs = 500
        # self.num_epochs = 50
        # self.patch_size =

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        # if isinstance(mod, OptimizedModule):
        #     mod = mod._orig_mod
        # mod.deep_supervised = enabled
    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        patch_size = trainer.configuration_manager.patch_size
        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = SwinUNETR(img_size=nnUNetTrainer_SwinUNETR.patch_size, in_channels=num_input_channels, out_channels=num_output_channels)
        model.apply(InitWeights_He(1e-2))
        return model

Yanfeng-Zhou avatar Jun 01 '24 16:06 Yanfeng-Zhou

The function build_network_architecture uses a static method, which cannot introduce self.configuration_manager.patch_size!

Thank you for your reply!

Yanfeng-Zhou avatar Jun 01 '24 16:06 Yanfeng-Zhou

Use the arch_init_kwargs dictionary and add the patch size to that

FabianIsensee avatar Jun 04 '24 07:06 FabianIsensee

Use the arch_init_kwargs dictionary and add the patch size to that

Sorry, I can't find arch_init_kwargs, can you give me a minimum implementation? Thank you!

Yanfeng-Zhou avatar Jun 04 '24 08:06 Yanfeng-Zhou

You find it in the plans file as part of the configuration you are trying to run. It's called arch_kwargs in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']

FabianIsensee avatar Jun 06 '24 06:06 FabianIsensee

You find it in the plans file as part of the configuration you are trying to run. It's called arch_kwargs in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']

I succeeded! I share the complete process! Open nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py Add the following code on line 33:

...
class ConfigurationManager(object):
    def __init__(self, configuration_dict: dict):
        self.configuration = configuration_dict
        **self.configuration["architecture"]["arch_kwargs"]["patch_size"] = self.configuration["patch_size"]**
        # backwards compatibility
        if 'architecture' not in self.configuration.keys():
...

Then you can introduce patch_size in the custom network, and other parameters can also be operated in this way.

    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = SwinUNETR(img_size=**arch_init_kwargs['patch_size']**, in_channels=num_input_channels, out_channels=num_output_channels)
        model.apply(InitWeights_He(1e-2))
        return model

Thank you again for your help! @FabianIsensee

Yanfeng-Zhou avatar Jun 07 '24 02:06 Yanfeng-Zhou

Hey @Yanfeng-Zhou glad to hear it works now. It would be cleaner to add the patch size to the arch_kwargs in the plans file because that would not affect the rest of nnU-Net. Now you will always have this added even if you don't need it

FabianIsensee avatar Jun 18 '24 07:06 FabianIsensee