aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Numpy `lexsort` equivalent

Open vandalt opened this issue 4 years ago • 2 comments

Hello, I am currently looking for a way to replicate the behaviour of numpy.lexsort() with aesara. From my understanding there is no straightforward way to do this currently. Here is a short example of what I'm trying to do, i.e. sort indices according to one array, and then according to an other array:

import numpy as np

a = np.array([1, 1, 2, 3, 4])
b = np.array([2, 1, 1, 1, 1])

inds = np.lexsort((b, a))
print(inds)

Currently, with np.argsort(), it would be possible to convert the array to an Aesara tensor and then use tt.argsort(). Would it be possible to implement something similar with lexsort ? I looked at the sort.py file and from my understanding, the implementation of lexsort would look a lot like the one for argsort. If it is the case, I would be happy to try to implement it, but I have very little experience with aesara and the concepts involved.

Otherwise, is there a way to replicate this behaviour with what is currently available in Aesara ?

Thank you !

vandalt avatar Sep 14 '21 16:09 vandalt

I tried to implement this yesterday. As I said above I know very little about Aesara, so I just copied the argsort function and class and I started from there. It seemed to work, at least on my simple use case of two 2d arrays. As noted below, I was not able to make the function return a differnt shape than the input, so the input is a tuple of k NxN arrays, and then the tensor variable has shape (k, N, N). The output has shape (1, N, N), so I just took inds[0] to use the sorted indices in my code (using pymc3.aesaraf.take_along_axis.

I'm putting the test in case someone wants to use it as a starting point.

The lexsort code

import numpy as np
import theano
from theano.gradient import grad_undefined
from theano.graph.basic import Apply
from theano.graph.op import Op
from theano.misc.safe_asarray import _asarray
from theano.tensor.basic import mul
from theano.tensor.sort import _variable_is_none


def lexsort(a, axis=-1):
    """
    Returns the indices that would sort arrays
    """
    if axis is None:
        a = a.flatten()
        axis = 0
    # To preserve op shape, it returns extra dim 0 of length one -> take [0]
    # NOTE: Taking [0] here does not seem to do anything
    output = LexSortOp()(a, axis)[0]

    return output


class LexSortOp(Op):
    """
    This class is a wrapper for numpy lexsort function.
    """

    def __str__(self):
        return self.__class__.__name__

    def make_node(self, input, axis=-1):
        input = theano.tensor.as_tensor_variable(input)
        axis = theano.tensor.as_tensor_variable(axis)
        bcast = input.type.broadcastable
        return Apply(
            self,
            [input, axis],
            [theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()],
        )

    def perform(self, node, inputs, output_storage):
        a = inputs[0]
        axis = inputs[1]
        if axis is not None:
            if axis != int(axis):
                raise ValueError("sort axis must be an integer or None")
            axis = int(axis)
        z = output_storage[0]
        z[0] = _asarray(
            np.array([np.lexsort(a, axis=axis)]),
            dtype=node.outputs[0].dtype,
        )

    def infer_shape(self, fgraph, node, inputs_shapes):
        if _variable_is_none(node.inputs[1]):
            return [(mul(*inputs_shapes[0]),)]
        # axis should not be None, so there should be the same number of
        # dimensions in the input and output
        assert node.inputs[0].ndim == node.outputs[0].ndim
        assert inputs_shapes[1] == ()
        return [inputs_shapes[0]]

    def grad(self, inputs, output_grads):
        # No grad defined for intergers.
        inp, axis = inputs
        inp_grad = inp.zeros_like()
        axis_grad = grad_undefined(
            self,
            1,
            axis,
            "lexsort is not defined for non-integer axes so"
            " lexsort(x, axis+eps) is undefined",
        )
        return [inp_grad, axis_grad]

    """
    def R_op(self, inputs, eval_points):
        # R_op can receive None as eval_points.
        # That mean there is no diferientiable path through that input
        # If this imply that you cannot compute some outputs,
        # return None for those.
        if eval_points[0] is None:
            return eval_points
        return self.grad(inputs, eval_points)
    """

vandalt avatar Sep 15 '21 13:09 vandalt

It looks like you're on the right path. For more information, see the documentation on creating an Op. You can ignore all the C-related parts, though.

Also, feel free to create a draft PR so that we can more easily provide explicit feedback.

brandonwillard avatar Sep 18 '21 19:09 brandonwillard