Refactor perplexity so compute() does not run inference by default
This is a proposed refactor to the perplexity metric which would bring perplexity closer to the other metrics in evaluate, which generally do not run inference in their compute functions, and only take in necessary components to actually run the metric computation (in this case, logits, labels, and attention mask tensors).
The measurement version of perplexity retains the functionality of passing in a pretrained model by name to calculate perplexity against some data.
Test with:
import time
st = time.time()
ppl = Perplexity()
results = ppl._compute(
predictions=torch.tensor(np.random.uniform(-100, 0, (4, 12, 50257))),
references=torch.tensor(np.random.randint(1, 10000, (4, 12))),
attention_mask=torch.ones((4, 12)),
)
print(results)
print(f"time taken: {time.time() - st}")
st = time.time()
results2 = ppl.compute(
predictions=torch.tensor(np.random.uniform(-100, 0, (4, 12, 50257))),
references=torch.tensor(np.random.randint(1, 10000, (4, 12))),
attention_mask=torch.ones((4, 12)),
)
print(results2)
print(f"time taken: {time.time() - st}")
Outstanding is that the upstream slowness from Arrow not being kept in-memory results in the performance difference between _.compute and compute: 0.19296860694885254 vs. 6.994134187698364 seconds (respectively)
This is a breaking change so we should alert downstream.
Closes #240
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.