transformers
transformers copied to clipboard
Bug in WhisperTokenizer batch_decode, when set `skip_special_tokens=True` for FlaxWhisper model output
System Info
-
transformers
version: 4.43.0 - Platform: Linux-5.15.0-1061-nvidia-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.4
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.4.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
- Jax version: 0.4.29
- JaxLib version: 0.4.29
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A100-SXM4-40GB
Who can help?
@sanchit-gandhi
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
I use this piece of code to deploy a sample audio file on Flax Whisper-large-v3 model with Jax.
from transformers import FlaxWhisperForConditionalGeneration, WhisperTokenizer
from scipy.io import wavfile
import jax
import jax.numpy as jnp
import numpy as np
import torch
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
model_path="openai/whisper-large-v3"
audio_file_path = ".path/to/audio/audio_file.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)
tokenizer = WhisperTokenizer.from_pretrained(model_path)
with torch.no_grad():
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_path, dtype=jnp.float16, from_pt=True)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
samplerate, data_waveform = wavfile.read(audio_file_path)
data_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe")
print(tokenizer.batch_decode(pred_ids.sequences, skip_special_tokens=True))
It was working properly until version 4.42.4 of transformers, but from version 4.43.0 of transformers, it raises an error in the last line of the code (batch_decode):
File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3994, in batch_decode
return [
File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3995, in <listcomp>
self.decode(
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 692, in decode
filtered_ids = self._preprocess_token_ids(
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 637, in _preprocess_token_ids
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 860, in _strip_prompt
if not token_ids:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
core.check_bool_conversion(self)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
However in the batch_decode
method, if I disable the skip_special_tokens
arg (set it to False
), it raises no error but return lots of special chars.
Expected behavior
It is expected to return list of strings in the result of batch_decode
method, as same as how it works until version 4.42.4 of transformers