tfboys icon indicating copy to clipboard operation
tfboys copied to clipboard

Incorrect use of tf.Metrics?

Open SuperCrystal opened this issue 6 years ago • 2 comments

Incorrect use of tf.Metrics?

Hello, I've noticed that in your training loop it is written as:

train_loss = tf.metrics.Mean(name='train_loss')
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
for epoch in range(start_epoch, 120):
            try:
                for batch, data in enumerate(train_dataset):
                    # images, labels = data['image'], data['label']
                    images, labels = data
                    with tf.GradientTape() as tape:
                        predictions = model(images)
                        print('pred: ', predictions)
                        loss = loss_object(labels, predictions)
                    gradients = tape.gradient(loss, model.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                    train_loss(loss)
                    # we should gather all accuracy and using average
                    train_accuracy(labels, predictions)
                    if batch % 50 == 0:
                        logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
                        epoch, batch, train_loss.result(), train_accuracy.result()))
            except KeyboardInterrupt:
                logging.info('interrupted.')
                model.save_weights(ckpt_path.format(epoch=epoch))
                logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
                exit(0)

The 'train_loss' is a tf.Metrics object, and the calculated loss is added to 'train_loss' in every batch, then the logging info will present in every 50 batches. The problem occurs here in 'train_loss.result()' because the 'result()' method calculate the mean value of all the state variables (as the metrics you defined), however, you didn't reset the states at the end of the epoch so it will calculate the average loss of all the batches in all the past epochs. I get a quite a strange training and val loss curve in this way and I guess that is the reason.

Reference: reset_states This function is called between epochs/steps, when a metric is evaluated during training. tf.keras.metrics.Mean

SuperCrystal avatar Oct 23 '19 09:10 SuperCrystal

@SuperCrystal Hi, thanks for your investigation and walking around!

I think you were right. it actually got wired loss this way.

Would like send me a PR to fix this issue? I believe it can be solved by manually reset it's state.

lucasjinreal avatar Oct 23 '19 11:10 lucasjinreal

@SuperCrystal Hi, thanks for your investigation and walking around!

I think you were right. it actually got wired loss this way.

Would like send me a PR to fix this issue? I believe it can be solved by manually reset it's state.

Yes of course, I will send it later, thanks :)

SuperCrystal avatar Oct 24 '19 01:10 SuperCrystal