sam3
sam3 copied to clipboard
Add Apple Silicon (MPS) support
Adds MPS device support for both image and video predictors on Apple Silicon.
Changes:
- Add get_default_device() utility that detects MPS availability
- Fix device mismatches (coords cache, freqs_cis cache)
- Add MPS workaround for complex tensor repeat() in RoPE
- Make torch._assert_async conditional on CUDA
- Fix MPS memory leak in video predictor via synchronization points
Performance of the Video predictor:
- ~3x faster than CPU
- Runs with ~38GB peak memory. This is due to the way that MPS caches graphs. Before adding the synchronization points, running the video predictor would consume all available memory.