transformers icon indicating copy to clipboard operation
transformers copied to clipboard

SAM3 Loss

Open aselimc opened this issue 1 month ago • 7 comments

Feature request

Hello,

Can you please add SAM3 to return loss in outputs (if target is provided)?

Motivation

This feature would allow fine-tuning the SAM3 model for custom datasets.

Your contribution

Currently I am not in a position to help due to my limited time, but it would be great to see this feature.

aselimc avatar Dec 05 '25 16:12 aselimc

I'm a noob with this, but are you asking to have the output of Sam3ImageSegmentationOutput at the top level Sam3 modified to include loss just for fine-tuning? If it's just for fine-tuning, I thought we can just override the loss computation in Trainer to use the loss derived from your custom loss function given the targets instead of adding another arg to sam3.

dchou1618 avatar Dec 06 '25 23:12 dchou1618

I would expect something that is as easy to work with (kind of similar to SemanticSegmenterOutput, maybe like MaskSegmenterOutput or UniversalSegmenterOutput) as the following script would work easily.

from transformers import Sam3Model, Sam3Processor

model = Sam3Model.from_pretrained("facebook/sam3")
processor = Sam3Processor.from_pretrained("facebook/sam3")

# ---------------------------------------------------------------------------
# 1. Inference only (no targets)
# Expected: outputs.loss = None
# ---------------------------------------------------------------------------
inputs = processor(
    images=["path/to/image1.png", "path/to/image2.png"],
    return_tensors="pt",
)
outputs = model(**inputs)  # loss: None


# ---------------------------------------------------------------------------
# 2. Detection-only fine-tuning
# Expected loss: L_cls + L_l1 + L_giou (+ Align + DAC auxiliary losses)
# ---------------------------------------------------------------------------
detection_targets = [
    {"boxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": [1, 1]},
    {"boxes": [[5, 15, 25, 35]], "labels": [1]},
]

inputs = processor(
    images=["path/to/image1.png", "path/to/image2.png"],
    detection_targets=detection_targets,
    return_tensors="pt",
)
outputs = model(**inputs)  # loss: detection loss only


# ---------------------------------------------------------------------------
# 3. Segmentation-only fine-tuning
# Expected loss: L_mask_bce + L_mask_dice (+ semantic loss if enabled)
# ---------------------------------------------------------------------------
segmentation_targets = [
    {"masks": ["path/to/mask1.png", "path/to/mask2.png"]},
    {"masks": ["path/to/mask3.png"]},
]

inputs = processor(
    images=["path/to/image1.png", "path/to/image2.png"],
    segmentation_targets=segmentation_targets,
    return_tensors="pt",
)
outputs = model(**inputs)  # loss: segmentation loss only


# ---------------------------------------------------------------------------
# 4. Combined detection + segmentation fine-tuning
# Expected loss: detection loss + segmentation loss
# ---------------------------------------------------------------------------
combined_targets = [
    {"boxes": [[10, 20, 30, 40]], "labels": [1], "masks": ["mask1.png"]},
    {"boxes": [[50, 60, 70, 80]], "labels": [1], "masks": ["mask2.png"]},
]

inputs = processor(
    images=["path/to/image1.png", "path/to/image2.png"],
    detection_targets=combined_targets,
    segmentation_targets=combined_targets,
    return_tensors="pt",
)
outputs = model(**inputs)  # loss: full det + seg loss

aselimc avatar Dec 07 '25 09:12 aselimc

Hi @aselimc,

