pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[FEATURE] Support Tied-Augment

Open ekurtulus opened this issue 2 years ago • 3 comments

Recently, we introduced Tied-Augment, a simple framework that combines self-supervised learning learning and supervised learning by making forward passes on two augmented views of the data with tied (shared) weights. In addition to the classification loss, it adds a similarity term to enforce invariance between the features of the augmented views. We found that our framework can be used to improve the effectiveness of both simple flips-and-crops (Crop-Flip) and aggressive augmentations (RandAugment) even for few-epoch training. As the effect of data augmentation is amplified, the sample efficiency of the data increases.

I believe Tied-Augment would be a nice addition to Timm training script. It can significantly improve mixup/RandAugment (77.6% → 79.6%) with marginal extra cost. Here is my reference implementation.

ekurtulus avatar May 25 '23 23:05 ekurtulus

👍🏻 It would be great if you can implement Tied-Augment

pdedeler avatar May 27 '23 21:05 pdedeler

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

rwightman avatar May 29 '23 18:05 rwightman

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

@rwightman Yes, however, the only difference is that Tied-Augment requires the features of the augmented views. Therefore, an additional wrapper has to be put for the model as well.

Example (for a Timm model with num_classes=0)

class TimmWrapper(nn.Module):
    def __init__(self, model, num_classes):
        super(TimmWrapper, self).__init__()
        self.model = model
        self.fc = nn.Linear(model.num_features, num_classes)
    
    def forward(self, x, return_features=False):
        if self.training or return_features:
            features = self.model(x)
            logits = self.fc(features)
            return features, logits
        else:
            return self.fc(self.model(x))

ekurtulus avatar May 29 '23 18:05 ekurtulus