flax
flax copied to clipboard
`__getattr__` and pytype checking
Currently nn.Module
implement __getattr__
. This disable any type checking when making a typo in attribute:
my_module = MyModule()
my_module.non_existing_attribute # << pytype do not detect this
It would be nice to wrap __getattr__
in a if not typing.TYPE_CHECKING:
so that pytype actually check the attribute types.
Note that class Module
would have to be updated so that the Module
fields are defined (note doing this today would also allow type checking on module.scope
,...):
class Module:
scope: Optional[Scope]
I like the idea! Want to send a PR?
Note that class Module would have to be updated so that the Module fields are defined (note doing this today would also allow type checking on module.scope,...):
There is a type hint that enables static analysis for Module.scope
inside Module.__init_subclass__
:
https://github.com/google/flax/blob/90c08ca5fcb08b6e7c8cfb256f21df8d3d4ab858/flax/linen/module.py#L545
I don't think we want to add type annotations at the class level as this would mess up with the dataclass behaviour.
I don't think we want to add type annotations at the class level as this would mess up with the dataclass behaviour.
When trying the naive fix, pytype complain:
module = MyModule()
module.scope # pytype: attribute-error
So it looks somehow scope
is not correctly inferred inside __init_subclass__
, which is why I suggested making the annotation explicit.
If pytype was correctly implemented, I don't think it should interfere with dataclasses
as Module
is a @dataclass_transform
but not a @dataclass
Unfortunately I won't have time for a PR. Also submitting this might break existing users, so I think migration should be taken care of by the flax team (note pytype has a tool to automatically disable all existing pytype breakage of a cl by adding # pytype: disable=
comment in the relevant code).
We currently aren't enforcing pytype, which is something we should probably do!
As for this change, I think adding explicit dataclass attributes to Module could lead to problems, since the way we set up Modules attributes is a bit complex.
@levskaya: you probably has the most context here, do you think it is safe to continue with this change?
Just for clarification, I'm not suggesting to make nn.Module
a dataclass, but just to annotate the dynamic attributes:
class A: # < Notice there is NO `@dataclass`
x: int
This has NO runtime effect (A().x
will raise an AttributeError
unless x
is initialised in the __init__
somewhere). However, this is detected by type checkers. This is the official syntax to annotate attributes which are dynamically defined (e.g. through setattr(self, 'x', 123)
).
And because A
is not a dataclass, this has no effect on subclasses either even if the subclasses are dataclass
.
See: https://peps.python.org/pep-0526/#class-and-instance-variable-annotations
the value-less notation
a: int
allows one to annotate instance variables that should be initialized in__init__
or__new__
.
Under the hood, nn.Module
is already a kind of dataclass (which is why you can define attributes on it like in a dataclass). Annotating the dynamic attributes as you suggest will likely have an affect on when these attributes are set on the Module, and that may or may not give problems in our Module initialization logic. I am just not sure if this is safe to do.
My suggestion would be to just try adding these attributes and see if it works, but I suppose that @levskaya can probably tell us already up front whether this is a good idea or not.
Related issue: #1816
Tested adding adding annotations for scope
and name
in #2447 see diff.
Good news is that "it works" for python, Flax reorders / remove these from the constructor successfully, all tests pass. Bad news is that static analyzers are not happy about this as they now expect additional constructor arguments from the parent class, e.g. pylance
:
Curiously no errors from our pytype
checks, not sure if its our configuration or pytype behaves differently.
This seems to be a pylance bug I think.
Does the following also raise error ?:
from typing_extensions import dataclass_transform
@dataclass_transform(kw_only_default=True)
class A():
x: int
class B(A):
y: int
B(y=1)
Edit: I just tried it and with (kw_only_default=True)
, it seems B.__init__
expect x
, which is a bug.
Only B
should be a dataclass-like
, not A
so only B.__annotations__
should be added in __init__
. This is similar to:
class A():
x: int
@dataclasses.dataclass
class B(A):
y: int
B(y=1) # A is not a `dataclass`, so `x` not in `__init__`
A hack which seems to work with pylance is:
class _A:
x: int
@typing_extensions.dataclass_transform(kw_only_default=True)
class A(_A):
pass
class B(A):
y: int
B(y=1) # pylance detect `B(y=)` (no more `x=`)
Thanks @Conchylicultor for posting this in the pylance repo! I don't have any insights into how pylance operates but I wonder if they can handle Flax-like use-cases where the class annotations are programmatically modified before creating the dataclass, take a look at the _customized_dataclass_transform
classmethod that is called inside __init_subclass__
:
https://github.com/google/flax/blob/fdd1d6fef0dfea785a10b1f5ebd1635cc2509c2e/flax/linen/module.py#L588-L638
@levskaya might be able to give more insights into this when he gets back.
Current situation is:
- If we add the
typing.TYPE_CHECKING
around__getattr__
then pytype will be unhappy aboutscope
,name
, andparent
. - If we add
scope
,name
, andparent
as type annotations to make pytype happy then IDEs (pylance) will be unhappy (thinks constructors expect these) and user experience will be a nightmare. - If we add
scope
,name
, andparent
with default params to fix the previous then non-default fields on any subclass will result in pylance complaining withFields without default values cannot appear after fields with default values
.
Maybe we just keep it as is? Not ideal but the fix sounds worse.
Solution (1): What about my suggested solution from: https://github.com/google/flax/issues/2416#issuecomment-1240489780
class _Module:
scope: Any
name: Any
parent: Any
@typing_extensions.dataclass_transform()
class Module(_Module):
if not typing.TYPE_CHECKING:
def __getattr__(self, name):
...
This will work with both pylance
and pytype
.
If we add scope, name, and parent with default params to fix the previous then non-default fields on any subclass will result in pylance complaining with Fields without default values cannot appear after fields with default values.
Solution (2): This can be fixed if we made the dataclass keyword-only: dataclass_transform(kw_only_default=True)
.
But otherwise, solution (1) would work without this change. And pylance would correctly infer Module.__init__
For update, it seems this has been fixed in pylance side: https://github.com/microsoft/pylance-release/issues/3304#issuecomment-1270549408
So when the new version is released and enough people have upgraded to the new version, the hack from https://github.com/google/flax/issues/2416#issuecomment-1252042926 could be removed.