Implementing Numpy and JAX substrates using exoplanet-core
This is a big one and it'll be nice when it's finished!
To do:
- [ ] Implement JAX compat
- [ ] Update dev docs
- [ ] Update tutorials
- [ ] Write tests for substrates
- [ ] Write tutorial(s?) for substrates
There was some discussion on Twitter and elsewhere about the interface for this change. Here's my summary:
My original proposal:
import exoplanet.pymc as xo
# and
import exoplanet.jax as xo
# and
import exoplanet.numpy as xo
was met with some criticism because it seems like exoplanet now exports, for example, numpy (which I certainly agree is a bit odd!). Some alternatives were proposed:
-
From @barentsen:
xo.use("...")for consistency with matplotlib syntax. This won't work for us here because the particular interface/backend chosen will depend on the user's code. For example, when using PyMC3 for inference, users will generally want to import thetheano/pymcimplementation, whereas users of TensorFlow probability or numpyro will want to import thejaximplementation. Even though the exoplanet interface (things likexo.orbits.KeplerianOrbit) will have the same syntax, users will really care about which backend they have imported. Furthermore, there are many good reasons why users might want to use both the numpy interface and another within the same program (in order to simulate data and then fit it, respectively). -
From @jedbrown:
from expolanet.jax import xoto have a less confusing export. This would certainly work, but I'm not sure that it's far better than the original or my proposal below. I would probably rephrase this asfrom exoplanet.jax import exoplanet as xosincexois more like thenpfor numpy and I don't want to start aplvspltflame war :D
Some other options include:
-
Having completely separate packages
exoplanet_jax,exoplanet_pymc, etc., but that adds some maintenance overhead and I'm still not convinced thatimport exoplanet_jax as xois much less confusing thanimport exoplanet.jax as xo. -
Automatically detecting the context in which the library is being used. This would be slick, but it seems hard to do properly and it might be tricky to support multiple backends within the same script. I think explicit is better.
Finally, there was also some words of warning from @twiecki that supporting multiple backends might not be worth it, but this isn't quite the whole story here because of the general design of exoplanet. The key is that exoplanet is not a high-level interface for doing inference. Instead, exoplanet provides the building blocks to construct probabilistic models inside of higher-level frameworks. At it's core, it's really just a couple of custom C++ ops and backpropagation rules that evaluate exoplanet/astrophysics-specific models. While I agree that it adds some maintenance and contribution overhead to provide interfaces to these ops that support multiple frameworks, I think it's worth it! It's not obvious which inference library (PyMC 3+, TensorFlow probability, numpyro, emcee, dynesty, ...) is best suited for all exoplanet inference, and if a small amount of code is sufficient to expose support for all of these frameworks then I'm all for it.
It's also worth noting that there are plans to migrate PyMC to a JAX backend which means that it was going to be necessary to implement JAX-compatible XLA ops for all of the exoplanet code anyways.
My proposed API
I think that, in the short term, PyMC (perhaps implicitly with JAX used behind the scenes) will still be the primary interface, so I think that supporting import exoplanet as xo as an alias for the pymc interface would be good. Then I propose to expose the other interfaces using either the same syntax as TensorFlow probability:
import exoplanet.substrates.jax as xo
or the word interfaces:
import exoplanet.interfaces.jax as xo
since the goal is not quite the same as TFP's substrates. This still has the issue that it's strange to export modules called numpy or jax, but it does have a precedent in a similar domain. If folks have thoughts about this, I'd always love to hear them and apologies in advance if I'm a little stubborn and over-committed to what I've done so far.
I don't see import numpy as np as a pattern that needs repeating. It's used because import np would be claiming the non-descript np in the global namespace and yet people get tired of writing numpy.array and numpy.exp. You're already under the exoplanet namespace so you don't have to worry about collisions.
Re your "other option 2": do these interfaces only need to be accessed inside a with lib.Model() as model: block, or do you need a compatible set outside the model block? What about this, which is very explicit and self-contained:
with lib.Model() as model:
xo = exoplanet.api(model)
...
You could maybe offer a convenience wrapper to save a line
with exoplanet.Model(lib) as (xo, model):
...
Looks like a great proposal but I guess I would still ask what the benefit to the user is, or whether you've gotten requests from users for this. Bambi supported stan and pymc3 because the authors thought it would be good to let the user switch between them but it really made no difference, you still specified your model in the same way and got the same answers. A lot of developer and code complexity was spent on it which could have been spent on improving UX, adding features, documentation etc. Now there were 2 areas of bugs and things to support, all new features needed to be implemented for both backends, so dev velocity was severely negatively affected. So while it might be cool, I would still think carefully about the cost-benefit-tradeoff.
@jedbrown: Good point about namespaces! I'll think a little more about that part. For PyMC, your proposed model context interface would work, but it's definitely trickier in other cases. For example, numpyro and emcee, two use cases that I want to support, both use different model syntax. But the exoplanet.api(...) syntax (perhaps with something else, like a string, in the function call) is definitely something to consider - thanks!
@twiecki: I totally agree that it's not obvious and that it's important to do the cost-benefit analysis! I do get lots of requests for the numpy API because there are lots of use cases where the Theano dependency and model compilation overhead are not beneficial. In fact, I'd already implemented a subset of the API bypassing Theano/PyMC entirely. As for JAX, that's just for me so far, but (in my opinion) I'm my most important customer so I think it's worth it for now :D. I think I'm going to push forward with this while making sure that I market the PyMC interface as the stable one. As I said, the JAX ops are necessary (and I'm already getting some nice benefits out of them, more soon!) but a full API implementation is just the cherry on top so I agree that I should be careful about promising too much!
Thanks, both, for the feedback!
@dfm Sounds good :+1:, excited for the JAX backend!