nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

PRIMUS Trainer runs training+validation, but fails to run Inference

Open ChepchikValeri opened this issue 6 months ago • 3 comments

Hi nnUNet team!

First of all, thank you for such an amazing tool!

The code can run training and validation with no issues . I ran training like this:

CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 322 3d_fullres 1 -tr nnUNet_Primus_M_Trainer -p nnUNetResEncUNetLPlans

Btw I checked both, nnunet git main and https://github.com/TaWald/nnUNet/tree/primus# , both behave in the same way

I ran into an issue with PRIMUS Inference.

when running Inference I run into the following:


`CUDA_VISIBLE_DEVICES=4 nnUNetv2_predict -f 1 -i /home/volodymyr.buchakchiiskyi/input/ -o /home/volodymyr.buchakchiiskyi/outputPrimus/ -d 322 -c 3d_fullres --verbose -tr nnUNet_Primus_M_Trainer -p nnUNetResEncUNetLPlans

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################

Traceback (most recent call last):
  File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/bin/nnUNetv2_predict", line 8, in <module>
    sys.exit(predict_entry_point())
             ~~~~~~~~~~~~~~~~~~~^^
  File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/lib/python3.13/site-packages/nnunetv2/inference/predict_from_raw_data.py", line 975, in predict_entry_point
    predictor.initialize_from_trained_model_folder(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        model_folder,
        ^^^^^^^^^^^^^
        args.f,
        ^^^^^^^
        checkpoint_name=args.chk
        ^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/lib/python3.13/site-packages/nnunetv2/inference/predict_from_raw_data.py", line 104, in initialize_from_trained_model_folder
    network = trainer_class.build_network_architecture(
        configuration_manager.network_arch_class_name,
    ...<4 lines>...
        enable_deep_supervision=False
    )
TypeError: nnUNet_Primus_M_Trainer.build_network_architecture() missing 1 required positional argument: 'num_output_channels'
`

From my understanding, the core of the problem is that the primus trainers are inheriting abstract method to build network architecture build_network_architecture with 7 args, but the whole frameworks expects static method with 6 (missing 'self' here)

more detailed explanation below: >>>>>>>

-> in the default nnUNetTrainer.py , there is a static method build_network_architecture that takes 6 args.

@staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module:

-> primus_trainers.py defines an AbstractPrimus class with an abstract method

@abstractmethod def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: raise NotImplementedError()

-> and children class nnUNet_Primus_M_Trainer are using this abstract method, now wiht 7 args, 'self' is an extra one: `class nnUNet_Primus_M_Trainer(AbstractPrimus):

def build_network_architecture(
    self,
    architecture_class_name: str,
    arch_init_kwargs: dict,
    arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
    num_input_channels: int,
    num_output_channels: int,
    enable_deep_supervision: bool = True,
) -> nn.Module:
    model = Primus(
        num_input_channels,
        864,
        (8, 8, 8),
        num_output_channels,
        16,
        12,
        self.configuration_manager.patch_size,
        drop_path_rate=0.2,
        scale_attn_inner=True,
        init_values=0.1,
    )
    return model`

-> I made it working by reimplementing the primus training class as follows:

