Andreas Steiner

Results 36 comments of Andreas Steiner

This is due to Colab TPU runtimes being pinned to an old `jax==0.3.25` (because TPU runtime does only support legacy "TPU Node" setup). Flax has been requiring `jax>0.3.25` for a...

This was resolved by pinning Flax version as well on TPU runtimes; we now have: ``` >>> import jax >>> jax.__version__ '0.3.25' >>> import flax >>> flax.__version__ '0.6.4' ```

Copied from https://github.com/google/flax/issues/2950#issuecomment-1479258169 > Yes, indeed, TPU Colab runtime does not support new JAX versions anymore. > > So I would recommend to > > 1. either use the CPU...

As noted in #2950 it's better to install Flax via ``` !pip install jax==0.3.25 jaxlib==0.3.25 flax ``` That should keep the correct JAX version and install `flax` and its dependencies...

@2000222 I just checked on https://colab.sandbox.google.com/ with a T4 GPU runtime and got: ``` !pip freeze | egrep 'jax|flax' import jax jax.devices() flax==0.7.4 jax==0.4.16 jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.16+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl#sha256=78b3a9acfda4bfaae8a1dc112995d56454020f5c02dba4d24c40c906332efd4a [gpu(id=0)] ``` so...

You mean that you have 3x slower step time? Or is it 3x slower to target accuracy? The first would be unexpected, but I wouldn't know why that is the...