torchbearer icon indicating copy to clipboard operation
torchbearer copied to clipboard

Queston about training loop

Open AnabetsyR opened this issue 2 years ago • 4 comments

Hi! I'm trying to fork the repo and add some functionality for an experiment. But that requires an addition in the training loop. I've read the documentation and the code but I can't seem to understand where the training loop itself is defined. Can somebody point me in the right direction?

Thanks in advance!

AnabetsyR avatar Jun 24 '22 21:06 AnabetsyR

In trial.py see method run (looping over the epochs): https://github.com/pytorchbearer/torchbearer/blob/9d97c60ec4deb37a0627311ddecb9c6f1429cd82/torchbearer/trial.py#L946

and _fit_pass (looping over the batches): https://github.com/pytorchbearer/torchbearer/blob/9d97c60ec4deb37a0627311ddecb9c6f1429cd82/torchbearer/trial.py#L1019

I'd have thought most additions to the training loop can be added via one of the many callback hooks rather than modifying the source itself though

jonhare avatar Jun 24 '22 21:06 jonhare

Thanks for getting back to me! I'm trying to integrate stochastic weight averaging as in swa_utils from Pytorch. The way they implemented swa is like a wrapper on top of the torch optimizer (see their example below). Based on this, it seems I will need to pass the swa_model, the optimizer, and at least the swa_scheduler. And then I have to handle parameter update after swa kicks in during the training loop. Do you have any suggestions to go about doing this? Sorry, I'm new to both Torchbearer and SWA... I really appreciate any suggestions.

loader, optimizer, model, loss_fn = ... >>> swa_model = torch.optim.swa_utils.AveragedModel(model) >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, >>> T_max=300) >>> swa_start = 160 >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05) >>> for i in range(300): >>> for input, target in loader: >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() >>> if i > swa_start: >>> swa_model.update_parameters(model) >>> swa_scheduler.step() >>> else: >>> scheduler.step() >>> >>> # Update bn statistics for the swa_model at the end >>> torch.optim.swa_utils.update_bn(loader, swa_model)

AnabetsyR avatar Jun 24 '22 21:06 AnabetsyR

really, really (,really!) untested, but something like this should be equivalent based on the above and guessing the correct indentation:

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

swa_scheduler = SWALR(optimizer, swa_lr=0.05)
swa_start = 160

@torchbearer.callbacks.on_step_training
def swa_callback(state):
	if state[torchbearer.EPOCH] > swa_start:
		swa_model.update_parameters(model) #or avoiding the global access: swa_model.update_parameters(state[torchbearer.MODEL])
		swa_scheduler.step()
	else:
		scheduler.step()

trial = torchbearer.Trial(model, optimizer, loss_fn, callbacks=[swa_callback])
trial = trial.with_train_generator(loader)
trial.run(300)

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)

jonhare avatar Jun 24 '22 22:06 jonhare

@jonhare Thank you so much! I've spent the weekend playing around with it. At first it was working weird. Instead of just passing the swa_callback alone, I was adding it to my list of callbacks (which has other things in it), and then passing the whole list. So it seems some things don't play well together but I suspect it's due to some unnecessary schedulers etc in there. Of course now I have to combine the necessary portions but it's working!

Note that I added the update_bn line before the trial as it looked like it wasn't updating properly. Hopefully this is correct!

I'm extremely grateful that you took time our of your day to help me out! You really saved mmme a ton of headache. And I love torchbearer so much that I didn't want to switch the whole thing.

AnabetsyR avatar Jun 27 '22 17:06 AnabetsyR