TotalSegmentator
TotalSegmentator copied to clipboard
Model weights on Hugging Face
Hi! This is really amazing work! It would be awesome to have the model weights shared on Hugging Face. Here's some more information how to do so: https://huggingface.co/docs/hub/models-uploading Happy to help with any questions!
Thanks for the suggestion. Huggingface is doing a great job with hosting machine learning models. However, I think this model is a bit different from your typical pytorch model. Just hosting the weights on huggingface will not really help any of our users.
Thanks for your response! Curious to hear why this model is different from your typical pytorch model? I get asked quite frequently about medical image segmentation models, and I think even having this model searchable within the hub would be useful for others in the ML community who might not yet be aware of this awesome work. If it's okay with you, I'd be happy to help put the weights for the TotalSegmentator models on HF (of course, giving you full and proper attribution, linking back to this repository, including the appropriate license, etc). Let me know what you think.
@katielink to my knowledge, inference of all organs requires running inference for several models, further, a bit of a post processing is involved (see https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/nnunet.py#L331 ).
On top of that, you'll need a a decent GPU, or else inference will take a while (i'm guessing, 20 to 40min, depending on your CPU spec) So if people are going for the huggingface inference endpoint route, they will need to up the instance spec 💸. Most people/researchers using this have a tight budget and/or likely already have on-prem gpu/HPCs.
That said, I've gave it a quick shot (see below sample huggingface space/gradio impl), since I'm curious if the free hardware can execute totalsegmentator without error. I think you can convert it to a "huggingface model repo" with custom handler. One lacking logic is getting the weights to be cached in the deployed docker container / model repo. Also, I have to add the "--fast" flag, or else it will timeout at 30min when deployed to huggingface.co using the free instance (2cpu ram16gb).
--
requirements.txt
: #TODO add version
totalsegmentator
app.py
:
import os
import sys
import tempfile
import subprocess
import gradio as gr
import totalsegmentator
if os.environ.get("WEIGHTS_CACHED") != "TRUE":
print('downloading pretrained weights...')
subprocess.call("totalseg_download_weights -t total",shell=True)
os.environ["WEIGHTS_CACHED"]="TRUE"
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLE_NIFTI_PATH = os.path.join(THIS_DIR,'files','sample-image.nii.gz')
def myfunc(file_obj):
file_list = []
input_file_path = file_obj.name
with tempfile.TemporaryDirectory() as tmpdir:
output_folder_path = os.path.join(tmpdir,"segmentations")
cmd_str = f"TotalSegmentator -i {input_file_path} -o {output_folder_path} --fast"
subprocess.call(cmd_str,shell=True)
if os.path.exists(output_folder_path):
file_list = os.listdir(output_folder_path)
return {"status":file_list} # TODO: maybe add papaya js element to view nifti
if __name__ == "__main__":
demo = gr.Interface(myfunc, ["file"], "json",examples=[EXAMPLE_NIFTI_PATH],cache_examples=True)
demo.queue().launch(debug=True,show_api=True,share=False)