baal icon indicating copy to clipboard operation
baal copied to clipboard

Question - Using BAAL for automatic speech recognition (speech to text)

Open ognjenkundacina opened this issue 1 year ago • 5 comments

Hello!

Can BAAL be used out of the box for batch active learning for automatic speech recognition models like wav2vec2? If not, can you please give me suggestions on how to implement that?

Thank you!

ognjenkundacina avatar Feb 08 '24 08:02 ognjenkundacina

Hello! I'm not super familiar with s2t, but it should be similar to text generation which we support to some extent.

Do you use huggingface or another library? I can probably code something up this week to show case this capability.

Dref360 avatar Feb 08 '24 22:02 Dref360

Thanks for the help! Yes, it should be similar to text generation in terms of active learning criteria, since it is the same type of output.

I am using wav2vec2 (without the ngram language model - a simpler case) from huggingface: https://huggingface.co/docs/transformers/model_doc/wav2vec2

ognjenkundacina avatar Feb 08 '24 22:02 ognjenkundacina

Cool! So as expected it's pretty similar to text generation. I recorded a Loom for you and there is also a gist.

Loom Gist

Now that the uncertainty estimation seems to work, the actual active learning loop should work using this tutorial (The part where ActiveLearningLoop is used at the end).

Let me know if this work and I'm excited to know more about your usecase. We're definitly prepared to make changes to the library to fit your usecase.

Cheers!

Dref360 avatar Feb 10 '24 16:02 Dref360

Thank you for the detailed answer! In the next month I will be working on this in more depth and let you know about the findings.

ognjenkundacina avatar Feb 15 '24 14:02 ognjenkundacina

Thanks a lot again for the guidance and the resources you've shared; they're incredibly helpful!

I'm actually planning to research the uncertainty estimation of sequences for s2t. I'm thinking of developing a function in BAAL that can process all the sequences (transcriptions) generated through MC Dropout iterations, to compare these sequences to derive new uncertainty measures. For example, given the parameters in the code you provided, the function should have access to 20 transcriptions from MC Dropout iterations alongside a single transcription from the model without dropout for each audio sample. Could you provide any advice on how to implement this function and access the transcriptions as mentioned?

ognjenkundacina avatar Feb 23 '24 08:02 ognjenkundacina

It is simple to switch between MC-Dropout and Deterministic inference with Baal. But theoretically, we tend to use the mean prediction (we call this Bayesian Model Average, but we just average the logits). BALD does this by computing the variance in entropy between predictions.

Alternate between MC-Dropout and Deterministic

your_model = # Your S2T model
wrapper = BaalTrainer(...)

with MCDropoutModule(your_model) as model:
    # This is stochastic
    predictions = [model(input) for _ in range(ITERATIONS)]
    wrapper.predict_on_dataset(..., iterations=ITERATIONS)


# this is deterministic
output = model(input)
wrapper.predict_on_dataset(..., iterations=1)

You can then compute the uncertainty how you wish using both the deterministic and stochastic predictions.

To compute the average prediction, we have the ITERATION axis at the end so you can do:

predictions = trainer.prediction_on_dataset(...)
average_pred = predictions.mean(axis=-1)

I hope I understood your message correctly. I'm happy to hop on a call if it helps.

Dref360 avatar Feb 26 '24 21:02 Dref360

Thank you very much, I've managed to implement what I wanted using these instructions!

ognjenkundacina avatar Mar 01 '24 13:03 ognjenkundacina