Replace our `Tri` op with an `OpFromGraph`
Description
Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:
from pytensor.tensor.einsum import _iota
def tri(M, N, k):
return ((_iota(M) + k) > _iota(N)).astype(int)
This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)
I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.
@jessegrabowski New contributor here. I created a PR based on my understanding of the issue. It probably has a few errors. I was unable to run pytest locally due to some circular import issue. Waiting for feedback!