ndarray
ndarray copied to clipboard
Vectorisation like `np.vectorize` or `jax.vmap`
This is a feature suggestion.
One of the important feature of numpy
and jax
in Python is to provide a vectorisation scheme for function.
Here is a mini example in Python:
import jax.numpy as np
# This function takes two float arguments and output a float
def func(x: float, y: float) -> float:
return x - y
# This vmap will give a vectorised func, now this vectorised_func
# can take an array as input and produce an array output.
# Here[0, None] means that we vectorise func's first argument only.
vectorised_func = jax.vmap(func, in_axes=[0, None])
# Test arrays
a = jnp.array([1., 2., 3])
b = 2.
vectorised_func(a, b) # This will output an array([-1, 0, 1])
# You can even do more to vectorise the y argument of func also
even_more_vectorised_func = jax.vmap(vectorised_func, in_axes=[None, 0])
#
c = jnp.array([3., 2., 1.])
even_more_vectorised_func(a, c) # This will output a matrix ([[-2, -1, 2], [-1, 0, 2], [0, 1, 2]])
In numpy
, this is similar to np.vectorize
, but np.vectorize
is actually not using any parallelisation scheme unlike jax.vmap
.
This feature, in my opinion, is very important for scientific computing. For example, if you have two arrays, you can use this to compute their pairwise distances. As an another example, in multidimensional numerical quadrature, computing \int f(x) dx requires to evaluting the function f by a number of node points, this could be significantly improved by vectorisation.
I am a researcher in signal processing and machine learning, and I can definetly say this feature is valuable.