vision icon indicating copy to clipboard operation
vision copied to clipboard

Implement AutoAugment for Detection

Open lezwon opened this issue 3 years ago • 19 comments

🚀 The feature

Implement Learning Data Augmentation Strategies for Object Detection Refers to: #3817

Motivation, pitch

Good to have augmentation in Torchvision

Alternatives

No response

Additional context

No response

cc @vfdev-5 @datumbox

lezwon avatar Jun 30 '22 13:06 lezwon

@datumbox @vfdev-5 Based on our last conversation I suppose it would be best to start with the bbox_only_* augmentations for AA right? I could get started with implementing these in torchvision.prototype. Do let me know if that's alright.

  1. Rotate_Only_BBoxes
  2. ShearX_Only_BBoxes
  3. ShearY_Only_BBoxes
  4. TranslateX_Only_BBoxes
  5. TranslateY_Only_BBoxes
  6. Flip_Only_BBoxes
  7. Solarize_Only_BBoxes
  8. Equalize_Only_BBoxes
  9. Cutout_Only_BBoxes

lezwon avatar Jun 30 '22 13:06 lezwon

@lezwon thanks a lot! It's worth coordinating with @vfdev-5 because he is literally landing changes at the same API.

Victor, how do you propose to do this? Would you prefer to do an implementation on References (similar to what we did with CopyPaste) and move it later on the new API?

datumbox avatar Jun 30 '22 15:06 datumbox

@datumbox @lezwon if we see AA detection code similar to classification one <=> single transformation class calling functional ops inside, then <op>_Only_BBoxes could be implemented as a special function that produces a list of image crops and we can apply <op>_image_tensor to them. I think we can go directly with new API.

By the way, what does Cutout_Only_BBoxes ? Erase data from target bbox ?

vfdev-5 avatar Jun 30 '22 15:06 vfdev-5

@vfdev-5 Yep. Cutout does refer to erasing patches of data from the target box.

lezwon avatar Jun 30 '22 15:06 lezwon

@vfdev-5 Here is a sample Cutout Implementation for image classification. https://github.com/ain-soph/trojanzoo/blob/85730f629eddc17dcbc48218274b74c31daf5f99/trojanvision/utils/transform.py#L198-L226

torchvision doesn’t implement this because there is already a nearly equivalent RandomErasing. The only difference is for the boundary cases. And there is randomness of cutout area in RandomErasing. The class arguments are quite different as well.

ain-soph avatar Jun 30 '22 23:06 ain-soph

@vfdev-5 I'm sorry for the delay. I have been a bit occupied with other tasks and haven't been able to give any time to the implementation. I don't think I can pick this issue up anytime soon either. It would be best to assign someone else to it.

lezwon avatar Aug 03 '22 08:08 lezwon

Any help needed? I’m interested to implement this.

ain-soph avatar Aug 11 '22 18:08 ain-soph

@ain-soph I think it would be awesome if you implement it.

Just checking again with @vfdev-5 and @pmeier that they don't have any concern.

datumbox avatar Aug 11 '22 19:08 datumbox

No problems from my side

vfdev-5 avatar Aug 11 '22 20:08 vfdev-5

@vfdev-5 @datumbox Anyone gives me some guidance to start?

Should I refer to tensorflow implementation?

I see previously @lezwon claimed that the plan is to implement them in torchvision.prototype with:

  • [ ] Rotate_Only_BBoxes
  • [ ] ShearX_Only_BBoxes
  • [ ] ShearY_Only_BBoxes
  • [ ] TranslateX_Only_BBoxes
  • [ ] TranslateY_Only_BBoxes
  • [ ] Flip_Only_BBoxes
  • [ ] Solarize_Only_BBoxes
  • [ ] Equalize_Only_BBoxes
  • [ ] Cutout_Only_BBoxes

Shall I follow this idea?

ain-soph avatar Aug 15 '22 18:08 ain-soph

@ain-soph please check AA classes for classification task implemented with prototype API: https://github.com/pytorch/vision/blob/main/torchvision/prototype/transforms/_auto_augment.py The idea is to create something similar for detection. As most of prototype functional ops already support bboxes, you can easily write basic ops. As Only_BBoxes ops please check my previous comment: https://github.com/pytorch/vision/issues/6224#issuecomment-1171347121

vfdev-5 avatar Aug 15 '22 19:08 vfdev-5

@vfdev-5 Sorry for asking some beginner's questions about object detection... I wonder whether we shall make the ops processing both image and bbox coordinates like what we implement in reference training script https://github.com/pytorch/vision/blob/becaba0e0546f434be64de4e9603ab66e4b160b2/references/detection/transforms.py#L30-L45

I see tensorflow is doing so in https://github.com/tensorflow/tpu/blob/c75705856290a4119d609110956442449d73e0a5/models/official/detection/utils/autoaugment_utils.py#L1030-L1062

ain-soph avatar Aug 25 '22 18:08 ain-soph

