nnUNet
nnUNet copied to clipboard
cascade model calls lowres with default checkpoint at inference if lowres not run manually
Hi, Thank you for your amazing work!
I had a search through the repo and couldn't seem to find anything on this before I post.
When requesting the cascade model e.g. 3d_cascade_fullres
at inference and I haven't explicitly requested the lowres model before hand then predict_simple.py
calls the lowres model to run inference.
line 189 predict_simple.py
if model == "3d_cascade_fullres" and lowres_segmentations is None:
However, I notice that in this call it doesn't pass my specified argument for the checkpoint and therefore defaults to the default checkpoint which is "model_final_checkpoint"
for the lowres inference.
line 199-203 nnunet/inference/predict_simple.py
predict_from_folder(model_folder_name, input_folder, lowres_output_folder, folds, False,
num_threads_preprocessing, num_threads_nifti_save, None, part_id, num_parts, not disable_tta,
overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,
mixed_precision=not args.disable_mixed_precision,
step_size=step_size)
picks up that I haven't already predicted the lowres segmentations before doesn't explicitly call the checkpoint I requested in call_inference
line 604-610 nnunet/inference/predict.py
def predict_from_folder(model: str, input_folder: str, output_folder: str, folds: Union[Tuple[int], List[int]],
save_npz: bool, num_threads_preprocessing: int, num_threads_nifti_save: int,
lowres_segmentations: Union[str, None],
part_id: int, num_parts: int, tta: bool, mixed_precision: bool = True,
overwrite_existing: bool = True, mode: str = 'normal', overwrite_all_in_gpu: bool = None,
step_size: float = 0.5, checkpoint_name: str = "model_final_checkpoint",
segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False):
So if I request "model_best"
it still defaults to run the lowres inference with checkpoint_name="model_final_checkpoint"
. I would expect it to use which ever checkpoint I specified. i.e. adding checkpoint_name=args.chk
to the call to predict_from_folder
on line 203 in predict_simple.py
.