SAMed
SAMed copied to clipboard
About Semantic Segmentation
Hi, Thank you for your outstanding work! How did you adjust the mask decoder to achieve the final language segmentation? Can you please indicate the exact location of the code, it is a bit difficult to pinpoint from the source code. Thanks a lot!
@zzzyzh the original SAM uses multimask outputs to describe objects at different levels of detail. This repository seems to retrain SAM's multimask outputs such that it can predict semantically. No modifications need to be made, other than calculating loss between labels and predictions.
@25benjaminli You're right. But I don't see how he implements it in the code, can you point me to the exact code location please?
@zzzyzh minimal modifications need to be made on the code side.
In mask decoder's forward
function, you can find the following code:
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
# if multimask_output:
# mask_slice = slice(1, None)
# else:
# mask_slice = slice(0, 1)
# masks = masks[:, mask_slice, :, :]
# iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
Then, during the training phase, the following code is run:
outputs = model(image_batch, multimask_output, args.img_size)
loss, loss_ce, loss_dice = calc_loss(outputs, low_res_label_batch, ce_loss, dice_loss, args.dice_param)
Because the multimask outputs are already supposed to predict different levels of detail, I think adapting it for semantic segmentation is kind of like a transfer learning task.
@25benjaminli Thanks! I'll try it!