torchmetrics
torchmetrics copied to clipboard
Retrieval metrics are misleading - here's how it should be done instead
🚀 Request
Retrieval metrics should be more aligned with typical practices. Recommendation is below.
Explanation
The current best practice for calculating retrieval metrics follows this process:
- Calculate ground truth search labels. In the case of vector search, this would be via exact KNN search.
- Calculate predicted results. In the case of vector search, this would be retrieving ANN search results.
- Calculate score. The way a predicted result is determined to be relevant is determined by whether it appears in the ground truth list. Some implementations also consider the position.
Example (vector search recall)
In the case of top k recall for vector search:
- obtain the top k exact KNN results for the given query
- obtain the top k ANN results for the given query
- Calculate the percentage of ANN results that appear in the KNN result list
- Repeat for every query in the test set
Current implementation
The current implementation in torchmetrics expects indexes (corresponding to queries), predictions (probabilities used to rank, which could be similarity scores), targets (which are supposed to be ground truth labels).
What's misleading:
The problem is that the current implementation assumes that a 1:1 mapping exists between predictions and ground truth labels, which does not align with the industry practice.
How it should work instead
The method signature should look more like this: query/index (tensor), true_preds (tensor), true_targets (tensor), actual_preds (tensor), actual_targets (tensor), threshold (optional float), epsilon (optional float), similarity/distance (bool, defaults to distance)
With this data, the score calculation could be based on either:
-
count. If the preds are distinct, then top k could be obtained by: a. sorting (true_preds, true_targets) by true_preds and filtering to top k items, b. sorting (actual_preds, actual_targets) by actual_preds and filtering to top k items c. performing the computation between the lists for each given query/index
-
threshold. If preds are not distinct, then we would take top values obtained by a threshold, t, obtained by: a. sorting (true_preds, true_targets) by true_preds and filtering to top items where true_preds are <= t (if t is a distance) or >= t (if t is a similarity) b. sorting (actual_preds, actual_targets) by actual_preds and filtering to top items where actual_preds are <= (1 + epsilon) * t (if t is a distance) or >= (1 - epsilon) * t (if t is a similarity) where epsilon is a modifier to soften the filter on the actual side. c. performing the computation between the lists for each given query/index
An implementation like this would enable torchmetrics to calculate Recall, MAP, AUROC, NDCG, MRR, etc., based on industry-accepted practices.
Additional context
The well-known ANN-Benchmarks paper goes into detail on the recall calculation used in the example here.
(Aumüller, M., Bernhardsson, E., & Faithfull, A. (2020). ANN-Benchmarks: A benchmarking tool for approximate nearest neighbor algorithms. Information Systems, 87, 101374. Available at: https://arxiv.org/pdf/1807.05614.pdf )