sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

No object detected in inference after fine-tuning on custom dataset

Open simoneriggi opened this issue 4 weeks ago • 5 comments

Dear all, I am fine-tuning SAM3 on a custom astronomical dataset formatted in a COCO format. The dataset contains 5 object classes (custom naming), segmentation masks/bboxes and I have added a "noun_phrase" for each annotation. I adapted the roboflow train configuration file, making these changes:

  • set dataset & log paths
  • enable segmentation (loss & metrics)
  • adapted collator to perform runs with gradient accumulation >1
  • add sam3.train.transforms.segmentation.DecodeRle in validation transforms
  • set Slurm job parameters (batch=8, gradacc=8, 4 A100 GPUs)
  • comment some roboflow settings (supercategory, task array)
  • I have added in the code the possibility to freeze the backbone or other components (freeze_cfg config, commented out below)

I fine-tuned for 20 epochs both with all model components free (840M trainable pars) and also with backbone frozen (32.7M trainable pars). I set training configuration (learning rate, etc) to roboflow defaults. Below, I attach the training all loss and eval metrics in the two runs:

--> Full fine-tuning Image Meters: {'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP': 0.6391414438821105, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_50': 0.8176683410036266, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_75': 0.7222655357848473, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_small': 0.6245980936200656, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_medium': 0.5391393291034561, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_large': 0.95, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@1': 0.575984978215531, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@10': 0.7667850931014953, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@100': 0.790897661111249, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_small': 0.7862146778761182, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_medium': 0.686418844156647, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_large': 0.95, 'Losses/val_all_loss': 0, 'Losses/val_default_loss': 0, 'Losses/val_roboflow100_core_loss': 0.0, 'Trainer/where': 0.9997907949790795, 'Trainer/epoch': 19, 'Trainer/steps_val': 97090}

--> Backbone frozen Image Meters: {'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP': 0.5608939549867914, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_50': 0.758768676001815, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_75': 0.6316128247121221, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_small': 0.5385709234878986, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_medium': 0.4265914293746788, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AP_large': 0.95, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@1': 0.5454540602304996, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@10': 0.7218796950563996, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_maxDets@100': 0.7489443844002187, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_small': 0.744400135157136, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_medium': 0.6562509361033513, 'Meters_train/val_roboflow100/detection/coco_eval_bbox_AR_large': 0.95, 'Losses/val_all_loss': 0, 'Losses/val_default_loss': 0, 'Losses/val_roboflow100_core_loss': 0.0, 'Trainer/where': 0.9997907949790795, 'Trainer/epoch': 19, 'Trainer/steps_val': 97090}

As far as I understood from metrics and loss, the model is indeed learning something, although for sure I need to train more and with optimized parameters. Now I would like to run inference on a single image using the fine-tuning checkpoint and the example script https://github.com/facebookresearch/sam3/blob/main/examples/sam3_image_predictor_example.ipynb. However, when I run the inference script on train/eval images using the same noun_phrase prompt and a low confidence score (0.1) no objects are detected.

When I load the model I see a log saying that many model component keys are missing:

loaded [RUN DIR]/checkpoints/checkpoint.pt and found missing and/or unexpected keys: missing_keys=['backbone.vision_backbone.trunk.pos_embed', 'backbone.vision_backbone.trunk.patch_embed.proj.weight', 'backbone.vision_backbone.trunk.blocks.0.norm1.weight', 'backbone.vision_backbone.trunk.blocks.0.norm1.bias', 'backbone.vision_backbone.trunk.blocks.0.attn.freqs_cis', 'backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight', 'backbone.vision_backbone.trunk.blocks.0.attn.qkv.bias', 'backbone.vision_backbone.trunk.blocks.0.attn.proj.weight', 'backbone.vision_backbone.trunk.blocks.0.attn.proj.bias', 'backbone.vision_backbone.trunk.blocks.0.norm2.weight', 'backbone.vision_backbone.trunk.blocks.0.norm2.bias', 'backbone.vision_backbone.trunk.blocks.0.mlp.fc1.weight', ... ...

Could someone give me some hint on what I am doing wrong? Is it a matter of fine-tuning (category embedding, hyperparameters, etc) or how I do the inference (input data normalization/transform) or both?

Thanks a lot.

PS: My config file and inference script are reported below:

CONFIG FILE

# @package _global_
defaults:
  - _self_

# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
  roboflow_vl_100_root: [DATASET ROOT DIR]
  experiment_log_dir: [RUN LOG DIR]
  bpe_path: [SAM PATH]/sam3/assets/bpe_simple_vocab_16e6.txt.gz

#freeze_cfg:
#  backbone: true
#  backbone_blocks: 0      # e.g. freeze first N blocks; 0 = none
#  text_encoder: true
#  transformer: false
#  segmentation_head: false

# Roboflow dataset configuration
roboflow_train:
  num_images: null # Note: This is the number of images used for training. If null, all images are used.

  # Training transforms pipeline
  train_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
        - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
          box_noise_std: 0.1
          box_noise_max: 20
        - _target_: sam3.train.transforms.segmentation.DecodeRle
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes:
            _target_: sam3.train.transforms.basic.get_random_resize_scales
            size: ${scratch.resolution}
            min_size: 480
            rounded: false
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
          size: ${scratch.resolution}
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
    - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
      query_filter:
        _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
        max_num_objects: ${scratch.max_ann_per_img}

  # Validation transforms pipeline
  val_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        # 1) Decode COCO RLE/poly into mask tensors
        - _target_: sam3.train.transforms.segmentation.DecodeRle

        # 2) Resize image + masks
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes: ${scratch.resolution}
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: False

        # 3) Convert to torch tensors
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI

        # 4) Normalize
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}

  # NOTE: Loss to be used for training in case of segmentation
  loss:
     _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
     matcher: ${scratch.matcher}
     o2m_weight: 2.0
     o2m_matcher:
       _target_: sam3.train.matcher.BinaryOneToManyMatcher
       alpha: 0.3
       threshold: 0.4
       topk: 4
     use_o2m_matcher_on_o2m_aux: false
     loss_fns_find:
       - _target_: sam3.train.loss.loss_fns.Boxes
         weight_dict:
           loss_bbox: 5.0
           loss_giou: 2.0
       - _target_: sam3.train.loss.loss_fns.IABCEMdetr
         weak_loss: False
         weight_dict:
           loss_ce: 20.0 # Another option is 100.0
           presence_loss: 20.0
         pos_weight: 10.0 # Another option is 5.0
         alpha: 0.25
         gamma: 2
         use_presence: True  # Change
         pos_focal: false
         pad_n_queries: 200
         pad_scale_pos: 1.0
       - _target_: sam3.train.loss.loss_fns.Masks
         focal_alpha: 0.25
         focal_gamma: 2.0
         weight_dict:
           loss_mask: 200.0
           loss_dice: 10.0
         compute_aux: false
     loss_fn_semantic_seg:
       #_target_: sam3.losses.loss_fns.SemanticSegCriterion
       _target_: sam3.train.loss.loss_fns.SemanticSegCriterion
       presence_head: True
       presence_loss: False  # Change
       focal: True
       focal_alpha: 0.6
       focal_gamma: 2.0
       downsample: False
       weight_dict:
         loss_semantic_seg: 20.0
         loss_semantic_presence: 1.0
         loss_semantic_dice: 30.0
     scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}

# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
  enable_segmentation: True # NOTE: This is the number of queries used for segmentation
  # Model parameters
  d_model: 256
  pos_embed:
    _target_: sam3.model.position_encoding.PositionEmbeddingSine
    num_pos_feats: ${scratch.d_model}
    normalize: true
    scale: null
    temperature: 10000

  # Box processing
  use_presence_eval: True
  original_box_postprocessor:
    _target_: sam3.eval.postprocessors.PostProcessImage
    max_dets_per_img: -1  # infinite detections
    use_original_ids: true
    use_original_sizes_box: true
    use_presence: ${scratch.use_presence_eval}

  # Matcher configuration
  matcher:
    _target_: sam3.train.matcher.BinaryHungarianMatcherV2
    focal: true  # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
    cost_class: 2.0
    cost_bbox: 5.0
    cost_giou: 2.0
    alpha: 0.25
    gamma: 2
    stable: False
  scale_by_find_batch_size: True

  # Image processing parameters
  resolution: 1008
  consistent_transform: False
  max_ann_per_img: 200

  # Normalization parameters
  train_norm_mean: [0.5, 0.5, 0.5]
  train_norm_std: [0.5, 0.5, 0.5]
  val_norm_mean: [0.5, 0.5, 0.5]
  val_norm_std: [0.5, 0.5, 0.5]

  # Training parameters
  num_train_workers: 10
  num_val_workers: 0
  max_data_epochs: 20
  target_epoch_size: 1500
  hybrid_repeats: 1
  context_length: 2
  gather_pred_via_filesys: false

  # Learning rate and scheduler parameters
  lr_scale: 0.1
  lr_transformer: ${times:8e-4,${scratch.lr_scale}}
  lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
  lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
  lrd_vision_backbone: 0.9
  wd: 0.1
  scheduler_timescale: 20
  scheduler_warmup: 20
  scheduler_cooldown: 20

  val_batch_size: 1
  collate_fn_val:
    _target_: sam3.train.data.collator.collate_fn_api
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: roboflow100
    with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

  gradient_accumulation_steps: 8
  train_batch_size: 8

  ## USE THIS WITH GRAD ACC=1
  #collate_fn:
  #  _target_: sam3.train.data.collator.collate_fn_api
  #  _partial_: true
  #  repeats: ${scratch.hybrid_repeats}
  #  dict_key: all
  #  with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

  ## USE THIS WITH GRAD ACC>1
  collate_fn:
    _target_: sam3.train.data.collator.collate_fn_api_with_chunking
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: all
    with_seg_masks: ${scratch.enable_segmentation}
    num_chunks: ${scratch.gradient_accumulation_steps}

