ranking icon indicating copy to clipboard operation
ranking copied to clipboard

Extra loss not added to the overall model loss

Open xiaowec opened this issue 1 year ago • 0 comments

I am trying to add an extra loss to penalize some intermedia variables in the scoring function, but I notice that by calling tf.compat.v1.add_to_collection, the loss did not change. Here is one example:

import tensorflow as tf
import tensorflow_ranking as tfr

# create a scoring function
def scoring_fn(context_features, group_features, mode, params, config):
    # define input layer
    input_layer = tf.concat([
        tf.reshape(group_features['example_feature'], [-1, 1]),
        tf.reshape(context_features['query_feature'], [-1, 1]),
        tf.reshape(group_features['doc_feature'], [-1, 1])
    ], axis=1)
    # define hidden layers
    hidden_layer1 = tf.keras.layers.Dense(64, activation='relu')(input_layer)
    hidden_layer2 = tf.keras.layers.Dense(32, activation='relu')(hidden_layer1)
    logits = tf.keras.layers.Dense(1)(hidden_layer2)
    
    extra_loss = tf.reduce_sum(tf.square(tf.reduce_sum(hidden_layer2, axis=1)))
    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.LOSSES, extra_loss)

    return logits

# create a ranking head
ranking_head = tfr.head.create_ranking_head(
    eval_metric_fns={ 'metric': tfr.metrics.make_ranking_metric('metric', tfr.metrics.RankingMetricKey.ARP) },
    train_op_fn=tfr.model.make_groupwise_ranking_fn(
        group_score_fn=scoring_fn,
        group_size=1,
        transform_fn=None,
        loss_fn=tfr.losses.make_loss_fn('pairwise_logistic_loss'),
        global_step=tf.Variable(0)
    )
)

# create a model
estimator = tf.estimator.Estimator(
    model_fn=tfr.model.make_groupwise_ranking_fn(
        group_score_fn=scoring_fn,
        group_size=1,
        ranking_head=ranking_head))

With this change, I am not seeing any loss difference and seems like the extra_loss is not added to the overall loss. The square sum of hidden_layer2 should be pretty large. Is there anything we are missing?

xiaowec avatar Apr 05 '23 21:04 xiaowec