flax
flax copied to clipboard
TPU Colab fails with `AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'`
On a fresh Colab TPU runtime, we have:
>>> import flax
AttributeError Traceback (most recent call last)
[<ipython-input-2-f5b294e0faf0>](https://localhost:8080/#) in <cell line: 2>()
1 # Verify we can import everything.
----> 2 import flax
3 from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,
4 orbax_utils, prefetch_iterator, train_state, common_utils)
5 from flax.metrics import tensorboard
2 frames
[/usr/local/lib/python3.10/dist-packages/flax/core/frozen_dict.py](https://localhost:8080/#) in <module>
48
49
---> 50 @jax.tree_util.register_pytree_with_keys_class
51 class FrozenDict(Mapping[K, V]):
52 """An immutable variant of the Python dict."""
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
Note the versions:
!pip freeze | egrep 'jax|flax'
flax==0.7.0
jax==0.3.25
jaxlib==0.3.25
This is due to Colab TPU runtimes being pinned to an old jax==0.3.25 (because TPU runtime does only support legacy "TPU Node" setup).
Flax has been requiring jax>0.3.25 for a while, but recent changes make the incompatibility visible on a simple import.
Ideally we should pin flax==0.6.4 (most recent version that supports jax<=0.3.25) as well on TPU runtimes.
This was resolved by pinning Flax version as well on TPU runtimes; we now have:
>>> import jax
>>> jax.__version__
'0.3.25'
>>> import flax
>>> flax.__version__
'0.6.4'