flax
flax copied to clipboard
Remove `tree_map` deprecation filter after Flax upgrades minimum Python version to 3.10
Context:
- As of JAX 0.4.26,
jax.tree_map
is deprecated - #3823 renames all
jax.tree_map
usages tojax.tree_util.tree_map
in Flax, however we get an error in CI because of a CLU dependency - After fixing CLU and pushing a new release, the error remains on CI for Python 3.9 tests
- this is because Flax enforces an earlier version of CLU (before the
tree_map
fix) on python versions less than 3.10, since the match-case syntax used by CLU is only available in Python 3.10 or greater - i.e. because Flax supports a minimum Python version of 3.9, Flax must use an earlier CLU version (where the CLU fix has not landed yet), since the current CLU version with the fix uses a python syntax that isn't available until Python 3.10
- this is because Flax enforces an earlier version of CLU (before the
- Our current solution is to add a deprecation warning filter in #3828
Once Flax upgrades its minimum Python version to 3.10, we should remove the deprecation warning filter and remove enforcing an earlier version of CLU.