sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

SAM3 video training config file

Open machlovi opened this issue 2 weeks ago • 3 comments

Hi, did anyone work around for video training? All of the config files for videos are intended for evaluation only, and I have figured out how to start training. However, the issue is that model_builder.py currently only supports an image model for training. Even though with some tweaks I was able to start the training, the image has an assertion for num_samples=1, which does not seem to be working in the case of videos. This seems to be working for me, but if anyone can verify the approach, it will be appreciated. The overall take is sam3 Image model can't process the video frames so we have to build a model using simialr pipline for sam3image but allowing it process multiple frames for video traing.

This is what I did to achieve my training. 1: in roboflow_c100_full_ft100.yaml file


  data:
    train:
      _target_: sam3.train.data.torch_dataset.TorchDataset
      dataset:
        # _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset
        _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset

  model:
    _target_: sam3.model_builder.build_sam3_image_video_model  # ← Use new builder
    bpe_path: ${paths.bpe_path}
    device: cpu
    eval_mode: false
    enable_segmentation: ${scratch.enable_segmentation}
    checkpoint_path: null
    load_from_HF: false  # Set to true if you want to load pretrained weights
    async_all_gather: true
    gather_backbone_out: true

2 . Add the following function in model_builder.py

` def build_sam3_image_video_model( bpe_path=None, device="cuda" if torch.cuda.is_available() else "cpu", eval_mode=False, # Default to False for training checkpoint_path=None, load_from_HF=True, enable_segmentation=True, enable_inst_interactivity=False, compile=False, async_all_gather=True, gather_backbone_out=None):

Build SAM3 image model with multi-GPU video support for training.
This is the same as build_sam3_image_model but returns Sam3ImageOnVideoMultiGPU
instead of Sam3Image, allowing training with multiple frames.    
Args:
    bpe_path: Path to the BPE tokenizer vocabulary
    device: Device to load the model on ('cuda' or 'cpu')
    eval_mode: Whether to set the model to evaluation mode
    checkpoint_path: Optional path to model checkpoint
    enable_segmentation: Whether to enable segmentation head
    enable_inst_interactivity: Whether to enable instance interactivity
    compile: Whether to compile the model
    async_all_gather: Enable async all-gather for multi-GPU (video specific)
    gather_backbone_out: Whether to gather backbone features (video specific)
Returns:
    Sam3ImageOnVideoMultiGPU: A SAM3 image model with multi-frame support

if bpe_path is None:
    bpe_path = os.path.join(
        os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz"
    )

# Create visual components
compile_mode = "default" if compile else None
vision_encoder = _create_vision_backbone(
    compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity
)

# Create text components
text_encoder = _create_text_encoder(bpe_path)

# Create visual-language backbone
backbone = _create_vl_backbone(vision_encoder, text_encoder)

# Create transformer components
transformer = _create_sam3_transformer()

# Create dot product scoring
dot_prod_scoring = _create_dot_product_scoring()

# Create segmentation head if enabled
segmentation_head = (
    _create_segmentation_head(compile_mode=compile_mode)
    if enable_segmentation
    else None
)

# Create geometry encoder
input_geometry_encoder = _create_geometry_encoder()

# Create instance interactivity predictor if enabled
if enable_inst_interactivity:
    sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False)
    inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base)
else:
    inst_predictor = None

# Create matcher for training
matcher = None
if not eval_mode:
    from sam3.train.matcher import BinaryHungarianMatcherV2
    matcher = BinaryHungarianMatcherV2(
        focal=True,
        cost_class=2.0,
        cost_bbox=5.0,
        cost_giou=2.0,
        alpha=0.25,
        gamma=2,
        stable=False,
    )

# ✅ KEY DIFFERENCE: Use Sam3ImageOnVideoMultiGPU instead of Sam3Image
model = Sam3ImageOnVideoMultiGPU(
    backbone=backbone,
    transformer=transformer,
    input_geometry_encoder=input_geometry_encoder,
    segmentation_head=segmentation_head,
    num_feature_levels=1,
    o2m_mask_predict=True,
    dot_prod_scoring=dot_prod_scoring,
    use_instance_query=False,
    multimask_output=True,
    inst_interactive_predictor=inst_predictor,
    matcher=matcher,
    # Video-specific parameters
    async_all_gather=async_all_gather,
    gather_backbone_out=gather_backbone_out,
)

# Load checkpoint if provided
if load_from_HF and checkpoint_path is None:
    checkpoint_path = download_ckpt_from_hf()

if checkpoint_path is not None:
    _load_checkpoint(model, checkpoint_path)

# Setup device and mode
model = _setup_device_and_mode(model, device, eval_mode)

return model`

