Bug in WhisperTokenizer batch_decode, when set `skip_special_tokens=True` for FlaxWhisper model output
System Info
-
transformersversion: 4.43.0 - Platform: Linux-5.15.0-1061-nvidia-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.4
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.4.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
- Jax version: 0.4.29
- JaxLib version: 0.4.29
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A100-SXM4-40GB
Who can help?
@sanchit-gandhi
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
I use this piece of code to deploy a sample audio file on Flax Whisper-large-v3 model with Jax.
from transformers import FlaxWhisperForConditionalGeneration, WhisperTokenizer
from scipy.io import wavfile
import jax
import jax.numpy as jnp
import numpy as np
import torch
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
model_path="openai/whisper-large-v3"
audio_file_path = ".path/to/audio/audio_file.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)
tokenizer = WhisperTokenizer.from_pretrained(model_path)
with torch.no_grad():
model = FlaxWhisperForConditionalGeneration.from_pretrained(model_path, dtype=jnp.float16, from_pt=True)
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])
samplerate, data_waveform = wavfile.read(audio_file_path)
data_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe")
print(tokenizer.batch_decode(pred_ids.sequences, skip_special_tokens=True))
It was working properly until version 4.42.4 of transformers, but from version 4.43.0 of transformers, it raises an error in the last line of the code (batch_decode):
File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3994, in batch_decode
return [
File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3995, in <listcomp>
self.decode(
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 692, in decode
filtered_ids = self._preprocess_token_ids(
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 637, in _preprocess_token_ids
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 860, in _strip_prompt
if not token_ids:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
core.check_bool_conversion(self)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
However in the batch_decode method, if I disable the skip_special_tokens arg (set it to False), it raises no error but return lots of special chars.
Expected behavior
It is expected to return list of strings in the result of batch_decode method, as same as how it works until version 4.42.4 of transformers
Any updates @sanchit-gandhi @ArthurZucker ?
@eustlb, would you mind looking at this issue if you have some bandwidth?
I think the error is due to the issue of checking jax arrays with not in tokenization_whisper.py code:
https://github.com/huggingface/transformers/blob/d1f39c484d8347aa7b3170ea250a1e8f3bdfdf31/src/transformers/models/whisper/tokenization_whisper.py#L852
It is OK to check token_ids if it is torch or np, but for the cases that it is a JAX array, it is not possible to directly use a JAX array in a boolean context (e.g., if not jax_array:) so jax raises error for such cases:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
core.check_bool_conversion(self)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
feel free to open a PR for a fix then!
I created a PR: https://github.com/huggingface/transformers/pull/33151
@ArthurZucker Please review and merge it
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Closing as it was merged!