Transformers-Tutorials icon indicating copy to clipboard operation
Transformers-Tutorials copied to clipboard

Multiclass Segmentation using SAM

Open DhruvAwasthi opened this issue 1 year ago • 11 comments

Can segment anything model be used for finetuning on a multi-class segmentation task? Thanks!

DhruvAwasthi avatar Jun 08 '23 23:06 DhruvAwasthi

Did you succeed?

sharonsalabiglossai avatar Jul 05 '23 15:07 sharonsalabiglossai

Does anyone succeed with multiclass?

Zahoor-Ahmad avatar Oct 12 '23 12:10 Zahoor-Ahmad

I am also interested in this.

rafaelagrc avatar Oct 23 '23 13:10 rafaelagrc

+1

agentdr1 avatar Oct 26 '23 12:10 agentdr1

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

NielsRogge avatar Oct 26 '23 15:10 NielsRogge

Hi,

Could you clarify what you mean by multiclass? SAM just takes an image and prompt as input, and generates a mask for it. So if you want to generate several masks, you will need to create various prompts.

Hello @NielsRogge My question is if it is possible to finetune the model to identify different entities given a single prompt?
For example, considering this image, if I select a bounding box containing just the fruit basket as input prompt, how can I have multi-class segmentation that identifies each one of the fruits? I want to do the same but for a more specific case: identification of different structures in medical images.

image

Thank you for your attention :)

rafaelagrc avatar Oct 26 '23 16:10 rafaelagrc

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)
        
sam.load_state_dict(state_dict, strict=False)

TAUIL-Abd-Elilah avatar Feb 04 '24 15:02 TAUIL-Abd-Elilah

@TAUIL-Abd-Elilah how did you load the sam model in the first place?. And how do you change num_multimask_ouputs? It is not explained

jamesheatonrdm avatar Feb 16 '24 10:02 jamesheatonrdm

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)
        
sam.load_state_dict(state_dict, strict=False)

Hi, can you explain ?

cristian-cmyk4 avatar Feb 19 '24 18:02 cristian-cmyk4

Hi, I find a way to finetune segment anything model on a multi-class segmentation task by changing num_multimask_outputs that exist in MaskDecoder to the number of the classes that u want, load the state dictionary of sam, removes this keys 'mask_decoder.mask_tokens.weight', 'mask_decoder.iou_prediction_head.layers.2.weight', 'mask_decoder.iou_prediction_head.layers.2.bias, and then loads the modified state dictionary into sam with relaxed strictness.

state_dict = torch.load(f)
keys_to_remove = ['mask_decoder.mask_tokens.weight',
'mask_decoder.iou_prediction_head.layers.2.weight',
'mask_decoder.iou_prediction_head.layers.2.bias']

for key in keys_to_remove:
    state_dict.pop(key, None)
        
sam.load_state_dict(state_dict, strict=False)

@TAUIL-Abd-Elilah Can you explain a bit more detailed how you did this?

felixvh avatar Feb 29 '24 14:02 felixvh

I solved it after some time. The approach is slightly different from @TAUIL-Abd-Elilah.

I first load a SAM model based on my desired architecture.

# Initializing SAM vision, SAM Q-Former and language model configurations
vision_config = SamVisionConfig()
prompt_encoder_config = SamPromptEncoderConfig()
mask_decoder_config = SamMaskDecoderConfig(num_multimask_outputs=4)
config = SamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
# Initializing a SamConfig with `"facebook/sam-vit-huge"` style configuration
emptyModel = SamModel(config)

This model, however, does not have the trained weights that you might need for fine tuning. Therefore, I load the previously saved weights into my model. Before that I need to adjust mask_tokens such that the sizes fit. And after loading the weights, I need to adjust the again the mask_tokens to the previos one.

size = (4, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

state_dict = torch.load("model_weights.pth")
emptyModel.load_state_dict(state_dict, strict=False)
size = (5, 256)
parameter_tensor = nn.Parameter(torch.rand(size))
emptyModel.mask_decoder.mask_tokens.weight = parameter_tensor

model = emptyModel

Hope this helps the others here!

felixvh avatar Mar 01 '24 13:03 felixvh