Recommended resolution & min_size settings for lower GPU memory usage while fine-tuning
Dear authors,
I would like to change the training resolution (currently 1008) and the min_size (currently 480) to reduce GPU memory usage. However, due to the complexity of the transform pipeline, I have not been able to get the code to run correctly after modifying these values.
Could you please suggest some alternative combinations of resolution and min_size that are known to work with the current transform implementation?
Thank you very much!
@Joey-S-Liu Have you solved it?
Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.
Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.
@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?
Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.
@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?
Using a higher batch size (e.g. batch_size=16) in my case, resulted in empty train dataloader.
I think my problem was the small num_images hyper-parameter (e.g. num_images=20), in conjunction with bigger batch size and drop_last = True, caused the train loaded to be empty.
This was the problem, and I was happy that I was not getting OOM error:)
Changing the resolution to any other number seems to break the code. The numbers in the RoPE script are hard-coded! I have tried the multiples of 14 and 32 but they did not work. And surprisingly using a higher batch size (in my case 16) avoids OOM error.
@BahMoh Hello, I met the same question. But using a higher batch size(16) didn't work for me. Could you tell me more about your setup and configurations?
This is the configuration I used for trining (at least this does not throw errors other than OOM), my taks is segmentation:
# @package _global_
defaults:
- _self_
# ============================================================================
# Paths Configuration (Chage this to your own paths)
# ============================================================================
paths:
roboflow_vl_100_root: /kaggle/working
experiment_log_dir: /kaggle/working/log
bpe_path: /kaggle/working/sam3/assets/bpe_simple_vocab_16e6.txt.gz # This should be under assets/bpe_simple_vocab_16e6.txt.gz
# Roboflow dataset configuration
roboflow_train:
num_images: 50 # Note: This is the number of images used for training. If null, all images are used.
supercategory: Dental_AI_3 #${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}}
# Training transforms pipeline
# train_transforms:
# - _target_: sam3.train.transforms.basic_for_api.ComposeAPI
# transforms:
# - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
# sizes: ${scratch.resolution}
# max_size:
# _target_: sam3.train.transforms.basic.get_random_resize_max_size
# size: ${scratch.resolution}
# square: true
# consistent_transform: False
# - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
# - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
# mean: ${scratch.train_norm_mean}
# std: ${scratch.train_norm_std}
train_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterCrowds
- _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox
box_noise_std: 0.1
box_noise_max: 20
- _target_: sam3.train.transforms.segmentation.DecodeRle
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes:
_target_: sam3.train.transforms.basic.get_random_resize_scales
size: ${scratch.resolution}
min_size: 480
rounded: false
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI
size: ${scratch.resolution}
consistent_transform: ${scratch.consistent_transform}
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets
- _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries
query_filter:
_target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut
max_num_objects: ${scratch.max_ann_per_img}
# Validation transforms pipeline
val_transforms:
- _target_: sam3.train.transforms.basic_for_api.ComposeAPI
transforms:
- _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI
sizes: ${scratch.resolution}
max_size:
_target_: sam3.train.transforms.basic.get_random_resize_max_size
size: ${scratch.resolution}
square: true
consistent_transform: False
- _target_: sam3.train.transforms.basic_for_api.ToTensorAPI
- _target_: sam3.train.transforms.basic_for_api.NormalizeAPI
mean: ${scratch.train_norm_mean}
std: ${scratch.train_norm_std}
# loss config (no mask loss)
# loss:
# _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
# matcher: ${scratch.matcher}
# o2m_weight: 2.0
# o2m_matcher:
# _target_: sam3.train.matcher.BinaryOneToManyMatcher
# alpha: 0.3
# threshold: 0.4
# topk: 4
# use_o2m_matcher_on_o2m_aux: false # Another option is true
# loss_fns_find:
# - _target_: sam3.train.loss.loss_fns.Boxes
# weight_dict:
# loss_bbox: 5.0
# loss_giou: 2.0
# - _target_: sam3.train.loss.loss_fns.IABCEMdetr
# weak_loss: False
# weight_dict:
# loss_ce: 20.0 # Another option is 100.0
# presence_loss: 20.0
# pos_weight: 10.0 # Another option is 5.0
# alpha: 0.25
# gamma: 2
# use_presence: True # Change
# pos_focal: false
# pad_n_queries: 200
# pad_scale_pos: 1.0
# loss_fn_semantic_seg: null
# scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# NOTE: Loss to be used for training in case of segmentation
loss:
_target_: sam3.train.loss.sam3_loss.Sam3LossWrapper
matcher: ${scratch.matcher}
o2m_weight: 2.0
o2m_matcher:
_target_: sam3.train.matcher.BinaryOneToManyMatcher
alpha: 0.3
threshold: 0.4
topk: 4
use_o2m_matcher_on_o2m_aux: false
loss_fns_find:
- _target_: sam3.train.loss.loss_fns.Boxes
weight_dict:
loss_bbox: 5.0
loss_giou: 2.0
- _target_: sam3.train.loss.loss_fns.IABCEMdetr
weak_loss: False
weight_dict:
loss_ce: 20.0 # Another option is 100.0
presence_loss: 20.0
pos_weight: 10.0 # Another option is 5.0
alpha: 0.25
gamma: 2
use_presence: True # Change
pos_focal: false
pad_n_queries: 200
pad_scale_pos: 1.0
- _target_: sam3.train.loss.loss_fns.Masks
focal_alpha: 0.25
focal_gamma: 2.0
weight_dict:
loss_mask: 200.0
loss_dice: 10.0
compute_aux: false
loss_fn_semantic_seg:
_target_: sam3.train.loss.loss_fns.SemanticSegCriterion
presence_head: True
presence_loss: False # Change
focal: True
focal_alpha: 0.6
focal_gamma: 2.0
downsample: False
weight_dict:
loss_semantic_seg: 20.0
loss_semantic_presence: 1.0
loss_semantic_dice: 30.0
scale_by_find_batch_size: ${scratch.scale_by_find_batch_size}
# ============================================================================
# Different helper parameters and functions
# ============================================================================
scratch:
enable_segmentation: true # NOTE: This is the number of queries used for segmentation
# Model parameters
d_model: 256
pos_embed:
_target_: sam3.model.position_encoding.PositionEmbeddingSine
num_pos_feats: ${scratch.d_model}
normalize: true
scale: null
temperature: 10000
# Box processing
use_presence_eval: True
original_box_postprocessor:
_target_: sam3.eval.postprocessors.PostProcessImage
max_dets_per_img: -1 # infinite detections
use_original_ids: true
use_original_sizes_box: true
use_presence: ${scratch.use_presence_eval}
# Matcher configuration
matcher:
_target_: sam3.train.matcher.BinaryHungarianMatcherV2
focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher
cost_class: 2.0
cost_bbox: 5.0
cost_giou: 2.0
alpha: 0.25
gamma: 2
stable: False
scale_by_find_batch_size: True
# Image processing parameters
resolution: 1008
consistent_transform: False
max_ann_per_img: 200
# Normalization parameters
train_norm_mean: [0.5, 0.5, 0.5]
train_norm_std: [0.5, 0.5, 0.5]
val_norm_mean: [0.5, 0.5, 0.5]
val_norm_std: [0.5, 0.5, 0.5]
# Training parameters
num_train_workers: 4
num_val_workers: 0
max_data_epochs: 20
target_epoch_size: 1500
hybrid_repeats: 1
context_length: 2
gather_pred_via_filesys: false
# Learning rate and scheduler parameters
lr_scale: 0.1
lr_transformer: ${times:8e-4,${scratch.lr_scale}}
lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}}
lr_language_backbone: ${times:5e-5,${scratch.lr_scale}}
lrd_vision_backbone: 0.9
wd: 0.1
scheduler_timescale: 20
scheduler_warmup: 20
scheduler_cooldown: 20
val_batch_size: 1
collate_fn_val:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: roboflow100
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
gradient_accumulation_steps: 1
train_batch_size: 32
collate_fn:
_target_: sam3.train.data.collator.collate_fn_api
_partial_: true
repeats: ${scratch.hybrid_repeats}
dict_key: all
with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks!
# ============================================================================
# Trainer Configuration
# ============================================================================
trainer:
# checkpoint:
# model_weight_initializer:
# # This is the wrapper class, NOT the utility function
# _target_: sam3.train.model_weight_initializer.ModelWeightInitializer
# # It usually takes 'path' or 'weights_path', not 'path_list'
# path: "/kaggle/working/sam3_checkpoints/sam3.pt"
# save_dir: "/kaggle/working/sam3_checkpoints/sam3.pt"
# save_freq: 1 # 0 only last checkpoint is saved.
_target_: sam3.train.trainer.Trainer
skip_saving_ckpts: False
empty_gpu_mem_cache_after_eval: True
skip_first_val: True
max_epochs: 20
accelerator: cuda
seed_value: 123
val_epoch_freq: 10
mode: train
gradient_accumulation_steps: ${scratch.gradient_accumulation_steps}
distributed:
backend: nccl
find_unused_parameters: True
gradient_as_bucket_view: True
loss:
all: ${roboflow_train.loss}
roboflow100: ${roboflow_train.loss}
# default:
# _target_: sam3.train.loss.sam3_loss.DummyLoss
data:
train:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
limit_ids: ${roboflow_train.num_images}
transforms: ${roboflow_train.train_transforms}
load_segmentation: ${scratch.enable_segmentation}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
include_negatives: false
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
_partial_: true
max_ann_per_img: 500000
multiplier: 1
max_train_queries: 50000
max_val_queries: 50000
training: true
use_caching: False
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json #_annotations.coco.json
shuffle: True
batch_size: ${scratch.train_batch_size}
num_workers: ${scratch.num_train_workers}
pin_memory: True
drop_last: False
collate_fn: ${scratch.collate_fn}
val:
_target_: sam3.train.data.torch_dataset.TorchDataset
dataset:
_target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
limit_ids: ${roboflow_train.num_images}
load_segmentation: ${scratch.enable_segmentation}
coco_json_loader:
_target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON
include_negatives: false
category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU.
_partial_: true
img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/
ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
transforms: ${roboflow_train.val_transforms}
max_ann_per_img: 100000
multiplier: 1
training: false
shuffle: False
batch_size: ${scratch.val_batch_size}
num_workers: ${scratch.num_val_workers}
pin_memory: True
drop_last: False
collate_fn: ${scratch.collate_fn_val}
model:
_target_: sam3.model_builder.build_sam3_image_model
bpe_path: ${paths.bpe_path}
device: cpus
load_from_HF: True
eval_mode: false
enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation.
meters:
val:
roboflow100:
detection:
_target_: sam3.eval.coco_writer.PredictionDumper
iou_type: "bbox"
dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory}
merge_predictions: True
postprocessor: ${scratch.original_box_postprocessor}
gather_pred_via_filesys: ${scratch.gather_pred_via_filesys}
maxdets: 100
pred_file_evaluators:
- _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators
gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json
tide: False
iou_type: "bbox"
optim:
amp:
enabled: True
amp_dtype: bfloat16
optimizer:
_target_: torch.optim.AdamW
gradient_clip:
_target_: sam3.train.optim.optimizer.GradientClipper
max_norm: 0.1
norm_type: 2
param_group_modifiers:
- _target_: sam3.train.optim.optimizer.layer_decay_param_modifier
_partial_: True
layer_decay_value: ${scratch.lrd_vision_backbone}
apply_to: 'backbone.vision_backbone.trunk'
overrides:
- pattern: '*pos_embed*'
value: 1.0
options:
lr:
- scheduler: # transformer and class_embed
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_transformer}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_vision_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.vision_backbone.*'
- scheduler:
_target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler
base_lr: ${scratch.lr_language_backbone}
timescale: ${scratch.scheduler_timescale}
warmup_steps: ${scratch.scheduler_warmup}
cooldown_steps: ${scratch.scheduler_cooldown}
param_names:
- 'backbone.language_backbone.*'
weight_decay:
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: ${scratch.wd}
- scheduler:
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
value: 0.0
param_names:
- '*bias*'
module_cls_names: ['torch.nn.LayerNorm']
checkpoint:
save_dir: ${launcher.experiment_log_dir}/checkpoints
save_freq: 0 # 0 only last checkpoint is saved.
logging:
tensorboard_writer:
_target_: sam3.train.utils.logger.make_tensorboard_logger
log_dir: ${launcher.experiment_log_dir}/tensorboard
flush_secs: 120
should_log: True
wandb_writer: null
log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory}
log_freq: 10
# ============================================================================
# Launcher and Submitit Configuration
# ============================================================================
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
multiprocessing_context: forkserver
submitit:
account: null
partition: null
qos: null
timeout_hour: 72
use_cluster: False
cpus_per_task: 10
port_range: [10000, 65000]
constraint: null
# Uncomment for job array configuration
job_array:
num_tasks: 1
task_index: 0
# ============================================================================
# Available Roboflow Supercategories (for reference)
# ============================================================================
all_roboflow_supercategories:
# - -grccs
# - zebrasatasturias
# - cod-mw-warzone
# - canalstenosis
# - label-printing-defect-version-2
# - new-defects-in-wood
# - orionproducts
# - aquarium-combined
# - varroa-mites-detection--test-set
# - clashroyalechardetector
# - stomata-cells
# - halo-infinite-angel-videogame
# - pig-detection
# - urine-analysis1
# - aerial-sheep
# - orgharvest
# - actions
# - mahjong
# - liver-disease
# - needle-base-tip-min-max
# - wheel-defect-detection
# - aircraft-turnaround-dataset
# - xray
# - wildfire-smoke
# - spinefrxnormalvindr
# - ufba-425
# - speech-bubbles-detection
# - train
# - pill
# - truck-movement
# - car-logo-detection
# - inbreast
# - sea-cucumbers-new-tiles
# - uavdet-small
# - penguin-finder-seg
# - aerial-airport
# - bibdetection
# - taco-trash-annotations-in-context
# - bees
# - recode-waste
# - screwdetectclassification
# - wine-labels
# - aerial-cows
# - into-the-vale
# - gwhd2021
# - lacrosse-object-detection
# - defect-detection
# - dataconvert
# - x-ray-id
# - ball
# - tube
# - 2024-frc
# - crystal-clean-brain-tumors-mri-dataset
# - grapes-5
# - human-detection-in-floods
# - buoy-onboarding
# - apoce-aerial-photographs-for-object-detection-of-construction-equipment
# - l10ul502
# - floating-waste
# - deeppcb
# - ism-band-packet-detection
# - weeds4
# - invoice-processing
# - thermal-cheetah
# - tomatoes-2
# - marine-sharks
# - peixos-fish
# - sssod
# - aerial-pool
# - countingpills
# - asphaltdistressdetection
# - roboflow-trained-dataset
# - everdaynew
# - underwater-objects
# - soda-bottles
- dentalai
# - jellyfish
# - deepfruits
# - activity-diagrams
# - circuit-voltages
# - all-elements
# - macro-segmentation
# - exploratorium-daphnia
# - signatures
# - conveyor-t-shirts
# - fruitjes
# - grass-weeds
# - infraredimageofpowerequipment
# - 13-lkc01
# - wb-prova
# - flir-camera-objects
# - paper-parts
# - football-player-detection
# - trail-camera
# - smd-components
# - water-meter
# - nih-xray
# - the-dreidel-project
# - electric-pylon-detection-in-rsi
# - cable-damage