kornia icon indicating copy to clipboard operation
kornia copied to clipboard

Keypoints not properly transformed through Augmentations

Open alexanderswerdlow opened this issue 6 months ago • 7 comments

Describe the bug

Some combinations of augmentation operations cause keypoints to be incorrectly transformed. A demo script (below) creates a white image and a grid of keypoints. In the output image, all keypoints should be on the white input, but often they are not.

It seems to happen when a translate augmentation is applied along with other augmentations [cropping, rotation, etc.] but also happens with, for example, RandomResizedCrop + RandomRotation.

Reproduction steps

Reproduction Script [Additionally requires Matplotlib + Pillow]:


import kornia.augmentation as K
import matplotlib.pyplot as plt
import torch
from kornia.augmentation.container import AugmentationSequential
from kornia.geometry.keypoints import Keypoints
from PIL import ImageDraw, Image, ImageOps
from torchvision.transforms.functional import to_pil_image

# Setup Vis Utils
def draw_keypoints_with_circles(image, keypoints, colors, radius=2):
    pil_img = to_pil_image(image)
    draw = ImageDraw.Draw(pil_img)
    for i, (x, y) in enumerate(keypoints):
        if 0 <= y <= image.shape[1] and 0 <= x <= image.shape[2]:
            color = tuple(int(c*255) for c in colors(i)[:3])
            draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline=color, width=2)

    return pil_img

def add_border(im):
    return ImageOps.expand(im, border=5, fill=(0, 255, 0))

# From StackOverflow
def concat_images_vertically(*images):
    width = max(image.width for image in images)
    height = sum(image.height for image in images)
    composite = Image.new('RGB', (width, height))
    y = 0
    for image in images:
        composite.paste(image, (0, y))
        y += image.height
    return composite

B, C, H, W = (10, 3, 224, 224)
in_tensor = torch.ones((B, C, H, W))

# Setup grid of spaced out keypoints for visualization
step = 16
y_coords_viz = torch.linspace(0, H - 1, (H - 1) // step + 1)
x_coords_viz = torch.linspace(0, W - 1, (W - 1) // step + 1)
y_coords_viz, x_coords_viz = torch.meshgrid(y_coords_viz, x_coords_viz, indexing='ij')
viz_coords = torch.stack((x_coords_viz, y_coords_viz), dim=-1).unsqueeze(0).repeat(B, 1, 1, 1).reshape(B, -1, 2)
input_keypoints = Keypoints(viz_coords.float())

aug = AugmentationSequential(
    K.RandomResizedCrop(size=(128, 128)),
    K.RandomTranslate(0.5, 0.5),
    data_keys=["input", "keypoints"],
    random_apply=False
)
output_tensor, output_keypoints = aug(in_tensor, input_keypoints)

num_keypoints = output_keypoints.to_tensor().shape[1]
colors = plt.cm.get_cmap('flag', num_keypoints)

imgs = []
for j in range(B):
    input_with_keypoints = draw_keypoints_with_circles(in_tensor[j], input_keypoints.to_tensor()[j], colors)
    output_with_keypoints = draw_keypoints_with_circles(output_tensor[j], output_keypoints.to_tensor()[j], colors)
    
    imgs.append(add_border(input_with_keypoints))
    imgs.append(add_border(output_with_keypoints))
    
concat_images_vertically(*imgs).save('transform.png')


### Expected behavior

All grid points should be in the white output image.

### Environment

```shell
Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.1.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.28.0
Libc version: N/A

Python version: 3.10.2 (main, Jan 29 2022, 22:27:44) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-14.1.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchvision==0.16.1
[conda] Could not collect


### Additional context

_No response_

alexanderswerdlow avatar Dec 13 '23 17:12 alexanderswerdlow

Not all augmentations have support for keypoints yet (see #941), contributions are welcomed. @shijianjian do we have somewhere a list with what augmentations support each case?

johnnv1 avatar Dec 13 '23 22:12 johnnv1

In theory, all the augmentations inherits from the GeometricAugmentation class should be properly supported. It should work for the demo code you provided. It worth checking where went wrong if you got some time. @alexanderswerdlow

shijianjian avatar Dec 15 '23 23:12 shijianjian

Upon digging a little deeper, I think I figured out the issue which is pretty simple albeit unintuitive for an end-user. Keypoints can be transformed to be moved out of the image but then moved back into the image by subsequent transformations [e.g., Translate] with zeros for padding. The end result of this is that keypoints become detached from their original content as seen below.

If you have a series of augmentations as is common, it doesn't seem like there's a simple way to determine if a keypoint is no longer a valid one; perhaps returning a mask to denote which ones are valid or setting invalid keypoints to a special value [-1, NaN] would be a solution.

The example below is with RandomResizedCrop -> RandomTranslate.

transform

alexanderswerdlow avatar Dec 16 '23 19:12 alexanderswerdlow

Ah. Yes, I remember the design now. Since we use the same point transformation for bounding boxes, we do not remove those invalid keypoints. Those need to be kept otherwise the boxes will not have four corners.

I would vote for having a visibility field in the keypoint data structure. probably 0 for invisible and 1 for visible. Do you think it is easy to add @alexanderswerdlow ?

shijianjian avatar Dec 17 '23 08:12 shijianjian

geometrically speaking the keyword visibility can be a bit too image specific. To follow the recent generic geometry data structures I would considervalid/is_valid that will become also more idiomatic so that we can have something like.

pts = [....]
pts_filtered = [p for pts if p.valid]

besidesi, i'd also consider expanding the Keypoint data structure possibly based or inspired by Vector2 https://github.com/kornia/kornia/blob/40d07ec3aedbbdc83d7f2e613966980ae69f7bbe/kornia/geometry/vector.py#L95

edgarriba avatar Dec 17 '23 10:12 edgarriba

I don't think it should be very difficult but I'm not very familiar with Kornia (this is my first time using it actually).

I added a basic implementation but I'm not sure how robust it is and whether it works for all geometric augmentations. The few I tested (Shear, Translate, Crop, Rotate, Flip) seem to work though. It's a little hacky as it checks the output_size field, specifically for cropping, so a more general implementation (maybe that can apply to 3D) would be better, but unfortunately I don't have the capacity for that atm.

I think a mask is preferable to making each point into an object with a field, at least for my use-case with a very dense grid of key points.

It also might make sense to set the masked out key points to some special value so it's clear to the user that they need to ignore these values, unless there's a use case for them. On the surface, key points are ostensibly used for correspondence so it's hard to see why someone would want this broken.

alexanderswerdlow avatar Dec 17 '23 20:12 alexanderswerdlow

link to #2689

shijianjian avatar Dec 18 '23 13:12 shijianjian