addons icon indicating copy to clipboard operation
addons copied to clipboard

sample_weight parameter occurs Rank Error when TripletSemiHardLoss is used in multi-output model.

Open ubless607 opened this issue 2 years ago • 0 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04.3 LTS
  • TensorFlow version and how it was installed (source or binary): 2.9.1 (pypi)
  • TensorFlow-Addons version and how it was installed (source or binary): 0.17.1 (pypi)
  • Python version: 3.9.13
  • Is GPU used? (yes/no): yes

Describe the bug

A clear and concise description of what the bug is.

When TripletSemiHardLoss is used as a loss function, using sample_weight parameter doesn't occur any error. However, when TripletSemiHardLoss is used in multi-output model, using sample_weight parameter occurs Rank Error.

Code to reproduce the issue

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss={'embedding_output': tfa.losses.TripletSemiHardLoss(),
          'clf_output': 'sparse_categorical_crossentropy'})

Since class_weight doesn't work in TF2.1+ for multi-output, I applied class weighting using _make_class_weight_map_fn in keras.engine by mapping the function.

model.fit() gives the following error. TripletSemiHardLoss with sample_weight in single-output model works fine, other tensorflow-implemeneted loss functions in multi-output model with sample_weight also works. But when TripletSemiHardLoss and other loss function with sample_weight option gives the following error:

ValueError: Shapes must be equal rank, but are 2 and 0
    	From merging shape 0 with other shapes. for '{{node AddN}} = AddN[N=2, T=DT_FLOAT](TripletSemiHardLoss/weighted_loss/Mul, sparse_categorical_crossentropy/weighted_loss/value)' with input shapes: [?,1], [].

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

ubless607 avatar Sep 17 '22 08:09 ubless607