diart
diart copied to clipboard
How to add ASR
I've been trying a while to get streaming diarization + recognition working together. I've failed trying many different strategies for streaming wav to pyannote. I started with an example that did this for wav files and wound up here. Without yet knowing/being accustomed to rx, my latest attempt:
import random
import time
from queue import SimpleQueue
from typing import Tuple, Text, Optional, Iterable
import sounddevice as sd
import torch
from einops import rearrange
from pyannote.audio.core.io import Audio, AudioFile
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
from pyannote.audio import Pipeline
from transformers import pipeline
class AudioSource:
"""Represents a source of audio that can start streaming via the `stream` property.
Parameters
----------
uri: Text
Unique identifier of the audio source.
sample_rate: int
Sample rate of the audio source.
"""
def __init__(self, uri: Text, sample_rate: int):
self.uri = uri
self.sample_rate = sample_rate
self.stream = Subject()
@property
def is_regular(self) -> bool:
"""Whether the stream is regular. Defaults to False.
A regular stream always yields the same amount of samples per event.
"""
return False
@property
def duration(self) -> Optional[float]:
"""The duration of the stream if known. Defaults to None (unknown duration)"""
return None
def read(self):
"""Start reading the source and yielding samples through the stream"""
raise NotImplementedError
class AudioFileReader:
"""Represents a method for reading an audio file.
Parameters
----------
sample_rate: int
Sample rate of the audio file.
"""
def __init__(self, sample_rate: int):
self.audio = Audio(sample_rate=sample_rate, mono=True)
self.resolution = 1 / sample_rate
@property
def sample_rate(self) -> int:
return self.audio.sample_rate
@property
def is_regular(self) -> bool:
"""Whether the reading is regular. Defaults to False.
A regular reading method always yields the same amount of samples."""
return False
def get_duration(self, file: AudioFile) -> float:
return self.audio.get_duration(file)
def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
"""Return an iterable over the file's samples"""
raise NotImplementedError
class RegularAudioFileReader(AudioFileReader):
"""Reads a file always yielding the same number of samples with a given step.
Parameters
----------
sample_rate: int
Sample rate of the audio file.
window_duration: float
Duration of each chunk of samples (window) in seconds.
step_duration: float
Step duration between chunks in seconds.
"""
def __init__(
self,
sample_rate: int,
window_duration: float,
step_duration: float,
):
super().__init__(sample_rate)
self.window_duration = window_duration
self.step_duration = step_duration
self.window_samples = int(round(self.window_duration * self.sample_rate))
self.step_samples = int(round(self.step_duration * self.sample_rate))
@property
def is_regular(self) -> bool:
return True
def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
waveform, _ = self.audio(file)
chunks = rearrange(
waveform.unfold(1, self.window_samples, self.step_samples),
"channel chunk frame -> chunk channel frame",
).numpy()
for i, chunk in enumerate(chunks):
w = SlidingWindow(
start=i * self.step_duration,
duration=self.resolution,
step=self.resolution
)
yield SlidingWindowFeature(chunk.T, w)
class IrregularAudioFileReader(AudioFileReader):
"""Reads an audio file yielding a different number of non-overlapping samples in each event.
This class is useful to simulate how a system would work in unreliable reading conditions.
Parameters
----------
sample_rate: int
Sample rate of the audio file.
refresh_rate_range: (float, float)
Duration range within which to determine the number of samples to yield (in seconds).
simulate_delay: bool
Whether to simulate that the samples are being read in real time before they are yielded.
Defaults to False (no delay).
"""
def __init__(
self,
sample_rate: int,
refresh_rate_range: Tuple[float, float],
simulate_delay: bool = False,
):
super().__init__(sample_rate)
self.start, self.end = refresh_rate_range
self.delay = simulate_delay
def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
waveform, _ = self.audio(file)
total_samples = waveform.shape[1]
i = 0
while i < total_samples:
rnd_duration = random.uniform(self.start, self.end)
if self.delay:
time.sleep(rnd_duration)
num_samples = int(round(rnd_duration * self.sample_rate))
last_i = i
i += num_samples
yield waveform[:, last_i:i]
class FileAudioSource(AudioSource):
"""Represents an audio source tied to a file.
Parameters
----------
file: AudioFile
The file to stream.
uri: Text
Unique identifier of the audio source.
reader: AudioFileReader
Determines how the file will be read.
"""
def __init__(
self,
file: AudioFile,
uri: Text,
reader: AudioFileReader
):
super().__init__(uri, reader.sample_rate)
self.reader = reader
self._duration = self.reader.get_duration(file)
self.file = file
@property
def is_regular(self) -> bool:
"""The regularity depends on the reader"""
return self.reader.is_regular
@property
def duration(self) -> Optional[float]:
"""The duration of a file is known"""
return self._duration
def read(self):
"""Send each chunk of samples through the stream"""
for waveform in self.reader.iterate(self.file):
try:
self.stream.on_next(waveform)
except Exception as e:
self.stream.on_error(e)
self.stream.on_completed()
class MicrophoneAudioSource(AudioSource):
"""Represents an audio source tied to the default microphone available"""
def __init__(self, sample_rate: int):
super().__init__("live_recording", sample_rate)
self.block_size = 1024
self.mic_stream = sd.InputStream(
channels=1,
samplerate=sample_rate,
latency=0,
blocksize=self.block_size,
callback=self._read_callback
)
self.queue = SimpleQueue()
def _read_callback(self, samples, *args):
self.queue.put_nowait(samples[:, [0]].T)
def read(self):
sttapi = STTAPI()
self.mic_stream.start()
while self.mic_stream:
try:
mic_next = self.queue.get()
print(mic_next)
mic_data = torch.from_numpy(mic_next)
diarized, full_text = sttapi.segmentation({
"waveform": mic_data,
"sample_rate": 16000
})
print(diarized, full_text)
except Exception as e:
print(e) # FIXME: Expected more than 1 spatial element when training, got input size torch.Size([1, 60, 1])
self.stream.on_error(e)
break
self.stream.on_completed()
class STTAPI:
def __init__(self,):
self.asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-large-960h-lv60-self",
feature_extractor="facebook/wav2vec2-large-960h-lv60-self",
)
self.speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation")
def segmentation(self, audio):
speaker_output = self.speaker_segmentation(audio)
text_output = self.asr(audio, return_timestamps="word")
full_text = text_output['text'].lower()
chunks = text_output['chunks']
diarized_output = ""
i = 0
for turn, _, speaker in speaker_output.itertracks(yield_label=True):
diarized = ""
while i < len(chunks) and chunks[i]['timestamp'][1] <= turn.end:
diarized += chunks[i]['text'].lower() + ' '
i += 1
if diarized != "":
diarized_output += "{}: ''{}'' from {:.3f}-{:.3f}\n".format(speaker, diarized, turn.start, turn.end)
return diarized_output, full_text
def sound_device_info(device_id=None, in_or_out: str = 'input'):
return sd.query_devices(device_id, in_or_out)
def main():
input_device = sound_device_info()
samplerate = int(input_device['default_samplerate'])
mic = MicrophoneAudioSource(samplerate)
mic.read()
if __name__ == '__main__':
main()
for which I get Expected more than 1 spatial element when training, got input size torch.Size([1, 60, 1])
. It maybe out of scope here but asking anywhere else will be useless. If anyone is so inclined, it would be much appreciated.
Hi @bitnom, can you post the full stacktrace of the error? Are you sure that mic_data
has the correct shape that the segmentation and/or ASR model is expecting?
@bitnom hi have you solved this, I am trying same thing. Any suggestion will be useful. I am trying to using Azure Recognise along with it