aesara
aesara copied to clipboard
Convert `Eye` to a `COp` or implement it in terms of existing `COp`s
Eye doesn't have a C implementation, but adding one should be straightforward.
We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?
We can also try to use an OpFromGraph instead, it's just zeros + set_subtensor, no?
That might not be as performant as a COp, but I like it a lot more. Aside from being easier to implement, it could work well with our other transpilation targets (e.g. the result could be about as fast as direct use of np.eye in Numba and/or JAX).
I believe the following is an implementation using OpFromGraph:
n = at.iscalar('n')
m = at.iscalar('m')
k = at.iscalar('k')
i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m)
eye = at.set_subtensor(eye[i::m+1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m-k:, :], 0)
Eye = aesara.compile.builders.OpFromGraph([n, m, k], [eye])
The existing eye function in aesara.tensor could then be modified as follows:
def eye(n, m=None, k=0, dtype=None):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
n : int
Number of rows in the output.
m : int, optional
Number of columns in the output. If None, defaults to `N`.
k : int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal, and a negative value
to a lower diagonal.
dtype : data-type, optional
Data-type of the returned array.
Returns
-------
ndarray of shape (N,M)
An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.
"""
if dtype is None:
dtype = aesara.config.floatX
if m is None:
m = n
return Eye(n, m, k).astype(dtype)
I can add some tests to check corner cases and submit this as a pull request if it looks like I'm barking up the right tree? Where would I add the code to create the OpFromGraph, just floating inside aesara.tensor above def eye, or is there a more organized place to put it?
Looks about right. About where to put it... good question. Floating sounds about right, maybe add an underscore prefix to those intermediate variables?
Other thing worth checking is if any rewrites currently target the Eye Op and if they still work.
How would I check for rewrites that target Eye? I did ctrl+f on all the files in aesara.tensor.rewriting for Eye (and @node_rewriter([Eye]) ) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.
How would I check for rewrites that target
Eye? I did ctrl+f on all the files inaesara.tensor.rewritingforEye(and@node_rewriter([Eye])) and came up with nothing. This doesn't strike me as a very sophisticated way to check, though.
Sounds about right. There might not be any in which case you are in luck ;)
Edit: I didn't find anything either
You can also get rid of the Numba and JAX dispatch (I assume a dispatch for OpFromGraph is already been implemented)
Edit: It seems to be only for Numba...
Actually we might not even need an OpFromGraph, if we don't need to easily target it in rewrites and if we are not overriding the grad. You can just make eye a function that returns the correct Aesara symbolic expression. This also makes the dtype of the inputs more flexible, instead of constraining them to int32 in your example.
import aesara
import aesara.tensor as at
def eye_new(n, m=None, k=0, dtype=None):
if m is None:
m = n
if dtype is None:
dtype = aesara.config.floatX
n = at.as_tensor_variable(n)
m = at.as_tensor_variable(m)
k = at.as_tensor_variable(k)
i = at.switch(k >= 0, k, -k * m)
eye = at.zeros(n * m, dtype=dtype)
eye = at.set_subtensor(eye[i::m + 1], 1).reshape((n, m))
eye = at.set_subtensor(eye[m - k:, :], 0)
return eye
Seems to do alright
I'll make a pull request for this in a minute, I'm just fumbling around with git at the moment.