3 Add a clas in sam3.model.sam3_image.py

` class Sam3ImageOnVideoMultiGPU(Sam3Image): def init( self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs ): super().init(*args, **kwargs) self.rank = int(os.getenv("RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) self.async_all_gather = async_all_gather

    # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
    if gather_backbone_out is None:
        gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone)
    self.gather_backbone_out = gather_backbone_out

   
# ✅ ADD THIS METHOD:
def forward(self, input: BatchedDatapoint):
    """
    Override parent's forward to handle multiple frames.
    Processes frames sequentially (not distributed like forward_video_grounding_multigpu).
    """
    device = self.device
    backbone_out = {"img_batch_all_stages": input.img_batch}
    backbone_out.update(self.backbone.forward_image(input.img_batch))
    
    num_frames = len(input.find_inputs)
    # ✅ NO ASSERTION - allow multiple frames
    print(f"[Sam3ImageOnVideoMultiGPU] Processing {num_frames} frames")
    
    text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
    backbone_out.update(text_outputs)
    
    previous_stages_out = SAM3Output(
        iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
    )
    
    # Process each frame sequentially
    for frame_idx in range(num_frames):
        find_input = input.find_inputs[frame_idx]
        find_target = input.find_targets[frame_idx] if input.find_targets and frame_idx < len(input.find_targets) else None
        
        if find_input.input_points is not None and find_input.input_points.numel() > 0:
            print("Warning: Point prompts are ignored in PCS.")
        
        num_interactive_steps = 0 if self.training else self.num_interactive_steps_val
        geometric_prompt = Prompt(
            box_embeddings=find_input.input_boxes,
            box_mask=find_input.input_boxes_mask,
            box_labels=find_input.input_boxes_label,
        )
        
        # Init vars that are shared across the loop
        stage_outs = []
        for cur_step in range(num_interactive_steps + 1):
            if cur_step > 0:
                # Sample interactive geometric prompts
                geometric_prompt, _ = self.interactive_prompt_sampler.sample(
                    geo_prompt=geometric_prompt,
                    find_target=find_target,
                    previous_out=stage_outs[-1],
                )
            
            out = self.forward_grounding(
                backbone_out=backbone_out,
                find_input=find_input,
                find_target=find_target,
                geometric_prompt=geometric_prompt.clone(),
            )
            stage_outs.append(out)
        
        previous_stages_out.append(stage_outs)
    
    return previous_stages_out



def forward_video_grounding_multigpu(
    self,
    backbone_out,
    find_inputs,
    geometric_prompt: Prompt,
    frame_idx,
    num_frames,
    # `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls
    multigpu_buffer,
    track_in_reverse=False,
    # whether to also return the SAM2 backbone features
    return_sam2_backbone_feats=False,
    # whether to perform NMS and suppress the scores of those detections removed by NMS
    run_nms=False,
    nms_prob_thresh=None,
    nms_iou_thresh=None,
    **kwargs,
):
    """
    Compute the detector's detection outputs in a distributed manner, where all GPUs process
    a chunk of frames (equal to the number of GPUs) at once and store them in cache.
    """
    # Step 1: fetch the detector outputs in the current chunk from buffer
    frame_idx_curr_b = frame_idx - frame_idx % self.world_size
    frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
    # in case the current frame's detection results are not in the buffer yet, build the current chunk
    # (this should only happen on the first chunk, since we are also building the next chunk below)
    if frame_idx not in multigpu_buffer:
        with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
            self._build_multigpu_buffer_next_chunk(
                backbone_out=backbone_out,
                find_inputs=find_inputs,
                geometric_prompt=geometric_prompt,
                frame_idx_begin=frame_idx_curr_b,
                frame_idx_end=frame_idx_curr_e,
                num_frames=num_frames,
                multigpu_buffer=multigpu_buffer,
                run_nms=run_nms,
                nms_prob_thresh=nms_prob_thresh,
                nms_iou_thresh=nms_iou_thresh,
            )

    # read out the current frame's results from `multigpu_buffer`
    out = {}
    for k, (v, handle) in multigpu_buffer[frame_idx].items():
        if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
            continue
        if handle is not None:
            handle.wait()  # wait for async all-gather to finish
        out[k] = v

    # Step 2: remove detection outputs of the previous chunk from cache to save GPU memory
    if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
        frame_idx_prev_e = frame_idx_curr_b
        frame_idx_prev_b = frame_idx_curr_b - self.world_size
    elif track_in_reverse and frame_idx_curr_e < num_frames:
        frame_idx_prev_b = frame_idx_curr_e
        frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
    else:
        frame_idx_prev_b = frame_idx_prev_e = None
    if frame_idx_prev_b is not None:
        for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
            multigpu_buffer.pop(frame_idx_rm, None)

    # Step 3: compute and cache detection outputs of the next chunk ahead of time
    # (so that we can overlap computation with all-gather transfer)
    if not track_in_reverse and frame_idx_curr_e < num_frames:
        frame_idx_next_b = frame_idx_curr_e
        frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames)
    elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
        frame_idx_next_e = frame_idx_curr_b
        frame_idx_next_b = frame_idx_curr_b - self.world_size
    else:
        frame_idx_next_b = frame_idx_next_e = None
    if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
        with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
            self._build_multigpu_buffer_next_chunk(
                backbone_out=backbone_out,
                find_inputs=find_inputs,
                geometric_prompt=geometric_prompt,
                frame_idx_begin=frame_idx_next_b,
                frame_idx_end=frame_idx_next_e,
                num_frames=num_frames,
                multigpu_buffer=multigpu_buffer,
                run_nms=run_nms,
                nms_prob_thresh=nms_prob_thresh,
                nms_iou_thresh=nms_iou_thresh,
            )

    return out, backbone_out

