aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Convert `Eye` to a `COp` or implement it in terms of existing `COp`s

Open brandonwillard opened this issue 3 years ago • 2 comments

Eye doesn't have a C implementation, but adding one should be straightforward.

brandonwillard avatar Sep 13 '22 18:09 brandonwillard

We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?

ricardoV94 avatar Sep 13 '22 18:09 ricardoV94

We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?

That might not be as performant as a COp, but I like it a lot more. Aside from being easier to implement, it could work well with our other transpilation targets (e.g. the result could be about as fast as direct use of np.eye in Numba and/or JAX).

brandonwillard avatar Sep 13 '22 19:09 brandonwillard

I believe the following is an implementation using OpFromGraph:

n = at.iscalar('n')
m = at.iscalar('m')
k = at.iscalar('k')

i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m)
eye = at.set_subtensor(eye[i::m+1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m-k:, :], 0)

Eye = aesara.compile.builders.OpFromGraph([n, m, k], [eye])

The existing eye function in aesara.tensor could then be modified as follows:

def eye(n, m=None, k=0, dtype=None):
    """Return a 2-D array with ones on the diagonal and zeros elsewhere.
    Parameters
    ----------
    n : int
        Number of rows in the output.
    m : int, optional
        Number of columns in the output. If None, defaults to `N`.
    k : int, optional
        Index of the diagonal: 0 (the default) refers to the main diagonal,
        a positive value refers to an upper diagonal, and a negative value
        to a lower diagonal.
    dtype : data-type, optional
        Data-type of the returned array.
    Returns
    -------
    ndarray of shape (N,M)
        An array where all elements are equal to zero, except for the `k`-th
        diagonal, whose values are equal to one.
    """
    if dtype is None:
        dtype = aesara.config.floatX
    if m is None:
        m = n
    return Eye(n, m, k).astype(dtype)

I can add some tests to check corner cases and submit this as a pull request if it looks like I'm barking up the right tree? Where would I add the code to create the OpFromGraph, just floating inside aesara.tensor above def eye, or is there a more organized place to put it?

jessegrabowski avatar Sep 27 '22 15:09 jessegrabowski

Looks about right. About where to put it... good question. Floating sounds about right, maybe add an underscore prefix to those intermediate variables?

Other thing worth checking is if any rewrites currently target the Eye Op and if they still work.

ricardoV94 avatar Sep 27 '22 15:09 ricardoV94

How would I check for rewrites that target Eye? I did ctrl+f on all the files in aesara.tensor.rewriting for Eye (and @node_rewriter([Eye]) ) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.

jessegrabowski avatar Sep 27 '22 15:09 jessegrabowski

How would I check for rewrites that target Eye? I did ctrl+f on all the files in aesara.tensor.rewriting for Eye (and @node_rewriter([Eye]) ) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.

Sounds about right. There might not be any in which case you are in luck ;)

Edit: I didn't find anything either

ricardoV94 avatar Sep 27 '22 15:09 ricardoV94

You can also get rid of the Numba and JAX dispatch (I assume a dispatch for OpFromGraph is already been implemented)

Edit: It seems to be only for Numba...

ricardoV94 avatar Sep 27 '22 16:09 ricardoV94

Actually we might not even need an OpFromGraph, if we don't need to easily target it in rewrites and if we are not overriding the grad. You can just make eye a function that returns the correct Aesara symbolic expression. This also makes the dtype of the inputs more flexible, instead of constraining them to int32 in your example.

ricardoV94 avatar Sep 27 '22 16:09 ricardoV94

import aesara
import aesara.tensor as at
def eye_new(n, m=None, k=0, dtype=None):
    if m is None:
        m = n
    if dtype is None:
        dtype = aesara.config.floatX
        
    n = at.as_tensor_variable(n)
    m = at.as_tensor_variable(m)
    k = at.as_tensor_variable(k)
    
    i = at.switch(k >= 0, k, -k * m)
    eye = at.zeros(n * m, dtype=dtype)
    eye = at.set_subtensor(eye[i::m + 1], 1).reshape((n, m))
    eye = at.set_subtensor(eye[m - k:, :], 0)
    return eye    

Seems to do alright

ricardoV94 avatar Sep 27 '22 16:09 ricardoV94

I'll make a pull request for this in a minute, I'm just fumbling around with git at the moment.

jessegrabowski avatar Sep 27 '22 16:09 jessegrabowski