I'd like to help with this feature! I'm a contributor who recently submitted fixes for Conditional DETR (#42679) and Flash Attention 4 detection (#42405).

I can implement the loss functions needed for SAM3 fine-tuning, following the pattern used in Mask2Former/MaskFormer:

  1. dice_loss - For mask prediction loss
  2. sigmoid_cross_entropy_loss - For binary cross-entropy on masks
  3. pair_wise_dice_loss / pair_wise_sigmoid_cross_entropy_loss - For Hungarian matching cost computation

Here's a draft based on Mask2Former's implementation:

def dice_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
    """
    Compute the DICE loss for mask prediction.
    
    Args:
        inputs: Predicted mask logits of shape (num_masks, H, W)
        labels: Ground truth binary masks of shape (num_masks, H, W)
        num_masks: Number of masks for normalization
    
    Returns:
        Scalar dice loss
    """
    probs = inputs.sigmoid().flatten(1)
    numerator = 2 * (probs * labels.flatten(1)).sum(-1)
    denominator = probs.sum(-1) + labels.flatten(1).sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_masks


def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
    """
    Compute sigmoid cross-entropy loss for mask prediction.
    
    Args:
        inputs: Predicted mask logits of shape (num_masks, H, W)
        labels: Ground truth binary masks of shape (num_masks, H, W)
        num_masks: Number of masks for normalization
    
    Returns:
        Scalar cross-entropy loss
    """
    loss = F.binary_cross_entropy_with_logits(inputs.flatten(1), labels.flatten(1), reduction="none")
    return loss.mean(1).sum() / num_masks## Scope

This would be a partial contribution - just the loss functions. The full integration would still require:

  • [ ] Sam3Loss class with Hungarian matching
  • [ ] Modify forward() to accept labels parameter
  • [ ] Update Sam3ImageSegmentationOutput to include loss field
  • [ ] Add documentation and tests

Would this partial contribution be helpful? I can submit a PR with the loss functions, and the full integration can be built on top.

cc @HuggingFace maintainers - is this the right approach for adding loss to SAM3?

vasanthrpjan1-boop avatar Dec 08 '25 14:12 vasanthrpjan1-boop

cc @molbap @yonigozlan @nielsrogge - does the existing SAM3 not include loss output if labels are provided?

Rocketknight1 avatar Dec 08 '25 15:12 Rocketknight1

@Rocketknight1 , the current implementation only returns Sam3ImageSegmentationOutput (you can find the relevant code down below).

Forward function definition of Sam3Model

def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        vision_embeds: Optional[Sam3VisionEncoderOutput] = None,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        text_embeds: Optional[torch.FloatTensor] = None,
        input_boxes: Optional[torch.FloatTensor] = None,
        input_boxes_labels: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> Sam3ImageSegmentationOutput:

and Sam3ImageSegmentationOutput dataclass (removed docstring to simplify)

@dataclass
@auto_docstring
class Sam3ImageSegmentationOutput(ModelOutput):
    pred_masks: torch.FloatTensor = None
    pred_boxes: torch.FloatTensor = None
    pred_logits: Optional[torch.FloatTensor] = None
    presence_logits: Optional[torch.FloatTensor] = None
    semantic_seg: Optional[torch.FloatTensor] = None
    decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
    decoder_reference_boxes: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
    vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
    vision_attentions: Optional[tuple[torch.FloatTensor]] = None
    detr_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
    detr_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
    mask_decoder_attentions: Optional[tuple[torch.FloatTensor]] = None

aselimc avatar Dec 08 '25 18:12 aselimc

Hello @aselimc , @dchou1618 and @vasanthrpjan1-boop ! Thanks for opening this issue @aselimc . Indeed sam3 doesn't support finetuning/doesn't have a loss out-of-the-box in transformers at the moment. However, seeing that there's quite a bit of demand for this, it would be great to support having sam3 return a loss! I won't have much bandwidth for that at the moment, but happy to review a PR. However if we add this, let's try to be as close as possible to the loss defined in the sam3 paper/in the original repo , which might be more involved than what @vasanthrpjan1-boop described.

yonigozlan avatar Dec 08 '25 20:12 yonigozlan

@yonigozlan Thanks for the reply!

Unfortunately, I don’t have much spare time to tackle this myself right now as well, but I did a quick skim of the codebase. I noticed there isn't currently a generic output class like UniversalSegmenterOutput or MaskSegmenterOutput (similar to the existing SemanticSegmenterOutput) to facilitate the training of mask generation models.

My technical suggestions for whoever picks this up:

  1. Define a Shared Output Class: Create a dataclass for general mask segmentation models that supports both masks and bounding boxes.

  2. Implement SAM3 Losses: This is complex, as SAM3 utilizes a combination of focal + mask_bce + mask_dice + box_giou + box_l1 + presence losses.

  3. Fine-tuning Logic: The implementation needs to be flexible. As noted in the paper, fine-tuning for detection tasks often requires disabling specific mask losses, so the loss computation should be conditional.

Regarding implementation: While an open-source contributor could handle this, given the complexity of the loss landscape here, my personal opinion is that a core HF maintainer (or potentially someone from the author team) should probably lead the implementation of the losses to ensure correctness.

aselimc avatar Dec 08 '25 21:12 aselimc

Interesting. cc @merveenoyan as well here for awareness on everything vision finetuning

molbap avatar Dec 15 '25 14:12 molbap