ogb
ogb copied to clipboard
Update rank calculation
This small PR changes the calculation of rank values from a argsort based implementation to a faster version based on counting >= / > comparisons between paired positive and negative scores.
If there are no duplicate scores in the array, the results should be the same. For same scores, the new implementation yields more realistic performance estimates, as described in https://arxiv.org/abs/2002.06914, while the sorting-based version's estimates depend on the inner workings of the sorting algorithm.
I also ran a small benchmark (see Details below) on a Quadro RTX 8000 with torch 1.12, with speed-ups reaching from 1,5x to ~50x for different batch-size / number of negative samples combinations.
Code
# cf. https://pytorch.org/tutorials/recipes/recipes/benchmark.html
from torch.utils import benchmark
import torch
def old(y_pred_pos: torch.Tensor, y_pred_neg: torch.Tensor):
y_pred = torch.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1)
argsort = torch.argsort(y_pred, dim=1, descending=True)
ranking_list = torch.nonzero(argsort == 0, as_tuple=False)
ranking_list = ranking_list[:, 1] + 1
return ranking_list
def new(y_pred_pos: torch.Tensor, y_pred_neg: torch.Tensor):
y_pred_pos = y_pred_pos.view(-1, 1)
optimistic_rank = (y_pred_neg >= y_pred_pos).sum(dim=1)
pessimistic_rank = (y_pred_neg > y_pred_pos).sum(dim=1)
ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
return ranking_list
example_fuzzer = benchmark.Fuzzer(
parameters=[
benchmark.FuzzedParameter(
"n", minval=1, maxval=1_000_000, distribution="loguniform"
),
benchmark.FuzzedParameter(
"k", minval=1, maxval=10_000, distribution="loguniform"
),
],
tensors=[
benchmark.FuzzedTensor(
"y_pred_pos",
size=("n", 1),
min_elements=128,
max_elements=10_000_000,
),
benchmark.FuzzedTensor(
"y_pred_neg",
size=("n", "k"),
min_elements=128,
max_elements=10_000_000,
),
],
seed=0,
)
results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
# description is the column label
sub_label = f"{params['n']:<6} x {params['k']:<4}"
results.append(
benchmark.Timer(
stmt="old(y_pred_pos, y_pred_neg)",
setup="from __main__ import old",
globals=tensors,
label="rank",
sub_label=sub_label,
description="argsort/where",
).blocked_autorange(min_run_time=1)
)
results.append(
benchmark.Timer(
stmt="new(y_pred_pos, y_pred_pos)",
setup="from __main__ import new",
globals=tensors,
label="rank",
sub_label=sub_label,
description="comp/sum",
).blocked_autorange(min_run_time=1)
)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()
Results
[--------------------- rank ---------------------]
| argsort/where | comp/sum
1 threads: ---------------------------------------
19549 x 257 | 4400 | 90
199 x 1468 | 250 | 90
103865 x 22 | 726 | 91
694 x 1598 | 630 | 89
1352 x 45 | 140 | 90
2573 x 1 | 130 | 90
208 x 4077 | 500 | 89
10563 x 6 | 150 | 90
152 x 190 | 156 | 89
8255 x 346 | 2060 | 90
Times are in microseconds (us).