catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Add support to AutoGraph for

Open Spencer-Comin opened this issue 1 year ago • 1 comments

Context: https://github.com/PennyLaneAI/catalyst/pull/717 added support for converting in-place array updates (arr[i] = x) into the equivalent JAX traceable code (arr.at[i].set(x)). This change extends that support to operator assignment array updates.

Description of the Change:

  • Add new Autograph converter to map AugAssign ast nodes assigning to a single index subscript to calls to update_item_with_{add|sub|mult|div|pow}
  • Implement update_item_with_{add|sub|mult|div|pow} methods that map to the corresponding jax.numpy.ndarray.at equivalent methods for JAX arrays and the normal Python operator assignment otherwise
  • Overload transform_ast in CatalystTransformer to invoke the new converter

Benefits: We can use arr[i] += x instead of arr.at[i].add(x).

Possible Drawbacks: It would be cleaner to have the new converter live in the DiastaticMalt project.

Related GitHub Issues: https://github.com/PennyLaneAI/catalyst/issues/757

Based on the solution presented in this PR: https://github.com/PennyLaneAI/catalyst/pull/717

Spencer-Comin avatar May 28 '24 03:05 Spencer-Comin