Support for python 3.14
Fixes https://github.com/google/flax/issues/5027
- Removed content of linen/kw_only_dataclasses.py and kept methods and attributes for BC with
fieldanddataclasslinked to python built-in dataclasses - Removed usage of
kw_only_dataclassesin linen module.py - Adapted certain failing tests
Here is a repro for the problem we see with this internally. It's kind of wild and I don't know what's going on here:
import abc
from collections.abc import Iterator, Iterable
import io
from typing import Protocol, TypeVar
T = TypeVar("T")
class CheckpointableIterator(Iterator[T], Protocol[T]):
pass
isinstance(io.TextIOBase, Iterable)
from flax import linen as nn
class Steppable(metaclass=abc.ABCMeta):
path: str
class SequenceLayer(nn.Module, Steppable):
pass
I ran this under Python 3.12.
With the changes in this CL, I get:
Traceback (most recent call last):
File "/usr/local/google/home/phawkins/p/flax/t.py", line 21, in <module>
class SequenceLayer(nn.Module, Steppable):
File "<frozen abc>", line 106, in __new__
File "/usr/local/google/home/phawkins/p/flax/flax/linen/module.py", line 1040, in __init_subclass__
cls._customized_dataclass_transform(kw_only)
File "/usr/local/google/home/phawkins/p/flax/flax/linen/module.py", line 1100, in _customized_dataclass_transform
dataclasses.dataclass( # type: ignore[call-overload]
File "/usr/local/google/home/phawkins/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/dataclasses.py", line 1265, in wrap
return _process_class(cls, init, repr, eq, order, unsafe_hash,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/google/home/phawkins/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/dataclasses.py", line 1018, in _process_class
raise TypeError(f'{name!r} is a field but has no type annotation')
TypeError: 'parent' is a field but has no type annotation
I still don't know why this works, but adding:
setattr(cls, '__annotations__', cls.__annotations__)
right before the:
dataclasses.dataclass( # type: ignore[call-overload]
unsafe_hash='__hash__' not in cls.__dict__,
repr=False,
kw_only=kw_only,
)(cls)
call seems to work around the problem.
For some reason, __annotations__ isn't in the class dict, and that makes inspect.get_annotations fail.
@hawkinsp thanks a lot for investigations! I agree it's very wild, every line in the reproducer matters. I see that it fails with 3.12 and 3.13 but passes with 3.11 and 3.14...
Another note: for praxis I found that adding from __future__ import annotations was necessary to get the same behavior with Python 3.14 as with previous versions.
Reported the bug on cpython: https://github.com/python/cpython/issues/141681
This PR is the only thing blocking me from upgrading to Python 3.14... Any chance it could get some love?
TODO:
- Something is failing in praxis on py3.12 (to check)
- Something is failing in paxml on py3.12 (to check)