transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Mistral in Flax: generation is slow, JIT fails

Open dfdx opened this issue 4 months ago • 4 comments

System Info

  • transformers version: 4.38.1
  • Platform: Linux-6.2.0-1019-azure-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.21.1
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.8.1 (cpu)
  • Jax version: 0.4.25
  • JaxLib version: 0.4.25
  • Using GPU in script?: Yes (JAX default behavior)
  • Using distributed or parallel set-up in script?: No

Who can help?

@sanchit-gandhi

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [x] My own task or dataset (give details below)

Reproduction

On a VM/Docker with NVIDIA A100 run:

import jax
import jax.numpy as jnp
from transformers import FlaxAutoModelForCausalLM, AutoTokenizer

MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
model = FlaxAutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    from_pt=True,
    dtype=jnp.bfloat16,
    max_position_embeddings=4096,   # much smaller than the default value
    sliding_window=4096                     
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
texts = ["<s>[INST]Write a poem[/INST]"]
input_ids = tokenizer(texts, return_tensors="np")["input_ids"]

With this setup, I'm able to generate some output using:

model.generate(input_ids, max_new_tokens=32)

But it takes 8.91s (after warmup) - longer than what I'd expect for total of 45 tokens. Obvious next step is to JIT-compile it:

jax.jit(model.generate, static_argnames=("max_new_tokens",))(input_ids, max_new_tokens=32)

But it fails with:

0302 22:21:58.956701   19411 pjrt_stream_executor_client.cc:2804] Execution of replica 0 failed: INTERNAL: Failed to allocate 117440512 bytes for new constant
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[20], line 1
----> 1 jax.jit(model.generate, static_argnames=("max_new_tokens",))(input_ids, max_new_tokens=32)

    [... skipping hidden 10 frame]

File ~/.pyenv/versions/3.10.6/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1209, in ExecuteReplicated.__call__(self, *args)
   1207   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1208 else:
-> 1209   results = self.xla_executable.execute_sharded(input_bufs)
   1210 if dispatch.needs_check_special():
   1211   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: Failed to allocate 117440512 bytes for new constant

Expected behavior

I'd expect at least one of two things:

  1. Generation without JIT is done using jax.lax.while_loop which compiles its body function and so generation is fast.
  2. Generation with JIT and minimal settings does't fail because of OOM.

Some notes:

  • jax.jit(model)(input_ids), i.e. application of the model just once, works fine and takes only ~2.5ms, so in theory 32 new tokens should be generated in 80-200ms
  • XLA_PYTHON_CLIENT_PREALLOCATE=false and XLA_PYTHON_CLIENT_MEM_FRACTION=0.99 doesn't help - all 80Gb of VRAM are actually consumed

dfdx avatar Mar 02 '24 22:03 dfdx

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 02 '24 08:04 github-actions[bot]

@sanchit-gandhi I think you tested it for gemma which is based on the Llama one and we had correct performances no?

ArthurZucker avatar Apr 02 '24 08:04 ArthurZucker

Just checked it with the latest transformers, latest JAX and latest CUDA on Nvidia H100, but the problem persists. Can somebody try it and confirm/reject that it is reproducible?

On the other hand, of this behavior is expected, is there a recommended way to use transformers in JAX/Flax? Maybe some specific models?

dfdx avatar Apr 13 '24 08:04 dfdx

Hey @dfdx we are not actively working on this, opening this to the community in case some of community magicians figure it out! 🤗

ArthurZucker avatar May 10 '24 11:05 ArthurZucker