skorch
skorch copied to clipboard
FixMatch (semi-supervised learning) usecase
Hi
Thanks for the great work and package. I have recently started using skorch :) I would like to implement FixMatch in skorch: https://arxiv.org/abs/2001.07685
so the main idea is that in each epoch, I need to do inference and add pseudo-labels. Unfortunately, I am not sure how to do this with the .fit()
method. Of course, I am aware of how to write an "ugly" for
loop and call .fit
every time afterwards. But I am looking for a more elegant way.
Is there any notebook or example for semi-supervised learning in general? Can you please help? In case we can come up with a solution, I am willing to write a tutorial on the topic and send a pull request :)
Cheers Ali
I'm not familiar with this approach and don't know each step that would be involved. If you already have some code to share, that would really help.
From a glance, here is my intuition of how I would proceed:
I would not bother with overriding fit
or any of these methods. Instead I would write an "ugly" function (as you call it ;) that calls skorch's partial_fit
repeatedly and does the pseudo-labeling and augmentation afterwards. Below is a very simplified piece of code that probably misses a few important steps and that assumes that the same loss can be used for "normal" training and fixmatch training
# I assume fix match is applied after each epoch
net = NeuralNetClassifier(..., max_epochs=1) # use CE loss
def fix_match(net, X, y, epochs=10):
for epoch in epochs:
net.fit(X, y)
y_proba = net.predict_proba(X)
y_pseudo = generate_pseudo_labels(y, y_proba)
X_aug = augment(X)
net.fit(X_aug, y_pseudo)
It's not very beautiful (e.g. when it comes to the print log) but hopefully gets some quick results. And in case this works really well, we can still think about how to implement it more efficiently and more tightly within the net later.
Is there any notebook or example for semi-supervised learning in general?
I don't have anything handy, do you @ottonemo? IIRC we implemented stuff like mixup + semi supervised but I don't have the code anymore.
Thanks @BenjaminBossan for the response :)
Yeah, that makes sense. On the other hand, in the original implementation, the pseudo-labeling aspect should be done per batch. It means that each batch should contain labeled and unlabeled data. And pseudo-labeling should be done during the training.
I was wondering if I might be able to write a model, where it inherits properties from NeuralNetClassifier
. Is there any example for such models? Probably it makes sense to write a tutorial on that
in the original implementation, the pseudo-labeling aspect should be done per batch. It means that each batch should contain labeled and unlabeled data. And pseudo-labeling should be done during the training.
In that case, my code wouldn't quite achieve the same outcome (though maybe it's sufficient to do this per epoch instead of per batch?). Instead, I would probably try to achieve this by overriding get_loss
:
class MyNet(NeuralNetClassifier):
def get_loss(self, y_pred, y_true, X=None, training=False):
y_true = to_tensor(y_true, device=self.device)
loss_classification = self.criterion_(y_pred, y_true) # this is the normal loss
if not self.training: # for validation, skip the part with the pseudo-labels
return loss_classification
y_proba = self.infer(X)
y_pseudo = generate_pseudo_labels(y, y_proba)
X_aug = augment(X)
y_proba_aug = self.infer(X_aug)
loss_semi = ce_loss(y_proba, y_proba_aug)
return weight0 * loss_classification + weight1 * loss_semi
I was wondering if I might be able to write a model, where it inherits properties from
NeuralNetClassifier
. Is there any example for such models? Probably it makes sense to write a tutorial on that
We have a section on customizing skorch nets in the docs. Note that the subsection on "Initialization and custom modules" refers to changes that are only on master, i.e. will be released in skorch 0.11. But for your problem, that part should not be relevant.
Hey @BenjaminBossan! Thanks again for the feedback. So I could actually solve it by overwriting the loss part.
This snippet would be a "quick and dirty" example of it. I assume that we have a PyTorch dataset that outputs the same image once with weak and once with strong augmentations. For unlabeled data, it passes -1 as the label ( consistent with sklearn). the model needs to be modified
import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch.nn.functional as F
class ResNet18Modified(nn.Module):
def __init__(self, num_classes=3, pretrained=True, progress=True, **kwargs):
super().__init__()
model = resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
self.model = model
def forward(self, x):
x_weak = x[0]
x_strong = x[1]
y_weak = self.model(x[0])
y_strong = self.model(x[1])
return y_weak, y_strong
And that's how the FixMatch
is written
from skorch import NeuralNetClassifier
class FixMatch(NeuralNetClassifier):
def get_loss(self, y_pred, y_true, *args, **kwargs):
threshold = torch.tensor(0.9).cuda()
y_weak, y_strong = y_pred
y_weak = y_weak.cuda()
y_strong = y_strong.cuda()
indx = y_true == -1
probs, pseudo_label = torch.softmax(y_weak,dim=0).max(axis = 1)
indx = indx.cuda() & (probs > threshold).cuda()
if indx.sum() > 0:
pseudo_label = pseudo_label.long()
consistency_loss = F.cross_entropy(y_strong[indx,:].cuda(), pseudo_label[indx].cuda())
else:
consistency_loss = torch.tensor(0., requires_grad = True)
indx = y_true != -1
if indx.sum() > 0:
classification_loss = F.cross_entropy(y_weak[indx,:].cuda(), y_true[indx].cuda())
else:
classification_loss = torch.tensor(0., requires_grad = True)
return classification_loss + 0.5*consistency_loss
It worked quite nicely with torch.nn.DataParallel
:)
@aliechoes I'm very happy to hear that you could make it work and it doesn't look too complicated. If you happen to create a self-contained notebook using this, I can totally see adding it to skorch.
It worked quite nicely with
torch.nn.DataParallel
:)
Nice. I haven't used it myself, so it's good to know that it works.
I would be more than happy to write a tutorial on it. There is a famous usecase of training ImageNet or CIFAR-100 with only 10 samples per class. I think it would be a nice example. How should I do it? sending a pull request?
Yes, just create a PR with your tutorial in it. If it's a standalone notebook, you can add it to the notebooks
folder of skorch. for more elaborate work (e.g. including scripts and modules), use the examples
folder. There you can find some inspiration when it comes to how to write your example. And, at any time, feel free to ask questions.
Hey @aliechoes, any updates on this? :)
Since there hasn't been a reply for quite some time and the original issue was solved, I'll close this.