transformers icon indicating copy to clipboard operation
transformers copied to clipboard

wav2vec processor batching logic is too restrictive

Open LWprogramming opened this issue 2 years ago • 5 comments

System Info

transformers version at the time of writing is 4.26.1

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

# !pip install transformers torch # in jupyter notebook
from transformers import Wav2Vec2Processor
import torch
import numpy as np

batch = 4

# create Wav2Vec2Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
# generate random input tensor
input_tensor = torch.tensor(np.random.randn(batch, 10, 10))
# pass input tensor through processor
output = processor(input_tensor, return_tensors="pt")
print(output["input_values"].shape) # 1 x 4 x 10 x 10

Expected behavior

It seems reasonable that an input could be of shape batch x d_1 x d_2 ... and I'd expect the output to have the same shape. However, here the code has an extra check for type list or tuple that results in it misinterpreting the input as a single example.

Side note: I'm unsure what to infer from the type checking logic because it doesn't match the type hints i.e. tuple isn't supposed to be possible here anyways, according to the __call__ type hint. I did check some other examples of is_batched appearing in the src/transformers/models directory and they look similar but unexpected.

LWprogramming avatar Mar 15 '23 08:03 LWprogramming

cc @sanchit-gandhi @ArthurZucker

amyeroberts avatar Mar 15 '23 10:03 amyeroberts

Hey @LWprogramming! Thanks for the comprehensive issue description - I agree that the logic for checking if the input is_batched is broken when the input is a batched numpy array, e.g. the feature extractor should set is_batched=True when the numpy array is 2-d, but currently does not: https://github.com/huggingface/transformers/blob/57f25f4b7fb85ff069f8701372710b2a3207bf2d/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L184-L187

Would you like to open a PR to fix this? 🤗 We can just do one additional check to set is_batched = True if the input is a 2-d numpy array. Note that it should be 2-d with dims [batch, audio_input] and not 3-d since we only expect mono channel input to the feature extractor.

sanchit-gandhi avatar Mar 24 '23 12:03 sanchit-gandhi

Hey @LWprogramming! Just checking-in to see whether you'd like to open a PR to fix the issue you uncovered? Think you're in a good position to submit a clean fix! 🤗

sanchit-gandhi avatar Apr 21 '23 15:04 sanchit-gandhi

Hi! I'll take care of it, got preoccupied with some irl stuff that came up the past few weeks but things should be settling down soon :)

LWprogramming avatar Apr 22 '23 06:04 LWprogramming

That's awesome @LWprogramming! Excited for the PR 🤗 Feel free to tag me as soon as it's ready and I'll get you a review

sanchit-gandhi avatar Apr 25 '23 09:04 sanchit-gandhi

marking as still active, just fixing up the PR

LWprogramming avatar May 19 '23 18:05 LWprogramming