transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Bug in WhisperTokenizer batch_decode, when set `skip_special_tokens=True` for FlaxWhisper model output

Open hannan72 opened this issue 6 months ago • 5 comments

System Info

  • transformers version: 4.43.0
  • Platform: Linux-5.15.0-1061-nvidia-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
  • Jax version: 0.4.29
  • JaxLib version: 0.4.29
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@sanchit-gandhi

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

I use this piece of code to deploy a sample audio file on Flax Whisper-large-v3 model with Jax.

from transformers import FlaxWhisperForConditionalGeneration, WhisperTokenizer
from scipy.io import wavfile
import jax
import jax.numpy as jnp
import numpy as np
import torch
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
model_path="openai/whisper-large-v3"
audio_file_path = ".path/to/audio/audio_file.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)

tokenizer = WhisperTokenizer.from_pretrained(model_path)
with torch.no_grad():
    model = FlaxWhisperForConditionalGeneration.from_pretrained(model_path, dtype=jnp.float16, from_pt=True)

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])

samplerate, data_waveform = wavfile.read(audio_file_path)
data_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe")
print(tokenizer.batch_decode(pred_ids.sequences, skip_special_tokens=True))

It was working properly until version 4.42.4 of transformers, but from version 4.43.0 of transformers, it raises an error in the last line of the code (batch_decode):

File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3994, in batch_decode
    return [
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3995, in <listcomp>
    self.decode(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 692, in decode
    filtered_ids = self._preprocess_token_ids(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 637, in _preprocess_token_ids
    token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 860, in _strip_prompt
    if not token_ids:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
    core.check_bool_conversion(self)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
    raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

However in the batch_decode method, if I disable the skip_special_tokens arg (set it to False), it raises no error but return lots of special chars.

Expected behavior

It is expected to return list of strings in the result of batch_decode method, as same as how it works until version 4.42.4 of transformers

hannan72 avatar Aug 22 '24 10:08 hannan72