MobileSAM icon indicating copy to clipboard operation
MobileSAM copied to clipboard

Issue with SamAutomaticMaskGenerator on Apple Silicon (MPS)

Open TanujW opened this issue 6 months ago • 0 comments

There seems to be a minor issue with the Automatic Mask Generator on MPS. Specifically, when running the model, the following traceback is generated:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 19
     17 predictor.set_image(image)
     18 mask_generator = SamAutomaticMaskGenerator(mobile_sam)
---> 19 masks = mask_generator.generate(image)

File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self, image)
    138 """
    139 Generates masks for the given image.
    140 
   (...)
    159          the mask, given in XYWH format.
    160 """
    162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
    165 # Filter small disconnected regions and holes in masks
    166 if self.min_mask_region_area > 0:

File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self, image)
    204 data = MaskData()
    205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206     crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
    207     data.cat(crop_data)
    209 # Remove duplicate masks between crops

File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:245, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
    243 data = MaskData()
    244 for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 245     batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
    246     data.cat(batch_data)
    247     del batch_data

File ~/miniforge3/envs/tracking/lib/python3.10/site-packages/mobile_sam/automatic_mask_generator.py:277, in SamAutomaticMaskGenerator._process_batch(self, points, im_size, crop_box, orig_size)
    275 # Run model on this batch
    276 transformed_points = self.predictor.transform.apply_coords(points, im_size)
--> 277 in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
    278 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
    279 masks, iou_preds, _ = self.predictor.predict_torch(
    280     in_points[:, None, :],
    281     in_labels[:, None],
    282     multimask_output=True,
    283     return_logits=True,
    284 )

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

TanujW avatar Dec 10 '23 09:12 TanujW