pytensor
pytensor copied to clipboard
Implement equivalent to `np.diag_indices` and `np.diag_indices_from`
Description
Shoud be simple, just use pt.arange to allow symbolic inputs: https://github.com/numpy/numpy/blob/e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/lib/_index_tricks_impl.py#L1010-L1011
And for diag_indices_from:
https://github.com/numpy/numpy/blob/e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/lib/_index_tricks_impl.py#L1062-L1069
I would perhaps use pt.max(arr.shape) instead, and let indexing fail at runtime if it was not all square?