whisper
whisper copied to clipboard
added progress_callback in transcribe method
Hello, I wanted to thank you for your amazing work.
I have added a new Callable argument progress_callback to the transcribe method. This will allow users to track the progress of the transcription process through other frameworks, such as gradio.
Here's an example of how this could be used with gradio:
import gradio as gr
def run_transcribe(progressbar = gr.Progress()):
def progress_callback(progress_value):
progressbar(progress_value,desc="Transcribing..")
progressbar(0,desc="Transcribing")
model.transcribe(audio=audio,verbose=False,progress_callback=progress_callback)
return "Done!"
block = gr.Blocks().queue()
with block:
btn = gr.Button("run_transcribe")
tb = gr.Textbox()
btn.click(fn=run_transcribe,[],[tb])
block.launch()
Thank you!
Hi! Thanks for the suggestion. I'm actually thinking of a generator version of the function transcribe(), so the caller can iterate over the transcribed segments as they come in. Would that work for your use case?
I'm actually thinking of a generator version of the function transcribe(), so the caller can iterate over the transcribed segments as they come in. Would that work for your use case?
@jongwook Hey, sorry for jumping in, but I think this progress_callback thing actually can work as generator version as well, because you can also put any sort of other callbacks (for instance queue.put) inside.
Consider the following example:
import whisper
from multiprocessing import Queue, Process
def transcribe_producer(queue: Queue):
model = whisper.load_model("small")
audio = whisper.load_audio("javascript.wav")
model.transcribe(
audio,
language="Russian",
beam_size=5,
best_of=5,
segment_callback=queue.put,
)["text"]
if __name__ == "__main__":
q = Queue()
p = Process(target=transcribe_producer, args=(q, ))
p.start()
while True:
print(q.get())
I've changed the transcribe function signature as follows:
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
segment_callback: Callable[[Tuple[float, float, str]], None],
**decode_options,
):
And added following lines to the add_segment function:
if segment_callback is not None:
segment_callback((start, end, text))
For my use-case it works flawlessly and also pretty easy to implement. The cool thing also is that you can define any sort of callback not only queue.put. Look at the attached screenshot to get a feeling of what it looks like when running.
Thank you for comments. I think the idea of making a generator version is also a good one.
By yielding in the transcribe method instead of returning, the following code can be used.
For example, In transcribe
....
yield dict(
num_frames=num_frames,
seek=seek,
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]),
segments=all_segments,
language=language
)
it could be used in the following way in gradio use case:
import gradio as gr
def run_transcribe(progressbar = gr.Progress()):
progressbar(0,desc="Started Transcribing.")
result = model.transcribe(audio=audio,verbose=False)
for progress in result:
progressbar(progress["seek"]/progress["num_frames"],desc="Transcribing..")
result = progress
return f"Done! result is {result}"
block = gr.Blocks().queue()
with block:
btn = gr.Button("run_transcribe")
tb = gr.Textbox()
btn.click(fn=run_transcribe,inputs=[],outputs=[tb])
block.launch()
One issue I encountered during testing was that num_frames and seek needed to be included in the yield in order to track the progress. Calling len(list(result)) to use in for i,progress in enumerate(result) triggers an exhaustion of the generator.
Thanks !
Thank you for comments. I think the idea of making a generator version is also a good one.
@jhj0517 You are welcome, in terms of progress tracking I like the yield solution more, because it's more general and you can for instance use it for streaming results. However, I think it's good to be aware of one small disadvantage of generator over the callback approach suggested initially.
Consider this code sample:
import time
import whisper
def some_handler(chunk: dict):
pass
def time_consumer():
time.sleep(100000)
model = whisper.load_model("small")
audio = whisper.load_audio("sample.wav")
for result in model.transcribe(audio):
some_handler(chunk)
time_consumer() # time consuming process
In that case, model will just sleep and not produce any input, because it runs on the same interpreter and it's locked by the current thread. I think with callbacks you can overcome this issue using some asynchronous iterators and running the generator in a separate process.
If you wish I can make a PR or a gist to show you the concept. I'm not a maintainer though...
Here's an example of how it could look if you combine callbacks with generators. You may do the same without async, just replace asyncio.sleep with time.sleep and remove async / await keywords.
import asyncio
from multiprocessing import Queue, Process
import whisper
def transcribe_producer(queue: Queue):
model = whisper.load_model("tiny")
audio = whisper.load_audio("javascript.wav")
model.transcribe(
audio,
language="Russian",
beam_size=5,
best_of=5,
segment_callback=queue.put,
)
async def transcribe_async(queue_check_interval: int = 1):
q = Queue()
p = Process(target=transcribe_producer, args=(q,))
process_finished = False
p.start()
print(f"Started the process {p.pid}")
while not process_finished or not q.empty():
if q.empty():
print(f"No results available, sleeping for {queue_check_interval} seconds")
await asyncio.sleep(queue_check_interval)
continue
yield q.get(timeout=0)
# Try joining process to see if it's finished
if not process_finished:
p.join(timeout=0)
if not p.is_alive():
print(f"Process with id = {p.pid} was finished, yielding the rest of the queue")
process_finished = True
async def io_blocking_task(processing_time: int = 1):
print(f"IO blocking task for {processing_time} secs")
await asyncio.sleep(processing_time)
async def main():
async for item in transcribe_async(1):
print(f"Got an element: {item}")
await io_blocking_task(2)
print("Async task finished")
# This line is crucial to not make processes recursively
if __name__ == "__main__":
asyncio.run(main())
Hi, thank you for your great work. Any news on this?
any update on this?