flax icon indicating copy to clipboard operation
flax copied to clipboard

Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10

Open chiamp opened this issue 10 months ago • 0 comments

Context:

  • As of JAX 0.4.26, jax.tree_map is deprecated
  • #3823 renames all jax.tree_map usages to jax.tree_util.tree_map in Flax, however we get an error in CI because of a CLU dependency
  • After fixing CLU and pushing a new release, the error remains on CI for Python 3.9 tests
    • this is because Flax enforces an earlier version of CLU (before the tree_map fix) on python versions less than 3.10, since the match-case syntax used by CLU is only available in Python 3.10 or greater
    • i.e. because Flax supports a minimum Python version of 3.9, Flax must use an earlier CLU version (where the CLU fix has not landed yet), since the current CLU version with the fix uses a python syntax that isn't available until Python 3.10
  • Our current solution is to add a deprecation warning filter in #3828

Once Flax upgrades its minimum Python version to 3.10, we should remove the deprecation warning filter and remove enforcing an earlier version of CLU.

chiamp avatar Apr 10 '24 23:04 chiamp