`from typing import List, Tuple, Union
from torch import nn
from dynamic_network_architectures.architectures.primus import Primus

from nnunetv2.training.nnUNetTrainer.variants.lr_schedule.nnUNetTrainer_warmup import nnUNetTrainer_warmup
import torch
from torch import nn, autocast
from torch.nn.parallel import DistributedDataParallel as DDP

from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset
from nnunetv2.utilities.helpers import empty_cache, dummy_context

class MyCustomPrimusTrainer(nnUNetTrainer_warmup):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, device)
        self.initial_lr = 3e-4
        self.weight_decay = 5e-2
        self.enable_deep_supervision = False
        
    def configure_optimizers(self, stage: str = "warmup_all"):
        assert stage in ["warmup_all", "train"]

        if self.training_stage == stage:
            return self.optimizer, self.lr_scheduler

        if isinstance(self.network, DDP):
            params = self.network.module.parameters()
        else:
            params = self.network.parameters()

        if stage == "warmup_all":
            self.print_to_log_file("train whole net, warmup")
            optimizer = torch.optim.AdamW(
                params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True
            )
            lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)
            self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}")
        else:
            self.print_to_log_file("train whole net, default schedule")
            if self.training_stage == "warmup_all":
                # we can keep the existing optimizer and don't need to create a new one. This will allow us to keep
                # the accumulated momentum terms which already point in a useful driection
                optimizer = self.optimizer
            else:
                optimizer = torch.optim.AdamW(
                    params,
                    self.initial_lr,
                    weight_decay=self.weight_decay,
                    amsgrad=False,
                    betas=(0.9, 0.98),
                    fused=True,
                )
            lr_scheduler = PolyLRScheduler_offset(
                optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net
            )
            self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}")
        self.training_stage = stage
        empty_cache(self.device)
        return optimizer, lr_scheduler

    def train_step(self, batch: dict) -> dict:
        data = batch["data"]
        target = batch["target"]

        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        self.optimizer.zero_grad(set_to_none=True)
        # Autocast can be annoying
        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
        # So autocast will only be active if we have a cuda device.
        with autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
            output = self.network(data)
            # del data
            l = self.loss(output, target)

        if self.grad_scaler is not None:
            self.grad_scaler.scale(l).backward()
            self.grad_scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            l.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
            self.optimizer.step()
        return {"loss": l.detach().cpu().numpy()}

    def set_deep_supervision_enabled(self, enabled: bool):
        pass
    @staticmethod
    def build_network_architecture(
        architecture_class_name: str,
        arch_init_kwargs: dict,
        arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
        num_input_channels: int,
        num_output_channels: int,
        enable_deep_supervision: bool = True,
    ) -> nn.Module:
        print(f"arch_init_kwargs: {arch_init_kwargs}")
        print(f"arch_init_kwargs_req_import: {arch_init_kwargs_req_import}")
        print(f"num_input_channels: {num_input_channels}")
        print(f"num_output_channels: {num_output_channels}")

        model = Primus(
            num_input_channels,
            864,
            (8, 8, 8),
            num_output_channels,
            16,
            12,
            (192,192,192),
            drop_path_rate=0.2,
            scale_attn_inner=True,
            init_values=0.1,
        )
        return model


`

-> please note that the patch size here is hardcoded, because once you dont use 'self', you have to look for a config manager yourself. I found it in "nnUNetResEncUNetLPlans.json" "3d_fullres": { "data_identifier": "nnUNetPlans_3d_fullres", "preprocessor_name": "DefaultPreprocessor", "batch_size": 2, "patch_size": [ 192, 192, 192 ],

-> I ran the inference with this command:

nnUNetv2_predict -f 1 -i /home/kit/test1/ -o /home/kit/results/ -d 322 -c 3d_fullres -tr MyCustomPrimusTrainer -p nnUNetResEncUNetLPlans --verbose

I hope that I conveyed my point in an understandable way, although the explanation turned out to be quite cumbersome, I'm happy to provide more info if I forgot something.

Despite the fact that I made it running, it is not a clean solution at all, and I hope this can be solved in more elegant way.

ChepchikValeri avatar Jul 04 '25 12:07 ChepchikValeri

As a simpler workaround, adding the lines:

if 'primus' in str(trainer_class).lower():    
    trainer_class = trainer_class(plans=plans, configuration=configuration_name, fold=None, dataset_json=dataset_json, device=self.device)

in inference/predict_from_raw_data.py, line 104 seems to work.

remihndz avatar Jul 16 '25 12:07 remihndz

As a simpler workaround, adding the lines:

if 'primus' in str(trainer_class).lower():
trainer_class = trainer_class(plans=plans, configuration=configuration_name, fold=None, dataset_json=dataset_json, device=self.device)

in inference/predict_from_raw_data.py, line 104 seems to work.

Thank you, it works! Should this be added to the main repo?

I will make this issue closed, as the proposed solution works

ChepchikValeri avatar Jul 22 '25 12:07 ChepchikValeri

Hey, I just saw that this seems to be an issue and will provide a fix as this is clearly a bug.

Thanks for posting!

TaWald avatar Aug 18 '25 12:08 TaWald