mlx-examples
mlx-examples copied to clipboard
Can I convert nnU-Net pytorch model to MLX?
Hi all!
Is it possible to convert a medical image segmentation model trained using the nnU-Net framework and stored as a .pth
file into an MLX compatible format?
Thanks!
My guess is some of the U Net wants are not well supported yet in MLX (pooling / upsampling / transpose conv). We will add those ops in due time but some of them require new backend kernels to be reasonably efficient so it's not a very quick fix.
Just a follow-up: when trying to run the trained nnU-Net model directly using pytorch using torch.device("mps")
, I get ConvTranspose 3D is not supported on MPS
.
File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/dynamic_network_architectures/building_blocks/unet_decoder.py", line 109, in forward
x = self.transpconvs[s](lres_input)
File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 1104, in forward
return F.conv_transpose3d(
RuntimeError: ConvTranspose 3D is not supported on MPS
Related to https://github.com/pytorch/pytorch/issues/130256
Ok.. well that's a PyTorch issue. In MLX we have a 3D convolution, so that should work