skorch
skorch copied to clipboard
Stochastic Weight Averaging
PyTorch recently added methods to implement Stochastic Weight Averaging (SWA): [(https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/)]
This method can improve many models' performance by creating a new model with weights that are averaged over the last few training epochs. Paper here: [(https://arxiv.org/abs/1803.05407)]
The PyTorch implementation requires calling methods within a training loop but I wanted to use SWA with a Skorch network so I wrote a callback to do it. I wondered if this would be of some use to others.
` train_loader, skorch_model = ...
class StochasticWeightAveraging(Callback):
def on_train_begin(self, skorch_model, **kwargs):
skorch_model.swa_model = torch.optim.swa_utils.AveragedModel(skorch_model.module_)
def on_epoch_end(self, skorch_model, **kwargs):
if skorch_model.history[-1, 'epoch'] >= skorch_model.module__swa_start * skorch_model.max_epochs:
skorch_model.swa_model.update_parameters(skorch_model.module_)
def on_train_end(self, skorch_model, **kwargs):
torch.optim.swa_utils.update_bn(train_loader, skorch_model.swa_model, device = skorch_model.device)
`
I didn't know about stochastic weight averaging, thanks a lot. I looked at your code and the PyTorch example and came up with a slightly different implementation based on yours:
from torch.optim import swa_utils
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_utils,
swa_start=10,
verbose=0,
sink=print,
**kwargs # additional arguments to swa_utils.SWALR
):
self.swa_utils = swa_utils
self.swa_start = swa_start
self.verbose = verbose
self.sink = sink
vars(self).update(kwargs)
@property
def kwargs(self):
# These are the parameters that are passed to SWALR.
# Parameters that don't belong there must be excluded.
excluded = {'swa_utils', 'swa_start', 'verbose', 'sink'}
kwargs = {key: val for key, val in vars(self).items()
if not (key in excluded or key.endswith('_'))}
return kwargs
def on_train_begin(self, net, **kwargs):
self.optimizer_swa_ = self.swa_utils.SWALR(net.optimizer_, **self.kwargs)
if not hasattr(net, 'module_swa_'):
net.module_swa_ = self.swa_utils.AveragedModel(net.module_)
def on_epoch_begin(self, net, **kwargs):
if self.verbose and len(net.history) == self.swa_start + 1:
self.sink("Using SWA to update parameters")
def on_epoch_end(self, net, **kwargs):
if len(net.history) >= self.swa_start + 1:
net.module_swa_.update_parameters(net.module_)
self.optimizer_swa_.step()
def on_train_end(self, net, X, y=None, **kwargs):
if self.verbose:
self.sink("Using training data to update batch norm statistics of the SWA model")
loader = net.get_iterator(net.get_dataset(X, y))
self.swa_utils.update_bn(loader, net.module_swa_, device = net.device)
Let me explain some of the changes:
-
I want to pass
swa_utilsas a parameter, because this way, the skorch code still works with PyTorch versions < 1.6 ("works" in the sense that there won't be an import error, but this callback still won't be useable); also, a user could in theory provide their own implementations ofupdate_bnetc. -
I made
swa_starta parameter of the callback instead of the module -
Added verbosity to get a better feel of what's happening
-
Your code has this line:
if skorch_model.history[-1, 'epoch'] >= skorch_model.module__swa_start * skorch_model.max_epochs-- I think the logic is wrong there? Why multiply the epochs withswa_start? (Is it meant as a fraction?)
What your example is missing compared to the PyTorch example is the use of swa_utils.SWALR. In my code, I tried to work it in. However, my code differs from the PyTorch example because there, they use SWALR instead of the normal lr scheduler, whereas in my code, it's used in addition.
Below a working example using the callback as implemented above:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import swa_utils
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, EpochScoring
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
SWA_START = 5
MAX_EPOCHS = 100
LR = 0.01
LR_SWA = 0.05
# skorch implementation
class StochasticWeightAveraging(Callback):
...
torch.manual_seed(0)
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=50,
lr=LR,
callbacks=[
LRScheduler(CosineAnnealingLR, T_max=MAX_EPOCHS),
StochasticWeightAveraging(swa_utils, swa_start=SWA_START, verbose=1, swa_lr=LR_SWA),
EpochScoring('accuracy', lower_is_better=False, on_train=True, name='train_acc'),
],
train_split=False,
)
net.fit(X_train, y_train)
test_accuracy = (net.predict(X_test) == y_test).mean()
# PyTorch implementation inspired by linked example
torch.manual_seed(0)
loader = net.get_iterator(net.get_dataset(X, y))
model = ClassifierModule()
optimizer = torch.optim.SGD(model.parameters(), LR)
loss_fn = torch.nn.NLLLoss()
swa_model = swa_utils.AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
swa_scheduler = swa_utils.SWALR(optimizer, swa_lr=LR_SWA)
for epoch in range(MAX_EPOCHS):
losses = []
for input, target in loader:
optimizer.zero_grad()
loss = loss_fn(torch.log(model(input)), target)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch == 1 + SWA_START:
print("starting SWA")
if epoch > SWA_START:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
preds = swa_model(torch.as_tensor(X))
print("epoch: {:>2} | train loss: {:.4f} | train acc: {:.2f} %".format(
epoch, np.mean(losses), 100 * (preds.detach().numpy().argmax(-1) == y).mean()))
swa_utils.update_bn(loader, swa_model)
test_accuracy = (swa_model(torch.as_tensor(X_test)).detach().numpy().argmax(-1) == y_test).mean()
The skorch version gets a train loss of 0.548, train accuracy of 0.733, and test accuracy of 0.752.
The PyTorch version gets a train loss of 0.460, train accuracy of 0.727, and test accuracy of 0.736.
So there seems to be a significant difference in train loss, but I'm not sure where it's coming from. It's not due to the described difference, as introducing the same deviation in the PyTorch code doesn't make a difference. Do you have any idea?
@WillCastle Did you have opportunity to test this out yet?
I believe it might not even be necessary store swa_model (module_swa_ in my code) on the net, it can be stored on the callback instead.
@BenjaminBossan Hi, sorry I have been a little caught up in some job applications. I will have a look at this next week. I am not sure about the discrepancy in training loss, I'll run some test cases and try to work out where it's coming from. The changes you proposed look good, as to multiplying by swa_start, I did intend it as a fraction. I only did it this way as most of the examples I have read suggest beginning the averaging at 75% of the way through training, also this might be helpful when running several models for different numbers of epochs (say when tuning hyperparameters). It could probably do with a clearer name if used in this way though.
@BenjaminBossan Just had another look and noticed a couple of things. In your Skorch example, you create and fit the object net but then you check the accuracy of this same network:
test_accuracy = (net.predict(X_test) == y_test).mean()
I believe that the net object remains unchanged by SWA and the one we want to evaluate is net.module_swa_ which is actually a new model that is a sort of ensemble of some of the training iterations of net.
The SWA model is a PyTorch module so I follow it's creation with a conversion to a Skorch model with something like NeuralNetBinaryClassifier(module = net.module_swa_).
Also it looks like your skorch network is training for 50 epochs as you initialize it with max_epochs=50,, whereas the torch network trains using your MAX_EPOCHS = 100 variable. That might be the reason for the difference in training loss.
Thanks for taking another look @WillCastle
as to multiplying by
swa_start, I did intend it as a fraction
Okay, this makes sense. I would probably allow both possibilities: if int, take it as absolute value, if float, as relative value. This is consistent with how sklearn works in some places, e.g. the train_size argument in train_test_split.
I believe that the
netobject remains unchanged by SWA and the one we want to evaluate isnet.module_swa_which is actually a new model that is a sort of ensemble of some of the training iterations ofnet.
Yes, you're right; to be more precise, it's not the net object, but the net.module_ (which is the PyTorch module itself).
Unfortunately, that's not the reason for the discrepancy. I tested both the original module_ and the new module_swa_ and they still give different results for the skorch and the pure PyTorch implementation (I made sure to fix the seeds and use the exact same data loader):
| PyTorch | skorch | |
|---|---|---|
| test accuracy module | 0.752 | 0.772 |
| test accuracy swa | 0.704 | 0.760 |
``
Thanks for taking another look @WillCastle
as to multiplying by
swa_start, I did intend it as a fractionOkay, this makes sense. I would probably allow both possibilities: if int, take it as absolute value, if float, as relative value. This is consistent with how sklearn works in some places, e.g. the
train_sizeargument intrain_test_split.I believe that the
netobject remains unchanged by SWA and the one we want to evaluate isnet.module_swa_which is actually a new model that is a sort of ensemble of some of the training iterations ofnet.Yes, you're right; to be more precise, it's not the net object, but the
net.module_(which is the PyTorch module itself).Unfortunately, that's not the reason for the discrepancy. I tested both the original
module_and the newmodule_swa_and they still give different results for the skorch and the pure PyTorch implementation (I made sure to fix the seeds and use the exact same data loader): PyTorch skorch test accuracy module 0.752 0.772 test accuracy swa 0.704 0.760
quite different validation results between pytorch and skorch. Did skorch some weight initialization automatically by default?
quite different validation results between pytorch and skorch.
Yes, I need to investigate further, or perhaps someone else can spot a mistake.
Did skorch some weight initialization automatically by default?
No, this is left completely to the user. The module is initialized exactly the same, as well as the data loader.
SkorchAttributeError: Trying to set torch compoment 'module_swa_' outside of an initialize method. Consider defining it inside 'initialize_module' I'm getting this error! any update on this issue? @BenjaminBossan
I assume you have used the code I posted above and now encountered this error. In that case, could you please replace the line:
net.module_swa_ = self.swa_utils.AveragedModel(net.module_)
by
with net._current_init_context('module'):
net.module_swa_ = self.swa_utils.AveragedModel(net.module_)
and see if that fixes the issue?
@BenjaminBossan Yes it fixed the issue.