Move register_canonicalize to graph.rewriting.utils, Adjust function signature, and enhance AttributeError handling
Description
1.Function Relocation:
The register_canonicalize function has been relocated from tensor.rewriting.basic to graph.rewriting.utils. The move encompasses all necessary imports to ensure seamless functionality in the new location.However, during this transition, issues were identified specifically related to implementation which I tried to fix and are described below.
2.Type Mismatch Error Resolution:
The function signature of register function inside register_canonicalize was causing a type mismatch error. This was addressed by changing the input type to Union[RewriteDatabase, NodeRewriter] from Union[RewriteDatabase, Rewriter]
Before:
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
Error:
pytensor\graph\rewriting\utils.py:251: error: Argument 1 to "register_canonicalize" has incompatible type "RewriteDatabase |
Rewriter"; expected "RewriteDatabase | NodeRewriter | str"
After:
if isinstance(node_rewriter, str):
def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
return register
3.AttributeError Fix:
In register_canonicalize, there were errors related to missing __name__ attributes for node_rewriter:Union[RewriteDatabase, NodeRewriter, str]
This was resolved by using getattr() to handle cases where __name__ is not present .In cases where the attribute is not available, Name=None. (We can think of implementing default name )
before:
name = kwargs.pop("name", None) or node_rewriter.__name__
Error:
pytensor\graph\rewriting\utils.py:255: error: Item "RewriteDatabase" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"
pytensor\graph\rewriting\utils.py:255: error: Item "NodeRewriter" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"
after:
name = kwargs.pop("name", None) or getattr(node_rewriter, "__name__", None)
Related Issue
- [x] Closes #323
- [ ] Related to #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [x] refactor
Hi @ricardoV94 ,
I've moved the register_canonicalize() from tensor.rewriting.basic to graph.rewriting.utils. While doing so, I encountered some errors in the implementation of the function, which I've attempted to address in this PR. Could you please review and confirm if the changes are valid?
Thanks!