Personalize-SAM
Personalize-SAM copied to clipboard
Added support for MPS on apple silicon devices for faster inference.
Hello! Thanks for your work, I have modified your code to support device switching with --device
flag.
I changed the code to move tensors to correct device accordingly.
These changes also fix a bug in the current version: When using vit_t
cpu can be used for model inference but other tensors are still loaded on gpu causing an error if torch is not compiled with cuda making cpu inference not possible.
$ python persam_f.py --outdir ./outputs/ --sam_type vit_t
Traceback (most recent call last):
File "/Users/junkybyte/Desktop/Personalize-SAM/persam_f.py", line 74, in persam_f
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
AssertionError: Torch not compiled with CUDA enabled
As SAM works out of the box with MPS on apple silicon devices I choose the default device to be cuda or mps when available fallbacking to cpu otherwise.
I tested MPS on M2 Macbook air with torch==2.0.1
installed.
I changed the README to mirror these changes, let me know if you are interested in a merge. Thanks!
Other changes you see are just autopep8 style corrections that were applied by default by my editor.