PRIMUS Trainer runs training+validation, but fails to run Inference
Hi nnUNet team!
First of all, thank you for such an amazing tool!
The code can run training and validation with no issues . I ran training like this:
CUDA_VISIBLE_DEVICES=4 nnUNetv2_train 322 3d_fullres 1 -tr nnUNet_Primus_M_Trainer -p nnUNetResEncUNetLPlans
Btw I checked both, nnunet git main and https://github.com/TaWald/nnUNet/tree/primus# , both behave in the same way
I ran into an issue with PRIMUS Inference.
when running Inference I run into the following:
`CUDA_VISIBLE_DEVICES=4 nnUNetv2_predict -f 1 -i /home/volodymyr.buchakchiiskyi/input/ -o /home/volodymyr.buchakchiiskyi/outputPrimus/ -d 322 -c 3d_fullres --verbose -tr nnUNet_Primus_M_Trainer -p nnUNetResEncUNetLPlans
#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################
Traceback (most recent call last):
File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/bin/nnUNetv2_predict", line 8, in <module>
sys.exit(predict_entry_point())
~~~~~~~~~~~~~~~~~~~^^
File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/lib/python3.13/site-packages/nnunetv2/inference/predict_from_raw_data.py", line 975, in predict_entry_point
predictor.initialize_from_trained_model_folder(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
model_folder,
^^^^^^^^^^^^^
args.f,
^^^^^^^
checkpoint_name=args.chk
^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/volodymyr.buchakchiiskyi/miniforge3/envs/newprimus/lib/python3.13/site-packages/nnunetv2/inference/predict_from_raw_data.py", line 104, in initialize_from_trained_model_folder
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
...<4 lines>...
enable_deep_supervision=False
)
TypeError: nnUNet_Primus_M_Trainer.build_network_architecture() missing 1 required positional argument: 'num_output_channels'
`
From my understanding, the core of the problem is that the primus trainers are inheriting abstract method to build network architecture build_network_architecture with 7 args, but the whole frameworks expects static method with 6 (missing 'self' here)
more detailed explanation below: >>>>>>>
-> in the default nnUNetTrainer.py , there is a static method build_network_architecture that takes 6 args.
@staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module:
-> primus_trainers.py defines an AbstractPrimus class with an abstract method
@abstractmethod def build_network_architecture( self, architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True, ) -> nn.Module: raise NotImplementedError()
-> and children class nnUNet_Primus_M_Trainer are using this abstract method, now wiht 7 args, 'self' is an extra one: `class nnUNet_Primus_M_Trainer(AbstractPrimus):
def build_network_architecture(
self,
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
model = Primus(
num_input_channels,
864,
(8, 8, 8),
num_output_channels,
16,
12,
self.configuration_manager.patch_size,
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model`
-> I made it working by reimplementing the primus training class as follows:
`from typing import List, Tuple, Union
from torch import nn
from dynamic_network_architectures.architectures.primus import Primus
from nnunetv2.training.nnUNetTrainer.variants.lr_schedule.nnUNetTrainer_warmup import nnUNetTrainer_warmup
import torch
from torch import nn, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from nnunetv2.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset
from nnunetv2.utilities.helpers import empty_cache, dummy_context
class MyCustomPrimusTrainer(nnUNetTrainer_warmup):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, device)
self.initial_lr = 3e-4
self.weight_decay = 5e-2
self.enable_deep_supervision = False
def configure_optimizers(self, stage: str = "warmup_all"):
assert stage in ["warmup_all", "train"]
if self.training_stage == stage:
return self.optimizer, self.lr_scheduler
if isinstance(self.network, DDP):
params = self.network.module.parameters()
else:
params = self.network.parameters()
if stage == "warmup_all":
self.print_to_log_file("train whole net, warmup")
optimizer = torch.optim.AdamW(
params, self.initial_lr, weight_decay=self.weight_decay, amsgrad=False, betas=(0.9, 0.98), fused=True
)
lr_scheduler = Lin_incr_LRScheduler(optimizer, self.initial_lr, self.warmup_duration_whole_net)
self.print_to_log_file(f"Initialized warmup_all optimizer and lr_scheduler at epoch {self.current_epoch}")
else:
self.print_to_log_file("train whole net, default schedule")
if self.training_stage == "warmup_all":
# we can keep the existing optimizer and don't need to create a new one. This will allow us to keep
# the accumulated momentum terms which already point in a useful driection
optimizer = self.optimizer
else:
optimizer = torch.optim.AdamW(
params,
self.initial_lr,
weight_decay=self.weight_decay,
amsgrad=False,
betas=(0.9, 0.98),
fused=True,
)
lr_scheduler = PolyLRScheduler_offset(
optimizer, self.initial_lr, self.num_epochs, self.warmup_duration_whole_net
)
self.print_to_log_file(f"Initialized train optimizer and lr_scheduler at epoch {self.current_epoch}")
self.training_stage = stage
empty_cache(self.device)
return optimizer, lr_scheduler
def train_step(self, batch: dict) -> dict:
data = batch["data"]
target = batch["target"]
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context():
output = self.network(data)
# del data
l = self.loss(output, target)
if self.grad_scaler is not None:
self.grad_scaler.scale(l).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
l.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
self.optimizer.step()
return {"loss": l.detach().cpu().numpy()}
def set_deep_supervision_enabled(self, enabled: bool):
pass
@staticmethod
def build_network_architecture(
architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True,
) -> nn.Module:
print(f"arch_init_kwargs: {arch_init_kwargs}")
print(f"arch_init_kwargs_req_import: {arch_init_kwargs_req_import}")
print(f"num_input_channels: {num_input_channels}")
print(f"num_output_channels: {num_output_channels}")
model = Primus(
num_input_channels,
864,
(8, 8, 8),
num_output_channels,
16,
12,
(192,192,192),
drop_path_rate=0.2,
scale_attn_inner=True,
init_values=0.1,
)
return model
`
-> please note that the patch size here is hardcoded, because once you dont use 'self', you have to look for a config manager yourself.
I found it in "nnUNetResEncUNetLPlans.json"
"3d_fullres": { "data_identifier": "nnUNetPlans_3d_fullres", "preprocessor_name": "DefaultPreprocessor", "batch_size": 2, "patch_size": [ 192, 192, 192 ],
-> I ran the inference with this command:
nnUNetv2_predict -f 1 -i /home/kit/test1/ -o /home/kit/results/ -d 322 -c 3d_fullres -tr MyCustomPrimusTrainer -p nnUNetResEncUNetLPlans --verbose
I hope that I conveyed my point in an understandable way, although the explanation turned out to be quite cumbersome, I'm happy to provide more info if I forgot something.
Despite the fact that I made it running, it is not a clean solution at all, and I hope this can be solved in more elegant way.
As a simpler workaround, adding the lines:
if 'primus' in str(trainer_class).lower():
trainer_class = trainer_class(plans=plans, configuration=configuration_name, fold=None, dataset_json=dataset_json, device=self.device)
in inference/predict_from_raw_data.py, line 104 seems to work.
As a simpler workaround, adding the lines:
if 'primus' in str(trainer_class).lower():
trainer_class = trainer_class(plans=plans, configuration=configuration_name, fold=None, dataset_json=dataset_json, device=self.device)in
inference/predict_from_raw_data.py, line 104 seems to work.
Thank you, it works! Should this be added to the main repo?
I will make this issue closed, as the proposed solution works
Hey, I just saw that this seems to be an issue and will provide a fix as this is clearly a bug.
Thanks for posting!