catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Automatic conversion of operator-based array updates

Open dime10 opened this issue 1 year ago • 0 comments

Context

Catalyst uses a source-to-source transformation package called AutoGraph, which allows users to write regular Python code that is automagically transformed into JAX-style traceable code. For example, the following for loop can automatically be compiled by Catalyst to execute at run-time using the autograph=True option, whereas we would otherwise need to explicitly use the functional form:

from catalyst import *

@qjit(autograph=True)
def f(n: int):
    for i in range(n):
        debug.print(i)

f(5)

Recently, Catalyst got support for automatically converting "in-place" array updates to their JAX compatible form. That is, where you would normally do arr[i] = 5 in NumPy, JAX requires arr = arr.at[i].set(5) because JAX does not allow mutation of arrays. You can read more about this peculiarity here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates

The feature currently only supports setting a single array element, so we would like to improve it by allowing "in-place" operators as well, like +=, -=, *=, /=, etc. JAX has a complete list of equivalent functions available via the .at method here: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Goal

We would like to be able to run the following program in Catalyst, which currently raises an error because the statement is not converted to anything:

@qjit(autograph=True)
def f(x, i):
    x[i] *= 3
    return x

f(jnp.ones(5), 2)
TypeError: '<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method

whereas we would like the output to be:

Array([1., 1., 3., 1., 1.], dtype=float64)

Requirements:

  • Enable conversion of at least 1 in-place operator, for example *=.
  • The feature only needs to work with a single index, but it must support a static as well as a dynamic index (think constant value vs function argument variable).
  • If the assigned-to object is not a JAX array, the implementation should fall back to the regular Python operator.
  • Implement unit tests for this feature.
  • Update relevant AutoGraph documentation in Catalyst.

Technical Details

A good starting point is to look at the implementation of the existing feature. Catalyst uses the AutoGraph implementation from the malt repository, which provides a set of AST transformers and operator implementations. The AST transformers replace pieces of Python code, like a for loop, with a function/operator that can easily be overloaded (by default they just execute the Python version of the operators). Catalyst then provides their own operator implementations to execute JAX traceable code instead.

The existing in-place assignment implementation can be found here: https://github.com/PennyLaneAI/catalyst/blob/5cc12399d30bd42ae15b11ab12878030ad5fb9ee/frontend/catalyst/autograph/ag_primitives.py#L578 It relies on the SliceTransformer to replace the indexed assignment with the set_item operator.

Unfortunately, converting statements like arr[i] **= 2`` via AutoGraph cannot be done with the existing converter patterns. This is because while the experimental [lists](https://github.com/PennyLaneAI/diastatic-malt/blob/8a09938e9054e352b749d20522bc9aeb173307fd/malt/core/converter.py#L113) feature will convert x[5] = 3intox = ag__.set_item(ag__.ld(x), 5, 3), there is currently no equivalent conversion for x[5] += 3, which remains as ag__.ld(x)[5] += 3`.

An implementation would thus require a new AST converter pattern for each of the desired in-place operators.

A rough outline of the feature might look as follows:

  • Create a new AST Transformer class that converts an AST node from a statement like x[i] *= 3 into a call to a new operator (say update_item_with_op).
    • To keep things simple, this new class can be placed somewhere inside the autograph module in Catalyst, without modifying the malt package.
    • This will then require overwriting the inherited transform_ast function on the Catalyst side (in the CatalystTransformer class), to invoke the new transformer.
  • Create an implementation of update_item_with_op in the Catalyst ag_primitives.py module.
    • The function should ensure the input is a JAX array, or else false back to the Python version of the operator.
    • The function should dispatch the implementation based on the operator flavor (+=, *=, etc) used by the user (but they don't all need to be implemented).

Tips:

  • To inspect AST code for a function you can use the ast/gast module in conjuction with inspect: print(gast.dump(gast.parse(inspect.getsource(f)), indent=2)). The Python debugger will also be helpful.
  • Additional information on the AutoGraph system can be found in the official docs hosted by TensorFlow: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md

Installation Help

Since this issue is Python-only, the time to set up a development environment for Catalyst can be cut short by installing an editable version of the latest nightly build:

  • download the latest nightly build from the main branch, such as https://github.com/PennyLaneAI/catalyst/actions/runs/9131705639 (see artifacts at the bottom)
  • clone the catalyst repo & check out the commit corresponding to the build, in the given example git checkout 5cc1239
  • extract the wheel into the frontend directory of the repo (after extracting the wheel from the zip file), for example via unzip PennyLane_Catalyst*.whl -d catalyst/frontend (overwrite any existing files)

The following steps assume you are in the root of the repo:

  • install Python developer requirements, pip install -r requirements.txt
  • create an editable installation with make frontend
  • test your installation with make pytest

Refer to the Catalyst installation guide for general system and environment requirements.

dime10 avatar May 21 '24 16:05 dime10