whisper icon indicating copy to clipboard operation
whisper copied to clipboard

added progress_callback in transcribe method

Open jhj0517 opened this issue 2 years ago • 7 comments

Hello, I wanted to thank you for your amazing work. I have added a new Callable argument progress_callback to the transcribe method. This will allow users to track the progress of the transcription process through other frameworks, such as gradio.

Here's an example of how this could be used with gradio:

import gradio as gr

def run_transcribe(progressbar = gr.Progress()):
    def progress_callback(progress_value):
            progressbar(progress_value,desc="Transcribing..")
    progressbar(0,desc="Transcribing")  
    model.transcribe(audio=audio,verbose=False,progress_callback=progress_callback)
    return "Done!"
    
block = gr.Blocks().queue() 
with block:
    btn = gr.Button("run_transcribe")
    tb = gr.Textbox()
    btn.click(fn=run_transcribe,[],[tb])
block.launch()

Thank you!

jhj0517 avatar Mar 03 '23 20:03 jhj0517

Hi! Thanks for the suggestion. I'm actually thinking of a generator version of the function transcribe(), so the caller can iterate over the transcribed segments as they come in. Would that work for your use case?

jongwook avatar Mar 04 '23 00:03 jongwook

I'm actually thinking of a generator version of the function transcribe(), so the caller can iterate over the transcribed segments as they come in. Would that work for your use case?

@jongwook Hey, sorry for jumping in, but I think this progress_callback thing actually can work as generator version as well, because you can also put any sort of other callbacks (for instance queue.put) inside.

Consider the following example:

import whisper
from multiprocessing import Queue, Process

def transcribe_producer(queue: Queue):
    model = whisper.load_model("small")
    audio = whisper.load_audio("javascript.wav")

    model.transcribe(
        audio,
        language="Russian",
        beam_size=5,
        best_of=5,
        segment_callback=queue.put,
    )["text"]


if __name__ == "__main__":
    q = Queue()
    p = Process(target=transcribe_producer, args=(q, ))

    p.start()

    while True:
        print(q.get())

I've changed the transcribe function signature as follows:

def transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    verbose: Optional[bool] = None,
    temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    no_speech_threshold: Optional[float] = 0.6,
    condition_on_previous_text: bool = True,
    segment_callback: Callable[[Tuple[float, float, str]], None],
    **decode_options,
):

And added following lines to the add_segment function:

        if segment_callback is not None:
            segment_callback((start, end, text))

For my use-case it works flawlessly and also pretty easy to implement. The cool thing also is that you can define any sort of callback not only queue.put. Look at the attached screenshot to get a feeling of what it looks like when running.

image

vsmaxim avatar Mar 04 '23 09:03 vsmaxim

Thank you for comments. I think the idea of making a generator version is also a good one. By yielding in the transcribe method instead of returning, the following code can be used.

For example, In transcribe

            ....
            yield dict(
                num_frames=num_frames,
                seek=seek,
                text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
                segments=all_segments,
                language=language
            )

it could be used in the following way in gradio use case:

import gradio as gr

def run_transcribe(progressbar = gr.Progress()):
    progressbar(0,desc="Started Transcribing.")
    result = model.transcribe(audio=audio,verbose=False)
    for progress in result:
        progressbar(progress["seek"]/progress["num_frames"],desc="Transcribing..")
    result = progress
    
    return f"Done! result is {result}"
    
block = gr.Blocks().queue() 
with block:
    btn = gr.Button("run_transcribe")
    tb = gr.Textbox()
    btn.click(fn=run_transcribe,inputs=[],outputs=[tb])
block.launch()

One issue I encountered during testing was that num_frames and seek needed to be included in the yield in order to track the progress. Calling len(list(result)) to use in for i,progress in enumerate(result) triggers an exhaustion of the generator.

Thanks !

jhj0517 avatar Mar 04 '23 10:03 jhj0517

Thank you for comments. I think the idea of making a generator version is also a good one.

@jhj0517 You are welcome, in terms of progress tracking I like the yield solution more, because it's more general and you can for instance use it for streaming results. However, I think it's good to be aware of one small disadvantage of generator over the callback approach suggested initially.

Consider this code sample:

import time
import whisper

def some_handler(chunk: dict):
    pass

def time_consumer():
    time.sleep(100000)

model = whisper.load_model("small")
audio = whisper.load_audio("sample.wav")

for result in model.transcribe(audio):
      some_handler(chunk)
      time_consumer() # time consuming process

In that case, model will just sleep and not produce any input, because it runs on the same interpreter and it's locked by the current thread. I think with callbacks you can overcome this issue using some asynchronous iterators and running the generator in a separate process.

If you wish I can make a PR or a gist to show you the concept. I'm not a maintainer though...

vsmaxim avatar Mar 04 '23 12:03 vsmaxim

Here's an example of how it could look if you combine callbacks with generators. You may do the same without async, just replace asyncio.sleep with time.sleep and remove async / await keywords.

import asyncio
from multiprocessing import Queue, Process
import whisper


def transcribe_producer(queue: Queue):
    model = whisper.load_model("tiny")
    audio = whisper.load_audio("javascript.wav")
    model.transcribe(
        audio,
        language="Russian",
        beam_size=5,
        best_of=5,
        segment_callback=queue.put,
    )


async def transcribe_async(queue_check_interval: int = 1):
    q = Queue()
    p = Process(target=transcribe_producer, args=(q,))
    process_finished = False

    p.start()
    print(f"Started the process {p.pid}")

    while not process_finished or not q.empty():
        if q.empty():
            print(f"No results available, sleeping for {queue_check_interval} seconds")
            await asyncio.sleep(queue_check_interval)
            continue

        yield q.get(timeout=0)

        # Try joining process to see if it's finished
        if not process_finished:
            p.join(timeout=0)

            if not p.is_alive():
                print(f"Process with id = {p.pid} was finished, yielding the rest of the queue")
                process_finished = True


async def io_blocking_task(processing_time: int = 1):
    print(f"IO blocking task for {processing_time} secs")
    await asyncio.sleep(processing_time)

async def main():
    async for item in transcribe_async(1):
        print(f"Got an element: {item}")
        await io_blocking_task(2)

    print("Async task finished")

# This line is crucial to not make processes recursively
if __name__ == "__main__":
    asyncio.run(main())

vsmaxim avatar Mar 04 '23 12:03 vsmaxim

Hi, thank you for your great work. Any news on this?

samuelesabella avatar Jun 14 '23 09:06 samuelesabella

any update on this?

piccinnigius avatar Feb 06 '24 11:02 piccinnigius