pyhf
pyhf copied to clipboard
Track and update changes to percentile in PyTorch and JAX backends
percentile was added to the tensorlib backends in PR #817 but there was outstanding issues with both PyTorch and JAX that required some nuance.
- PyTorch is missing interpolation methods as of
v1.10.1:
I've checked in again on https://github.com/pytorch/pytorch/pull/59397#issuecomment-965856806 but I'm not really sure if we'll see this an interpolation option in
torchinv1.10.1, so we might want to implement raise not implemented errors for the time being so that PR #817 can finally move forward. :/
Originally posted by @matthewfeickert in https://github.com/scikit-hep/pyhf/issues/815#issuecomment-965858409
- JAX requires some additional support for dtype promotion in percentile when using the linear interpolation method.
- c.f. https://github.com/google/jax/issues/8513
Both these Issues should be monitored in the future so that they can hopefully be resolved along the way to a patch release.
Tracking for PyTorch has moved from https://github.com/pytorch/pytorch/pull/59397 to https://github.com/pytorch/pytorch/pull/70637. :+1:
https://github.com/pytorch/pytorch/pull/70637 has been merged on 2022-01-05, but was not in torch v1.10.2 which was released on 2022-01-27. It is scheduled to be in the next minor release: torch v1.11.0.
The JAX issue was already resolved through Issue #1729 and PR #1730.
torch v1.11.0 which was released today (2022-03-10) and it now supports an interpolation keyword for torch.quantile
torch.quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None):
...
We should now be able to unify the percentile API support across backends.