Modify np.tri Op to use _iota instead
Description
Related Issue
- [ ] Closes #
- [ ] Related to #1265
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1276.org.readthedocs.build/en/1276/
The circular import is because you are importing _iota from tensor.einsum inside tensor.basic, but inside tensor.einsum, tensor.basic is imported.
The solution is to move the _iota function to tensor.basic. I also suggest to remove the leading underscore, because there's no reason this function should be considered "hidden". I don't think it needs to be added to __all__ (because it's not a function that exists in numpy), but I also wouldn't object to it.
Those Ops and a few more (arange, alloc, ...) should probably be in a tensor_creation file. basic is hosting too much
@jessegrabowski My bad for missing it, I'll make the change. @ricardoV94 If that is appropriate in scope for me, I'll be okay with working on it.
@ricardoV94 If that is appropriate in scope for me, I'll be okay with working on it.
That's up to you. It would be much appreciated!
@jessegrabowski Made the change. I have left the docstring of iota mostly the same, but I guess the example section may need to be removed now? Tests are pending, I'll try adding some and see how it goes.
@jessegrabowski The tests were passing for tri but now the tests for tril and triu are failing. The issue seems, in the test we are passing the matrix as a symbolic variable, and the call to Tri here also passes M and dtype as symbolic variables (because they are derived from the matrix?) which is not allowed according to the current structure. Any suggestions on how to move ahead? Still figuring out symbolic variables.
Tri doesn't need init at all, just make it an empty OpFromGraph, then make a function called tri that actually does the work. Have a look at how the kronecker product is implemented for a template:
@jessegrabowski I made the change, the tests are passing now. I have removed 'complex64' dtype from the tests for now, because there seemed to be some compilation error which I could not remove.
Well what was the error? I don't want us removing tests.
@jessegrabowski ~~The stack trace looks like (removed some parts)~~
Put the test back and i'll trigger a CI run so I can see the full output
@jessegrabowski Also stuck at this test case; seems to be failing because we are passing concrete values instead of symbolic ones (I think). The same is tested in the test case below it.