pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Move register_canonicalize to graph.rewriting.utils, Adjust function signature, and enhance AttributeError handling

Open OmGhadge opened this issue 1 year ago • 1 comments

Description

1.Function Relocation:

The register_canonicalize function has been relocated from tensor.rewriting.basic to graph.rewriting.utils. The move encompasses all necessary imports to ensure seamless functionality in the new location.However, during this transition, issues were identified specifically related to implementation which I tried to fix and are described below.

2.Type Mismatch Error Resolution:

The function signature of register function inside register_canonicalize was causing a type mismatch error. This was addressed by changing the input type to Union[RewriteDatabase, NodeRewriter] from Union[RewriteDatabase, Rewriter] Before:

    if isinstance(node_rewriter, str):

        def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
            return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)

        return register

Error:

pytensor\graph\rewriting\utils.py:251: error: Argument 1 to "register_canonicalize" has incompatible type "RewriteDatabase | 
Rewriter"; expected "RewriteDatabase | NodeRewriter | str"

After:

  if isinstance(node_rewriter, str):
        def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
            return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
        return register

3.AttributeError Fix:

In register_canonicalize, there were errors related to missing __name__ attributes for node_rewriter:Union[RewriteDatabase, NodeRewriter, str] This was resolved by using getattr() to handle cases where __name__ is not present .In cases where the attribute is not available, Name=None. (We can think of implementing default name ) before: name = kwargs.pop("name", None) or node_rewriter.__name__

Error:

pytensor\graph\rewriting\utils.py:255: error: Item "RewriteDatabase" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"
pytensor\graph\rewriting\utils.py:255: error: Item "NodeRewriter" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"

after: name = kwargs.pop("name", None) or getattr(node_rewriter, "__name__", None)

Related Issue

  • [x] Closes #323
  • [ ] Related to #

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [x] refactor

OmGhadge avatar Jan 23 '24 17:01 OmGhadge

Hi @ricardoV94 ,

I've moved the register_canonicalize() from tensor.rewriting.basic to graph.rewriting.utils. While doing so, I encountered some errors in the implementation of the function, which I've attempted to address in this PR. Could you please review and confirm if the changes are valid?

Thanks!

OmGhadge avatar Jan 31 '24 13:01 OmGhadge