# ============================================================================
# Trainer Configuration
# ============================================================================

trainer:

  _target_: sam3.train.trainer.Trainer
  skip_saving_ckpts: false
  empty_gpu_mem_cache_after_eval: True
  skip_first_val: True
  max_epochs: 20
  accelerator: cuda
  seed_value: 123
  val_epoch_freq: 1
  mode: train
  gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}

  distributed:
    backend: nccl
    find_unused_parameters: True
    gradient_as_bucket_view: True

  loss:
    all: ${roboflow_train.loss}
    default:
      _target_: sam3.train.loss.sam3_loss.DummyLoss
    
  data:

    train:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        limit_ids: ${roboflow_train.num_images}
        transforms: ${roboflow_train.train_transforms}
        load_segmentation: ${scratch.enable_segmentation}
        max_ann_per_img: 500000
        multiplier: 1
        max_train_queries: 50000
        max_val_queries: 50000
        training: true
        use_caching: False
        img_folder: ${paths.roboflow_vl_100_root}
        ann_file: ${paths.roboflow_vl_100_root}/dataset_sam_train.json

      shuffle: True
      batch_size: ${scratch.train_batch_size}
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn: ${scratch.collate_fn}

    val:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        load_segmentation: ${scratch.enable_segmentation}
        coco_json_loader:
          _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
          include_negatives: true
          category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
          _partial_: true
        img_folder: ${paths.roboflow_vl_100_root}
        ann_file: ${paths.roboflow_vl_100_root}/dataset_sam_val.json
        transforms: ${roboflow_train.val_transforms}
        max_ann_per_img: 100000
        multiplier: 1
        training: false

      shuffle: False
      batch_size: ${scratch.val_batch_size}
      num_workers: ${scratch.num_val_workers}
      pin_memory: True
      drop_last: False
      collate_fn: ${scratch.collate_fn_val}

  model:
    _target_: sam3.model_builder.build_sam3_image_model
    bpe_path: ${paths.bpe_path}
    device: cpus
    eval_mode: false
    enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
    checkpoint_path: [HF_HOME]/models/huggingface/hub/models--facebook--sam3/snapshots/3c879f39826c281e95690f02c7821c4de09afae7/sam3.pt

  freeze_cfg: ${freeze_cfg}

  meters:
    val:
      roboflow100:
        detection:
          _target_: sam3.eval.coco_writer.PredictionDumper
          iou_type: "bbox"
          dump_dir: ${launcher.experiment_log_dir}/dumps
          merge_predictions: True
          postprocessor: ${scratch.original_box_postprocessor}
          gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
          maxdets: 100
          pred_file_evaluators:
            - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
              gt_path: ${paths.roboflow_vl_100_root}/dataset_sam_val.json
              tide: False
              iou_type: "bbox"

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: sam3.train.optim.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: ${scratch.lrd_vision_backbone}
        apply_to: 'backbone.vision_backbone.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:  # transformer and class_embed
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_transformer}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_vision_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.vision_backbone.*'
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_language_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.language_backbone.*'

      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: ${scratch.wd}
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0  # 0 only last checkpoint is saved.

  logging:
    tensorboard_writer:
      _target_: sam3.train.utils.logger.make_tensorboard_logger
      log_dir: ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    wandb_writer: null
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================

