aesara
aesara copied to clipboard
Numpy `lexsort` equivalent
Hello, I am currently looking for a way to replicate the behaviour of numpy.lexsort() with aesara. From my understanding there is no straightforward way to do this currently. Here is a short example of what I'm trying to do, i.e. sort indices according to one array, and then according to an other array:
import numpy as np
a = np.array([1, 1, 2, 3, 4])
b = np.array([2, 1, 1, 1, 1])
inds = np.lexsort((b, a))
print(inds)
Currently, with np.argsort(), it would be possible to convert the array to an Aesara tensor and then use tt.argsort(). Would it be possible to implement something similar with lexsort ? I looked at the sort.py file and from my understanding, the implementation of lexsort would look a lot like the one for argsort. If it is the case, I would be happy to try to implement it, but I have very little experience with aesara and the concepts involved.
Otherwise, is there a way to replicate this behaviour with what is currently available in Aesara ?
Thank you !
I tried to implement this yesterday. As I said above I know very little about Aesara, so I just copied the argsort function and class and I started from there. It seemed to work, at least on my simple use case of two 2d arrays. As noted below, I was not able to make the function return a differnt shape than the input, so the input is a tuple of k NxN arrays, and then the tensor variable has shape (k, N, N). The output has shape (1, N, N), so I just took inds[0] to use the sorted indices in my code (using pymc3.aesaraf.take_along_axis.
I'm putting the test in case someone wants to use it as a starting point.
The lexsort code
import numpy as np
import theano
from theano.gradient import grad_undefined
from theano.graph.basic import Apply
from theano.graph.op import Op
from theano.misc.safe_asarray import _asarray
from theano.tensor.basic import mul
from theano.tensor.sort import _variable_is_none
def lexsort(a, axis=-1):
"""
Returns the indices that would sort arrays
"""
if axis is None:
a = a.flatten()
axis = 0
# To preserve op shape, it returns extra dim 0 of length one -> take [0]
# NOTE: Taking [0] here does not seem to do anything
output = LexSortOp()(a, axis)[0]
return output
class LexSortOp(Op):
"""
This class is a wrapper for numpy lexsort function.
"""
def __str__(self):
return self.__class__.__name__
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return Apply(
self,
[input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()],
)
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
if axis is not None:
if axis != int(axis):
raise ValueError("sort axis must be an integer or None")
axis = int(axis)
z = output_storage[0]
z[0] = _asarray(
np.array([np.lexsort(a, axis=axis)]),
dtype=node.outputs[0].dtype,
)
def infer_shape(self, fgraph, node, inputs_shapes):
if _variable_is_none(node.inputs[1]):
return [(mul(*inputs_shapes[0]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
def grad(self, inputs, output_grads):
# No grad defined for intergers.
inp, axis = inputs
inp_grad = inp.zeros_like()
axis_grad = grad_undefined(
self,
1,
axis,
"lexsort is not defined for non-integer axes so"
" lexsort(x, axis+eps) is undefined",
)
return [inp_grad, axis_grad]
"""
def R_op(self, inputs, eval_points):
# R_op can receive None as eval_points.
# That mean there is no diferientiable path through that input
# If this imply that you cannot compute some outputs,
# return None for those.
if eval_points[0] is None:
return eval_points
return self.grad(inputs, eval_points)
"""
It looks like you're on the right path. For more information, see the documentation on creating an Op. You can ignore all the C-related parts, though.
Also, feel free to create a draft PR so that we can more easily provide explicit feedback.