mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Can I convert nnU-Net pytorch model to MLX?

Open valosekj opened this issue 1 year ago • 3 comments

Hi all! Is it possible to convert a medical image segmentation model trained using the nnU-Net framework and stored as a .pth file into an MLX compatible format? Thanks!

valosekj avatar Jan 24 '24 16:01 valosekj

My guess is some of the U Net wants are not well supported yet in MLX (pooling / upsampling / transpose conv). We will add those ops in due time but some of them require new backend kernels to be reasonably efficient so it's not a very quick fix.

awni avatar Jan 24 '24 17:01 awni

Just a follow-up: when trying to run the trained nnU-Net model directly using pytorch using torch.device("mps"), I get ConvTranspose 3D is not supported on MPS.

  File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/dynamic_network_architectures/building_blocks/unet_decoder.py", line 109, in forward
    x = self.transpconvs[s](lres_input)
  File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/valosek/miniconda3/envs/nnunet/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 1104, in forward
    return F.conv_transpose3d(
RuntimeError: ConvTranspose 3D is not supported on MPS

Related to https://github.com/pytorch/pytorch/issues/130256

valosekj avatar Jul 08 '24 17:07 valosekj

Ok.. well that's a PyTorch issue. In MLX we have a 3D convolution, so that should work

awni avatar Jul 08 '24 17:07 awni