axon icon indicating copy to clipboard operation
axon copied to clipboard

Bug in training loops batch count

Open seanmor5 opened this issue 9 months ago • 1 comments

Step counts are a bit off. They should be consistent across batches, e.g.:

Epoch: 0, Batch: 350, accuracy: 0.5190082 loss: 0.6815987
Epoch: 1, Batch: 341, accuracy: 0.7035355 loss: 0.6067376
Epoch: 2, Batch: 332, accuracy: 0.7509853 loss: 0.5645106
Epoch: 3, Batch: 323, accuracy: 0.7610434 loss: 0.5400541
Epoch: 4, Batch: 314, accuracy: 0.7647320 loss: 0.5240928
Epoch: 5, Batch: 355, accuracy: 0.7672051 loss: 0.5114447
Epoch: 6, Batch: 346, accuracy: 0.7690023 loss: 0.5031880
Epoch: 7, Batch: 237, accuracy: 0.7714022 loss: 0.4982935

Should be 350 for all

seanmor5 avatar Sep 20 '23 23:09 seanmor5

@seanmor5 I think this is happening because event_counts[:iteration_completed] keeps increasing with every iteration despite training going into the next epoch.

For example, training on MNIST with batch_size = 15000 (60000 / 15000 = 4 iterations per epoch) and log: 5 option inside trainer gives

Epoch: 0, Batch: 0, Accuracy: 0.0750000 loss: 0.0000000
Epoch: 1, Batch: 1, Accuracy: 0.5341333 loss: 2.0990303
Epoch: 2, Batch: 2, Accuracy: 0.6975778 loss: 1.8120724
Epoch: 3, Batch: 3, Accuracy: 0.7700500 loss: 1.5773102

Epoch: 5, Batch: 0, Accuracy: 0.8228667 loss: 1.3881876
Epoch: 6, Batch: 1, Accuracy: 0.8449667 loss: 1.2417462
Epoch: 7, Batch: 2, Accuracy: 0.8608667 loss: 1.1269298
Epoch: 8, Batch: 3, Accuracy: 0.8743500 loss: 1.0358800

Epoch: 10, Batch: 0, Accuracy: 0.8869333 loss: 0.9610912
Epoch: 11, Batch: 1, Accuracy: 0.8943000 loss: 0.8996136
Epoch: 12, Batch: 2, Accuracy: 0.8975111 loss: 0.8477340
Epoch: 13, Batch: 3, Accuracy: 0.9028834 loss: 0.8036935

Epoch: 15, Batch: 0, Accuracy: 0.9069334 loss: 0.7651854

Note the empty lines.

I am not sure how to fix this though. Editing the filter function could be one way. I would honestly prefer logging to happen in every iteration, i.e. log: 1 to be the only possibility. It should not be computationally expensive (given that all the stats are saved) and we see the total number of batches in every epoch.

Epoch: 0, Batch: 3, Accuracy: 0.2306833 loss: 2.2516294
Epoch: 1, Batch: 3, Accuracy: 0.5954500 loss: 2.0017467
Epoch: 2, Batch: 3, Accuracy: 0.7270499 loss: 1.7843739
Epoch: 3, Batch: 3, Accuracy: 0.7800333 loss: 1.5970893
Epoch: 4, Batch: 3, Accuracy: 0.8154500 loss: 1.4406831
Epoch: 5, Batch: 3, Accuracy: 0.8401667 loss: 1.3119802
Epoch: 6, Batch: 3, Accuracy: 0.8556499 loss: 1.2060242
Epoch: 7, Batch: 3, Accuracy: 0.8690001 loss: 1.1182691
Epoch: 8, Batch: 3, Accuracy: 0.8780000 loss: 1.0448862
Epoch: 9, Batch: 3, Accuracy: 0.8851334 loss: 0.9828119
Epoch: 10, Batch: 3, Accuracy: 0.8905167 loss: 0.9296926
Epoch: 11, Batch: 3, Accuracy: 0.8949667 loss: 0.8837242
Epoch: 12, Batch: 3, Accuracy: 0.8988333 loss: 0.8435498
Epoch: 13, Batch: 3, Accuracy: 0.9025500 loss: 0.8081132
Epoch: 14, Batch: 3, Accuracy: 0.9056333 loss: 0.7765985
Epoch: 15, Batch: 3, Accuracy: 0.9087667 loss: 0.7483661

The drawback could be if someone wanted to write logs in a file.

By the way, can we maybe start counting the epochs and batches from 1? :)

krstopro avatar Oct 14 '23 10:10 krstopro