launcher:
  num_nodes: 1
  gpus_per_node: 4
  experiment_log_dir: ${paths.experiment_log_dir}

submitit:
  use_cluster: True
  account: XXX
  partition: XXX
  qos: XXX
  timeout_hour: 96
  name: sam3
  cpus_per_task: 8
  port_range: [10000, 65000]
  
# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================

all_roboflow_supercategories:
  - -grccs
  - zebrasatasturias
  ...
  ...

INFERENCE SCRIPT

import os
import sys
import matplotlib.pyplot as plt
import numpy as np

import sam3
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
from sam3.model.position_encoding import PositionEmbeddingSine
from sam3.eval.postprocessors import PostProcessImage

import torch
import torchvision
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

#########################
##   MAIN
#########################
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
print("sam3_root")
print(sam3_root)
device = "cuda" if torch.cuda.is_available() else "cpu"

# - Build model
print("Loading model ...")
bpe_path = f"{sam3_root}/assets/bpe_simple_vocab_16e6.txt.gz"
checkpoint_path= "[RUN DIR]/checkpoints/checkpoint.pt"

model= build_sam3_image_model(
    bpe_path=bpe_path,
    device=device,
    eval_mode=True,
    checkpoint_path=checkpoint_path,
    load_from_HF=False,
    enable_segmentation=True,
    enable_inst_interactivity=False,
    compile=False,
)

# - Load image
print("Loading image ...")
image_path= "sidelobe0001.png"
image = Image.open(image_path).convert('RGB')
width, height = image.size
print(f"Image width={width}, height={height}")

# - Transform image
print("Transforming image ...")
resize_size= 1008
processor = Sam3Processor(model, resolution=resize_size, confidence_threshold=0.0)
inference_state = processor.set_image(image) ## Looking at the code, image is resized inside set_image method

# - Inference
prompt= "spurious source, imaging artefact, sidelobe"
processor.reset_all_prompts(inference_state)
inference_state= processor.set_text_prompt(state=inference_state, prompt=prompt)

masks, boxes, scores = inference_state["masks"], inference_state["boxes"], inference_state["scores"]

# - Draw results
img0 = Image.open(image_path)
plot_results(img0, inference_state)
plt.show()

simoneriggi avatar Dec 10 '25 12:12 simoneriggi