flax icon indicating copy to clipboard operation
flax copied to clipboard

Flax doesn't work on google colab

Open necludov opened this issue 2 years ago • 15 comments

It seems like flax just stopped working on google colab. Simply running

import jax
!pip install --quiet flax
import flax

yields the error

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.9.0 requires jedi>=0.10, which is not installed.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-1-41b8697fa850>](https://localhost:8080/#) in <module>
      1 import jax
      2 get_ipython().system('pip install --quiet flax')
----> 3 import flax

4 frames
[/usr/local/lib/python3.8/dist-packages/flax/core/meta.py](https://localhost:8080/#) in Partitioned()
    263     return self.replace(names=tuple(names))
    264 
--> 265   def get_partition_spec(self) -> jax.sharding.PartitionSpec:
    266     """Returns the ``Partitionspec`` for this partitioned value."""
    267     return jax.sharding.PartitionSpec(*self.names)

AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'

Can be solved by downgrading Flax to 0.6.4.

necludov avatar Feb 09 '23 23:02 necludov

@necludov - sorry for the trouble, what happens if you force upgrade jax via:

!pip install -U --quiet jax jaxlib

does that resolve the problem? I think colab jax might be pinned to very old versions.

levskaya avatar Feb 10 '23 07:02 levskaya

Indeed what @levskaya, but you should restart the runtime after installing the latest jax version, otherwise it will keep using the old one.

marcvanzee avatar Feb 10 '23 07:02 marcvanzee

@levskaya this also works, but then I couldn't use it with a GPU.

necludov avatar Feb 10 '23 16:02 necludov

@necludov try installing the GPU version of JAX:

pip install -U jax[cuda11_cudnn82]>=0.4.2 \
  --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

cgarciae avatar Feb 10 '23 16:02 cgarciae

I think I've tried this, but it seems that installing jax this way on colab is not compatible with GPU. I haven't looked deeper into it, to be honest. So far I'm just using the downgraded version of flax.

necludov avatar Feb 10 '23 17:02 necludov

Could you let us know what went wrong with the jax cuda install? My guess is that a jax-cude version was installed incompatible with the CUDA/CUDNN libraries on the machine. There should be no fundamental problem installing cuda-version of jax/jaxlib via pip.

levskaya avatar Feb 10 '23 19:02 levskaya

nvcc --version will give you version info, see the instructions at https://github.com/google/jax#pip-installation-gpu-cuda for info on pip installs.

levskaya avatar Feb 10 '23 19:02 levskaya

Though this won't resolve any JAX cuda version issues - I've temporarily yanked the 0.6.5 Flax release from PyPI until the version pinning issues are resolved (hopefully soon).

levskaya avatar Feb 14 '23 17:02 levskaya

Installing jedi worked for me.

!pip install jedi

image

siddhi47 avatar Feb 21 '23 18:02 siddhi47

is there a recommendation for Flax version for TPU on Colab? i just discovered last night that this issue i reported to Jax a few weeks ago includes Flax version conflicts. my issue was that the newest Jax with TPU on Colab causes Optax import to just completely fail - but having the newest Jax was the fault of Flax requiring the New Hotness, where apparently Google Colab TPU is Olde and Busted...

so, is there a Ye Olde Busted Hotness version combo recommendation for us Colab junkies? i.e., what's the combination of compatible versions across all the Google libraries (Jax, Flax, Haiku, Optax, Pix, etc.) that will work with each other and Google Colab, TPU and the various GPU options?

krahnikblis avatar Mar 01 '23 02:03 krahnikblis

Copied from https://github.com/google/flax/issues/2950#issuecomment-1479258169

Yes, indeed, TPU Colab runtime does not support new JAX versions anymore.

So I would recommend to

  1. either use the CPU or GPU runtime (both work with the latest flax-0.6.7)
  2. install Flax like this on a TPU runtime while keeping JAX runtime fixed: !pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25

We're sorry about the inconveniences caused, but making Colab TPU runtime infra compatible with new JAX versions is beyond what we can currently fix.

andsteing avatar Mar 22 '23 10:03 andsteing

As noted in #2950 it's better to install Flax via

!pip install jax==0.3.25 jaxlib==0.3.25 flax

That should keep the correct JAX version and install flax and its dependencies accordingly.

Also created discussion #2995 for this topic.

andsteing avatar Mar 30 '23 12:03 andsteing

Also the installed "jax" cuda version 0.4.19 by "pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" also conflicts with the flax 0.5.1 version in Google Colab ......So how to solve this

2000222 avatar Oct 23 '23 13:10 2000222

@2000222 I just checked on https://colab.sandbox.google.com/ with a T4 GPU runtime and got:

!pip freeze | egrep 'jax|flax'
import jax
jax.devices()
flax==0.7.4
jax==0.4.16
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.16+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl#sha256=78b3a9acfda4bfaae8a1dc112995d56454020f5c02dba4d24c40c906332efd4a
[gpu(id=0)]

so that seems to work as expected (note that both JAX and Flax are pre-installed with compatible versions in Google Colab).

andsteing avatar Nov 06 '23 14:11 andsteing

I can confirm that on Colab running the same commands as mentioned in the original issue comment imports without any error.

import jax
!pip install --quiet flax
import flax

SauravMaheshkar avatar Jan 05 '24 19:01 SauravMaheshkar