nnUNet icon indicating copy to clipboard operation
nnUNet copied to clipboard

ONNX export

Open ogencoglu opened this issue 1 year ago • 5 comments

Do you support ONNX export and is it tested? Any examples of it would be appreciated.

ogencoglu avatar Jul 29 '24 17:07 ogencoglu

I want to know the same question, can you give an example? Thank you very much.

ZxnSnowy avatar Aug 07 '24 09:08 ZxnSnowy

Hi, conversion to ONNX can be done like this:

def convert_to_ONNX(
    model_dir: str,
    onnx_model_path: str,
    batch_size: int = 1,
    tile_size: int =512,
    fold: int = 1,
    num_channels: int = 3,  
    
):  
    model_path = f"{model_dir}/fold_{fold}/checkpoint_best.pth"
    dataset_json = load_json(join(model_dir, 'dataset.json'))
    plans = load_json(join(model_dir, 'plans.json'))
    plans_manager = PlansManager(plans)
    
    parameters = []
    use_folds = [fold]
    for i, f in enumerate(use_folds):
        f = int(f) if f != 'all' else f
        checkpoint = torch.load(join(model_dir, f'fold_{f}', model_path),
                                map_location=torch.device('cpu'))
        if i == 0:
            trainer_name = checkpoint['trainer_name']
            configuration_name = checkpoint['init_args']['configuration']
            inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
                'inference_allowed_mirroring_axes' in checkpoint.keys() else None

        parameters.append(checkpoint['network_weights'])

    configuration_manager = plans_manager.get_configuration(configuration_name)
    # restore network
    num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
    trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')
    model = trainer_class.build_network_architecture(
            configuration_manager.network_arch_class_name,
            configuration_manager.network_arch_init_kwargs,
            configuration_manager.network_arch_init_kwargs_req_import,
            num_input_channels,
            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
            enable_deep_supervision=False
        )
    for params in parameters:
        model.load_state_dict(params)
    model.eval()

    # convert to onnx
    dummy = torch.randn(
        batch_size, num_channels, tile_size, tile_size, requires_grad=True
    )
    torch.onnx.export(
        model,
        dummy,
        onnx_model_path,
        verbose=False,
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
    )

Hope it helps

rubencardenes avatar Oct 22 '24 10:10 rubencardenes

I haven't tested @rubencardenes code yet, but I wanted to add I would also be interested in an ONNX export, as I'm unable to install and use nnUNet due to #2430

vmiller987 avatar Mar 14 '25 12:03 vmiller987

@rubencardenes

Hi, conversion to ONNX can be done like this:

def convert_to_ONNX(

Hope it helps

Hey, I was able to make this work and get an ONNX model file. However, I'm still new to ONNX and I was wondering if you could assist with the inference side.

The ONNX model doesn't preprocess the images for inference in the same was as nnUNet or I set it up wrong.

ONNX Model Creation

convert_to_ONNX(
    model_dir=join(
        nnUNet_results, "Dataset500_M662_nnUNet/nnUNetTrainer__nnUNetResEncUNetMPlans__2d/"
    ),
    onnx_model_path=join(nnUNet_results, "Dataset500_M662_nnUNet/infer2donnx/m662_2d_model.onnx"),
    batch_size=1,
    tile_size=512,
    fold="all",
    num_channels=1,
)

ONNX Inference Code

# Used nnUNet's for simplicity in reading the images
img, props = SimpleITKIO().read_images(
    [join(nnUNet_raw, "Dataset500_M662_nnUNet/imagesTs/180_0000.png")]
)

import onnxruntime as ort

model_path = "/path/to/model.onnx"

session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name

predictions = session.run(None, {input_name: img})

print(f"Inference Predictions:\n: {predictions}")

This fails on the session.run due to a mismatch between the image size and the model's expected image size. The model is set for 512 (I just copied what you had), but my images are larger than this. I assumed the ONNX model would handle this.

Error

---------------------------------------------------------------------------
InvalidArgument                           Traceback (most recent call last)
Cell In[38], line 10
      7 session = ort.InferenceSession(model_path)
      8 input_name = session.get_inputs()[0].name
---> 10 predictions = session.run(None, {input_name: img})
     12 print(f"Inference Predictions:\n: {predictions}")

File ~/work/caml2/caml/.venv/lib64/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:270, in Session.run(self, output_names, input_feed, run_options)
    268     output_names = [output.name for output in self._outputs_meta]
    269 try:
--> 270     return self._sess.run(output_names, input_feed, run_options)
    271 except C.EPFail as err:
    272     if self._enable_fallback:

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
 index: 2 Got: 1960 Expected: 512
 index: 3 Got: 648 Expected: 512
 Please fix either the inputs/outputs or the model.

Are you able to point me in the right direction to resolve this? I was hoping it would bundle preprocessing into the ONNX model.

vmiller987 avatar Mar 21 '25 16:03 vmiller987

Hello

I used @rubencardenes convert_to_onnx method to export an nnunet model to onnx, but I had to preprocess the image and convert it to patches with the size accepted by the model beforehand. I hope it helps!

smercav avatar Mar 21 '25 16:03 smercav