segmentation_models.pytorch icon indicating copy to clipboard operation
segmentation_models.pytorch copied to clipboard

integrate SAM (segment anything) encoder with Unet

Open Rusteam opened this issue 2 years ago • 68 comments

Closes #756

Added:

  • SAM to models
  • 3 SAM backbones (vit_h, vit_b and vit_l) to encoders
  • unittests and docs for SAM

Changed:

  • flake8 pre-commit repo to github (current) and version to latest

Rusteam avatar May 03 '23 11:05 Rusteam

hi @qubvel is there any update on this? I've just trained a model using this branch and it worked.

Rusteam avatar May 05 '23 05:05 Rusteam

@Rusteam is the code merged into the main repo??i want to use this model to fine-tune my data?

It's not. Not sure if @qubvel has had a chance to look into this PR. You could use my fork in the meanwhile. And do let me know how your fine-Tuning goes because I haven't had much success so far.

Rusteam avatar May 14 '23 06:05 Rusteam

@Rusteam how to train a model ,can u give some outlines?as author is not responding pls help me to train a model.. I have sent u an mail pls give a look

make sure you install this package from my fork pip instal git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam and then initialize your model as usual create_model("SAM", "sam-vit_b", encoder_weights=None, **kwargs) and run your training. You could pass weights="sa-1b" in kwargs if you want to fine-tune from pre-trained weights.

So far I have been able to train the model, but I can't say it's learning. I'm still struggling there. Also I cannot fit more than 1 sample per batch on a 32gb gpu with a 512 input size.

Rusteam avatar May 15 '23 05:05 Rusteam

@Rusteam how about this: https://github.com/tianrun-chen/SAM-Adapter-PyTorch

ccl-private avatar May 16 '23 07:05 ccl-private

thanks for sharing, I'll try it if my current approach does not work. I've able to get some learning with this transformers notebook

Rusteam avatar May 16 '23 07:05 Rusteam

Hi @Rusteam, thanks a lot for your contribution and sorry for the delay, I am going to review the request and will let you know

qubvel avatar May 17 '23 14:05 qubvel

Hey hey hey. While this solution worked I can't say the model was able to learn on my data. We might need to use the version before my ddp adjustments or make the model handle points and boxes as inputs, or use Sam image encoder with unet or other architectures.

Rusteam avatar May 17 '23 14:05 Rusteam

Yes, I was actually thinking about just pre-trained encoder integration, did you test it?

qubvel avatar May 17 '23 15:05 qubvel

can we use this model to train on custom data??

@qubvel It didn't work with Unet yet, but I can make it work. Which models would be essential to integrate?

Rusteam avatar May 18 '23 05:05 Rusteam

@Rusteam @qubvel can we use this model to train on custom data??

that was my intention as well, but I was unable to make it learn without passing box/point prompts. However, when passing a prompt along with input image, it does learn. We might need to integrate multiple inputs to forward() call for it to work, or just use sam's image encoder with other arches like Unet

Rusteam avatar May 18 '23 05:05 Rusteam

It would be nice to start with integrating just the encoder first, so it could work with Unet and other decoders

qubvel avatar May 19 '23 07:05 qubvel

It would be nice to start with integrating just the encoder first, so it could work with Unet and other decoders

Actually just did that and running some experiments at the moment. I'm also glad to see that's able to learn. I'm attaching a screenshot of my current learning curves: image

Rusteam avatar May 19 '23 10:05 Rusteam

Do you want to use Sam image encoder with a Unet decoder or do you want to fine-tune full Sam model? If former, the it's working now, if latter then we have to figure out what's best strategy to implement a forward function with multiple inputs.

Rusteam avatar May 19 '23 13:05 Rusteam

Following this closely, can't wait to see it working with UNet and for @qubvel to push this though to the main repo :) Thanks for all the efforts guys.

chefkrym avatar May 19 '23 16:05 chefkrym

Squat merge. hahaha

lixiang007666 avatar May 20 '23 03:05 lixiang007666

@qubvel can u merge this into the main repo? @Rusteam are u making any extra changes??

sushmanthreddy avatar May 21 '23 17:05 sushmanthreddy

@sushmanthreddy just use my fork in the meantime: pip install git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam

Rusteam avatar May 22 '23 06:05 Rusteam

@Rusteam Could you please remove the Decoder part becuase it is not stable yet? Or create another PR with just the encoder, so I am able to merge it?

qubvel avatar May 24 '23 08:05 qubvel

@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?

qubvel avatar May 24 '23 08:05 qubvel

@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?

This one includes mit_b1 backbone with Unet image

Rusteam avatar May 24 '23 09:05 Rusteam

@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?

ok, I will do it shortly

Rusteam avatar May 24 '23 09:05 Rusteam

@qubvel removed sam decoder, I'll see if I get time to finish its implementation in the nearest future and add it in a new PR if I do.

Rusteam avatar May 25 '23 06:05 Rusteam

Thanks, @Rusteam, I added the review, could you please go through it?

qubvel avatar May 25 '23 16:05 qubvel

Btw, how many output tensors does the encoder have? Is it more than the usual 5? Is it 12 outputs?

qubvel avatar May 26 '23 13:05 qubvel

Btw, how many output tensors does the encoder have? Is it more than the usual 5? Is it 12 outputs?

It's actually just one output tensor at the moment. The implementation of skip connections can be a nice add-on.

Rusteam avatar May 26 '23 14:05 Rusteam

Thats probably might be a cause of poor performance for Unet-like architectures..

qubvel avatar May 26 '23 14:05 qubvel

What is the dimension of this tensor compared to original size, H//32 W//32? Am I correct that encoder.forward(x) produce now just two tensors [x, vit_output]?

qubvel avatar May 26 '23 14:05 qubvel

@Rusteam i have been using ur repo for the segmentation...The branch ,was using is sam-unet .. i don't knew was going wrong ...

here is the link for notebook. https://github.com/sushmanthreddy/GSOC_2023/blob/main/notebooks/cell-nucleus-segmentor.ipynb please free feel to comment, where I was going wrong?

sushmanthreddy avatar May 27 '23 18:05 sushmanthreddy

@Rusteam model = smp.SAM( encoder_name="sam-vit_b" encoder_weights="sa-1b", weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, )

what we should add in the weights variable?

sushmanthreddy avatar May 27 '23 18:05 sushmanthreddy

@Rusteam model = smp.SAM( encoder_name="sam-vit_b" encoder_weights="sa-1b", weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, )

what we should add in the weights variable?

we've decided to remove smp.SAM for the moment and only keep sam image encoder with Unet. Refrain from using it or use at your own risk. It was supposed to be smp.SAM(..., encoder_weights=None, weights="sa-1b")

Rusteam avatar May 29 '23 06:05 Rusteam