Colabs don't seem to work
I cannot get the Colabs to run on https://colab.research.google.com.
I had to replace
!pip install https://github.com/deepmind/gemma
with
!pip install "git+https://github.com/google-deepmind/gemma.git"
as the former repository does not exist.
I am still unable to get the versions to match for the code to run. Also, Google provides a free TPU tier for Colab so it would be great if the code could be adapted (or some notes included) to run it on TPU as well as GPU.
After fixing the gemma install and updating the JAX import as:
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
the code ends up failing with the following stack trace:
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-6-cb05cf1a7a98>](https://localhost:8080/#) in <cell line: 2>()
1 import re
----> 2 from gemma import params as params_lib
3 from gemma import sampler as sampler_lib
4 from gemma import transformer as transformer_lib
5
3 frames
[/usr/local/lib/python3.10/dist-packages/gemma/params.py](https://localhost:8080/#) in <module>
20 import jax
21 import jax.numpy as jnp
---> 22 import orbax.checkpoint
23
24 Params = Mapping[str, Any]
[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/__init__.py](https://localhost:8080/#) in <module>
17 import functools
18
---> 19 from orbax.checkpoint import checkpoint_utils
20 from orbax.checkpoint import lazy_utils
21 from orbax.checkpoint import test_utils
[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py](https://localhost:8080/#) in <module>
23 from jax.sharding import Mesh
24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
26 from orbax.checkpoint import utils
27
[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py](https://localhost:8080/#) in <module>
22 from etils import epath
23 import jax
---> 24 from jax.experimental.gda_serialization import serialization
25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
26 import jax.numpy as jnp
ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------
Adding:
!pip install -U orbax
!pip install -U chex
in addition to the import changes suggested in the issue seems to make the gsm8k_eval.ipynb imports work. I guess now we just need to wait for parameter and vocab checkpoints to become available.
Hi @hamzamerzic,
I understand you're facing issues running Colab notebooks that require the Gemma model and potentially using TPUs for acceleration. Here's a breakdown of the steps to get you started:
-
Install Latest Google Cloud TPU Libraries:
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.htmlThis command installs the latest libraries needed to interact with Google Cloud TPUs through JAX. -
Verify TPU Availability:
import jax.tools.colab_tpu
import jax
jax.tools.colab_tpu.setup_tpu()
print(jax.devices())
Running this code snippet will attempt to set up the TPU runtime and then print the available TPU devices in your Colab environment.
- Install Gemma Model:
!pip install git+https://github.com/google-deepmind/gemma.git
Additional Resource:
- This Colab notebook gist demonstrating Gemma inference on TPUs might be helpful.
@selamw1 I liked your colab! that would be a nice recipe for the Gemma cookbook: goo.gle/gemma-cookbook
For future users, we have some more tutorials using JAX + Gemma here: ai.google.dev/gemma
Great thanks @gustheman
This PR added two new tutorials:
-
gemma_inference_on_tpu: Demonstrates basic inference with Gemma on TPUs.
-
gemma-data-parallel-inference-in-jax-tpu: Showcases data-parallel inference for faster processing on TPUs using JAX.