SAM3 Loss
Feature request
Hello,
Can you please add SAM3 to return loss in outputs (if target is provided)?
Motivation
This feature would allow fine-tuning the SAM3 model for custom datasets.
Your contribution
Currently I am not in a position to help due to my limited time, but it would be great to see this feature.
I'm a noob with this, but are you asking to have the output of Sam3ImageSegmentationOutput at the top level Sam3 modified to include loss just for fine-tuning? If it's just for fine-tuning, I thought we can just override the loss computation in Trainer to use the loss derived from your custom loss function given the targets instead of adding another arg to sam3.
I would expect something that is as easy to work with (kind of similar to SemanticSegmenterOutput, maybe like MaskSegmenterOutput or UniversalSegmenterOutput) as the following script would work easily.
from transformers import Sam3Model, Sam3Processor
model = Sam3Model.from_pretrained("facebook/sam3")
processor = Sam3Processor.from_pretrained("facebook/sam3")
# ---------------------------------------------------------------------------
# 1. Inference only (no targets)
# Expected: outputs.loss = None
# ---------------------------------------------------------------------------
inputs = processor(
images=["path/to/image1.png", "path/to/image2.png"],
return_tensors="pt",
)
outputs = model(**inputs) # loss: None
# ---------------------------------------------------------------------------
# 2. Detection-only fine-tuning
# Expected loss: L_cls + L_l1 + L_giou (+ Align + DAC auxiliary losses)
# ---------------------------------------------------------------------------
detection_targets = [
{"boxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": [1, 1]},
{"boxes": [[5, 15, 25, 35]], "labels": [1]},
]
inputs = processor(
images=["path/to/image1.png", "path/to/image2.png"],
detection_targets=detection_targets,
return_tensors="pt",
)
outputs = model(**inputs) # loss: detection loss only
# ---------------------------------------------------------------------------
# 3. Segmentation-only fine-tuning
# Expected loss: L_mask_bce + L_mask_dice (+ semantic loss if enabled)
# ---------------------------------------------------------------------------
segmentation_targets = [
{"masks": ["path/to/mask1.png", "path/to/mask2.png"]},
{"masks": ["path/to/mask3.png"]},
]
inputs = processor(
images=["path/to/image1.png", "path/to/image2.png"],
segmentation_targets=segmentation_targets,
return_tensors="pt",
)
outputs = model(**inputs) # loss: segmentation loss only
# ---------------------------------------------------------------------------
# 4. Combined detection + segmentation fine-tuning
# Expected loss: detection loss + segmentation loss
# ---------------------------------------------------------------------------
combined_targets = [
{"boxes": [[10, 20, 30, 40]], "labels": [1], "masks": ["mask1.png"]},
{"boxes": [[50, 60, 70, 80]], "labels": [1], "masks": ["mask2.png"]},
]
inputs = processor(
images=["path/to/image1.png", "path/to/image2.png"],
detection_targets=combined_targets,
segmentation_targets=combined_targets,
return_tensors="pt",
)
outputs = model(**inputs) # loss: full det + seg loss
Hi @aselimc,
I'd like to help with this feature! I'm a contributor who recently submitted fixes for Conditional DETR (#42679) and Flash Attention 4 detection (#42405).
I can implement the loss functions needed for SAM3 fine-tuning, following the pattern used in Mask2Former/MaskFormer:
-
dice_loss- For mask prediction loss -
sigmoid_cross_entropy_loss- For binary cross-entropy on masks -
pair_wise_dice_loss/pair_wise_sigmoid_cross_entropy_loss- For Hungarian matching cost computation
Here's a draft based on Mask2Former's implementation:
def dice_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
"""
Compute the DICE loss for mask prediction.
Args:
inputs: Predicted mask logits of shape (num_masks, H, W)
labels: Ground truth binary masks of shape (num_masks, H, W)
num_masks: Number of masks for normalization
Returns:
Scalar dice loss
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels.flatten(1)).sum(-1)
denominator = probs.sum(-1) + labels.flatten(1).sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
"""
Compute sigmoid cross-entropy loss for mask prediction.
Args:
inputs: Predicted mask logits of shape (num_masks, H, W)
labels: Ground truth binary masks of shape (num_masks, H, W)
num_masks: Number of masks for normalization
Returns:
Scalar cross-entropy loss
"""
loss = F.binary_cross_entropy_with_logits(inputs.flatten(1), labels.flatten(1), reduction="none")
return loss.mean(1).sum() / num_masks## Scope
This would be a partial contribution - just the loss functions. The full integration would still require:
- [ ]
Sam3Lossclass with Hungarian matching - [ ] Modify
forward()to acceptlabelsparameter - [ ] Update
Sam3ImageSegmentationOutputto includelossfield - [ ] Add documentation and tests
Would this partial contribution be helpful? I can submit a PR with the loss functions, and the full integration can be built on top.
cc @HuggingFace maintainers - is this the right approach for adding loss to SAM3?
cc @molbap @yonigozlan @nielsrogge - does the existing SAM3 not include loss output if labels are provided?
@Rocketknight1 , the current implementation only returns Sam3ImageSegmentationOutput (you can find the relevant code down below).
Forward function definition of Sam3Model
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
vision_embeds: Optional[Sam3VisionEncoderOutput] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
text_embeds: Optional[torch.FloatTensor] = None,
input_boxes: Optional[torch.FloatTensor] = None,
input_boxes_labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Sam3ImageSegmentationOutput:
and Sam3ImageSegmentationOutput dataclass (removed docstring to simplify)
@dataclass
@auto_docstring
class Sam3ImageSegmentationOutput(ModelOutput):
pred_masks: torch.FloatTensor = None
pred_boxes: torch.FloatTensor = None
pred_logits: Optional[torch.FloatTensor] = None
presence_logits: Optional[torch.FloatTensor] = None
semantic_seg: Optional[torch.FloatTensor] = None
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
decoder_reference_boxes: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
vision_attentions: Optional[tuple[torch.FloatTensor]] = None
detr_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
detr_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
mask_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
Hello @aselimc , @dchou1618 and @vasanthrpjan1-boop ! Thanks for opening this issue @aselimc . Indeed sam3 doesn't support finetuning/doesn't have a loss out-of-the-box in transformers at the moment. However, seeing that there's quite a bit of demand for this, it would be great to support having sam3 return a loss! I won't have much bandwidth for that at the moment, but happy to review a PR. However if we add this, let's try to be as close as possible to the loss defined in the sam3 paper/in the original repo , which might be more involved than what @vasanthrpjan1-boop described.
@yonigozlan Thanks for the reply!
Unfortunately, I don’t have much spare time to tackle this myself right now as well, but I did a quick skim of the codebase. I noticed there isn't currently a generic output class like UniversalSegmenterOutput or MaskSegmenterOutput (similar to the existing SemanticSegmenterOutput) to facilitate the training of mask generation models.
My technical suggestions for whoever picks this up:
-
Define a Shared Output Class: Create a dataclass for general mask segmentation models that supports both masks and bounding boxes.
-
Implement SAM3 Losses: This is complex, as SAM3 utilizes a combination of focal + mask_bce + mask_dice + box_giou + box_l1 + presence losses.
-
Fine-tuning Logic: The implementation needs to be flexible. As noted in the paper, fine-tuning for detection tasks often requires disabling specific mask losses, so the loss computation should be conditional.
Regarding implementation: While an open-source contributor could handle this, given the complexity of the loss landscape here, my personal opinion is that a core HF maintainer (or potentially someone from the author team) should probably lead the implementation of the losses to ensure correctness.
Interesting. cc @merveenoyan as well here for awareness on everything vision finetuning