MobileSAM
MobileSAM copied to clipboard
Issue with SamAutomaticMaskGenerator on Apple Silicon (MPS)
There seems to be a minor issue with the Automatic Mask Generator on MPS. Specifically, when running the model, the following traceback is generated:
TypeError Traceback (most recent call last)
Cell In[2], line 19
17 predictor.set_image(image)
18 mask_generator = SamAutomaticMaskGenerator(mobile_sam)
---> 19 masks = mask_generator.generate(image)
File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self, image)
138 """
139 Generates masks for the given image.
140
(...)
159 the mask, given in XYWH format.
160 """
162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
165 # Filter small disconnected regions and holes in masks
166 if self.min_mask_region_area > 0:
File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self, image)
204 data = MaskData()
205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206 crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 data.cat(crop_data)
209 # Remove duplicate masks between crops
File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:245, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
243 data = MaskData()
244 for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 245 batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246 data.cat(batch_data)
247 del batch_data
File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:277, in SamAutomaticMaskGenerator._process_batch(self, points, im_size, crop_box, orig_size)
275 # Run model on this batch
276 transformed_points = self.predictor.transform.apply_coords(points, im_size)
--> 277 in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279 masks, iou_preds, _ = self.predictor.predict_torch(
280 in_points[:, None, :],
281 in_labels[:, None],
282 multimask_output=True,
283 return_logits=True,
284 )
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.