ndarray icon indicating copy to clipboard operation
ndarray copied to clipboard

Vectorisation like `np.vectorize` or `jax.vmap`

Open zgbkdlm opened this issue 2 years ago • 0 comments

This is a feature suggestion.

One of the important feature of numpy and jax in Python is to provide a vectorisation scheme for function.

Here is a mini example in Python:

import jax.numpy as np

# This function takes two float arguments and output a float
def func(x: float, y: float) -> float:
    return x - y

# This vmap will give a vectorised func, now this vectorised_func
# can take an array as input and produce an array output.
# Here[0, None] means that we vectorise func's first argument only.
vectorised_func = jax.vmap(func, in_axes=[0, None])

# Test arrays
a = jnp.array([1., 2., 3])
b = 2.

vectorised_func(a, b)  # This will output an array([-1, 0, 1])

# You can even do more to vectorise the y argument of func also
even_more_vectorised_func = jax.vmap(vectorised_func, in_axes=[None, 0])

# 
c = jnp.array([3., 2., 1.])
even_more_vectorised_func(a, c) # This will output a matrix ([[-2, -1, 2], [-1, 0, 2], [0, 1, 2]])

In numpy, this is similar to np.vectorize, but np.vectorize is actually not using any parallelisation scheme unlike jax.vmap.

This feature, in my opinion, is very important for scientific computing. For example, if you have two arrays, you can use this to compute their pairwise distances. As an another example, in multidimensional numerical quadrature, computing \int f(x) dx requires to evaluting the function f by a number of node points, this could be significantly improved by vectorisation.

I am a researcher in signal processing and machine learning, and I can definetly say this feature is valuable.

zgbkdlm avatar Nov 16 '21 18:11 zgbkdlm