sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

Sam 3 Segmentation Fine Tuning

Open MohitDAngrish opened this issue 2 weeks ago • 5 comments

Does sam3 not support segmentation fine tuning at this moment?

When I use the following config:

@package global

defaults:

  • self

============================================================================

Paths Configuration (Chage this to your own paths)

============================================================================

paths: roboflow_vl_100_root: /home/ubuntu/sam/custom-root experiment_log_dir: /home/ubuntu/sam/custom-root/logs bpe_path: /home/ubuntu/sam/sam3/assets/bpe_simple_vocab_16e6.txt.gz

Roboflow dataset configuration

roboflow_train: num_images: 100 # Note: This is the number of images used for training. If null, all images are used. supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}

Training transforms pipeline

train_transforms: - target: sam3.train.transforms.basic_for_api.ComposeAPI transforms: - target: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries query_filter: target: sam3.train.transforms.filter_query_transforms.FilterCrowds - target: sam3.train.transforms.point_sampling.RandomizeInputBbox box_noise_std: 0.1 box_noise_max: 20 - target: sam3.train.transforms.segmentation.DecodeRle - target: sam3.train.transforms.basic_for_api.RandomResizeAPI sizes: target: sam3.train.transforms.basic.get_random_resize_scales size: ${scratch.resolution} min_size: 480 rounded: false max_size: target: sam3.train.transforms.basic.get_random_resize_max_size size: ${scratch.resolution} square: true consistent_transform: ${scratch.consistent_transform} - target: sam3.train.transforms.basic_for_api.PadToSizeAPI size: ${scratch.resolution} consistent_transform: ${scratch.consistent_transform} - target: sam3.train.transforms.basic_for_api.ToTensorAPI - target: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries query_filter: target: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets - target: sam3.train.transforms.basic_for_api.NormalizeAPI mean: ${scratch.train_norm_mean} std: ${scratch.train_norm_std} - target: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries query_filter: target: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets - target: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries query_filter: target: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut max_num_objects: ${scratch.max_ann_per_img}

Validation transforms pipeline

val_transforms: - target: sam3.train.transforms.basic_for_api.ComposeAPI transforms: - target: sam3.train.transforms.basic_for_api.RandomResizeAPI sizes: ${scratch.resolution} max_size: target: sam3.train.transforms.basic.get_random_resize_max_size size: ${scratch.resolution} square: true consistent_transform: False - target: sam3.train.transforms.basic_for_api.ToTensorAPI - target: sam3.train.transforms.basic_for_api.NormalizeAPI mean: ${scratch.train_norm_mean} std: ${scratch.train_norm_std}

loss config (no mask loss)

loss:

target: sam3.train.loss.sam3_loss.Sam3LossWrapper

matcher: ${scratch.matcher}

o2m_weight: 2.0

o2m_matcher:

target: sam3.train.matcher.BinaryOneToManyMatcher

alpha: 0.3

threshold: 0.4

topk: 4

use_o2m_matcher_on_o2m_aux: false # Another option is true

loss_fns_find:

- target: sam3.train.loss.loss_fns.Boxes

weight_dict:

loss_bbox: 5.0

loss_giou: 2.0

- target: sam3.train.loss.loss_fns.IABCEMdetr

weak_loss: False

weight_dict:

loss_ce: 20.0 # Another option is 100.0

presence_loss: 20.0

pos_weight: 10.0 # Another option is 5.0

alpha: 0.25

gamma: 2

use_presence: True # Change

pos_focal: false

pad_n_queries: 200

pad_scale_pos: 1.0

loss_fn_semantic_seg: null

scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}

NOTE: Loss to be used for training in case of segmentation

loss: target: sam3.train.loss.sam3_loss.Sam3LossWrapper matcher: ${scratch.matcher} o2m_weight: 2.0 o2m_matcher: target: sam3.train.matcher.BinaryOneToManyMatcher alpha: 0.3 threshold: 0.4 topk: 4 use_o2m_matcher_on_o2m_aux: false loss_fns_find: - target: sam3.train.loss.loss_fns.Boxes weight_dict: loss_bbox: 5.0 loss_giou: 2.0 - target: sam3.train.loss.loss_fns.IABCEMdetr weak_loss: False weight_dict: loss_ce: 20.0 # Another option is 100.0 presence_loss: 20.0 pos_weight: 10.0 # Another option is 5.0 alpha: 0.25 gamma: 2 use_presence: True # Change pos_focal: false pad_n_queries: 200 pad_scale_pos: 1.0 - target: sam3.train.loss.loss_fns.Masks focal_alpha: 0.25 focal_gamma: 2.0 weight_dict: loss_mask: 200.0 loss_dice: 10.0 compute_aux: false loss_fn_semantic_seg: target: sam3.train.loss.loss_fns.SemanticSegCriterion presence_head: True presence_loss: False # Change focal: True focal_alpha: 0.6 focal_gamma: 2.0 downsample: False weight_dict: loss_semantic_seg: 20.0 loss_semantic_presence: 1.0 loss_semantic_dice: 30.0 scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}

