numpyro
numpyro copied to clipboard
Support multiple tensor backends via NEP-47 (Python array API standard)
The recently proposed NEP-47 attempts to unify the APIs of various tensor frameworks (NumPy, Tensorflow, PyTorch, Dask, JAX, CuPy, MXNet, etc.), via the Python array API standard.
It is a much more compact version of the original NumPy APIs, removing unnecessary functions that are not friendly to heterogenous hardware like GPUs.
Since NumPyro is using JAX as backend, whose APIs closely match NumPy, it should be quite doable to adopt NEP-47 for multi-backend support?
Related discussion:
- https://github.com/data-apis/array-api/issues/1
- https://discourse.pymc.io/t/support-multiple-tensor-backends-via-nep-47-python-array-api-standard/7686
I quickly browsed through NumPyro source code to locate JAX-heavy code. For example, the most frequently used APIs include:
-
numpyro.sample
-
numpyro.infer.{MCMC, NUTS}
-
numpyro.distributions.{Normal, Exponential}
-
sample
is defined innumpyro/primitives.py
, relying onjax.random.{randint, choice, PRNGKey}
and other small ops.
https://github.com/pyro-ppl/numpyro/blob/f8f482ad7b2f3f658d198c99f76fdda12cff84ca/numpyro/primitives.py#L104-L110
-
infer.{MCMC, NUTS}
are defined innumpyro/infer/{mcmc.py, hmc.py}
, relying on JAX-specific utils likejit
,vmap
,pmap
,tree_util.tree_map
, and probablygrad
.
https://github.com/pyro-ppl/numpyro/blob/f8f482ad7b2f3f658d198c99f76fdda12cff84ca/numpyro/infer/mcmc.py#L198-L200
https://github.com/pyro-ppl/numpyro/blob/f8f482ad7b2f3f658d198c99f76fdda12cff84ca/numpyro/infer/hmc.py#L764-L767
-
distributions.{Normal, Exponential}
are defined innumpyro/distributions/continuous.py
, relying on JAX's linear solvers and special functions:
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import betainc, expit, gammaln, logit, multigammaln, ndtr, ndtri
All of the above should be relatively easy to implement, using NEP-47 core APIs + a few framework-specific utilities (MXNet, CuPy, ChainerX, etc.). The NumPy frontend can be lowered to a variety of IRs like tvm.relay
and MLIR, which can support a diverse set of hardware.
Hi @learning-chip, thanks for the suggestion. This seems like an issue best resolved at the level of PyTorch, JAX, and TensorFlow, none of which are fully compatible with NEP-47 yet (to the best of my knowledge, and as far as I can tell from the discussion in https://github.com/data-apis/array-api/issues/1). An additional complication is that only these three frameworks (again, to the best of my knowledge) have both full-featured automatic differentiation and a distributions library with full broadcasting semantics and reparameterized samplers, and there are already large bodies of functionally similar idiomatic PyTorch and TensorFlow code in Pyro and TFP respectively so it's not clear how much users might benefit from any attempts to support other backends in NumPyro.
There was some discussion about __array_function__
in JAX in https://github.com/google/jax/issues/1565. You may also be interested in our less comprehensive attempts at sharing code/interfaces across backends in Funsor (mostly for inference code) and pyro-api (mostly for model code).
none of which are fully compatible with NEP-47 yet
This is correct, at this moment. My question is, once those frameworks get fully-compatible with NEP-47, what would be the amount of effort to add them as new backends for NumPyro? In terms of missing features to support
, lines of code
, or people * month
, for example. This will affect whether the framework developer team should design their own Bayesian learning library, or simply reuse (Num)Pyro.
There seems to be more and more HPC/AI frameworks providing NumPy-like API. Nvidia has recently open-sourced Legate NumPy; and DaCe is another framework that excels at optimizing HPC kernels and runs on FPGAs.
only these three frameworks (again, to the best of my knowledge) have both full-featured automatic differentiation
I recall that both MXNet and ONNX are interested in the Python Data API. Maybe @rgommers and @szha are the right people to ask.
full broadcasting semantics and reparameterized samplers
Broadcasting is a key feature in Python array API standard. For "reparameterized samplers", could you elaborate more on the exact functionalities?
how much users might benefit from any attempts to support other backends in NumPyro.
From my hardware system background, I think a big difference between those AI frameworks, is their compiler & hardware support. Their software functionalities are indeed getting similar -- all providing autodiff, AI model zoo, distributed training, etc. But the compile chain is quite different: TensorFlow -> XLA -> TPU, and ONNX -> MicroTVM -> IoT/edge devices, are some unique examples. Say if Bayesian AI models get popular in autonomous vehicles (I think they already are), then you might need MicroTVM for edge deployment, a case that NumPyro+JAX does not currently support.
@learning-chip thanks for the ping and for bringing up the array API standard. The work for adopting the standard are in progress for MXNet and ONNX. Once array libraries finish implementing the compatibility with the standard it should indeed make it straightforward to switch array backends as long as the implementation only relies on the operations defined in the standard.
@learning-chip to be clear, Pyro and NumPyro rely crucially on advanced features of PyTorch and JAX (especially higher-order forward- and reverse-mode AD and a large library of probability distributions with Tensorflow-style shape semantics and reparameterized samplers) that are outside the scope of NEP-47.
As far as I know, only PyTorch, JAX and TensorFlow implement all of the relevant features, and there are many small differences across each framework's implementation and API that would make unifying them difficult and unpleasant. Unless the maintainers of these frameworks plan on standardizing all of these features as well, it's unlikely that we will be able to make our existing codebases in Pyro and NumPyro backend-independent via NEP-47.
Like most open source projects, we are a small team and do not have the developer bandwidth to reimplement and standardize these features ourselves or refactor thousands of lines of backend-specific test code in Pyro and NumPyro, although we certainly support the high-level goal of a backend-independent array programming ecosystem. Barring significant new external contributions, our efforts at backend-independence will probably remain restricted for the time being to the pyro-api
and funsor
libraries I pointed out above, which are also better targets for future NEP-47 support.
I think a big difference between those AI frameworks, is their compiler & hardware support ... you might need MicroTVM for edge deployment, a case that NumPyro+JAX does not currently support.
I would bet on XLA and TVM adding more backend support and even achieving some degree of interoperability as a near-term path to this sort of thing before than the higher-level software ecosystem adopts NEP-47 en masse, but if you or anyone reading this thread know of users who want to deploy existing Pyro/NumPyro models but can't because of missing hardware backend support, please tell them to contact us!
especially higher-order forward- and reverse-mode AD and a large library of probability distributions with Tensorflow-style shape semantics and reparameterized samplers
Thanks, this is very useful information. I will read that paper carefully, including this doc: https://www.tensorflow.org/probability/examples/TensorFlow_Distributions_Tutorial
Unless the maintainers of these frameworks plan on standardizing all of these features as well, it's unlikely that we will be able to make our existing codebases in Pyro and NumPyro backend-independent via NEP-47.
I totally understand and agree -- I just need a list of "missing features" (beyond NEP-47) that other frameworks should support, if they want to be plugged into NumPyro's backend.
Barring significant new external contributions, our efforts at backend-independence will probably remain restricted for the time being to the
pyro-api
andfunsor
libraries I pointed out above, which are also better targets for future NEP-47 support.
This sounds a very reasonable target to me.
I would bet on XLA and TVM adding more backend support and even achieving some degree of interoperability as a near-term path to this sort of thing before than the higher-level software ecosystem adopts NEP-47 en masse
XLA and TVM could be a thinner layer than NEP-47 & Python Array API, indeed. The path of interoperability is unclear to me, though. See Relay MLIR Frontend discussions.
XLA is moving towards MLIR-HLO, which is a significant & long-term change in my opinion (MLIR is a huge beast). Thus I would not expect near-term improvements 😂