Prompting TFWhisperForConditionalGeneration leads to runtime crahses
System Info
transformersversion: 4.42.3- Platform: Linux-6.6.37-x86_64-with-glibc2.38
- Python version: 3.10.13
- Huggingface_hub version: 0.23.4
- Safetensors version: 0.4.1
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2+cu121 (False)
- Tensorflow version (GPU?): 2.9.1 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
Who can help?
@sanchit-gandhi @gante
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
import librosa
from transformers import WhisperProcessor, WhisperTokenizer, WhisperFeatureExtractor, TFWhisperForConditionalGeneration
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
# get some dummy data
audio, sr = librosa.load("audio/samples_jfk.wav", sr=16000, mono=True)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="tf")
input_features = inputs.input_features
prompts = processor.get_prompt_ids("some random words", return_tensors="tf")
outputs = model.generate(
input_features,
return_dict_in_generate=True,
prompt_ids = prompts,
)
Expected behavior
Running the code above leads to:
AttributeError: EagerTensor object has no attribute 'tolist'. If you are looking for numpy-related methods, please run the following: from tensorflow.python.ops.numpy_ops import np_config np_config.enable_numpy_behavior()
It seems that prompts are not available to use in graph mode (which I need). Also, a workaround using decoder_input_ids leads o other issues.
cc @gante @Rocketknight1
@Manuel030 have you tried return_tensors="np" instead?
Thanks @Rocketknight1, I don't have numpy available when executing in graph mode. I would expect the generate pass to be compatible with tensorflow's graph execution mode.
Also, the doc string states it should be a tf.Tensor.
Got it - can you paste the entire traceback so I can figure out where it's happening?
Sure:
File "/home/manuel/Projects/whisper-finetune/issue.py", line 14, in <module>
outputs = model.generate(
File "/home/manuel/Projects/whisper-finetune/venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_tf_whisper.py", line 1646, in generate
prompt_ids = prompt_ids.tolist()
File "/home/manuel/Projects/whisper-finetune/venv/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 440, in __getattr__
raise AttributeError(
AttributeError: EagerTensor object has no attribute 'tolist'.
If you are looking for numpy-related methods, please run the following:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
cc @gante @sanchit-gandhi, the relevant code is
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
which is failing in graph mode when the input is tf.Tensor and not np.ndarray. A simple workaround would be something like this:
if isinstance(prompt_ids, np.ndarray):
prompt_ids = prompt_ids.tolist()
else:
prompt_ids = prompt_ids.numpy().tolist()
but I'll wait for @gante's feedback here. The reason this is needed in the first place is because @Manuel030 wants to compile the generation loop, presumably for XLA/export. Is that possible for TFWhisper, or will we just run into other problems if we fix this line?
👋
Under the hood, TFWhisper's generate calls the OG generate, so it should be compileable! However, I'm not sure if it can be compiled when prompt_ids is set (that code path has things like enumerate, which is often incompatible).
Regardless of whether it fixes the XLA use case, the suggested change LGTM @Rocketknight1 👍
Unfortunately, patching the high-level generate as suggested by @Rocketknight1 is not successful. My use case is an export to the tflite format.
@manuel030 I'm not sure we have a good solution in that case - you might have to make some changes to the generate() function in TFWhisper, like the one I made above, until you can get it to compile successfully. If you do, please open a PR to add the changes to the codebase!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.