jax
jax copied to clipboard
Array dispatching with __array_ufunc__ in JAX
Is it possible to dispatch jax functions to custom array-like classes?
For instance, in the example below, I have a class that represents an array with non-zero elements on the main diagonal (only those are stored). I would like to use my custom _exp
method which requires computing the exponential of each element only on the diagonal elements.
I can achieve this in NumPy using __array_ufunc__
. Is there any equivalent way in JAX?
import jax.numpy as jnp
import numpy as np
class DiagonalArray:
def __init__(self, diagonal):
self.diagonal = jnp.asarray(diagonal)
def __jax_array__(self):
return jnp.diag(self.diagonal)
def _exp(self):
print('_exp called')
return DiagonalArray(jnp.exp(self.diagonal))
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if ufunc in (jnp.exp, np.exp):
return self._exp()
else:
return NotImplemented
x = DiagonalArray([1, 2, 3])
np.exp(x) # _exp called
jnp.exp(x) # _exp not called