aesara
aesara copied to clipboard
Add JAX Op for wrapping user-defined JAX functions
@dfm has built this cool wrapper that allows JAX functions to be turned into Aesara Ops easily: https://gist.github.com/dfm/a2db466f46ab931947882b08b2f21558
This would be a great feature to include as PyMC3 users could add arbitrary deterministics coded in JAX to their PyMC3 model.
Proposing to add this to the library.
It might be better to just add something like this to the docs rather than trying to support a feature since I expect it would be hard to do shape and dtype inference in a general non-kludgey way. Also: it would probably be good to add a grad method using the vjp, but I didn't work that through yet.
This would be a great feature to include as PyMC3 users could add arbitrary deterministics coded in JAX to their PyMC3 model.
I'm not aware of any JAX issues involving PyMC3 Deterministic
s. As far as I can tell, they should've always been supported via the standard Aesara-to-JAX conversion process.
Also, that example in the Gist isn't a single function that converts JAX functions to Op
s, so it can't really serve as a new utility function of any sort.
More importantly, this isn't the kind of approach to Op
creation that we want to promote, because—as @dfm mentioned—there are a lot of important things to consider and quite a few ways in which this could go wrong.
For instance, is every JAX function completely NumPy compatible (i.e. always takes and returns NumPy-compatible arrays)? If any return nested list
s, a tuple
, dict
, etc., those would likely result in some very cryptic errors. Do we need to be concerned about views on data: e.g. does JAX take and/or return array views in a consistent way, or does it always make copies? These sorts of things could subtly ruin performance or break things completely.
In general, we don't want to treat JAX as a NumPy/SciPy replacement, or some other low(ish)-level numerical library. Those are the kinds of things that Op
s are intended to model.
It's a little deceiving through our use of jax.numpy
and the like, but what we're really trying to do is convert Aesara graphs to JAX graphs (i.e. high-level graph to high-level graph), so these NumPy-like JAX functions are not the target; we're just using them as a convenient means of constructing the equivalent JAX graphs. With that in mind, there's really no place in Aesara for the parts of JAX that perform numeric computations.
FYI: this is also the exact same approach described here, where it was given as a means of using shared variables alongside the NumPyro JAX sampler code.
I followed up somewhat recently to clarify that the Op.perform
part of this approach is unimportant/undesirable, and that—in these rare cases—it's probably better to raise a NotImplementedError
so that people won't mistakenly use these Op
s when they're not compiling to JAX.
Overall, this is a decent way to include arbitrary JAX code into the JAX conversion process, but, without putting special attention into the Op
implementation, it could result in some really poor Op
s that shouldn't be used in any other scenario.
@brandonwillard I agree that we need something like this so that JAX code can interact with Theano Op
s. What do you mean however, special attention? I.e. what would need to be added here to make it better?
Also as for grads: shouldn't these come automatically from the JAX autodiff? I don't think we need our own grads for JAX anyway so this should already work for this Op.
...we need something like this so that JAX code can interact with Theano
Op
s.
I'm saying that we might need to occasionally use this approach as a one-off to include arbitrary JAX expressions into JAX-ified Aesara graphs, but, outside of that context, we really shouldn't create Op
s for JAX functions.
I think this would be a super nice feature for users who have custom JAX code and want to use it with their pymc3 model. The same for custom distributions. Seems like this Op would make that pretty straight forward.
Do we still want this feature? Are there settings where we want the user writing JAX directly vs using aesara?
Yeah, I think as user-facing this could be very useful. Alternatively we could place it into PyMC.
We write JAX-only custom Ops all the time at labs.
Just like we originally did in PyMC for Distribution
s, we can always automate the creation of dispatch functions during the construction of an Op
(e.g. within __new__
based on method/function names), but doing this at the type level is not the best approach (i.e. via a compilation target-specific type like a JaxOp
).
Heads up, that snippet looks easy because it misses a lot of details and functionality.
Specially concerning input-output types, gradients and shape inference, as @dfm mentioned already.
For instance it fails if you pass numpy inputs, because it does not convert to TensorVariables as most Ops do in make_node.
For outputs it's important for Aesara to know if an output dimension will be broadcastable or not, and that can't be automated either.
Then there are also issues of float precision, or you'll get a lot of warnings from jax about truncated float64.
I am not convinced we can offer something that won't be too brittle or frustrating to users. Also I fear people just abusing it and writing full Aesara graphs in jax wrapped Ops, many times unaware that Aesara provides those same functions with extras like rewrites.
Can we close this as Not Planned?