integrate SAM (segment anything) encoder with Unet
Closes #756
Added:
- SAM to models
- 3 SAM backbones (
vit_h,vit_bandvit_l) to encoders - unittests and docs for SAM
Changed:
- flake8 pre-commit repo to github (current) and version to latest
hi @qubvel is there any update on this? I've just trained a model using this branch and it worked.
@Rusteam is the code merged into the main repo??i want to use this model to fine-tune my data?
It's not. Not sure if @qubvel has had a chance to look into this PR. You could use my fork in the meanwhile. And do let me know how your fine-Tuning goes because I haven't had much success so far.
@Rusteam how to train a model ,can u give some outlines?as author is not responding pls help me to train a model.. I have sent u an mail pls give a look
make sure you install this package from my fork pip instal git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam and then initialize your model as usual create_model("SAM", "sam-vit_b", encoder_weights=None, **kwargs) and run your training. You could pass weights="sa-1b" in kwargs if you want to fine-tune from pre-trained weights.
So far I have been able to train the model, but I can't say it's learning. I'm still struggling there. Also I cannot fit more than 1 sample per batch on a 32gb gpu with a 512 input size.
@Rusteam how about this: https://github.com/tianrun-chen/SAM-Adapter-PyTorch
thanks for sharing, I'll try it if my current approach does not work. I've able to get some learning with this transformers notebook
Hi @Rusteam, thanks a lot for your contribution and sorry for the delay, I am going to review the request and will let you know
Hey hey hey. While this solution worked I can't say the model was able to learn on my data. We might need to use the version before my ddp adjustments or make the model handle points and boxes as inputs, or use Sam image encoder with unet or other architectures.
Yes, I was actually thinking about just pre-trained encoder integration, did you test it?
can we use this model to train on custom data??
@qubvel It didn't work with Unet yet, but I can make it work. Which models would be essential to integrate?
@Rusteam @qubvel can we use this model to train on custom data??
that was my intention as well, but I was unable to make it learn without passing box/point prompts. However, when passing a prompt along with input image, it does learn. We might need to integrate multiple inputs to forward() call for it to work, or just use sam's image encoder with other arches like Unet
It would be nice to start with integrating just the encoder first, so it could work with Unet and other decoders
It would be nice to start with integrating just the encoder first, so it could work with Unet and other decoders
Actually just did that and running some experiments at the moment. I'm also glad to see that's able to learn. I'm attaching a screenshot of my current learning curves:
Do you want to use Sam image encoder with a Unet decoder or do you want to fine-tune full Sam model? If former, the it's working now, if latter then we have to figure out what's best strategy to implement a forward function with multiple inputs.
Following this closely, can't wait to see it working with UNet and for @qubvel to push this though to the main repo :) Thanks for all the efforts guys.
Squat merge. hahaha
@qubvel can u merge this into the main repo? @Rusteam are u making any extra changes??
@sushmanthreddy just use my fork in the meantime: pip install git+https://github.com/Rusteam/segmentation_models.pytorch.git@sam
@Rusteam Could you please remove the Decoder part becuase it is not stable yet? Or create another PR with just the encoder, so I am able to merge it?
@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?
@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?
This one includes mit_b1 backbone with Unet
@Rusteam Do you have any reference for learning curves with another backbone on your task, e.g. resnet18/34 for comparison?
ok, I will do it shortly
@qubvel removed sam decoder, I'll see if I get time to finish its implementation in the nearest future and add it in a new PR if I do.
Thanks, @Rusteam, I added the review, could you please go through it?
Btw, how many output tensors does the encoder have? Is it more than the usual 5? Is it 12 outputs?
Btw, how many output tensors does the encoder have? Is it more than the usual 5? Is it 12 outputs?
It's actually just one output tensor at the moment. The implementation of skip connections can be a nice add-on.
Thats probably might be a cause of poor performance for Unet-like architectures..
What is the dimension of this tensor compared to original size, H//32 W//32?
Am I correct that encoder.forward(x) produce now just two tensors [x, vit_output]?
@Rusteam i have been using ur repo for the segmentation...The branch ,was using is sam-unet .. i don't knew was going wrong ...
here is the link for notebook. https://github.com/sushmanthreddy/GSOC_2023/blob/main/notebooks/cell-nucleus-segmentor.ipynb please free feel to comment, where I was going wrong?
@Rusteam
model = smp.SAM( encoder_name="sam-vit_b" encoder_weights="sa-1b", weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, )
what we should add in the weights variable?
@Rusteam
model = smp.SAM( encoder_name="sam-vit_b" encoder_weights="sa-1b", weights=None, image_size=64, decoder_multimask_output=decoder_multiclass_output, classes=n_classes, )what we should add in the weights variable?
we've decided to remove smp.SAM for the moment and only keep sam image encoder with Unet. Refrain from using it or use at your own risk. It was supposed to be smp.SAM(..., encoder_weights=None, weights="sa-1b")