flax icon indicating copy to clipboard operation
flax copied to clipboard

Support for python 3.14

Open vfdev-5 opened this issue 3 months ago • 7 comments

Fixes https://github.com/google/flax/issues/5027

  • Removed content of linen/kw_only_dataclasses.py and kept methods and attributes for BC with field and dataclass linked to python built-in dataclasses
  • Removed usage of kw_only_dataclasses in linen module.py
  • Adapted certain failing tests

vfdev-5 avatar Nov 13 '25 11:11 vfdev-5

Here is a repro for the problem we see with this internally. It's kind of wild and I don't know what's going on here:

import abc
from collections.abc import Iterator, Iterable
import io
from typing import Protocol, TypeVar

T = TypeVar("T")

class CheckpointableIterator(Iterator[T], Protocol[T]):
  pass

isinstance(io.TextIOBase, Iterable)

from flax import linen as nn

class Steppable(metaclass=abc.ABCMeta):
  path: str

class SequenceLayer(nn.Module, Steppable):
  pass

I ran this under Python 3.12.

With the changes in this CL, I get:

Traceback (most recent call last):
  File "/usr/local/google/home/phawkins/p/flax/t.py", line 21, in <module>
    class SequenceLayer(nn.Module, Steppable):
  File "<frozen abc>", line 106, in __new__
  File "/usr/local/google/home/phawkins/p/flax/flax/linen/module.py", line 1040, in __init_subclass__
    cls._customized_dataclass_transform(kw_only)
  File "/usr/local/google/home/phawkins/p/flax/flax/linen/module.py", line 1100, in _customized_dataclass_transform
    dataclasses.dataclass(  # type: ignore[call-overload]
  File "/usr/local/google/home/phawkins/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/dataclasses.py", line 1265, in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/phawkins/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/lib/python3.12/dataclasses.py", line 1018, in _process_class
    raise TypeError(f'{name!r} is a field but has no type annotation')
TypeError: 'parent' is a field but has no type annotation

hawkinsp avatar Nov 14 '25 23:11 hawkinsp

I still don't know why this works, but adding:

    setattr(cls, '__annotations__', cls.__annotations__)

right before the:

    dataclasses.dataclass(  # type: ignore[call-overload]
      unsafe_hash='__hash__' not in cls.__dict__,
      repr=False,
      kw_only=kw_only,
    )(cls)

call seems to work around the problem.

For some reason, __annotations__ isn't in the class dict, and that makes inspect.get_annotations fail.

hawkinsp avatar Nov 15 '25 15:11 hawkinsp

@hawkinsp thanks a lot for investigations! I agree it's very wild, every line in the reproducer matters. I see that it fails with 3.12 and 3.13 but passes with 3.11 and 3.14...

vfdev-5 avatar Nov 15 '25 21:11 vfdev-5

Another note: for praxis I found that adding from __future__ import annotations was necessary to get the same behavior with Python 3.14 as with previous versions.

hawkinsp avatar Nov 17 '25 18:11 hawkinsp

Reported the bug on cpython: https://github.com/python/cpython/issues/141681

vfdev-5 avatar Nov 17 '25 20:11 vfdev-5

This PR is the only thing blocking me from upgrading to Python 3.14... Any chance it could get some love?

cool-RR avatar Nov 26 '25 08:11 cool-RR

TODO:

  • Something is failing in praxis on py3.12 (to check)
  • Something is failing in paxml on py3.12 (to check)

vfdev-5 avatar Dec 01 '25 17:12 vfdev-5