torchmd-net icon indicating copy to clipboard operation
torchmd-net copied to clipboard

Ability to roll back an epoch

Open peastman opened this issue 2 years ago • 3 comments

I've been experimenting with training protocols to see what gives the best results. So far, the best I've found is to start with a large learning rate that pushes the boundaries of what's stable, then aggressively reduce it as needed. This seems to consistently give faster learning and a better final result than starting from a lower learning rate or reducing it more slowly.

Sometimes, though, the signal that your learning rate is too high can be quite dramatic. It isn't just that it fails to learn, but that the loss suddenly jumps way up. Here's an example from a recent training run. Notice how at epoch 30 the training loss and validation loss both increased by several times. It then took a few epochs to realize it needed to reduce the learning rate, and 10 epochs before the loss got back to where it had been before.

26,0.00039999998989515007,9832.5087890625,10538.8203125,5701.85009765625,4130.658203125,3393.247802734375,7145.572265625,30023
27,0.00039999998989515007,9179.70703125,12511.1748046875,5192.06591796875,3987.641357421875,5366.404296875,7144.77099609375,31135
28,0.00039999998989515007,8929.75390625,9831.01171875,5093.1494140625,3836.604736328125,2955.695556640625,6875.3173828125,32247
29,0.00039999998989515007,8372.0537109375,9940.771484375,4676.4296875,3695.62353515625,3155.14306640625,6785.62890625,33359
30,0.00039999998989515007,310591.90625,31877.890625,286285.375,24306.5078125,14621.484375,17256.404296875,34471
31,0.00039999998989515007,23531.33984375,27256.47265625,12166.06640625,11365.2734375,14695.8525390625,12560.6220703125,35583
32,0.00039999998989515007,18280.423828125,17700.439453125,9990.2021484375,8290.2216796875,7149.82421875,10550.6142578125,36695
33,0.00039999998989515007,14477.421875,14724.962890625,7723.6826171875,6753.73876953125,5242.77490234375,9482.1884765625,37807
34,0.00031999999191612005,12286.7001953125,13225.00390625,6404.654296875,5882.0458984375,4444.48876953125,8780.513671875,38919
35,0.00031999999191612005,11554.501953125,16997.564453125,6239.15966796875,5315.341796875,8691.9501953125,8305.61328125,40031
36,0.00031999999191612005,10664.4365234375,16877.546875,5792.0390625,4872.3974609375,9064.205078125,7813.341796875,41143
37,0.00031999999191612005,9997.931640625,11185.5439453125,5469.81396484375,4528.11767578125,3710.89599609375,7474.64794921875,42255
38,0.00025599999935366213,8595.7841796875,10532.455078125,4355.873046875,4239.9111328125,3280.87353515625,7251.58154296875,43367
39,0.00025599999935366213,8495.14453125,14786.5244140625,4462.298828125,4032.845458984375,7734.66015625,7051.86474609375,44479
40,0.00025599999935366213,7973.798828125,10847.546875,4123.08740234375,3850.7109375,3955.357177734375,6892.18994140625,45591

What do you think about adding an option to detect this instability by checking for the training loss increasing by more than a specified amount? When that happened, it would undo the epoch, rolling the model parameters back to the end of the previous epoch, and immediately reduce the learning rate.

peastman avatar Jul 07 '21 23:07 peastman

it recovers automatically in just a few epochs. I don't see it as a problem.

On Thu, Jul 8, 2021 at 1:25 AM Peter Eastman @.***> wrote:

I've been experimenting with training protocols to see what gives the best results. So far, the best I've found is to start with a large learning rate that pushes the boundaries of what's stable, then aggressively reduce it as needed. This seems to consistently give faster learning and a better final result than starting from a lower learning rate or reducing it more slowly.

Sometimes, though, the signal that your learning rate is too high can be quite dramatic. It isn't just that it fails to learn, but that the loss suddenly jumps way up. Here's an example from a recent training run. Notice how at epoch 30 the training loss and validation loss both increased by several times. It then took a few epochs to realize it needed to reduce the learning rate, and 10 epochs before the loss got back to where it had been before.

26,0.00039999998989515007,9832.5087890625,10538.8203125,5701.85009765625,4130.658203125,3393.247802734375,7145.572265625,30023 27,0.00039999998989515007,9179.70703125,12511.1748046875,5192.06591796875,3987.641357421875,5366.404296875,7144.77099609375,31135 28,0.00039999998989515007,8929.75390625,9831.01171875,5093.1494140625,3836.604736328125,2955.695556640625,6875.3173828125,32247 29,0.00039999998989515007,8372.0537109375,9940.771484375,4676.4296875,3695.62353515625,3155.14306640625,6785.62890625,33359 30,0.00039999998989515007,310591.90625,31877.890625,286285.375,24306.5078125,14621.484375,17256.404296875,34471 31,0.00039999998989515007,23531.33984375,27256.47265625,12166.06640625,11365.2734375,14695.8525390625,12560.6220703125,35583 32,0.00039999998989515007,18280.423828125,17700.439453125,9990.2021484375,8290.2216796875,7149.82421875,10550.6142578125,36695 33,0.00039999998989515007,14477.421875,14724.962890625,7723.6826171875,6753.73876953125,5242.77490234375,9482.1884765625,37807 34,0.00031999999191612005,12286.7001953125,13225.00390625,6404.654296875,5882.0458984375,4444.48876953125,8780.513671875,38919 35,0.00031999999191612005,11554.501953125,16997.564453125,6239.15966796875,5315.341796875,8691.9501953125,8305.61328125,40031 36,0.00031999999191612005,10664.4365234375,16877.546875,5792.0390625,4872.3974609375,9064.205078125,7813.341796875,41143 37,0.00031999999191612005,9997.931640625,11185.5439453125,5469.81396484375,4528.11767578125,3710.89599609375,7474.64794921875,42255 38,0.00025599999935366213,8595.7841796875,10532.455078125,4355.873046875,4239.9111328125,3280.87353515625,7251.58154296875,43367 39,0.00025599999935366213,8495.14453125,14786.5244140625,4462.298828125,4032.845458984375,7734.66015625,7051.86474609375,44479 40,0.00025599999935366213,7973.798828125,10847.546875,4123.08740234375,3850.7109375,3955.357177734375,6892.18994140625,45591

What do you think about adding an option to detect this instability by checking for the training loss increasing by more than a specified amount? When that happened, it would undo the epoch, rolling the model parameters back to the end of the previous epoch, and immediately reduce the learning rate.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/compsciencelab/torchmd-net/issues/29, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOVXVDU2L7726YPAYJDTWTO7HANCNFSM477TDLDQ .

giadefa avatar Jul 08 '21 08:07 giadefa

The goal is to make training as fast as possible, and to get the best final result. This helps those goals in three ways. First, you don't waste a bunch of epochs recovering from an error that you could have spotted and undone immediately. Second, it lets you push the learning rate up to larger values that lead to even faster training, but have a larger risk of instability. And third, it lets it realize much more quickly when the learning rate needs to be decreased.

Also note that I was using the change from #27, which allowed me to set lr_patience very low. Without that change, I would have had to set it much higher and 10 epochs wouldn't even have been enough for it to realize the learning rate needed to be decreased.

peastman avatar Jul 08 '21 15:07 peastman

@peastman maybe this helps with your idea: https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-validation-check-frequency-within-1-training-epoch

PhilippThoelke avatar Sep 07 '21 15:09 PhilippThoelke