============================================================================

Different helper parameters and functions

============================================================================

scratch: enable_segmentation: True # NOTE: This is the number of queries used for segmentation

Model parameters

d_model: 256 pos_embed: target: sam3.model.position_encoding.PositionEmbeddingSine num_pos_feats: ${scratch.d_model} normalize: true scale: null temperature: 10000

Box processing

use_presence_eval: True original_box_postprocessor: target: sam3.eval.postprocessors.PostProcessImage max_dets_per_img: -1 # infinite detections use_original_ids: true use_original_sizes_box: true use_presence: ${scratch.use_presence_eval}

Matcher configuration

matcher: target: sam3.train.matcher.BinaryHungarianMatcherV2 focal: true # with focal: true it is equivalent to BinaryFocalHungarianMatcher cost_class: 2.0 cost_bbox: 5.0 cost_giou: 2.0 alpha: 0.25 gamma: 2 stable: False scale_by_find_batch_size: True

Image processing parameters

resolution: 1008 consistent_transform: False max_ann_per_img: 200

Normalization parameters

train_norm_mean: [0.5, 0.5, 0.5] train_norm_std: [0.5, 0.5, 0.5] val_norm_mean: [0.5, 0.5, 0.5] val_norm_std: [0.5, 0.5, 0.5]

Training parameters

num_train_workers: 2 num_val_workers: 1 max_data_epochs: 20 target_epoch_size: 1500 hybrid_repeats: 1 context_length: 2 gather_pred_via_filesys: false

Learning rate and scheduler parameters

lr_scale: 0.1 lr_transformer: ${times:8e-4,${scratch.lr_scale}} lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}} lr_language_backbone: ${times:5e-5,${scratch.lr_scale}} lrd_vision_backbone: 0.9 wd: 0.1 scheduler_timescale: 20 scheduler_warmup: 20 scheduler_cooldown: 20

val_batch_size: 1 collate_fn_val: target: sam3.train.data.collator.collate_fn_api partial: true repeats: ${scratch.hybrid_repeats} dict_key: roboflow100 with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

gradient_accumulation_steps: 1 train_batch_size: 1 collate_fn: target: sam3.train.data.collator.collate_fn_api partial: true repeats: ${scratch.hybrid_repeats} dict_key: all with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

============================================================================

Trainer Configuration

============================================================================

trainer: target: sam3.train.trainer.Trainer skip_saving_ckpts: False empty_gpu_mem_cache_after_eval: True skip_first_val: False max_epochs: 20 accelerator: cuda seed_value: 123 val_epoch_freq: 1 mode: train gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}

distributed: backend: nccl find_unused_parameters: True gradient_as_bucket_view: True

loss: all: ${roboflow_train.loss} default: target: sam3.train.loss.sam3_loss.DummyLoss

data: train: target: sam3.train.data.torch_dataset.TorchDataset dataset: target: sam3.train.data.sam3_image_dataset.Sam3ImageDataset limit_ids: ${roboflow_train.num_images} transforms: ${roboflow_train.train_transforms} load_segmentation: ${scratch.enable_segmentation} max_ann_per_img: 500000 multiplier: 1 max_train_queries: 50000 max_val_queries: 50000 training: true use_caching: False img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/ ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json

  shuffle: True
  batch_size: ${scratch.train_batch_size}
  num_workers: ${scratch.num_train_workers}
  pin_memory: True
  drop_last: True
  collate_fn: ${scratch.collate_fn}

val:
  _target_: sam3.train.data.torch_dataset.TorchDataset
  dataset:
    _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
    load_segmentation: ${scratch.enable_segmentation}
    coco_json_loader:
      _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
      include_negatives: true
      category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
      _partial_: true
    img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/val/
    ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/val/_annotations.coco.json
    transforms: ${roboflow_train.val_transforms}
    max_ann_per_img: 100000
    multiplier: 1
    training: false

  shuffle: False
  batch_size: ${scratch.val_batch_size}
  num_workers: ${scratch.num_val_workers}
  pin_memory: True
  drop_last: False
  collate_fn: ${scratch.collate_fn_val}

model: target: sam3.model_builder.build_sam3_image_model bpe_path: ${paths.bpe_path} device: cpus eval_mode: False enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.

meters: val: roboflow100: segmentation: target: sam3.eval.coco_writer.PredictionDumper iou_type: "segm" dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory} merge_predictions: True postprocessor: ${scratch.original_box_postprocessor} gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} maxdets: 800 pred_file_evaluators: - target: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/val/_annotations.coco.json tide: False iou_type: "segm"

