axon
axon copied to clipboard
Bug in training loops batch count
Step counts are a bit off. They should be consistent across batches, e.g.:
Epoch: 0, Batch: 350, accuracy: 0.5190082 loss: 0.6815987
Epoch: 1, Batch: 341, accuracy: 0.7035355 loss: 0.6067376
Epoch: 2, Batch: 332, accuracy: 0.7509853 loss: 0.5645106
Epoch: 3, Batch: 323, accuracy: 0.7610434 loss: 0.5400541
Epoch: 4, Batch: 314, accuracy: 0.7647320 loss: 0.5240928
Epoch: 5, Batch: 355, accuracy: 0.7672051 loss: 0.5114447
Epoch: 6, Batch: 346, accuracy: 0.7690023 loss: 0.5031880
Epoch: 7, Batch: 237, accuracy: 0.7714022 loss: 0.4982935
Should be 350 for all
@seanmor5 I think this is happening because event_counts[:iteration_completed]
keeps increasing with every iteration despite training going into the next epoch.
For example, training on MNIST with batch_size = 15000
(60000 / 15000 = 4 iterations per epoch) and log: 5
option inside trainer gives
Epoch: 0, Batch: 0, Accuracy: 0.0750000 loss: 0.0000000
Epoch: 1, Batch: 1, Accuracy: 0.5341333 loss: 2.0990303
Epoch: 2, Batch: 2, Accuracy: 0.6975778 loss: 1.8120724
Epoch: 3, Batch: 3, Accuracy: 0.7700500 loss: 1.5773102
Epoch: 5, Batch: 0, Accuracy: 0.8228667 loss: 1.3881876
Epoch: 6, Batch: 1, Accuracy: 0.8449667 loss: 1.2417462
Epoch: 7, Batch: 2, Accuracy: 0.8608667 loss: 1.1269298
Epoch: 8, Batch: 3, Accuracy: 0.8743500 loss: 1.0358800
Epoch: 10, Batch: 0, Accuracy: 0.8869333 loss: 0.9610912
Epoch: 11, Batch: 1, Accuracy: 0.8943000 loss: 0.8996136
Epoch: 12, Batch: 2, Accuracy: 0.8975111 loss: 0.8477340
Epoch: 13, Batch: 3, Accuracy: 0.9028834 loss: 0.8036935
Epoch: 15, Batch: 0, Accuracy: 0.9069334 loss: 0.7651854
Note the empty lines.
I am not sure how to fix this though. Editing the filter function could be one way.
I would honestly prefer logging to happen in every iteration, i.e. log: 1
to be the only possibility. It should not be computationally expensive (given that all the stats are saved) and we see the total number of batches in every epoch.
Epoch: 0, Batch: 3, Accuracy: 0.2306833 loss: 2.2516294
Epoch: 1, Batch: 3, Accuracy: 0.5954500 loss: 2.0017467
Epoch: 2, Batch: 3, Accuracy: 0.7270499 loss: 1.7843739
Epoch: 3, Batch: 3, Accuracy: 0.7800333 loss: 1.5970893
Epoch: 4, Batch: 3, Accuracy: 0.8154500 loss: 1.4406831
Epoch: 5, Batch: 3, Accuracy: 0.8401667 loss: 1.3119802
Epoch: 6, Batch: 3, Accuracy: 0.8556499 loss: 1.2060242
Epoch: 7, Batch: 3, Accuracy: 0.8690001 loss: 1.1182691
Epoch: 8, Batch: 3, Accuracy: 0.8780000 loss: 1.0448862
Epoch: 9, Batch: 3, Accuracy: 0.8851334 loss: 0.9828119
Epoch: 10, Batch: 3, Accuracy: 0.8905167 loss: 0.9296926
Epoch: 11, Batch: 3, Accuracy: 0.8949667 loss: 0.8837242
Epoch: 12, Batch: 3, Accuracy: 0.8988333 loss: 0.8435498
Epoch: 13, Batch: 3, Accuracy: 0.9025500 loss: 0.8081132
Epoch: 14, Batch: 3, Accuracy: 0.9056333 loss: 0.7765985
Epoch: 15, Batch: 3, Accuracy: 0.9087667 loss: 0.7483661
The drawback could be if someone wanted to write logs in a file.
By the way, can we maybe start counting the epochs and batches from 1? :)