catalyst
catalyst copied to clipboard
Add support to AutoGraph for
Context: https://github.com/PennyLaneAI/catalyst/pull/717 added support for converting in-place array updates (arr[i] = x) into the equivalent JAX traceable code (arr.at[i].set(x)). This change extends that support to operator assignment array updates.
Description of the Change:
- Add new Autograph converter to map
AugAssignast nodes assigning to a single index subscript to calls toupdate_item_with_{add|sub|mult|div|pow} - Implement
update_item_with_{add|sub|mult|div|pow}methods that map to the correspondingjax.numpy.ndarray.atequivalent methods for JAX arrays and the normal Python operator assignment otherwise - Overload
transform_astinCatalystTransformerto invoke the new converter
Benefits: We can use arr[i] += x instead of arr.at[i].add(x).
Possible Drawbacks: It would be cleaner to have the new converter live in the DiastaticMalt project.
Related GitHub Issues: https://github.com/PennyLaneAI/catalyst/issues/757
Based on the solution presented in this PR: https://github.com/PennyLaneAI/catalyst/pull/717