skorch icon indicating copy to clipboard operation
skorch copied to clipboard

FixMatch (semi-supervised learning) usecase

Open aliechoes opened this issue 3 years ago • 8 comments

Hi

Thanks for the great work and package. I have recently started using skorch :) I would like to implement FixMatch in skorch: https://arxiv.org/abs/2001.07685

so the main idea is that in each epoch, I need to do inference and add pseudo-labels. Unfortunately, I am not sure how to do this with the .fit() method. Of course, I am aware of how to write an "ugly" for loop and call .fit every time afterwards. But I am looking for a more elegant way.

Is there any notebook or example for semi-supervised learning in general? Can you please help? In case we can come up with a solution, I am willing to write a tutorial on the topic and send a pull request :)

Cheers Ali

aliechoes avatar Sep 28 '21 07:09 aliechoes

I'm not familiar with this approach and don't know each step that would be involved. If you already have some code to share, that would really help.

From a glance, here is my intuition of how I would proceed:

I would not bother with overriding fit or any of these methods. Instead I would write an "ugly" function (as you call it ;) that calls skorch's partial_fit repeatedly and does the pseudo-labeling and augmentation afterwards. Below is a very simplified piece of code that probably misses a few important steps and that assumes that the same loss can be used for "normal" training and fixmatch training

# I assume fix match is applied after each epoch
net = NeuralNetClassifier(..., max_epochs=1)  # use CE loss
def fix_match(net, X, y, epochs=10):
    for epoch in epochs:
        net.fit(X, y)
        y_proba = net.predict_proba(X)
        y_pseudo = generate_pseudo_labels(y, y_proba)

        X_aug = augment(X)
        net.fit(X_aug, y_pseudo)

It's not very beautiful (e.g. when it comes to the print log) but hopefully gets some quick results. And in case this works really well, we can still think about how to implement it more efficiently and more tightly within the net later.

Is there any notebook or example for semi-supervised learning in general?

I don't have anything handy, do you @ottonemo? IIRC we implemented stuff like mixup + semi supervised but I don't have the code anymore.

BenjaminBossan avatar Sep 28 '21 19:09 BenjaminBossan

Thanks @BenjaminBossan for the response :)
Yeah, that makes sense. On the other hand, in the original implementation, the pseudo-labeling aspect should be done per batch. It means that each batch should contain labeled and unlabeled data. And pseudo-labeling should be done during the training.

I was wondering if I might be able to write a model, where it inherits properties from NeuralNetClassifier. Is there any example for such models? Probably it makes sense to write a tutorial on that

aliechoes avatar Sep 29 '21 08:09 aliechoes

in the original implementation, the pseudo-labeling aspect should be done per batch. It means that each batch should contain labeled and unlabeled data. And pseudo-labeling should be done during the training.

In that case, my code wouldn't quite achieve the same outcome (though maybe it's sufficient to do this per epoch instead of per batch?). Instead, I would probably try to achieve this by overriding get_loss:

class MyNet(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, X=None, training=False):
        y_true = to_tensor(y_true, device=self.device)
        loss_classification = self.criterion_(y_pred, y_true)  # this is the normal loss
        if not self.training:  # for validation, skip the part with the pseudo-labels
            return loss_classification

        y_proba = self.infer(X)
        y_pseudo = generate_pseudo_labels(y, y_proba)
        X_aug = augment(X)
        y_proba_aug = self.infer(X_aug)
        loss_semi = ce_loss(y_proba, y_proba_aug)

        return weight0 * loss_classification + weight1 * loss_semi

I was wondering if I might be able to write a model, where it inherits properties from NeuralNetClassifier. Is there any example for such models? Probably it makes sense to write a tutorial on that

We have a section on customizing skorch nets in the docs. Note that the subsection on "Initialization and custom modules" refers to changes that are only on master, i.e. will be released in skorch 0.11. But for your problem, that part should not be relevant.

BenjaminBossan avatar Sep 29 '21 21:09 BenjaminBossan

Hey @BenjaminBossan! Thanks again for the feedback. So I could actually solve it by overwriting the loss part.

This snippet would be a "quick and dirty" example of it. I assume that we have a PyTorch dataset that outputs the same image once with weak and once with strong augmentations. For unlabeled data, it passes -1 as the label ( consistent with sklearn). the model needs to be modified

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch.nn.functional as F

class ResNet18Modified(nn.Module):
    def __init__(self,  num_classes=3, pretrained=True, progress=True, **kwargs):
        super().__init__()
        model = resnet18(pretrained=True) 
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        self.model = model

    def forward(self, x):
        x_weak = x[0]
        x_strong = x[1]
        
        y_weak =  self.model(x[0])
        y_strong = self.model(x[1])
                
        return y_weak, y_strong

And that's how the FixMatch is written

from skorch import NeuralNetClassifier

class FixMatch(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        threshold = torch.tensor(0.9).cuda()        
        y_weak, y_strong = y_pred
        y_weak = y_weak.cuda()
        y_strong = y_strong.cuda()
        
        indx = y_true == -1
        probs, pseudo_label = torch.softmax(y_weak,dim=0).max(axis = 1)
        indx = indx.cuda() & (probs > threshold).cuda()
        
        if indx.sum() > 0: 
            pseudo_label = pseudo_label.long() 
            consistency_loss = F.cross_entropy(y_strong[indx,:].cuda(), pseudo_label[indx].cuda())
            
        else: 
            consistency_loss = torch.tensor(0., requires_grad = True)
            
        indx = y_true != -1
        if indx.sum() > 0:         
            classification_loss = F.cross_entropy(y_weak[indx,:].cuda(), y_true[indx].cuda())        
        else: 
            classification_loss = torch.tensor(0., requires_grad = True)
        
        return classification_loss + 0.5*consistency_loss

It worked quite nicely with torch.nn.DataParallel :)

aliechoes avatar Oct 01 '21 14:10 aliechoes

@aliechoes I'm very happy to hear that you could make it work and it doesn't look too complicated. If you happen to create a self-contained notebook using this, I can totally see adding it to skorch.

It worked quite nicely with torch.nn.DataParallel :)

Nice. I haven't used it myself, so it's good to know that it works.

BenjaminBossan avatar Oct 03 '21 11:10 BenjaminBossan

I would be more than happy to write a tutorial on it. There is a famous usecase of training ImageNet or CIFAR-100 with only 10 samples per class. I think it would be a nice example. How should I do it? sending a pull request?

aliechoes avatar Oct 07 '21 07:10 aliechoes

Yes, just create a PR with your tutorial in it. If it's a standalone notebook, you can add it to the notebooks folder of skorch. for more elaborate work (e.g. including scripts and modules), use the examples folder. There you can find some inspiration when it comes to how to write your example. And, at any time, feel free to ask questions.

BenjaminBossan avatar Oct 09 '21 17:10 BenjaminBossan

Hey @aliechoes, any updates on this? :)

ottonemo avatar Sep 16 '22 11:09 ottonemo

Since there hasn't been a reply for quite some time and the original issue was solved, I'll close this.

BenjaminBossan avatar Mar 09 '23 11:03 BenjaminBossan