flax
flax copied to clipboard
Flax doesn't work on google colab
It seems like flax just stopped working on google colab. Simply running
import jax
!pip install --quiet flax
import flax
yields the error
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-1-41b8697fa850>](https://localhost:8080/#) in <module>
1 import jax
2 get_ipython().system('pip install --quiet flax')
----> 3 import flax
4 frames
[/usr/local/lib/python3.8/dist-packages/flax/core/meta.py](https://localhost:8080/#) in Partitioned()
263 return self.replace(names=tuple(names))
264
--> 265 def get_partition_spec(self) -> jax.sharding.PartitionSpec:
266 """Returns the ``Partitionspec`` for this partitioned value."""
267 return jax.sharding.PartitionSpec(*self.names)
AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'
Can be solved by downgrading Flax to 0.6.4.
@necludov - sorry for the trouble, what happens if you force upgrade jax via:
!pip install -U --quiet jax jaxlib
does that resolve the problem? I think colab jax might be pinned to very old versions.
Indeed what @levskaya, but you should restart the runtime after installing the latest jax version, otherwise it will keep using the old one.
@levskaya this also works, but then I couldn't use it with a GPU.
@necludov try installing the GPU version of JAX:
pip install -U jax[cuda11_cudnn82]>=0.4.2 \
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I think I've tried this, but it seems that installing jax this way on colab is not compatible with GPU. I haven't looked deeper into it, to be honest. So far I'm just using the downgraded version of flax.
Could you let us know what went wrong with the jax cuda install? My guess is that a jax-cude version was installed incompatible with the CUDA/CUDNN libraries on the machine. There should be no fundamental problem installing cuda-version of jax/jaxlib via pip.
nvcc --version
will give you version info, see the instructions at https://github.com/google/jax#pip-installation-gpu-cuda for info on pip installs.
Though this won't resolve any JAX cuda version issues - I've temporarily yanked the 0.6.5 Flax release from PyPI until the version pinning issues are resolved (hopefully soon).
Installing jedi worked for me.
!pip install jedi
is there a recommendation for Flax version for TPU on Colab? i just discovered last night that this issue i reported to Jax a few weeks ago includes Flax version conflicts. my issue was that the newest Jax with TPU on Colab causes Optax import to just completely fail - but having the newest Jax was the fault of Flax requiring the New Hotness, where apparently Google Colab TPU is Olde and Busted...
so, is there a Ye Olde Busted Hotness version combo recommendation for us Colab junkies? i.e., what's the combination of compatible versions across all the Google libraries (Jax, Flax, Haiku, Optax, Pix, etc.) that will work with each other and Google Colab, TPU and the various GPU options?
Copied from https://github.com/google/flax/issues/2950#issuecomment-1479258169
Yes, indeed, TPU Colab runtime does not support new JAX versions anymore.
So I would recommend to
- either use the CPU or GPU runtime (both work with the latest
flax-0.6.7
)- install Flax like this on a TPU runtime while keeping JAX runtime fixed:
!pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25
We're sorry about the inconveniences caused, but making Colab TPU runtime infra compatible with new JAX versions is beyond what we can currently fix.
As noted in #2950 it's better to install Flax via
!pip install jax==0.3.25 jaxlib==0.3.25 flax
That should keep the correct JAX version and install flax
and its dependencies accordingly.
Also created discussion #2995 for this topic.
Also the installed "jax" cuda version 0.4.19 by "pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" also conflicts with the flax 0.5.1 version in Google Colab ......So how to solve this
@2000222 I just checked on https://colab.sandbox.google.com/ with a T4 GPU runtime and got:
!pip freeze | egrep 'jax|flax'
import jax
jax.devices()
flax==0.7.4
jax==0.4.16
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.16+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl#sha256=78b3a9acfda4bfaae8a1dc112995d56454020f5c02dba4d24c40c906332efd4a
[gpu(id=0)]
so that seems to work as expected (note that both JAX and Flax are pre-installed with compatible versions in Google Colab).
I can confirm that on Colab running the same commands as mentioned in the original issue comment imports without any error.
import jax
!pip install --quiet flax
import flax