equinox
equinox copied to clipboard
Problems with access to class variables in __init_subclass__ during instance creation for equinox Modules
Since version 0.11.6, the following results in a strange AttributeError:
import jax
import equinox as eqx
class Parent(eqx.Module):
abs_cls_var: eqx.AbstractClassVar[str]
def __init__(self, **kwargs):
pass
def __init_subclass__(cls):
"""__init_subclass__
tries to access cls.abs_cls_var
"""
print(cls.abs_cls_var)
class Child(Parent):
abs_cls_var = 'w0'
Child()
The attribute error happens in the last line (the creation of an instance of Child and the stack trace in Google Colab goes
AttributeError Traceback (most recent call last)
[<ipython-input-7-60f7e23927a5>](https://localhost:8080/#) in <cell line: 1>()
----> 1 Child()
[... skipping hidden 4 frame]
6 frames
[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in _make_initable_wrapper(cls)
788 def _make_initable_wrapper(cls: _ActualModuleMeta) -> _ActualModuleMeta:
789 post_init = getattr(cls, "__post_init__", None)
--> 790 return _make_initable(cls, cls.__init__, post_init, wraps=False)
791
792
[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in _make_initable(***failed resolving arguments***)
806 field_names = {field.name for field in dataclasses.fields(cls)}
807
--> 808 class _InitableModule(cls, _Initable):
809 pass
810
[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in __new__(mcs, name, bases, dict_, strict, **kwargs)
200
201 # [Step 1] Create the class as normal.
--> 202 cls = super().__new__(mcs, name, bases, dict_, **kwargs)
203 # [Step 2] Arrange for bound methods to be treated as PyTrees as well. This
204 # ensures that
[/usr/local/lib/python3.10/dist-packages/equinox/_better_abstract.py](https://localhost:8080/#) in __new__(mcs, name, bases, namespace, **kwargs)
176
177 def __new__(mcs, name, bases, namespace, /, **kwargs):
--> 178 cls = super().__new__(mcs, name, bases, namespace, **kwargs)
179
180 # We don't try and check that our AbstractVars and AbstractClassVars are
[/usr/lib/python3.10/abc.py](https://localhost:8080/#) in __new__(mcls, name, bases, namespace, **kwargs)
104 """
105 def __new__(mcls, name, bases, namespace, **kwargs):
--> 106 cls = super().__new__(mcls, name, bases, namespace, **kwargs)
107 _abc_init(cls)
108 return cls
[<ipython-input-6-b018fe5e563f>](https://localhost:8080/#) in __init_subclass__(cls)
9 tries to access cls.abs_cls_var
10 """
---> 11 print(cls.abs_cls_var)
12
13
[/usr/local/lib/python3.10/dist-packages/equinox/_module.py](https://localhost:8080/#) in __getattribute__(cls, item)
610 # `module_update_wrapper`, but if `dataclass` sees it then it tries to follow it.
611 def __getattribute__(cls, item):
--> 612 value = super().__getattribute__(item)
613 if (
614 item == "__wrapped__"
AttributeError: type object '_InitableModule' has no attribute 'abs_cls_var'
This problem did not occur in Equinox version 0.11.5
The problem also doesn't occur if you comment-out the __init__ method in Parent.
Thanks for the report! I've just pushed a commit, so this should be fixed on the latest HEAD.
Wow, that was quick! Thanks for fixing this so swiftly!