@ain-soph that's exactly what we need to do. We should have already all the ops excluding perhaps the *_Only_BBoxes. Here are examples of the codebase for these implementations: https://github.com/pytorch/vision/blob/becaba0e0546f434be64de4e9603ab66e4b160b2/torchvision/prototype/transforms/functional/_geometry.py#L259-L268

datumbox avatar Aug 26 '22 08:08 datumbox

I've got a helper function to deal with this, so that all *_Only_Bboxes transforms could just call this helper function and pass the corresponding transform function. Any comment?

Or is there any way to avoid the for loop? Will vmap help?

def _transform_only_bboxes(
    img: torch.Tensor,
    bounding_box: torch.Tensor,
    format: features.BoundingBoxFormat,
    transform: Callable[..., torch.Tensor],
    **kwargs,
) -> torch.Tensor:
    bounding_box = convert_bounding_box_format(
        bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
    ).view(-1, 4)

    new_img = img.clone()
    for bbox in bounding_box:
        bbox_crop_img = new_img[..., bbox[0]:bbox[2], bbox[1]:bbox[3]]
        bbox_crop_img.fill_(transform(bbox_crop_img, **kwargs))
    return new_img


def horizontal_flip_only_bboxes(
    img: torch.Tensor,
    bounding_box: torch.Tensor,
    format: features.BoundingBoxFormat,
) -> torch.Tensor:
    return _transform_only_bboxes(img, bounding_box, format, transform=horizontal_flip_image_tensor)

ain-soph avatar Sep 02 '22 06:09 ain-soph

@ain-soph So far the kernels process each input independently. Aka they don't receive together img and bounding_box as inputs, but instead usually contain one of them along with configuration that helps the operation. In this specific example you provide, the bbox seems to operate as configuration rather than input. This is an interesting approach. A few other things to keep in mind about the low-level kernels is that they must be JIT-scriptable. I don't think the idiom you propose here is.

@vfdev-5 What are your thoughts on the above? Any alternative ideas on how these kernels should be structured?

datumbox avatar Sep 02 '22 16:09 datumbox

As it is about auto augment, we may not need to put such op into low-level ops and just code a transform:

import torch
from torchvision.prototype.transforms import Transform
from torchvision.prototype.transforms.functional import horizontal_flip
from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.prototype.features import Image, BoundingBox


class AADet(Transform):
    
    def _get_params(self, sample):
        
        bbox = None
        if torch.rand(()) > 0.2:        
            bbox = query_bounding_box(sample)
            bbox = bbox.to_format(format="XYXY")

        return dict(bbox=bbox, op="hflip")
    
    def _transform_image_in_bboxes(self, fn, fn_kwrgs, image, bboxes):
        new_img = img.clone()
        for bbox in bboxes:
            bbox_img = new_img[..., bbox[1]:bbox[3], bbox[0]:bbox[2]]
            out_bbox_img = fn(bbox_img, **fn_kwrgs)
            new_img[..., bbox[1]:bbox[3], bbox[0]:bbox[2]] = out_bbox_img
        return new_img    

    def _transform(self, inpt, params):                
        if isinstance(inpt, Image):
            if params["op"] == "hflip" and params["bbox"] is not None:
                return self._transform_image_in_bboxes(horizontal_flip, {}, inpt, params["bbox"])
        return inpt

Usage:

image_size = (64, 76)

bboxes = [
    [10, 15, 25, 35],
    [50, 5, 70, 22],
    [45, 46, 56, 62],
    [4, 50, 10, 60],    
]
labels = [1, 2, 3, 4]

img = torch.zeros(1, 3, *image_size)
for in_box, label in zip(in_boxes, labels):
    img[..., in_box[1]:in_box[3], in_box[0]:in_box[2]] = \
        (torch.arange(23, 23 + 3 * (in_box[3]-in_box[1]) * (in_box[2]-in_box[0])) % 200).reshape(1, 3, in_box[3]-in_box[1], in_box[2]-in_box[0])

img = Image(img)
bboxes = BoundingBox(bboxes, format="XYXY", image_size=image_size)

sample = [img, bboxes]

t = AADet()
out = t(sample)

import matplotlib.pyplot as plt

plt.figure()
plt.subplot(121)
plt.imshow(img[0, ...].permute(1, 2, 0) / 255.0)
plt.subplot(122)
plt.imshow(out[0][0, ...].permute(1, 2, 0) / 255.0)

image

vfdev-5 avatar Sep 09 '22 15:09 vfdev-5

@ain-soph I was hoping you could confirm that Victor's reply unblocked you and you are able to continue the feature. Please let me know if there are more outstanding questions. Thank you!

datumbox avatar Sep 14 '22 14:09 datumbox

Yes, I'll open a draft PR this weekend using Victor's format.

Just one question about the hyper-parameters. For example, for the rotate, where shall we set the degree parameter? argument in class __init__ or as forward argument?

ain-soph avatar Sep 15 '22 06:09 ain-soph

For example, for the rotate, where shall we set the degree parameter?

I think it can same as AA for classification: https://github.com/pytorch/vision/blob/main/torchvision/prototype/transforms/_auto_augment.py

vfdev-5 avatar Sep 15 '22 07:09 vfdev-5