gemma icon indicating copy to clipboard operation
gemma copied to clipboard

Colabs don't seem to work

Open hamzamerzic opened this issue 1 year ago • 4 comments

I cannot get the Colabs to run on https://colab.research.google.com.

I had to replace

!pip install https://github.com/deepmind/gemma

with

!pip install "git+https://github.com/google-deepmind/gemma.git"

as the former repository does not exist.

I am still unable to get the versions to match for the code to run. Also, Google provides a free TPU tier for Colab so it would be great if the code could be adapted (or some notes included) to run it on TPU as well as GPU.

After fixing the gemma install and updating the JAX import as:

!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

the code ends up failing with the following stack trace:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-6-cb05cf1a7a98>](https://localhost:8080/#) in <cell line: 2>()
      1 import re
----> 2 from gemma import params as params_lib
      3 from gemma import sampler as sampler_lib
      4 from gemma import transformer as transformer_lib
      5 

3 frames
[/usr/local/lib/python3.10/dist-packages/gemma/params.py](https://localhost:8080/#) in <module>
     20 import jax
     21 import jax.numpy as jnp
---> 22 import orbax.checkpoint
     23 
     24 Params = Mapping[str, Any]

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/__init__.py](https://localhost:8080/#) in <module>
     17 import functools
     18 
---> 19 from orbax.checkpoint import checkpoint_utils
     20 from orbax.checkpoint import lazy_utils
     21 from orbax.checkpoint import test_utils

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py](https://localhost:8080/#) in <module>
     23 from jax.sharding import Mesh
     24 import numpy as np
---> 25 from orbax.checkpoint import type_handlers
     26 from orbax.checkpoint import utils
     27 

[/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py](https://localhost:8080/#) in <module>
     22 from etils import epath
     23 import jax
---> 24 from jax.experimental.gda_serialization import serialization
     25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
     26 import jax.numpy as jnp

ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

hamzamerzic avatar Feb 24 '24 17:02 hamzamerzic

Adding:

!pip install -U orbax
!pip install -U chex

in addition to the import changes suggested in the issue seems to make the gsm8k_eval.ipynb imports work. I guess now we just need to wait for parameter and vocab checkpoints to become available.

hamzamerzic avatar Feb 24 '24 18:02 hamzamerzic

Hi @hamzamerzic,

I understand you're facing issues running Colab notebooks that require the Gemma model and potentially using TPUs for acceleration. Here's a breakdown of the steps to get you started:

  1. Install Latest Google Cloud TPU Libraries: !pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html This command installs the latest libraries needed to interact with Google Cloud TPUs through JAX.

  2. Verify TPU Availability:

import jax.tools.colab_tpu
import jax

jax.tools.colab_tpu.setup_tpu()
print(jax.devices())

Running this code snippet will attempt to set up the TPU runtime and then print the available TPU devices in your Colab environment.

  1. Install Gemma Model: !pip install git+https://github.com/google-deepmind/gemma.git

Additional Resource:

  • This Colab notebook gist demonstrating Gemma inference on TPUs might be helpful.

selamw1 avatar Apr 22 '24 22:04 selamw1

@selamw1 I liked your colab! that would be a nice recipe for the Gemma cookbook: goo.gle/gemma-cookbook

For future users, we have some more tutorials using JAX + Gemma here: ai.google.dev/gemma

gustheman avatar Jul 16 '24 11:07 gustheman

Great thanks @gustheman

This PR added two new tutorials:

  • gemma_inference_on_tpu: Demonstrates basic inference with Gemma on TPUs.

  • gemma-data-parallel-inference-in-jax-tpu: Showcases data-parallel inference for faster processing on TPUs using JAX.

selamw1 avatar Jul 22 '24 22:07 selamw1