aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Add JAX Op for wrapping user-defined JAX functions

Open twiecki opened this issue 3 years ago • 10 comments

@dfm has built this cool wrapper that allows JAX functions to be turned into Aesara Ops easily: https://gist.github.com/dfm/a2db466f46ab931947882b08b2f21558

This would be a great feature to include as PyMC3 users could add arbitrary deterministics coded in JAX to their PyMC3 model.

Proposing to add this to the library.

twiecki avatar Mar 03 '21 09:03 twiecki

It might be better to just add something like this to the docs rather than trying to support a feature since I expect it would be hard to do shape and dtype inference in a general non-kludgey way. Also: it would probably be good to add a grad method using the vjp, but I didn't work that through yet.

dfm avatar Mar 03 '21 10:03 dfm

This would be a great feature to include as PyMC3 users could add arbitrary deterministics coded in JAX to their PyMC3 model.

I'm not aware of any JAX issues involving PyMC3 Deterministics. As far as I can tell, they should've always been supported via the standard Aesara-to-JAX conversion process.

Also, that example in the Gist isn't a single function that converts JAX functions to Ops, so it can't really serve as a new utility function of any sort.

More importantly, this isn't the kind of approach to Op creation that we want to promote, because—as @dfm mentioned—there are a lot of important things to consider and quite a few ways in which this could go wrong.

For instance, is every JAX function completely NumPy compatible (i.e. always takes and returns NumPy-compatible arrays)? If any return nested lists, a tuple, dict, etc., those would likely result in some very cryptic errors. Do we need to be concerned about views on data: e.g. does JAX take and/or return array views in a consistent way, or does it always make copies? These sorts of things could subtly ruin performance or break things completely.

In general, we don't want to treat JAX as a NumPy/SciPy replacement, or some other low(ish)-level numerical library. Those are the kinds of things that Ops are intended to model.

It's a little deceiving through our use of jax.numpy and the like, but what we're really trying to do is convert Aesara graphs to JAX graphs (i.e. high-level graph to high-level graph), so these NumPy-like JAX functions are not the target; we're just using them as a convenient means of constructing the equivalent JAX graphs. With that in mind, there's really no place in Aesara for the parts of JAX that perform numeric computations.

brandonwillard avatar Mar 04 '21 01:03 brandonwillard

FYI: this is also the exact same approach described here, where it was given as a means of using shared variables alongside the NumPyro JAX sampler code.

I followed up somewhat recently to clarify that the Op.perform part of this approach is unimportant/undesirable, and that—in these rare cases—it's probably better to raise a NotImplementedError so that people won't mistakenly use these Ops when they're not compiling to JAX.

Overall, this is a decent way to include arbitrary JAX code into the JAX conversion process, but, without putting special attention into the Op implementation, it could result in some really poor Ops that shouldn't be used in any other scenario.

brandonwillard avatar Mar 04 '21 02:03 brandonwillard

@brandonwillard I agree that we need something like this so that JAX code can interact with Theano Ops. What do you mean however, special attention? I.e. what would need to be added here to make it better?

Also as for grads: shouldn't these come automatically from the JAX autodiff? I don't think we need our own grads for JAX anyway so this should already work for this Op.

twiecki avatar Mar 04 '21 08:03 twiecki

...we need something like this so that JAX code can interact with Theano Ops.

I'm saying that we might need to occasionally use this approach as a one-off to include arbitrary JAX expressions into JAX-ified Aesara graphs, but, outside of that context, we really shouldn't create Ops for JAX functions.

brandonwillard avatar Mar 07 '21 01:03 brandonwillard

I think this would be a super nice feature for users who have custom JAX code and want to use it with their pymc3 model. The same for custom distributions. Seems like this Op would make that pretty straight forward.

twiecki avatar Mar 07 '21 18:03 twiecki

Do we still want this feature? Are there settings where we want the user writing JAX directly vs using aesara?

zaxtax avatar Mar 30 '22 16:03 zaxtax

Yeah, I think as user-facing this could be very useful. Alternatively we could place it into PyMC.

We write JAX-only custom Ops all the time at labs.

twiecki avatar Mar 30 '22 18:03 twiecki

Just like we originally did in PyMC for Distributions, we can always automate the creation of dispatch functions during the construction of an Op (e.g. within __new__ based on method/function names), but doing this at the type level is not the best approach (i.e. via a compilation target-specific type like a JaxOp).

brandonwillard avatar Mar 30 '22 18:03 brandonwillard

Heads up, that snippet looks easy because it misses a lot of details and functionality.

Specially concerning input-output types, gradients and shape inference, as @dfm mentioned already.

For instance it fails if you pass numpy inputs, because it does not convert to TensorVariables as most Ops do in make_node.

For outputs it's important for Aesara to know if an output dimension will be broadcastable or not, and that can't be automated either.

Then there are also issues of float precision, or you'll get a lot of warnings from jax about truncated float64.

I am not convinced we can offer something that won't be too brittle or frustrating to users. Also I fear people just abusing it and writing full Aesara graphs in jax wrapped Ops, many times unaware that Aesara provides those same functions with extras like rewrites.

ricardoV94 avatar Mar 30 '22 18:03 ricardoV94

Can we close this as Not Planned?

rlouf avatar Sep 15 '22 08:09 rlouf