addons icon indicating copy to clipboard operation
addons copied to clipboard

Loss calculated incorrectly in networks_seq2seq_nmt.ipynb

Open martingoodson opened this issue 3 years ago • 4 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
  • TensorFlow version and how it was installed (source or binary): N/A
  • TensorFlow-Addons version and how it was installed (source or binary): N/A
  • Python version: N/A
  • Is GPU used? (yes/no): N/A

Describe the bug This bug is in https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

The loss function is not calculated properly. The mean should only be calculated over non-masked elements. This line should be replaced:

loss = tf.reduce_mean(loss)

with this:

loss = tf.math.reduce_sum(loss) / tf.math.reduce_sum(mask)

This now gives the same results as keras.metrics.SparseCategoricalCrossentropy(from_logits=True), as expected.

def loss_function(real, pred):
  # real shape = (BATCH_SIZE, max_length_output)
  # pred shape = (BATCH_SIZE, max_length_output, tar_vocab_size )
  cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
  loss = cross_entropy(y_true=real, y_pred=pred)
  mask = tf.logical_not(tf.math.equal(real,0))   #output 0 for y=0 else output 1
  mask = tf.cast(mask, dtype=loss.dtype)  
  loss = mask* loss
  loss = tf.reduce_mean(loss)
  return loss  

Code to reproduce the issue

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

martingoodson avatar Dec 23 '21 21:12 martingoodson

Yes, the loss is not correctly reduced. Could you send a PR with your change?

guillaumekln avatar Dec 28 '21 08:12 guillaumekln

@guillaumekln I would like to contribute to this if no one is working on it.

MrinalTyagi avatar Jan 28 '22 05:01 MrinalTyagi

I will submit this pull request shortly. I've been on holiday.

-- Martin Goodson @martingoodson

On Fri, Jan 28, 2022 at 5:58 AM MrinalTyagi @.***> wrote:

@guillaumekln https://github.com/guillaumekln I would like to contribute to this if no one is working on it.

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/addons/issues/2637#issuecomment-1023911479, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANXGHLX74PYS7GWYICJMADUYIWATANCNFSM5KVRSXNQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you authored the thread.Message ID: @.***>

martingoodson avatar Jan 28 '22 08:01 martingoodson

I will submit this pull request shortly. I've been on holiday. -- Martin Goodson @martingoodson - - On Fri, Jan 28, 2022 at 5:58 AM MrinalTyagi @.> wrote: @guillaumekln https://github.com/guillaumekln I would like to contribute to this if no one is working on it. — Reply to this email directly, view it on GitHub <#2637 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AANXGHLX74PYS7GWYICJMADUYIWATANCNFSM5KVRSXNQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you authored the thread.Message ID: @.>

sorry. thought it was available for contribution

MrinalTyagi avatar Jan 28 '22 09:01 MrinalTyagi