jax icon indicating copy to clipboard operation
jax copied to clipboard

Array dispatching with __array_ufunc__ in JAX

Open gautierronan opened this issue 8 months ago • 8 comments

Is it possible to dispatch jax functions to custom array-like classes?

For instance, in the example below, I have a class that represents an array with non-zero elements on the main diagonal (only those are stored). I would like to use my custom _exp method which requires computing the exponential of each element only on the diagonal elements.

I can achieve this in NumPy using __array_ufunc__. Is there any equivalent way in JAX?

import jax.numpy as jnp
import numpy as np

class DiagonalArray:
    def __init__(self, diagonal):
        self.diagonal = jnp.asarray(diagonal)

    def __jax_array__(self):
        return jnp.diag(self.diagonal)

    def _exp(self):
        print('_exp called')
        return DiagonalArray(jnp.exp(self.diagonal))

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if ufunc in (jnp.exp, np.exp):
            return self._exp()
        else:
            return NotImplemented

x = DiagonalArray([1, 2, 3])
np.exp(x)  # _exp called
jnp.exp(x) # _exp not called

gautierronan avatar Jun 25 '24 15:06 gautierronan