flax icon indicating copy to clipboard operation
flax copied to clipboard

Distribute inference on a TPU

Open zaidalyafeai opened this issue 11 months ago • 1 comments

I have the following code that works on TPU v3-8 but I am trying to use a larger LLM i.e. gemma-7b-it. The replicate function causes OOM error. How to modify the code to distribute inference?

import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training.common_utils import shard
from transformers import FlaxGemmaForCausalLM, AutoTokenizer

model_name = "google/gemma-2b-it"
max_new_tokens = 4096
batch_size =32
prompt = "Write an article about AI"
dtype = jnp.bfloat16

model, params = FlaxGemmaForCausalLM.from_pretrained(model_name, revision="flax", _do_init=False, dtype=dtype, token=hf_token)

tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)

max_input_length = 32

inputs = tokenizer(
    input_text,
    padding="max_length",
    max_length=max_input_length,
    return_attention_mask=True,
    return_tensors="np",
)

params = jax_utils.replicate(params)

inputs = shard(inputs.data)

def generate(inputs, params, max_new_tokens):
    generated_ids = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        params=params,
        max_new_tokens=max_new_tokens,
        do_sample=True,
    )
    return generated_ids.sequences

p_generate = jax.pmap(
    generate, "inputs", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,)
)

zaidalyafeai avatar Mar 24 '25 09:03 zaidalyafeai

You'll need to shard the model params across chips - replicate forces them to be replicated on each chip, which is too small for 7b.

Instead of replicate, use jax.device_put to put params into a specific sharding that you want. Unfortunately FlaxGemmaForCausalLM doesn't contain sharding annotations, and you might have to come up with your own ones, like FSDP which shards along d_model dimension.

Once params are sharded correctly, you might want to use jax.jit to compile your distributed generate function - jax.pmap is a bit outdated.

IvyZX avatar Mar 26 '25 18:03 IvyZX