sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

Recommended resolution & min_size settings for lower GPU memory usage while fine-tuning

Open Joey-S-Liu opened this issue 3 months ago • 2 comments

Dear authors,

I would like to change the training resolution (currently 1008) and the min_size (currently 480) to reduce GPU memory usage. However, due to the complexity of the transform pipeline, I have not been able to get the code to run correctly after modifying these values.

Could you please suggest some alternative combinations of resolution and min_size that are known to work with the current transform implementation?

Thank you very much!

Joey-S-Liu avatar Nov 28 '25 07:11 Joey-S-Liu

@Joey-S-Liu Have you solved it?

csqqlee avatar Dec 04 '25 13:12 csqqlee

Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.

BahMoh avatar Dec 18 '25 15:12 BahMoh

Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.

@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?

meng-yuan321 avatar Dec 24 '25 03:12 meng-yuan321

Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.

@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?

Using a higher batch size (e.g. batch_size=16) in my case, resulted in empty train dataloader. I think my problem was the small num_images hyper-parameter (e.g. num_images=20), in conjunction with bigger batch size and drop_last = True, caused the train loaded to be empty. This was the problem, and I was happy that I was not getting OOM error:)

BahMoh avatar Dec 24 '25 07:12 BahMoh

Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.

@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?

This is the configuration I used for trining (at least this does not throw errors other than OOM), my taks is segmentation:

# @package _global_
defaults:
  - _self_

# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
  roboflow_vl_100_root: /kaggle/working
  experiment_log_dir: /kaggle/working/log
  bpe_path: /kaggle/working/sam3/assets/bpe_simple_vocab_16e6.txt.gz # This should be under assets/bpe_simple_vocab_16e6.txt.gz

# Roboflow dataset configuration
roboflow_train:
  num_images: 50 # Note: This is the number of images used for training. If null, all images are used.
  supercategory: Dental_AI_3  #${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}

  # Training transforms pipeline
  # train_transforms:
  #   - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
  #     transforms:
  #       - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
  #         sizes: ${scratch.resolution}
  #         max_size:
  #           _target_: sam3.train.transforms.basic.get_random_resize_max_size
  #           size: ${scratch.resolution}
  #         square: true
  #         consistent_transform: False
  #       - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
  #       - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
  #         mean: ${scratch.train_norm_mean}
  #         std: ${scratch.train_norm_std}
  train_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
        - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
          box_noise_std: 0.1
          box_noise_max: 20
        - _target_: sam3.train.transforms.segmentation.DecodeRle
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes:
            _target_: sam3.train.transforms.basic.get_random_resize_scales
            size: ${scratch.resolution}
            min_size: 480
            rounded: false
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
          size: ${scratch.resolution}
          consistent_transform: ${scratch.consistent_transform}
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}
        - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
          query_filter:
            _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
    - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
      query_filter:
        _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
        max_num_objects: ${scratch.max_ann_per_img}

  # Validation transforms pipeline
  val_transforms:
    - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
      transforms:
        - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
          sizes: ${scratch.resolution}
          max_size:
            _target_: sam3.train.transforms.basic.get_random_resize_max_size
            size: ${scratch.resolution}
          square: true
          consistent_transform: False
        - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
        - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
          mean: ${scratch.train_norm_mean}
          std: ${scratch.train_norm_std}

  # loss config (no mask loss)
  # loss:
  #   _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
  #   matcher: ${scratch.matcher}
  #   o2m_weight: 2.0
  #   o2m_matcher:
  #     _target_: sam3.train.matcher.BinaryOneToManyMatcher
  #     alpha: 0.3
  #     threshold: 0.4
  #     topk: 4
  #   use_o2m_matcher_on_o2m_aux: false # Another option is true
  #   loss_fns_find:
  #     - _target_: sam3.train.loss.loss_fns.Boxes
  #       weight_dict:
  #         loss_bbox: 5.0
  #         loss_giou: 2.0
  #     - _target_: sam3.train.loss.loss_fns.IABCEMdetr
  #       weak_loss: False
  #       weight_dict:
  #         loss_ce: 20.0 # Another option is 100.0
  #         presence_loss: 20.0
  #       pos_weight: 10.0 # Another option is 5.0
  #       alpha: 0.25
  #       gamma: 2
  #       use_presence: True  # Change
  #       pos_focal: false
  #       pad_n_queries: 200
  #       pad_scale_pos: 1.0

  #   loss_fn_semantic_seg: null
  #   scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}


  # NOTE: Loss to be used for training in case of segmentation
  loss:
    _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
    matcher: ${scratch.matcher}
    o2m_weight: 2.0
    o2m_matcher:
      _target_: sam3.train.matcher.BinaryOneToManyMatcher
      alpha: 0.3
      threshold: 0.4
      topk: 4
    use_o2m_matcher_on_o2m_aux: false
    loss_fns_find:
      - _target_: sam3.train.loss.loss_fns.Boxes
        weight_dict:
          loss_bbox: 5.0
          loss_giou: 2.0
      
      - _target_: sam3.train.loss.loss_fns.IABCEMdetr
        weak_loss: False
        weight_dict:
          loss_ce: 20.0 # Another option is 100.0
          presence_loss: 20.0
        pos_weight: 10.0 # Another option is 5.0
        alpha: 0.25
        gamma: 2
        use_presence: True  # Change
        pos_focal: false
        pad_n_queries: 200
        pad_scale_pos: 1.0
      - _target_: sam3.train.loss.loss_fns.Masks
        focal_alpha: 0.25
        focal_gamma: 2.0
        weight_dict:
          loss_mask: 200.0
          loss_dice: 10.0
        compute_aux: false
    loss_fn_semantic_seg:
      _target_: sam3.train.loss.loss_fns.SemanticSegCriterion
      presence_head: True
      presence_loss: False  # Change
      focal: True
      focal_alpha: 0.6
      focal_gamma: 2.0
      downsample: False
      weight_dict:
        loss_semantic_seg: 20.0
        loss_semantic_presence: 1.0
        loss_semantic_dice: 30.0
    scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}

# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
  enable_segmentation: true # NOTE: This is the number of queries used for segmentation
  # Model parameters
  d_model: 256
  pos_embed:
    _target_: sam3.model.position_encoding.PositionEmbeddingSine
    num_pos_feats: ${scratch.d_model}
    normalize: true
    scale: null
    temperature: 10000

  # Box processing
  use_presence_eval: True
  original_box_postprocessor:
    _target_: sam3.eval.postprocessors.PostProcessImage
    max_dets_per_img: -1  # infinite detections
    use_original_ids: true
    use_original_sizes_box: true
    use_presence: ${scratch.use_presence_eval}

  # Matcher configuration
  matcher:
    _target_: sam3.train.matcher.BinaryHungarianMatcherV2
    focal: true  # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
    cost_class: 2.0
    cost_bbox: 5.0
    cost_giou: 2.0
    alpha: 0.25
    gamma: 2
    stable: False
  scale_by_find_batch_size: True

  # Image processing parameters
  resolution: 1008
  consistent_transform: False
  max_ann_per_img: 200

  # Normalization parameters
  train_norm_mean: [0.5, 0.5, 0.5]
  train_norm_std: [0.5, 0.5, 0.5]
  val_norm_mean: [0.5, 0.5, 0.5]
  val_norm_std: [0.5, 0.5, 0.5]

  # Training parameters
  num_train_workers: 4
  num_val_workers: 0
  max_data_epochs: 20
  target_epoch_size: 1500
  hybrid_repeats: 1
  context_length: 2
  gather_pred_via_filesys: false

  # Learning rate and scheduler parameters
  lr_scale: 0.1
  lr_transformer: ${times:8e-4,${scratch.lr_scale}}
  lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
  lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
  lrd_vision_backbone: 0.9
  wd: 0.1
  scheduler_timescale: 20
  scheduler_warmup: 20
  scheduler_cooldown: 20

  val_batch_size: 1
  collate_fn_val:
    _target_: sam3.train.data.collator.collate_fn_api
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: roboflow100
    with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

  gradient_accumulation_steps: 1
  train_batch_size: 32
  collate_fn:
    _target_: sam3.train.data.collator.collate_fn_api
    _partial_: true
    repeats: ${scratch.hybrid_repeats}
    dict_key: all
    with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!

# ============================================================================
# Trainer Configuration
# ============================================================================

trainer:
  # checkpoint:
  #   model_weight_initializer:
  #     # This is the wrapper class, NOT the utility function
  #     _target_: sam3.train.model_weight_initializer.ModelWeightInitializer
  #     # It usually takes 'path' or 'weights_path', not 'path_list'
  #     path: "/kaggle/working/sam3_checkpoints/sam3.pt"
      
  #   save_dir: "/kaggle/working/sam3_checkpoints/sam3.pt"
  #   save_freq: 1  # 0 only last checkpoint is saved.

  _target_: sam3.train.trainer.Trainer
  skip_saving_ckpts: False
  empty_gpu_mem_cache_after_eval: True
  skip_first_val: True
  max_epochs: 20
  accelerator: cuda
  seed_value: 123
  val_epoch_freq: 10
  mode: train
  gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}

  distributed:
    backend: nccl
    find_unused_parameters: True
    gradient_as_bucket_view: True

  loss:
    all: ${roboflow_train.loss}
    roboflow100: ${roboflow_train.loss}
    # default:
    #   _target_: sam3.train.loss.sam3_loss.DummyLoss

  data:
    train:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        limit_ids: ${roboflow_train.num_images}
        transforms: ${roboflow_train.train_transforms}
        load_segmentation: ${scratch.enable_segmentation}
        coco_json_loader: 
          _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
          include_negatives: false
          category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
          _partial_: true
        max_ann_per_img: 500000
        multiplier: 1
        max_train_queries: 50000
        max_val_queries: 50000
        training: true
        use_caching: False
        img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
        ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json #_annotations.coco.json

      shuffle: True
      batch_size: ${scratch.train_batch_size}
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: False
      collate_fn: ${scratch.collate_fn}

    val:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        limit_ids: ${roboflow_train.num_images}
        load_segmentation: ${scratch.enable_segmentation}
        coco_json_loader: 
          _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
          include_negatives: false
          category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
          _partial_: true
        img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
        ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
        transforms: ${roboflow_train.val_transforms}
        max_ann_per_img: 100000
        multiplier: 1
        training: false

      shuffle: False
      batch_size: ${scratch.val_batch_size}
      num_workers: ${scratch.num_val_workers}
      pin_memory: True
      drop_last: False
      collate_fn: ${scratch.collate_fn_val}


  model:
    _target_: sam3.model_builder.build_sam3_image_model
    bpe_path: ${paths.bpe_path}
    device: cpus
    load_from_HF: True
    eval_mode: false
    enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.

  meters:
    val:
      roboflow100:
        detection:
          _target_: sam3.eval.coco_writer.PredictionDumper
          iou_type: "bbox"
          dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
          merge_predictions: True
          postprocessor: ${scratch.original_box_postprocessor}
          gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
          maxdets: 100
          pred_file_evaluators:
            - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
              gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
              tide: False
              iou_type: "bbox"

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: sam3.train.optim.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: ${scratch.lrd_vision_backbone}
        apply_to: 'backbone.vision_backbone.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:  # transformer and class_embed
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_transformer}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_vision_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.vision_backbone.*'
        - scheduler:
            _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
            base_lr: ${scratch.lr_language_backbone}
            timescale: ${scratch.scheduler_timescale}
            warmup_steps: ${scratch.scheduler_warmup}
            cooldown_steps: ${scratch.scheduler_cooldown}
          param_names:
            - 'backbone.language_backbone.*'

      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: ${scratch.wd}
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.

  logging:
    tensorboard_writer:
      _target_: sam3.train.utils.logger.make_tensorboard_logger
      log_dir: ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    wandb_writer: null
    log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
    log_freq: 10

# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================

launcher:
  num_nodes: 1
  gpus_per_node: 2
  experiment_log_dir: ${paths.experiment_log_dir}
  multiprocessing_context: forkserver

submitit:
  account: null
  partition: null
  qos: null
  timeout_hour: 72
  use_cluster: False
  cpus_per_task: 10
  port_range: [10000, 65000]
  constraint: null
  # Uncomment for job array configuration
  job_array:
    num_tasks: 1
    task_index: 0

# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================

all_roboflow_supercategories:
  # - -grccs
  # - zebrasatasturias
  # - cod-mw-warzone
  # - canalstenosis
  # - label-printing-defect-version-2
  # - new-defects-in-wood
  # - orionproducts
  # - aquarium-combined
  # - varroa-mites-detection--test-set
  # - clashroyalechardetector
  # - stomata-cells
  # - halo-infinite-angel-videogame
  # - pig-detection
  # - urine-analysis1
  # - aerial-sheep
  # - orgharvest
  # - actions
  # - mahjong
  # - liver-disease
  # - needle-base-tip-min-max
  # - wheel-defect-detection
  # - aircraft-turnaround-dataset
  # - xray
  # - wildfire-smoke
  # - spinefrxnormalvindr
  # - ufba-425
  # - speech-bubbles-detection
  # - train
  # - pill
  # - truck-movement
  # - car-logo-detection
  # - inbreast
  # - sea-cucumbers-new-tiles
  # - uavdet-small
  # - penguin-finder-seg
  # - aerial-airport
  # - bibdetection
  # - taco-trash-annotations-in-context
  # - bees
  # - recode-waste
  # - screwdetectclassification
  # - wine-labels
  # - aerial-cows
  # - into-the-vale
  # - gwhd2021
  # - lacrosse-object-detection
  # - defect-detection
  # - dataconvert
  # - x-ray-id
  # - ball
  # - tube
  # - 2024-frc
  # - crystal-clean-brain-tumors-mri-dataset
  # - grapes-5
  # - human-detection-in-floods
  # - buoy-onboarding
  # - apoce-aerial-photographs-for-object-detection-of-construction-equipment
  # - l10ul502
  # - floating-waste
  # - deeppcb
  # - ism-band-packet-detection
  # - weeds4
  # - invoice-processing
  # - thermal-cheetah
  # - tomatoes-2
  # - marine-sharks
  # - peixos-fish
  # - sssod
  # - aerial-pool
  # - countingpills
  # - asphaltdistressdetection
  # - roboflow-trained-dataset
  # - everdaynew
  # - underwater-objects
  # - soda-bottles
  - dentalai
  # - jellyfish
  # - deepfruits
  # - activity-diagrams
  # - circuit-voltages
  # - all-elements
  # - macro-segmentation
  # - exploratorium-daphnia
  # - signatures
  # - conveyor-t-shirts
  # - fruitjes
  # - grass-weeds
  # - infraredimageofpowerequipment
  # - 13-lkc01
  # - wb-prova
  # - flir-camera-objects
  # - paper-parts
  # - football-player-detection
  # - trail-camera
  # - smd-components
  # - water-meter
  # - nih-xray
  # - the-dreidel-project
  # - electric-pylon-detection-in-rsi
  # - cable-damage

BahMoh avatar Dec 24 '25 07:12 BahMoh