ONNX export
Do you support ONNX export and is it tested? Any examples of it would be appreciated.
I want to know the same question, can you give an example? Thank you very much.
Hi, conversion to ONNX can be done like this:
def convert_to_ONNX(
model_dir: str,
onnx_model_path: str,
batch_size: int = 1,
tile_size: int =512,
fold: int = 1,
num_channels: int = 3,
):
model_path = f"{model_dir}/fold_{fold}/checkpoint_best.pth"
dataset_json = load_json(join(model_dir, 'dataset.json'))
plans = load_json(join(model_dir, 'plans.json'))
plans_manager = PlansManager(plans)
parameters = []
use_folds = [fold]
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_dir, f'fold_{f}', model_path),
map_location=torch.device('cpu'))
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
parameters.append(checkpoint['network_weights'])
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
model = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
for params in parameters:
model.load_state_dict(params)
model.eval()
# convert to onnx
dummy = torch.randn(
batch_size, num_channels, tile_size, tile_size, requires_grad=True
)
torch.onnx.export(
model,
dummy,
onnx_model_path,
verbose=False,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
)
Hope it helps
I haven't tested @rubencardenes code yet, but I wanted to add I would also be interested in an ONNX export, as I'm unable to install and use nnUNet due to #2430
@rubencardenes
Hi, conversion to ONNX can be done like this:
def convert_to_ONNX(
Hope it helps
Hey, I was able to make this work and get an ONNX model file. However, I'm still new to ONNX and I was wondering if you could assist with the inference side.
The ONNX model doesn't preprocess the images for inference in the same was as nnUNet or I set it up wrong.
ONNX Model Creation
convert_to_ONNX(
model_dir=join(
nnUNet_results, "Dataset500_M662_nnUNet/nnUNetTrainer__nnUNetResEncUNetMPlans__2d/"
),
onnx_model_path=join(nnUNet_results, "Dataset500_M662_nnUNet/infer2donnx/m662_2d_model.onnx"),
batch_size=1,
tile_size=512,
fold="all",
num_channels=1,
)
ONNX Inference Code
# Used nnUNet's for simplicity in reading the images
img, props = SimpleITKIO().read_images(
[join(nnUNet_raw, "Dataset500_M662_nnUNet/imagesTs/180_0000.png")]
)
import onnxruntime as ort
model_path = "/path/to/model.onnx"
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
predictions = session.run(None, {input_name: img})
print(f"Inference Predictions:\n: {predictions}")
This fails on the session.run due to a mismatch between the image size and the model's expected image size. The model is set for 512 (I just copied what you had), but my images are larger than this. I assumed the ONNX model would handle this.
Error
---------------------------------------------------------------------------
InvalidArgument Traceback (most recent call last)
Cell In[38], line 10
7 session = ort.InferenceSession(model_path)
8 input_name = session.get_inputs()[0].name
---> 10 predictions = session.run(None, {input_name: img})
12 print(f"Inference Predictions:\n: {predictions}")
File ~/work/caml2/caml/.venv/lib64/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:270, in Session.run(self, output_names, input_feed, run_options)
268 output_names = [output.name for output in self._outputs_meta]
269 try:
--> 270 return self._sess.run(output_names, input_feed, run_options)
271 except C.EPFail as err:
272 if self._enable_fallback:
InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
index: 2 Got: 1960 Expected: 512
index: 3 Got: 648 Expected: 512
Please fix either the inputs/outputs or the model.
Are you able to point me in the right direction to resolve this? I was hoping it would bundle preprocessing into the ONNX model.
Hello
I used @rubencardenes convert_to_onnx method to export an nnunet model to onnx, but I had to preprocess the image and convert it to patches with the size accepted by the model beforehand. I hope it helps!