Implement `@as_jax_op` to wrap a JAX function for use in PyTensor
Description
Add a decorator that transforms a JAX function such that it can be used in PyTensor. Shape and dtype inference works automatically and input and output can be any nested python structure (e.g. Pytrees). Furthermore, using a transformed function as an argument for another transformed function should also work.
Related Issue
- [x] Closes #537
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
ToDos
- [ ] Implement
Op.__props__ - [x] Let
make_nodebe specified by the user, to support non-inferrable shapes - JAXOp is now directly usable by the user - [x] Add tests for JAXOp
- [ ] Add some meaningful error messages to common runtime errors
📚 Documentation preview 📚: https://pytensor--1120.org.readthedocs.build/en/1120/
I have a question, where should I put the @as_jax_op. Currently, it is in a new file pytensor/link/jax/ops.py. Does that make sense? Also, how should one access it? Only by calling pytensor.link.jax.ops.as_jax_op? Or include it in a __init__.py such that pytensor.as_jax_op works?
We can put in init as long as imports work in a way that jax is still optional for Pytensor users (obviously calling the decorator can raise if it's not installed, hopefully with an informative message)
Big level picture. What's going on with the flattening of inputs and why is it needed?
Big level picture. What's going on with the flattening of inputs and why is it needed?
To be able to wrap JAX function that accept pytrees as input. pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve) works, and can be used in the same way as one would use the original diffrax.diffeqsolve.
Big level picture. What's going on with the flattening of inputs and why is it needed?
To be able to wrap JAX function that accept pytrees as input.
pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve)works, and can be used in the same way as one would use the original diffrax.diffeqsolve.
And if I have a matrix input function will this work or expect it to be a vector instead?
Big level picture. What's going on with the flattening of inputs and why is it needed?
To be able to wrap JAX function that accept pytrees as input.
pytensor_diffeqsolve = as_jax_op(diffrax.diffeqsolve)works, and can be used in the same way as one would use the original diffrax.diffeqsolve.And if I have a matrix input function will this work or expect it to be a vector instead?
It will work, it doesn't change pytensor.Variables, a matrix will stay a matrix. What it does, is to flatten nested python structure, e.g. {"a": tensor_a, "b": [tensor_b, tensor_c]} becomes [tensor_a, tensor_b, tensor_c] (and a treedef object which saves the structure of the tree), where tensor_x are three different tensors of potentially different shape and dtype. As pytensor operators accept a list of tensors as input, the flattened version can be used to define our op. The shapes of the tensors aren't changed. This is also basically how operators in JAX are written, see the second code box in this paragraph: https://jax.readthedocs.io/en/latest/autodidax.html#pytrees-and-flattening-user-functions-inputs-and-outputs
I would begin in parallel to write an example notebook. I opened an issue here
Regarding the functionality and whether one should remove some of it for sake of simplicity: My goal was that diffrax.diffeqsolve can be easily wrapped, but I understand that this might not be the goal for an inclusion in pytensor. The wrapping of diffrax.diffeqsolve requires three parts:
- Infering shapes. That includes infering the dimensionality of (None,) dimensions in the pytensor graph by walking up the predecessors until a non-None dimension is found in in its parents; and infering the output shapes from the jax function. I think we want to keep this, as it drastically facilitates the usage of wrapper.
- Automatically dealing with non-numerical arguments and outputs. Some of the arguments of
diffrax.diffeqsolveare not array-like objects, and also some of the returned values are non-array objects. The additional code for this functionality is not much, basically 4 lines: thept_vars/static_varspartitioning witheqx.partitionin line 99, the output partitioning in line 307, the input combination witheqx.combinein line 302 and the output combination in line 213. I would argue that this is quite useful: non-numerical arguments are quite common in jax functions (but non-numerical outputs less), and it eliminates a whole category of potential runtime errors, that is that the wrapper tries to either transform non-numerical variables to jax, or non-numerical outputs to pytensor variables. - Allowing wrapped jax functions as arguments to a wrapped function. The reason I programmed it, is that if ODEs have time-dependent variables, one has to define the time-dependent variables as functions that interpolate between the variables between their definition timepoints. As this function has to be called from inside the system of differential equation defined in JAX, one cannot directly use a pytensor function; and also one cannot use a jax function if one wants to add the time-dependent variables to the pytensor graph. The workaround for time-dependent ODEs if we would remove this functionality would be to write everything in a JAX function: the definition of the interpolation function and the ODE solver, and wrap with
@as_jax_opthe whole function. This functionality does add quite a bit of complexity, namely the class_WrappedFuncand severaleqx.partitionandeqx.combinein the rest of the code. One could think to remove it, I don't have a strong opinion about it. I also don't think it is useful for other use cases besides differential equations.
One additional remark, removing the functionality of point 2 and 3 would also remove the additionally dependency on equinox, I don't know how relevant it is for the decision I will go through you other remarks later.
Infering shapes. That includes infering the dimensionality of (None,) dimensions in the pytensor graph by walking up the predecessors until a non-None dimension is found in in its parents; and infering the output shapes from the jax function. I think we want to keep this, as it drastically facilitates the usage of wrapper.
There's already something like that: infer_static_shape that's used in Ops like Alloc and RandomVariable where the output types are functions of the input values, not just their types. However, note that something like x = vector(shape=(None,)) is a fundamental valid PyTensor type and we shouldn't a-priori prevent jaxifying Ops with these types of inputs. It's also very common. All PyMC models with dims actually look like None shape variables, because those are allowed to change size.
I suggested allowing the user to specify make_node which is the Op API of specifying how input types translate into output types in PyTensor. The static-shape logic you're doing can be a default, but shouldn't be the only option because it's fundamentally limited.
Automatically dealing with non-numerical arguments and outputs.
My issue with non-numerical outputs is that, from reading the tests, are arbitrarily truncated? In that test where a JAX function has a string output. PyTensor is rather flexible in what types of variables it can accommodate, for instance we have string types implemented here: https://github.com/pymc-devs/pymc/blob/e0e751199319e68f376656e2477c1543606c49c7/pymc/pytensorf.py#L1101-L1116
PyTensor itself has sparse matrices, homogenous lists, slices, None, scalars ... As such it seems odd to me to support some extra types only on this JAX wrapper Op helper. If those types are deemed useful enough for this wrapper to handle them, then the case would be made we should add them as regular PyTensor types, and not-special case JAX.
I guess I'm just not clear as to what the wrapper is doing with these special inputs (I'm assuming outputs are just being ignored as I wrote above). For the inputs, it's creating a partial function on the perform method? Then it sounds like they should also be implemented as Op.__props__, which is the PyTensor API for parametrizing Ops with non-symbolic inputs. It uses this for nice debugprint and reasoning for stuff like two Ops with the same props and inputs can be considered equivalent and merged.
Allowing wrapped jax functions as arguments to a wrapped function.
Also seems somewhat similar to PyTensor Ops with inner functions (ScalarLoop, OpFromGraph, Scan), that compile inner PyTensor functions (or dispatched equivalents on backends like JAX).
I guess the common theme is that this PR may be reinventing several things that PyTensor already does (I could be wrong), and there may be room to reuse existing functionality, or expanding it so that it's not restricted to the JAX backend, and more specifically this wrapper. Let me know if any of this makes sense.
I added a ToDo list in the first post, so you can check the progress. I refactored the code with the help of Cursor AI, and now JAXOp can also be called directly, which can be used to specify undetermined output shapes. I also think it would be useful to have a Zoom meeting to have a better idea of which direction to go. You could for example write me via the Pymc discourse. Mondays and Tuesdays are quite full for me, but otherwise, I am generally available.
Thank you for the great work! I would love to see this feature implemented!
The reason I programmed it, is that if ODEs have time-dependent variables, one has to define the time-dependent variables as functions that interpolate between the variables between their definition timepoints [...] One could think to remove it, I don't have a strong opinion about it. I also don't think it is useful for other use cases besides differential equations.
I strongly support to keep time-dependent variables, since this is a feature that sunode does not support to my knowledge. I'm currently exploring the use of pymc in favor of numpyro for inference of ODEs and having a module that translates existing ode models directly to pytensor would be fantastic.