TotalSegmentator icon indicating copy to clipboard operation
TotalSegmentator copied to clipboard

Model weights on Hugging Face

Open katielink opened this issue 1 year ago • 3 comments

Hi! This is really amazing work! It would be awesome to have the model weights shared on Hugging Face. Here's some more information how to do so: https://huggingface.co/docs/hub/models-uploading Happy to help with any questions!

katielink avatar Oct 19 '23 14:10 katielink

Thanks for the suggestion. Huggingface is doing a great job with hosting machine learning models. However, I think this model is a bit different from your typical pytorch model. Just hosting the weights on huggingface will not really help any of our users.

wasserth avatar Oct 20 '23 08:10 wasserth

Thanks for your response! Curious to hear why this model is different from your typical pytorch model? I get asked quite frequently about medical image segmentation models, and I think even having this model searchable within the hub would be useful for others in the ML community who might not yet be aware of this awesome work. If it's okay with you, I'd be happy to help put the weights for the TotalSegmentator models on HF (of course, giving you full and proper attribution, linking back to this repository, including the appropriate license, etc). Let me know what you think.

katielink avatar Oct 20 '23 12:10 katielink

@katielink to my knowledge, inference of all organs requires running inference for several models, further, a bit of a post processing is involved (see https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/nnunet.py#L331 ).

On top of that, you'll need a a decent GPU, or else inference will take a while (i'm guessing, 20 to 40min, depending on your CPU spec) So if people are going for the huggingface inference endpoint route, they will need to up the instance spec 💸. Most people/researchers using this have a tight budget and/or likely already have on-prem gpu/HPCs.

That said, I've gave it a quick shot (see below sample huggingface space/gradio impl), since I'm curious if the free hardware can execute totalsegmentator without error. I think you can convert it to a "huggingface model repo" with custom handler. One lacking logic is getting the weights to be cached in the deployed docker container / model repo. Also, I have to add the "--fast" flag, or else it will timeout at 30min when deployed to huggingface.co using the free instance (2cpu ram16gb).

-- requirements.txt: #TODO add version

totalsegmentator

app.py:

import os
import sys
import tempfile
import subprocess
import gradio as gr
import totalsegmentator

if os.environ.get("WEIGHTS_CACHED") != "TRUE":
    print('downloading pretrained weights...')
    subprocess.call("totalseg_download_weights -t total",shell=True)
    os.environ["WEIGHTS_CACHED"]="TRUE"

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLE_NIFTI_PATH = os.path.join(THIS_DIR,'files','sample-image.nii.gz')

def myfunc(file_obj):    
    file_list = []
    input_file_path = file_obj.name
    with tempfile.TemporaryDirectory() as tmpdir:
        output_folder_path = os.path.join(tmpdir,"segmentations")
        cmd_str = f"TotalSegmentator -i {input_file_path} -o {output_folder_path} --fast"
        subprocess.call(cmd_str,shell=True)
        if os.path.exists(output_folder_path):
            file_list = os.listdir(output_folder_path)
        return {"status":file_list} # TODO: maybe add papaya js element to view nifti

if __name__ == "__main__":
    demo = gr.Interface(myfunc, ["file"], "json",examples=[EXAMPLE_NIFTI_PATH],cache_examples=True)
    demo.queue().launch(debug=True,show_api=True,share=False)

image

pangyuteng avatar Dec 07 '23 22:12 pangyuteng