catalyst
catalyst copied to clipboard
Automatic conversion of operator-based array updates
Context
Catalyst uses a source-to-source transformation package called AutoGraph, which allows users to write regular Python code that is automagically transformed into JAX-style traceable code. For example, the following for loop can automatically be compiled by Catalyst to execute at run-time using the autograph=True option, whereas we would otherwise need to explicitly use the functional form:
from catalyst import *
@qjit(autograph=True)
def f(n: int):
for i in range(n):
debug.print(i)
f(5)
Recently, Catalyst got support for automatically converting "in-place" array updates to their JAX compatible form. That is, where you would normally do arr[i] = 5 in NumPy, JAX requires arr = arr.at[i].set(5) because JAX does not allow mutation of arrays. You can read more about this peculiarity here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates
The feature currently only supports setting a single array element, so we would like to improve it by allowing "in-place" operators as well, like +=, -=, *=, /=, etc. JAX has a complete list of equivalent functions available via the .at method here: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Goal
We would like to be able to run the following program in Catalyst, which currently raises an error because the statement is not converted to anything:
@qjit(autograph=True)
def f(x, i):
x[i] *= 3
return x
f(jnp.ones(5), 2)
TypeError: '<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method
whereas we would like the output to be:
Array([1., 1., 3., 1., 1.], dtype=float64)
Requirements:
- Enable conversion of at least 1 in-place operator, for example
*=. - The feature only needs to work with a single index, but it must support a static as well as a dynamic index (think constant value vs function argument variable).
- If the assigned-to object is not a JAX array, the implementation should fall back to the regular Python operator.
- Implement unit tests for this feature.
- Update relevant AutoGraph documentation in Catalyst.
Technical Details
A good starting point is to look at the implementation of the existing feature. Catalyst uses the AutoGraph implementation from the malt repository, which provides a set of AST transformers and operator implementations. The AST transformers replace pieces of Python code, like a for loop, with a function/operator that can easily be overloaded (by default they just execute the Python version of the operators). Catalyst then provides their own operator implementations to execute JAX traceable code instead.
The existing in-place assignment implementation can be found here:
https://github.com/PennyLaneAI/catalyst/blob/5cc12399d30bd42ae15b11ab12878030ad5fb9ee/frontend/catalyst/autograph/ag_primitives.py#L578
It relies on the SliceTransformer to replace the indexed assignment with the set_item operator.
Unfortunately, converting statements like arr[i] **= 2`` via AutoGraph cannot be done with the existing converter patterns. This is because while the experimental [lists](https://github.com/PennyLaneAI/diastatic-malt/blob/8a09938e9054e352b749d20522bc9aeb173307fd/malt/core/converter.py#L113) feature will convert x[5] = 3intox = ag__.set_item(ag__.ld(x), 5, 3), there is currently no equivalent conversion for x[5] += 3, which remains as ag__.ld(x)[5] += 3`.
An implementation would thus require a new AST converter pattern for each of the desired in-place operators.
A rough outline of the feature might look as follows:
- Create a new AST Transformer class that converts an AST node from a statement like
x[i] *= 3into a call to a new operator (sayupdate_item_with_op).- To keep things simple, this new class can be placed somewhere inside the
autographmodule in Catalyst, without modifying themaltpackage. - This will then require overwriting the inherited transform_ast function on the Catalyst side (in the
CatalystTransformerclass), to invoke the new transformer.
- To keep things simple, this new class can be placed somewhere inside the
- Create an implementation of
update_item_with_opin the Catalyst ag_primitives.py module.- The function should ensure the input is a JAX array, or else false back to the Python version of the operator.
- The function should dispatch the implementation based on the operator flavor (
+=,*=, etc) used by the user (but they don't all need to be implemented).
Tips:
- To inspect AST code for a function you can use the
ast/gastmodule in conjuction withinspect:print(gast.dump(gast.parse(inspect.getsource(f)), indent=2)). The Python debugger will also be helpful. - Additional information on the AutoGraph system can be found in the official docs hosted by TensorFlow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md
Installation Help
Since this issue is Python-only, the time to set up a development environment for Catalyst can be cut short by installing an editable version of the latest nightly build:
- download the latest nightly build from the
mainbranch, such as https://github.com/PennyLaneAI/catalyst/actions/runs/9131705639 (see artifacts at the bottom) - clone the catalyst repo & check out the commit corresponding to the build, in the given example
git checkout 5cc1239 - extract the wheel into the frontend directory of the repo (after extracting the wheel from the zip file), for example via
unzip PennyLane_Catalyst*.whl -d catalyst/frontend(overwrite any existing files)
The following steps assume you are in the root of the repo:
- install Python developer requirements,
pip install -r requirements.txt - create an editable installation with
make frontend - test your installation with
make pytest
Refer to the Catalyst installation guide for general system and environment requirements.