wav2vec processor batching logic is too restrictive
System Info
transformers version at the time of writing is 4.26.1
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
# !pip install transformers torch # in jupyter notebook
from transformers import Wav2Vec2Processor
import torch
import numpy as np
batch = 4
# create Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
# generate random input tensor
input_tensor = torch.tensor(np.random.randn(batch, 10, 10))
# pass input tensor through processor
output = processor(input_tensor, return_tensors="pt")
print(output["input_values"].shape) # 1 x 4 x 10 x 10
Expected behavior
It seems reasonable that an input could be of shape batch x d_1 x d_2 ... and I'd expect the output to have the same shape. However, here the code has an extra check for type list or tuple that results in it misinterpreting the input as a single example.
Side note: I'm unsure what to infer from the type checking logic because it doesn't match the type hints i.e. tuple isn't supposed to be possible here anyways, according to the __call__ type hint. I did check some other examples of is_batched appearing in the src/transformers/models directory and they look similar but unexpected.
cc @sanchit-gandhi @ArthurZucker
Hey @LWprogramming! Thanks for the comprehensive issue description - I agree that the logic for checking if the input is_batched is broken when the input is a batched numpy array, e.g. the feature extractor should set is_batched=True when the numpy array is 2-d, but currently does not:
https://github.com/huggingface/transformers/blob/57f25f4b7fb85ff069f8701372710b2a3207bf2d/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L184-L187
Would you like to open a PR to fix this? 🤗 We can just do one additional check to set is_batched = True if the input is a 2-d numpy array. Note that it should be 2-d with dims [batch, audio_input] and not 3-d since we only expect mono channel input to the feature extractor.
Hey @LWprogramming! Just checking-in to see whether you'd like to open a PR to fix the issue you uncovered? Think you're in a good position to submit a clean fix! 🤗
Hi! I'll take care of it, got preoccupied with some irl stuff that came up the past few weeks but things should be settling down soon :)
That's awesome @LWprogramming! Excited for the PR 🤗 Feel free to tag me as soon as it's ready and I'll get you a review
marking as still active, just fixing up the PR