tfboys
tfboys copied to clipboard
Incorrect use of tf.Metrics?
Incorrect use of tf.Metrics?
Hello, I've noticed that in your training loop it is written as:
train_loss = tf.metrics.Mean(name='train_loss')
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
for epoch in range(start_epoch, 120):
try:
for batch, data in enumerate(train_dataset):
# images, labels = data['image'], data['label']
images, labels = data
with tf.GradientTape() as tape:
predictions = model(images)
print('pred: ', predictions)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
# we should gather all accuracy and using average
train_accuracy(labels, predictions)
if batch % 50 == 0:
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
epoch, batch, train_loss.result(), train_accuracy.result()))
except KeyboardInterrupt:
logging.info('interrupted.')
model.save_weights(ckpt_path.format(epoch=epoch))
logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
exit(0)
The 'train_loss' is a tf.Metrics object, and the calculated loss is added to 'train_loss' in every batch, then the logging info will present in every 50 batches. The problem occurs here in 'train_loss.result()' because the 'result()' method calculate the mean value of all the state variables (as the metrics you defined), however, you didn't reset the states at the end of the epoch so it will calculate the average loss of all the batches in all the past epochs. I get a quite a strange training and val loss curve in this way and I guess that is the reason.
Reference: reset_states This function is called between epochs/steps, when a metric is evaluated during training. tf.keras.metrics.Mean
@SuperCrystal Hi, thanks for your investigation and walking around!
I think you were right. it actually got wired loss this way.
Would like send me a PR to fix this issue? I believe it can be solved by manually reset it's state.
@SuperCrystal Hi, thanks for your investigation and walking around!
I think you were right. it actually got wired loss this way.
Would like send me a PR to fix this issue? I believe it can be solved by manually reset it's state.
Yes of course, I will send it later, thanks :)