flax icon indicating copy to clipboard operation
flax copied to clipboard

TPU Colab fails with `AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'`

Open andsteing opened this issue 2 years ago • 1 comments

On a fresh Colab TPU runtime, we have:

>>> import flax
AttributeError                            Traceback (most recent call last)
[<ipython-input-2-f5b294e0faf0>](https://localhost:8080/#) in <cell line: 2>()
      1 # Verify we can import everything.
----> 2 import flax
      3 from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,
      4                            orbax_utils, prefetch_iterator, train_state, common_utils)
      5 from flax.metrics import tensorboard

2 frames
[/usr/local/lib/python3.10/dist-packages/flax/core/frozen_dict.py](https://localhost:8080/#) in <module>
     48 
     49 
---> 50 @jax.tree_util.register_pytree_with_keys_class
     51 class FrozenDict(Mapping[K, V]):
     52   """An immutable variant of the Python dict."""

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'

Note the versions:

!pip freeze | egrep 'jax|flax'
flax==0.7.0
jax==0.3.25
jaxlib==0.3.25

andsteing avatar Jul 31 '23 07:07 andsteing

This is due to Colab TPU runtimes being pinned to an old jax==0.3.25 (because TPU runtime does only support legacy "TPU Node" setup).

Flax has been requiring jax>0.3.25 for a while, but recent changes make the incompatibility visible on a simple import.

Ideally we should pin flax==0.6.4 (most recent version that supports jax<=0.3.25) as well on TPU runtimes.

andsteing avatar Jul 31 '23 07:07 andsteing

This was resolved by pinning Flax version as well on TPU runtimes; we now have:

>>> import jax
>>> jax.__version__
'0.3.25'
>>> import flax
>>> flax.__version__
'0.6.4'

andsteing avatar Mar 24 '24 17:03 andsteing