addons
addons copied to clipboard
SequenceLoss shoudn't override the __call__ of tf.keras.losses.Loss
Everything is in the title. __call__ is considered a private API of keras. For TF 2.2 we had to change our code which shows that we shoudn't rely on the behavior of __call__.
Follow-up on https://github.com/tensorflow/addons/pull/1371#issuecomment-604883254
@guillaumekln @pavithrasv @qlzh727 are the persons who might be concerned by this issue.
@gabrieldemarmiesse , __call__ is passing some class variables to sequence_loss() function. What is the right way to remove __call__? Can we include method build() to store these class variables in SequenceLoss class and remove __call__ and call sequence_loss in call() method ?
We override the __call__ method to access the sample_weight argument and support various reduction strategies that are not covered by the reduction argument of tf.keras.losses.Loss. I'm not aware of any workaround at the moment that would keep the same features, so I'm marking this as blocked.