def _build_multigpu_buffer_next_chunk(
    self,
    backbone_out,
    find_inputs,
    geometric_prompt: Prompt,
    frame_idx_begin,
    frame_idx_end,
    num_frames,
    multigpu_buffer,
    run_nms=False,
    nms_prob_thresh=None,
    nms_iou_thresh=None,
):
    """Compute detection outputs on a chunk of frames and store their results in multigpu_buffer."""
    # each GPU computes detections on one frame in the chunk (in a round-robin manner)
    frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
    # `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame
    with torch.profiler.record_function("forward_grounding"):
        out_local = self.forward_grounding(
            backbone_out=backbone_out,
            find_input=find_inputs[frame_idx_local_gpu],
            find_target=None,
            geometric_prompt=geometric_prompt,
        )
    if run_nms:
        with torch.profiler.record_function("nms_masks"):
            # run NMS as a post-processing step on top of the detection outputs
            assert nms_prob_thresh is not None and nms_iou_thresh is not None
            pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
            pred_masks = out_local["pred_masks"]
            # loop over text prompts (not an overhead for demo where there's only 1 prompt)
            for prompt_idx in range(pred_probs.size(0)):
                keep = nms_masks(
                    pred_probs=pred_probs[prompt_idx],
                    pred_masks=pred_masks[prompt_idx],
                    prob_threshold=nms_prob_thresh,
                    iou_threshold=nms_iou_thresh,
                )
                # set a very low threshold for those detections removed by NMS
                out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()

    if self.gather_backbone_out:
        # gather the SAM 2 backbone features across GPUs
        feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
        assert len(feats["backbone_fpn"]) == 3  # SAM2 backbone always have 3 levels
        # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
        # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
        backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
        fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0])
        fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1])
        fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2])
        # vision_pos_enc is the same on all frames, so no need to all-gather them
        vision_pos_enc = feats["vision_pos_enc"]

    # trim the detector output to only include the necessary keys
    out_local = {
        "pred_logits": out_local["pred_logits"],
        "pred_boxes": out_local["pred_boxes"],
        "pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
        "pred_masks": out_local["pred_masks"],
    }

    # gather the results: after this step, each GPU will receive detector outputs on
    # all frames in the chunk and store them in `multigpu_buffer`
    out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
    for rank in range(self.world_size):
        frame_idx_to_save = frame_idx_begin + rank
        if frame_idx_to_save >= num_frames:
            continue
        frame_buffer = {
            k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
        }
        if self.gather_backbone_out:
            # also add gathered SAM 2 backbone features to frame_buffer
            frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
            frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
            frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
            frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None)

        multigpu_buffer[frame_idx_to_save] = frame_buffer

def _gather_tensor(self, x):
    if self.world_size == 1:
        return [x], None

    async_op = self.async_all_gather
    # here `.contiguous()` is required -- otherwise NCCL all_gather
    # sometimes gives wrong results
    x = x.contiguous()  # ensure contiguous memory for NCCL
    output_list = [torch.empty_like(x) for _ in range(self.world_size)]
    handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
    return output_list, handle

`

machlovi avatar Dec 10 '25 01:12 machlovi