[BUG] XLNET-CLM eval recall metric value does not match with custom np based recall metric value
Bug description
When we train an XLNet model with CLM masking, the model prints out its own evaluation metrics (ndcg@k, recall@k, etc.) from trainer.evaluate() step. If we want to apply our own custom metric func using numpy something like below, the metric values do not match, but they match if we use MLM masking instead.
def recall(predicted_items: np.ndarray, real_items: np.ndarray) -> float:
bs, top_k = predicted_items.shape
valid_rows = real_items != 0
# reshape predictions and labels to compare
# the top-10 predicted item-ids with the label id.
real_items = real_items.reshape(bs, 1, -1)
predicted_items = predicted_items.reshape(bs, 1, top_k)
num_relevant = real_items.shape[-1]
predicted_correct_sum = (predicted_items == real_items).sum(-1)
predicted_correct_sum = predicted_correct_sum[valid_rows]
recall_per_row = predicted_correct_sum / num_relevant
return np.mean(recall_per_row)
Steps/Code to reproduce bug
coming soon.
Expected behavior
Environment details
- Transformers4Rec version:
- Platform:
- Python version:
- Huggingface Transformers version:
- PyTorch version (GPU?):
- Tensorflow version (GPU?):
Additional context
If I use dev branch, I am getting much higher CLM accuracy metrics (~2.5x higher) compared to MLM from end-to-end example with yoochoose dataset. I think this is not expected.
Is this bug already fixed in some T4R version? I am currently experiencing similar discrepancies when it comes to evaluating NDCG and MRR metrics on my dataset. My question is: is it worth creating a reproducible example, or are you already working on it?"
@SPP3000 can you please provide more details about I am currently experiencing similar discrepancies ?
what model you are using? and how do you evaluate? are you using our evaluation method fit_and_evaluate function?
@rnyak I just opened a new bug report with all details here.
@SPP3000 are you seeing same issue with XLNet MLM? did you test MLM?
Hello, are there any updates regarding this issue? @rnyak @SPP3000