optim: amp: enabled: True amp_dtype: bfloat16

optimizer:
  _target_: torch.optim.AdamW

gradient_clip:
  _target_: sam3.train.optim.optimizer.GradientClipper
  max_norm: 0.1
  norm_type: 2

param_group_modifiers:
  - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
    _partial_: True
    layer_decay_value: ${scratch.lrd_vision_backbone}
    apply_to: 'backbone.vision_backbone.trunk'
    overrides:
      - pattern: '*pos_embed*'
        value: 1.0

options:
  lr:
    - scheduler:  # transformer and class_embed
        _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
        base_lr: ${scratch.lr_transformer}
        timescale: ${scratch.scheduler_timescale}
        warmup_steps: ${scratch.scheduler_warmup}
        cooldown_steps: ${scratch.scheduler_cooldown}
    - scheduler:
        _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
        base_lr: ${scratch.lr_vision_backbone}
        timescale: ${scratch.scheduler_timescale}
        warmup_steps: ${scratch.scheduler_warmup}
        cooldown_steps: ${scratch.scheduler_cooldown}
      param_names:
        - 'backbone.vision_backbone.*'
    - scheduler:
        _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
        base_lr: ${scratch.lr_language_backbone}
        timescale: ${scratch.scheduler_timescale}
        warmup_steps: ${scratch.scheduler_warmup}
        cooldown_steps: ${scratch.scheduler_cooldown}
      param_names:
        - 'backbone.language_backbone.*'

  weight_decay:
    - scheduler:
        _target_: fvcore.common.param_scheduler.ConstantParamScheduler
        value: ${scratch.wd}
    - scheduler:
        _target_: fvcore.common.param_scheduler.ConstantParamScheduler
        value: 0.0
      param_names:
        - '*bias*'
      module_cls_names: ['torch.nn.LayerNorm']

checkpoint: save_dir: ${launcher.experiment_log_dir}/checkpoints save_freq: 5 # 0 only last checkpoint is saved.

logging: tensorboard_writer: target: sam3.train.utils.logger.make_tensorboard_logger log_dir: ${launcher.experiment_log_dir}/tensorboard flush_secs: 120 should_log: True wandb_writer: null log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory} log_freq: 10

============================================================================

Launcher and Submitit Configuration

============================================================================

launcher: num_nodes: 1 gpus_per_node: 1 experiment_log_dir: ${paths.experiment_log_dir} multiprocessing_context: forkserver

submitit: account: null partition: null qos: null timeout_hour: 72 use_cluster: True cpus_per_task: 10 port_range: [10000, 65000] constraint: null

Uncomment for job array configuration

job_array: num_tasks: 1 task_index: 0

============================================================================

Available Roboflow Supercategories (for reference)

============================================================================

all_roboflow_supercategories:

  • z22

