segment-anything
segment-anything copied to clipboard
Wrong image shape in apply_image_torch method
Hi,
Thank you for the awesome work!
I found this bug in the apply_image_torch function in the utils/transforms.py file.
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
>>> target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
According to the comments and the get_preprocess_shape function, the first and second arguments have to be the height and width of a given batch of images.
So, I presume those arguments to be image.shape[2] and image.shape[3].
Thanks!