sam3 icon indicating copy to clipboard operation
sam3 copied to clipboard

Feature request: backend-agnostic inference (CPU/MPS/TPU support)

Open trothe opened this issue 3 months ago • 6 comments

Summary

It would be great if SAM 3 exposed a backend-agnostic inference path so the model can run on devices other than CUDA GPUs (Apple Silicon / MPS, CPU-only servers, TPU via XLA, etc.). Right now the repo assumes CUDA everywhere, so even image-only inference fails on machines without NVIDIA GPUs.

Use case

I (and likely many others) want to use SAM 3 for bulk auto-labeling images when preparing YOLO datasets. Those labeling jobs often run on whatever hardware is already available—e.g., a MacBook Pro M3 with the Metal build of PyTorch. Because SAM 3 hard-requires CUDA 12.6, there is no way to try the model without renting a remote GPU, which makes lightweight labeling workflows much harder.

Current blockers

These are some spots that make CUDA mandatory today:

  • The README lists "CUDA-compatible GPU with CUDA 12.6 or higher" as a prerequisite and only shows CUDA install commands (README.md:58-79).
  • Core preprocessing helpers unconditionally move tensors to CUDA (e.g. load_image_as_single_frame_video calls .cuda() on the image, mean, and std tensors in sam3/model/io_utils.py:93-112). On a PyTorch build compiled without CUDA, importing these utilities throws Torch not compiled with CUDA enabled.
  • The video inference stack wraps methods with @torch.autocast(device_type="cuda") and raises if self.device.type != "cuda" (sam3/model/sam3_video_inference.py:797-810), so even if most of the model could run on CPU/MPS, execution aborts immediately.
  • Multi-GPU support is hard-coded to NCCL + CUDA tensors (sam3/model/sam3_video_predictor.py:420-433), again preventing CPU or other backends.
  • Performance-critical helpers require CUDA-only extensions such as torch_generic_nms and Triton kernels built with TORCH_CUDA_ARCH_LIST (sam3/perflib/nms.py:11-69). There is no fallback that keeps execution on CPU/MPS when these aren’t available.

Request / proposal

Would Meta consider supporting a backend-agnostic inference mode? Concretely, that could mean:

  1. Abstracting device selection (use tensor.to(device) instead of raw .cuda() and allowing torch.device("mps"), "cpu", etc.).
  2. Guarding CUDA-specific features (NCCL, Triton kernels, torch.cuda.* logging) behind availability checks and providing CPU-friendly defaults.
  3. Documenting a CPU/MPS installation path—e.g., PyTorch nightly with Metal and a note about expected performance.
  4. (Stretch) exposing hooks so advanced users could plug in XLA/TPU devices once the general abstraction exists.

I’m happy to test on Apple Silicon if guidance is provided, but an official stance on whether CPU/MPS/TPU support is on the roadmap would already help the community plan.

trothe avatar Nov 20 '25 13:11 trothe

It works fine, tested on an 32GB M2 MBP if you work through the issues you've mentioned. Not many changes needed, and you should set PYTORCH_ENABLE_MPS_FALLBACK.

Image

I moved a couple of the imports (decord and triton) to where they're used as a quick fix, but guards may be better. I've only tried the image predictor notebook so far. These are only used for training(?), and only a couple of references, but they're imported by some common files at inference time.

Video needs more work, but you can play with image prediction using this branch:

https://github.com/facebookresearch/sam3/compare/main...jveitchmichaelis:sam3:device-agnostic

jveitchmichaelis avatar Nov 21 '25 02:11 jveitchmichaelis

Some of the analysis is not quite correct, for example you cite nms as a blocker, while it has an explicit cpu fallback: https://github.com/facebookresearch/sam3/blob/84cc43bca4347b772f17d1078a1ddb4c054655c2/sam3/perflib/nms.py#L63-L71

On a more general note, I think it would be quite easy to make the image predictor backend-agnostic (though we can't promise support for additional backends, especially those we don't have access to). If you want to submit a PR for this we'll gladly accept it. For the video predictor, it would be a far more complicated endeavor, since the current predictor is highly optimized for the multi-gpu setting. We do not have plans to support this in the near future.

alcinos avatar Nov 21 '25 11:11 alcinos

Thanks @jveitchmichaelis for the branch! I pulled it locally on a 36GB M3 MBP and can confirm the image predictor notebook (and a simple CLI) run fine on the MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1.

One extra tweak I needed was to guard the decoder FFN autocast block so it doesn’t call torch.amp.autocast(device_type=...) on unsupported devices (wrap it in a nullcontext() when tgt.device.type isn’t cuda or cpu). Without that, the image predictor hits a RuntimeError: unsupported scalarType on MPS. I can submit that as part of a small PR.

My use case is auto-labeling images for YOLO datasets, so having per-image inference on local Apple Silicon hardware is already super helpful. I’m happy to contribute a PR that combines the device-agnostic fixes + a small sample script (e.g., run SAM3 on example.png and emit masks/boxes) so others can replicate the workflow. Let me know if that would be welcome.

trothe avatar Nov 21 '25 13:11 trothe

@trothe great! I submitted a PR already and I can add the other decoder fix. What exactly triggers it, so I can replicate? Sounds like it might be an M3 problem? (Similar to the op that should have worked with fallback).

jveitchmichaelis avatar Nov 21 '25 20:11 jveitchmichaelis

Thanks for cutting PR #173. I tested the branch on an M3 MBP (PyTorch Metal build) with PYTORCH_ENABLE_MPS_FALLBACK=1. Image predictor runs, but there’s still one MPS-only crash:

  • TransformerDecoderLayer.forward_ffn calls torch.amp.autocast(device_type='mps', enabled=False), which raises RuntimeError: unsupported scalarType. MPS doesn’t support that autocast context.

Minimal fix that works here:

# sam3/model/decoder.py
+from contextlib import nullcontext

     def forward_ffn(self, tgt):
-        with torch.amp.autocast(device_type=tgt.device.type, enabled=False):
+        device_type = tgt.device.type
+        if device_type in {"cuda", "cpu"}:
+            autocast_ctx = torch.amp.autocast(device_type=device_type, enabled=False)
+        else:
+            autocast_ctx = nullcontext()
+
+        with autocast_ctx:
             tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
         tgt = tgt + self.dropout4(tgt2)
         tgt = self.norm3(tgt)

With that guard in place, the image predictor completes on MPS. Feel free to cherry-pick or I can open a small follow-up PR if you prefer.

trothe avatar Nov 25 '25 02:11 trothe

@trothe I've rebased + added your fix, let me know if it works now (I don't have a machine to replicate it).

jveitchmichaelis avatar Dec 03 '25 20:12 jveitchmichaelis