diart
diart copied to clipboard
Words repeated in whisper transcription with initial_prompt
Hello all,
Thank you for doing this great work! I just updated this code to use faster whisper and I facing repeated words issue when I use initial_prompt param in the transcription method. the issue happened when I end my talk in some specific word something like Okay.
The issue:
Okay. Okay. Okay. Okay. Okay.
Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
Yeah.
Okay. Okay. Okay. Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
Okay. Okay. Okay. Okay. Okay.
The code:
import logging
import os
import sys
import traceback
from contextlib import contextmanager
import torch
import whisper
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
# import whisper_timestamped as whisper
from faster_whisper import WhisperModel
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import WebSocketAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
def concat(chunks, collar=0.05):
"""
Concatenate predictions and audio
given a list of `(diarization, waveform)` pairs
and merge contiguous single-speaker regions
with pauses shorter than `collar` seconds.
"""
first_annotation = chunks[0][0]
first_waveform = chunks[0][1]
annotation = Annotation(uri=first_annotation.uri)
data = []
for ann, wav in chunks:
annotation.update(ann)
data.append(wav.data)
annotation = annotation.support(collar)
window = SlidingWindow(
first_waveform.sliding_window.duration,
first_waveform.sliding_window.step,
first_waveform.sliding_window.start,
)
data = np.concatenate(data, axis=0)
return annotation, SlidingWindowFeature(data, window)
def colorize_transcription(transcription):
"""
Unify a speaker-aware transcription represented as
a list of `(speaker: int, text: str)` pairs
into a single text colored by speakers.
"""
colors = 2 * [
"bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
"yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
]
result = []
for speaker, text in transcription:
if speaker == -1:
# No speakerfound for this text, use default terminal color
result.append(text)
else:
result.append(f"[{colors[speaker]}]{text}")
return "\n".join(result)
@contextmanager
def suppress_stdout():
# Auxiliary function to suppress Whisper logs (it is quite verbose)
# All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
with open(os.devnull, "w") as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
class WhisperTranscriber:
def __init__(self, model="small", device=None):
self.model = WhisperModel(model, device="cuda", compute_type="float16")
#self.model = whisper.load_model(model, device=device)
self._buffer = ""
def transcribe(self, waveform):
"""Transcribe audio using Whisper"""
# Pad/trim audio to fit 30 seconds as required by Whisper
audio = waveform.data.astype("float32").reshape(-1)
audio = whisper.pad_or_trim(audio)
# Transcribe the given audio while suppressing logs
with suppress_stdout():
segments, _ = self.model.transcribe(
audio=audio,
language='en',
# We use past transcriptions to condition the model
#initial_prompt=self._buffer,
word_timestamps=True
)
return segments
def identify_speakers(self, segments, diarization, time_shift):
"""Iterate over transcription segments to assign speakers"""
speaker_captions = []
text = ""
for segment in segments:
text += segment.text
# Crop diarization to the segment timestamps
start = time_shift + segment.words[0].start
end = time_shift + segment.words[-1].end
dia = diarization.crop(Segment(start, end))
# Assign a speaker to the segment based on diarization
speakers = dia.labels()
num_speakers = len(speakers)
if num_speakers == 0:
# No speakers were detected
caption = (-1, segment.text)
elif num_speakers == 1:
# Only one speaker is active in this segment
spk_id = int(speakers[0].split("speaker")[1])
caption = (spk_id, segment.text)
else:
# Multiple speakers, select the one that speaks the most
max_speaker = int(np.argmax([
dia.label_duration(spk) for spk in speakers
]))
caption = (max_speaker, segment.text)
speaker_captions.append(caption)
return speaker_captions, text
def __call__(self, diarization, waveform):
# Step 1: Transcribe
segments = self.transcribe(waveform)
# The audio may not be the beginning of the conversation
time_shift = waveform.sliding_window.start
# Step 2: Assign speakers
speaker_transcriptions, text = self.identify_speakers(segments, diarization, time_shift)
# Update transcription buffer
self._buffer += text
return speaker_transcriptions
# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
# If you have a GPU, you can also set device=torch.device("cuda")
config = PipelineConfig(
duration=5,
step=0.5,
latency="min",
tau_active=0.5,
rho_update=0.1,
delta_new=0.57,
device=torch.device("cpu")
)
dia = OnlineSpeakerDiarization(config)
source = WebSocketAudioSource(config.sample_rate, "0.0.0.0", 8081)
# If you have a GPU, you can also set device="cuda"
asr = WhisperTranscriber(model="base")
# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)
# Chain of operations to apply on the stream of microphone audio
source.stream.pipe(
# Format audio stream to sliding windows of 5s with a step of 500ms
dops.rearrange_audio_stream(
config.duration, config.step, config.sample_rate
),
# Wait until a batch is full
# The output is a list of audio chunks
ops.buffer_with_count(count=batch_size),
# Obtain diarization prediction
# The output is a list of pairs `(diarization, audio chunk)`
ops.map(dia),
# Concatenate 500ms predictions/chunks to form a single 2s chunk
ops.map(concat),
# Ignore this chunk if it does not contain speech
ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
# Obtain speaker-aware transcriptions
# The output is a list of pairs `(speaker: int, caption: str)`
ops.starmap(asr),
# Color transcriptions according to the speaker
# The output is plain text with color references for rich
ops.map(colorize_transcription),
).subscribe(
on_next=rich.print, # print colored text
on_error=lambda _: traceback.print_exc() # print stacktrace if error
)
print("Listening...")
source.read()
Any help ?
Great code!!! My guess is the cpu is too slow. I ran into that issue with whisper-realtime before switching over to cuda. Works fine once I switched to cuda and my mic as audio source.
Do we have any solution for this problem when using CPU?
Hello all,
Thank you for doing this great work! I just updated this code to use faster whisper and I facing repeated words issue when I use initial_prompt param in the transcription method. the issue happened when I end my talk in some specific word something like Okay.
The issue:
Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Yeah. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay. Okay.The code:
import logging import os import sys import traceback from contextlib import contextmanager import torch import whisper import diart.operators as dops import numpy as np import rich import rx.operators as ops # import whisper_timestamped as whisper from faster_whisper import WhisperModel from diart import OnlineSpeakerDiarization, PipelineConfig from diart.sources import WebSocketAudioSource from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment def concat(chunks, collar=0.05): """ Concatenate predictions and audio given a list of `(diarization, waveform)` pairs and merge contiguous single-speaker regions with pauses shorter than `collar` seconds. """ first_annotation = chunks[0][0] first_waveform = chunks[0][1] annotation = Annotation(uri=first_annotation.uri) data = [] for ann, wav in chunks: annotation.update(ann) data.append(wav.data) annotation = annotation.support(collar) window = SlidingWindow( first_waveform.sliding_window.duration, first_waveform.sliding_window.step, first_waveform.sliding_window.start, ) data = np.concatenate(data, axis=0) return annotation, SlidingWindowFeature(data, window) def colorize_transcription(transcription): """ Unify a speaker-aware transcription represented as a list of `(speaker: int, text: str)` pairs into a single text colored by speakers. """ colors = 2 * [ "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1", "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2" ] result = [] for speaker, text in transcription: if speaker == -1: # No speakerfound for this text, use default terminal color result.append(text) else: result.append(f"[{colors[speaker]}]{text}") return "\n".join(result) @contextmanager def suppress_stdout(): # Auxiliary function to suppress Whisper logs (it is quite verbose) # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/ with open(os.devnull, "w") as devnull: old_stdout = sys.stdout sys.stdout = devnull try: yield finally: sys.stdout = old_stdout class WhisperTranscriber: def __init__(self, model="small", device=None): self.model = WhisperModel(model, device="cuda", compute_type="float16") #self.model = whisper.load_model(model, device=device) self._buffer = "" def transcribe(self, waveform): """Transcribe audio using Whisper""" # Pad/trim audio to fit 30 seconds as required by Whisper audio = waveform.data.astype("float32").reshape(-1) audio = whisper.pad_or_trim(audio) # Transcribe the given audio while suppressing logs with suppress_stdout(): segments, _ = self.model.transcribe( audio=audio, language='en', # We use past transcriptions to condition the model #initial_prompt=self._buffer, word_timestamps=True ) return segments def identify_speakers(self, segments, diarization, time_shift): """Iterate over transcription segments to assign speakers""" speaker_captions = [] text = "" for segment in segments: text += segment.text # Crop diarization to the segment timestamps start = time_shift + segment.words[0].start end = time_shift + segment.words[-1].end dia = diarization.crop(Segment(start, end)) # Assign a speaker to the segment based on diarization speakers = dia.labels() num_speakers = len(speakers) if num_speakers == 0: # No speakers were detected caption = (-1, segment.text) elif num_speakers == 1: # Only one speaker is active in this segment spk_id = int(speakers[0].split("speaker")[1]) caption = (spk_id, segment.text) else: # Multiple speakers, select the one that speaks the most max_speaker = int(np.argmax([ dia.label_duration(spk) for spk in speakers ])) caption = (max_speaker, segment.text) speaker_captions.append(caption) return speaker_captions, text def __call__(self, diarization, waveform): # Step 1: Transcribe segments = self.transcribe(waveform) # The audio may not be the beginning of the conversation time_shift = waveform.sliding_window.start # Step 2: Assign speakers speaker_transcriptions, text = self.identify_speakers(segments, diarization, time_shift) # Update transcription buffer self._buffer += text return speaker_transcriptions # Suppress whisper-timestamped warnings for a clean output logging.getLogger("whisper_timestamped").setLevel(logging.ERROR) # If you have a GPU, you can also set device=torch.device("cuda") config = PipelineConfig( duration=5, step=0.5, latency="min", tau_active=0.5, rho_update=0.1, delta_new=0.57, device=torch.device("cpu") ) dia = OnlineSpeakerDiarization(config) source = WebSocketAudioSource(config.sample_rate, "0.0.0.0", 8081) # If you have a GPU, you can also set device="cuda" asr = WhisperTranscriber(model="base") # Split the stream into 2s chunks for transcription transcription_duration = 2 # Apply models in batches for better efficiency batch_size = int(transcription_duration // config.step) # Chain of operations to apply on the stream of microphone audio source.stream.pipe( # Format audio stream to sliding windows of 5s with a step of 500ms dops.rearrange_audio_stream( config.duration, config.step, config.sample_rate ), # Wait until a batch is full # The output is a list of audio chunks ops.buffer_with_count(count=batch_size), # Obtain diarization prediction # The output is a list of pairs `(diarization, audio chunk)` ops.map(dia), # Concatenate 500ms predictions/chunks to form a single 2s chunk ops.map(concat), # Ignore this chunk if it does not contain speech ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0), # Obtain speaker-aware transcriptions # The output is a list of pairs `(speaker: int, caption: str)` ops.starmap(asr), # Color transcriptions according to the speaker # The output is plain text with color references for rich ops.map(colorize_transcription), ).subscribe( on_next=rich.print, # print colored text on_error=lambda _: traceback.print_exc() # print stacktrace if error ) print("Listening...") source.read()Any help ?
Hi, I tried using the code provided by you, but it is not printing any result other than Listening on the terminal. I think that my mic is also not getting enabled also at the same time when I am running the code.
Any possible solution for the same.
Hi, do we have any solution for this issue, as I also after some secs of recording get the repeated results.
I am using device as "CPU"
Hello, the issue with repeating words could be caused by many issues:
- CPU is taking too long, so reading from the microphone is being interrupted by Whisper, as pointed out by @brian-j-connolly-aero. If you don't have a GPU this could be solved by using an optimized whisper version (whisper.cpp, faster-whisper, etc) and/or running whisper in a separate process
- Whisper hallucinates because of short chunks, noise, music, etc. You could try playing with the decoding configuration (see for example the compression threshold). Another option is to buffer audio to give whisper a bigger context and then merge/update transcriptions. This should reduce the chances of weird behavior
Hello, the issue with repeating words could be caused by many issues:
- CPU is taking too long, so reading from the microphone is being interrupted by Whisper, as pointed out by @brian-j-connolly-aero. If you don't have a GPU this could be solved by using an optimized whisper version (whisper.cpp, faster-whisper, etc) and/or running whisper in a separate process
- Whisper hallucinates because of short chunks, noise, music, etc. You could try playing with the decoding configuration (see for example the compression threshold). Another option is to buffer audio to give whisper a bigger context and then merge/update transcriptions. This should reduce the chances of weird behavior
I am currently passing the audio file directly and not using the microphone as source. Does this account in any way for this strange behaviour of repeating the words?
@shanky100 most probably not. Make sure to check if Whisper hallucinates when transcribing the entire file at once (instead of streaming it). It's possible that the chunks are too short. In that case I suggest you try increasing the ASR chunk size or buffering the audio.
You can also try removing the text conditioning. If Whisper hallucinates at the beginning and you keep conditioning it on the hallucination, you may be seeing a snowballing effect