Distribute inference on a TPU
I have the following code that works on TPU v3-8 but I am trying to use a larger LLM i.e. gemma-7b-it. The replicate function causes OOM error. How to modify the code to distribute inference?
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training.common_utils import shard
from transformers import FlaxGemmaForCausalLM, AutoTokenizer
model_name = "google/gemma-2b-it"
max_new_tokens = 4096
batch_size =32
prompt = "Write an article about AI"
dtype = jnp.bfloat16
model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=dtype, token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
max_input_length = 32
inputs = tokenizer(
input_text,
padding="max_length",
max_length=max_input_length,
return_attention_mask=True,
return_tensors="np",
)
params = jax_utils.replicate(params)
inputs = shard(inputs.data)
def generate(inputs, params, max_new_tokens):
generated_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
params=params,
max_new_tokens=max_new_tokens,
do_sample=True,
)
return generated_ids.sequences
p_generate = jax.pmap(
generate, "inputs", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,)
)
You'll need to shard the model params across chips - replicate forces them to be replicated on each chip, which is too small for 7b.
Instead of replicate, use jax.device_put to put params into a specific sharding that you want. Unfortunately FlaxGemmaForCausalLM doesn't contain sharding annotations, and you might have to come up with your own ones, like FSDP which shards along d_model dimension.
Once params are sharded correctly, you might want to use jax.jit to compile your distributed generate function - jax.pmap is a bit outdated.