recommenders icon indicating copy to clipboard operation
recommenders copied to clipboard

Why does the Retrieval class use an identify matrix for labels?

Open jiunsiew opened this issue 4 years ago • 13 comments

Hi there,

Thanks so much for releasing and maintaining this code base. It's really fantastic.

I don't really have a bug to report, but I do have a question regarding the Retrieval class and how the loss function is being calculated. I've been working through the tutorials, focusing on the basic retrieval example and I understand that by default, the loss function uses a categorical cross entropy loss function.

Obviously, this implies having a label and predicted probabilities which I can see in the Retrieval class.

From that class, the scores are the matrix multiplication of the query and candidate embeddings:

 scores = tf.linalg.matmul(
        query_embeddings, candidate_embeddings, transpose_b=True)

Then, the labels are derived as:

labels = tf.eye(num_queries, num_candidates)

Which is then passed to tf.keras.losses.CategoricalCrossentropy to calculate the loss:

loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)

What I don't quite understand is why is the identity matrix used as the labels? Doesn't this imply that user_i has selected candidate_i? Or am I missing something?

If I relate this back to the basic retrieval example, then would the scores matrix be of size number_unique_user_ids x number_unique_movie_ids? Likewise, the rows of the loss matrix would relate to a user_id and the columns to a candidate. Wouldn't this imply that user_1 reviewed candidate_1, etc..?

Apologies if this is a fairly basic question, but I'm quite new to Tensorflow. Would appreciate any feedback or references. I've tried looking at the issues here and also on stackoverflow, but haven't really been able to find anything. Thanks.

jiunsiew avatar Jul 14 '21 10:07 jiunsiew

This is called an in-batch softmax loss. The input data is as follows:

# User embeddings        # Candidate embeddigns
[[user_0],                        [[candidate_clicked_by_user_0],
 [user_N]]                       [[candidate_clicked_by_user_N]]

We then create a num users x num candidates scores matrix, where the scores for the user-candidate pairs in the input data are on the diagonal. Hence the identity matrix for labels.

maciejkula avatar Jul 21 '21 17:07 maciejkula

This is called an in-batch softmax loss. The input data is as follows:

# User embeddings        # Candidate embeddigns
[[user_0],                        [[candidate_clicked_by_user_0],
 [user_N]]                       [[candidate_clicked_by_user_N]]

We then create a num users x num candidates scores matrix, where the scores for the user-candidate pairs in the input data are on the diagonal. Hence the identity matrix for labels.

Thanks, what if a user had interaction with all candidates..? then for that row, shouldn't it have multiple 1 as label..?

xiaoyaoyang avatar Aug 06 '21 16:08 xiaoyaoyang

Hi @xiaoyaoyang, It's still an identity matrix. There may be candidates that the user has interacted with which are used as negatives for this batch, even though they will be a positive example in another batch. This generally amounts to a regularisation effect.

If it helps, think of word2vec. Two different examples passed through the model could have the same context The front [door] was left open. and The front [gate] was left open.. For each example you sample negatives, but you don't prevent sampling negatives which are positives for other examples in your corpus.

what if a user had interaction with all candidates

This should be an unlikely scenario for your data, assuming the interaction matrix is sparse. If this is not the case then it would seem to be more of a ranking problem than a retrieval problem.

patrickorlando avatar Aug 08 '21 23:08 patrickorlando

Hi @xiaoyaoyang, It's still an identity matrix. There may be candidates that the user has interacted with which are used as negatives for this batch, even though they will be a positive example in another batch. This generally amounts to a regularisation effect.

If it helps, think of word2vec. Two different examples passed through the model could have the same context The front [door] was left open. and The front [gate] was left open.. For each example you sample negatives, but you don't prevent sampling negatives which are positives for other examples in your corpus.

what if a user had interaction with all candidates

This should be an unlikely scenario for your data, assuming the interaction matrix is sparse. If this is not the case then it would seem to be more of a ranking problem than a retrieval problem.

Just want to make sure I understand it correctly. Let's say User had interacted with 2 candidates, the matrix is sparse and if I write down the interaction matrix, the row for that user will have two cell = 1 and the other = 0.

However, when doing batch calculation, for that user, we will random pick one positive example for each batch (out of two positive cases) and thus the matrix will still be identity matrix..?

xiaoyaoyang avatar Aug 31 '21 22:08 xiaoyaoyang

Just want to make sure I understand it correctly. Let's say User had interacted with 2 candidates, the matrix is sparse and if I write down the interaction matrix, the row for that user will have two cell = 1 and the other = 0.

Agree.

However, when doing batch calculation, for that user, we will random pick one positive example for each batch (out of two positive cases) and thus the matrix will still be identity matrix..?

Not quite.

Let's imagine you have 3 users and 4 items and that the positive interactions are as follows

row user_id item_id
1 1 3
2 1 4
3 2 1
4 2 3
5 3 2
6 3 4

As an interaction matrix you have

0  0  1  1
1  0  1  0
0  1  0  1

Assume a batch size of 2, and a query/candidate embedding of size 12. We randomly sample 2 rows from the table above, rows 1 and 5.

Your queries will be a matrix of shape (2, 12). The first row will the the query for user 1, and the second row for user 3. Your candidates will be a matrix of shape (2, 12). The first row will be for item 3, and the second for item 2.

When you perform the dot product between queries and candidates you get a score matrix of shape (2, 2). The rows are the users, and the columns are the items.

(u1,i3)  (u1, i2)
(u3,i3)  (u3, i2)

The diagonal of this matrix is the score for the positive interactions that we sampled. All the other elements are the scores for that query and the positive items for other examples, which we then use as negatives for that example

This matrix is not representative of the global interaction matrix. Consider we sample a batch of rows 1 and 4 instead. In this case, both users 1 and 2 interacted with item 3. The score matrix is then,

(u1,i3)  (u1, i3)
(u2,i3)  (u2, i3)

Here, the negative for each positive case is the same as the positive item. This is an accidental hit, and the tfrs library has the ability parameter to remove these hits if you pass the candidate ids in.

So in summary, each row of the scores matrix corresponds to is a single (user, item) pair. The diagonal is the score for that pair and all other columns are negatives sampled from the other pairs in that same mini-batch.

patrickorlando avatar Sep 14 '21 07:09 patrickorlando

Hi @patrickorlando , thanks for the great and detailed explanation. Just to clarify and to make sure I've understood, in your last example where the batch contains rows 1 and 4, wouldn't user 2 interact with item 3 instead of item 4? If row 4 has item_id = 4, then the accidental hit would occur. Would that be correct? Thanks again. It's starting to make a lot more sense now.

jiunsiew avatar Sep 14 '21 10:09 jiunsiew

Hi @jiunsiew, Whoops, looks like I made some mistakes in the examples. 😅 You are correct. I'll edit my answer to not confuse others.

patrickorlando avatar Sep 29 '21 06:09 patrickorlando

(u1,i3)  (u1, i3)
(u2,i3)  (u2, i3)

Here, the negative for each positive case is the same as the positive item. This is an accidental hit, and the tfrs library has the ability parameter to remove these hits if you pass the candidate ids in.

So in summary, each row of the scores matrix corresponds to is a single (user, item) pair. The diagonal is the score for that pair and all other columns are negatives sampled from the other pairs in that same mini-batch.

Thanks for the explanation! this is really helpful. One last Q.. as for accidental hit, I am assuming it will happen a lot for popular candidates, I took look at the code:

For each row in the batch, zeros the logits of negative candidates that have
    the same id as the positive candidate in that row.

Seems the solution is just "zero out" negative pair element in the matrix, in the above example:

(u1,i3)  (u1, i3)
(u2,i3)  (u2, i3)

in the M(2*2) matrix, it will zero out M(0,1) and M(1,0).

xiaoyaoyang avatar Nov 10 '21 02:11 xiaoyaoyang

@xiaoyaoyang, correct. The optimisation is attempting to push all scores for postive items higher, and all scores for negative items lower. If there is an accidental hit, we artificially set the score to a very large negative number, https://github.com/tensorflow/recommenders/blob/a4de92c66d95d0370c664f04e450b4dd294e5763/tensorflow_recommenders/layers/loss.py#L139-L147

https://github.com/tensorflow/recommenders/blob/a4de92c66d95d0370c664f04e450b4dd294e5763/tensorflow_recommenders/layers/loss.py#L23

This ensures there is no gradient that attempts to weaken the positive interaction for that example.

> import numpy as np
> print(np.finfo(np.float32).min / 100.0)
-3.4028234663852885e+36

patrickorlando avatar Nov 18 '21 01:11 patrickorlando

Thank you for the wonderful explanations, @patrickorlando!

maciejkula avatar Nov 18 '21 01:11 maciejkula

Deleted

dexter1729 avatar Jul 13 '22 08:07 dexter1729

Hi @patrickorlando

Is it possible to use label smoothing with the CrossCategoricalEntropy with the remove_accidental_hits? as setting it to large negative value, will result in some gradient if it happens with conjunction of label smoothing, I didn't see any paper talking about impact of label smoothing in the two tower context, but i thought it might be a good choice as we don't know the true ground truth

This was discussed here as well in this issue

OmarMAmin avatar Feb 11 '23 21:02 OmarMAmin

I haven't tried label smoothing myself @OmarMAmin and I haven't seen any literature on label smoothing in two tower models. However, you can write your own implementations of RemoveAccidentalHits and the Retrieval task to implement it. I wonder whether the negative sampling introduces a sort of implicit smoothing when it samples a negative that was a positive in another batch, which is why this area hasn't been explored. It would be interesting to hear if you find any benefit from it.

patrickorlando avatar Feb 12 '23 22:02 patrickorlando