The code fail with predictions has no masks key during validation. But during training time I am observing that we do have train segmentation loss also present. {"Losses/train_all_loss": 291.40562232284805, "Losses/train_default_loss": 0, "Losses/train_all_loss_bbox": 0.04907124418765307, "Losses/train_all_loss_giou": 0.3097055367380381, "Losses/train_all_loss_bbox_o2m": 0.25392176304012537, "Losses/train_all_loss_giou_o2m": 1.178743647634983, "Losses/train_all_loss_ce": 0.017840963501075748, "Losses/train_all_ce_f1": 0.23937933709472417, "Losses/train_all_presence_loss": 0.1593843568693046, "Losses/train_all_presence_dec_acc": 0.9228571724891662, "Losses/train_all_loss_ce_o2m": 0.09227114700712263, "Losses/train_all_ce_f1_o2m": 0.4906483779847622, "Losses/train_all_presence_loss_o2m": 0.0, "Losses/train_all_presence_dec_acc_o2m": 0.0, "Losses/train_all_loss_mask": 0.003172511872544419, "Losses/train_all_loss_dice": 0.29496058978140355, "Losses/train_all_loss_mask_o2m": 0.02098744154907763, "Losses/train_all_loss_dice_o2m": 1.1210279282927513, "Losses/train_all_loss_bbox_aux_0": 0.06682779693976044, "Losses/train_all_loss_giou_aux_0": 0.4081690326333046, "Losses/train_all_loss_bbox_aux_0_o2m": 0.14454417156055568, "Losses/train_all_loss_giou_aux_0_o2m": 0.8131714349985123, "Losses/train_all_loss_ce_aux_0": 0.016816309143323452, "Losses/train_all_ce_f1_aux_0": 0.2018275634199381, "Losses/train_all_presence_loss_aux_0": 0.15791587307904592, "Losses/train_all_presence_dec_acc_aux_0": 0.9207143145799637, "Losses/train_all_loss_ce_aux_0_o2m": 0.03715709789190441, "Losses/train_all_ce_f1_aux_0_o2m": 0.22506245568394662, "Losses/train_all_presence_loss_aux_0_o2m": 0.0, "Losses/train_all_presence_dec_acc_aux_0_o2m": 0.0, "Losses/train_all_loss_bbox_aux_1": 0.05492462095804512, "Losses/train_all_loss_giou_aux_1": 0.33395033814013003, "Losses/train_all_loss_bbox_aux_1_o2m": 0.12702802307903766, "Losses/train_all_loss_giou_aux_1_o2m": 0.7043332879245281, "Losses/train_all_loss_ce_aux_1": 0.01780572710733395, "Losses/train_all_ce_f1_aux_1": 0.20777562160044907, "Losses/train_all_presence_loss_aux_1": 0.15562092261569888, "Losses/train_all_presence_dec_acc_aux_1": 0.9200000303983689, "Losses/train_all_loss_ce_aux_1_o2m": 0.03819340365938842, "Losses/train_all_ce_f1_aux_1_o2m": 0.2233561248332262, "Losses/train_all_presence_loss_aux_1_o2m": 0.0, "Losses/train_all_presence_dec_acc_aux_1_o2m": 0.0, "Losses/train_all_loss_bbox_aux_2": 0.05241847197525203, "Losses/train_all_loss_giou_aux_2": 0.3214942040294409, "Losses/train_all_loss_bbox_aux_2_o2m": 0.12400685776025057, "Losses/train_all_loss_giou_aux_2_o2m": 0.6849668900668621, "Losses/train_all_loss_ce_aux_2": 0.01778830372611992, "Losses/train_all_ce_f1_aux_2": 0.23926191195845603, "Losses/train_all_presence_loss_aux_2": 0.15706820492800033, "Losses/train_all_presence_dec_acc_aux_2": 0.9235714572668076, "Losses/train_all_loss_ce_aux_2_o2m": 0.03885760723147541, "Losses/train_all_ce_f1_aux_2_o2m": 0.23316837422549724, "Losses/train_all_presence_loss_aux_2_o2m": 0.0, "Losses/train_all_presence_dec_acc_aux_2_o2m": 0.0, "Losses/train_all_loss_bbox_aux_3": 0.05105132736265659, "Losses/train_all_loss_giou_aux_3": 0.31483356446027755, "Losses/train_all_loss_bbox_aux_3_o2m": 0.11894042745232582, "Losses/train_all_loss_giou_aux_3_o2m": 0.6671358375251293, "Losses/train_all_loss_ce_aux_3": 0.01781949920405168, "Losses/train_all_ce_f1_aux_3": 0.2319374280422926, "Losses/train_all_presence_loss_aux_3": 0.1578532818506028, "Losses/train_all_presence_dec_acc_aux_3": 0.9228571724891662, "Losses/train_all_loss_ce_aux_3_o2m": 0.039612381584011015, "Losses/train_all_ce_f1_aux_3_o2m": 0.22166858714073898, "Losses/train_all_presence_loss_aux_3_o2m": 0.0, "Losses/train_all_presence_dec_acc_aux_3_o2m": 0.0, "Losses/train_all_loss_bbox_aux_4": 0.049818813037127256, "Losses/train_all_loss_giou_aux_4": 0.311868184953928, "Losses/train_all_loss_bbox_aux_4_o2m": 0.11900744654238224, "Losses/train_all_loss_giou_aux_4_o2m": 0.6695742666721344, "Losses/train_all_loss_ce_aux_4": 0.01783945863484405, "Losses/train_all_ce_f1_aux_4": 0.2419732666015625, "Losses/train_all_presence_loss_aux_4": 0.15830876644788078, "Losses/train_all_presence_dec_acc_aux_4": 0.9214286017417908, "Losses/train_all_loss_ce_aux_4_o2m": 0.03997141381725669, "Losses/train_all_ce_f1_aux_4_o2m": 0.22639159649610519, "Losses/train_all_presence_loss_aux_4_o2m": 0.0, "Losses/train_all_presence_dec_acc_aux_4_o2m": 0.0, "Losses/train_all_core_loss": 291.40562232284805, "Losses/train_all_loss_semantic_presence": 0.0, "Losses/train_all_presence_acc": 0.0, "Losses/train_all_loss_semantic_seg": 0.03686944677960127, "Losses/train_all_loss_semantic_dice": 0.3965323495678604, "Losses/train_all_miou_semantic_seg": 0.5273188393190503, "Trainer/where": 0.0495, "Trainer/epoch": 0, "Trainer/steps_train": 100}

Do I need to change something in the meter attribute in the config?

MohitDAngrish avatar Dec 11 '25 05:12 MohitDAngrish