diart
diart copied to clipboard
Is it possible to pass raw byte data into the pipeline?
I'm currently working on a implementation where I occasionally receive raw byte data from a TCP socket. I want to pass this data into the pipeline but the current AudioSource:s seem to be limited to microphone input and audio files. Does the current version support what I'm trying to implement or do I have to write it myself?
Hi @chanleii,
Your question is very similar to #67. I think the same answer applies.
To summarize: the audio source you need is not implemented, you should implement your own by subclassing AudioSource (you can imitate MicrophoneAudioSource, which is very similar to what you want to achieve).
As I mentioned in #67, I would be glad to merge a PR with this feature if you want to contribute :)
Hi, I'm also working on a similar implementation. So far I have achieve receiving audio data from websocket and push it into pipeline. But I can't get the rest of diarization part to work. My code take a bytes object (which is record with pcm encoding on a website) recieve from websocket and turn it into numpy array that is acceptable by diart. I'm thinking maybe the numpy array is wrong, but I'm sure how to do it. Would you take a look of my codes and give some advise?
The pipeline I build:
import diart.operators as dops
from diart.sinks import RTTMWriter
from diart.sources import AudioSource
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
config = PipelineConfig()
pipeline = OnlineSpeakerDiarization(config)
rttm_writer = RTTMWriter(path="testing.rttm")
source = AudioSource("live_streaming", sample_rate)
observable = pipeline.from_audio_source(source)
observable.pipe(
dops.progress(f"Streaming {source.uri}", total=source.length, leave=True),
).subscribe(rttm_writer)
The data transfer function:
def bytes2nd(data: bytes):
nd = pcm2float(np.frombuffer(data, dtype=np.int8), dtype='float64')
nd = np.reshape(nd, (1, -1)
return nd
The pcm2float function is taken from here
Hi @ckliao-nccu,
I see you're instantiating AudioSource, which is an abstract class. You should always use a concrete audio source. The ones that diart provides (as of today) are: FileAudioSource, PrecalculatedFeaturesAudioSource and MicrophoneAudioSource.
There are 2 key things missing in your code:
AudioSourceshould be replaced by aWebsocketAudioSourcethat you should implement and that inherits fromAudioSource.- You should call
source.read()in order for chunks to start being emitted and passed to the pipeline.
In the current state of your code, calling source.read() will throw an error because read() is not implemented in AudioSource (see here).
Hi @juanmc2005 ,
Thank you for your quick response.
I'm not really familiar with rxpy, but as far as I know using the source.read() will run a blocking process. And that process will block the websocket server process. So instead of using source.read(), I'm trying to emitted chunks by myself when receiving websocket message.
Please correct me if I have any misunderstanding.
Here are my full code:
import os, io
import numpy as np
import soundfile as sf
from tornado import ioloop
from tornado.escape import json_decode
from tornado.web import Application, RequestHandler, url
from tornado.websocket import WebSocketHandler
import diart.operators as dops
from diart.sinks import RTTMWriter
from diart.sources import AudioSource
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
config = PipelineConfig()
pipeline = OnlineSpeakerDiarization(config)
rttm_writer = RTTMWriter(path="testing.rttm")
segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
sample_rate = segmentation.model.get_sample_rate()
class WSHandler(WebSocketHandler):
def open(self):
print("WebSocket opened")
self.source = AudioSource("live_streaming", sample_rate)
self.observable = pipeline.from_audio_source(self.source).pipe(
dops.progress(f"Streaming {self.source.uri}", total=self.source.length, leave=True),
).subscribe(rttm_writer)
# Message received from websocket.
def on_message(self, message):
if message == "complete":
self.source.stream.on_completed()
else :
data, samplerate = sf.read(io.BytesIO(message),
format='RAW',
samplerate=sample_rate,
channels=1,
subtype='FLOAT'
)
data = np.asarray(data)
data = np.reshape(data, (1, -1))
# Emit chunks
self.source.stream.on_next(data)
def on_close(self):
print("WebSocket closed")
class MainHandler(RequestHandler):
def get(self):
self.render("index.html")
def main():
port = os.environ.get("PORT", 8888)
app = Application(
[
url(r"/", MainHandler),
(r"/ws", WSHandler),
]
)
print("Starting server at port: %s" % port)
app.listen(int(port))
ioloop.IOLoop.current().start()
if __name__ == "__main__":
main()
@ckliao-nccu ok it looks like your implementation should be working.
- How is the code failing? Is
dops.regularize_stream()getting executed? Does the pipeline hang or throw an error? Where does it hang or what's the stacktrace? - Does
on_message()run concurrently? If so, does it run on a different thread than the main thread? (I've had issues with rx in this scenario before) - Are you sure that
stream.on_completed()is not called beforestream.on_next()?
- It didn't fail. And
dops.regularize_stream()is also working. When I pass audio signal from websocket, the pipeline just print out and updateStreaming live_streaming: 1it [00:05, 5.95s/it]like normal microphone stream. It will produce the output rttm file, but there is nothing in it. - I'm not sure how to check on this. I will try dig into tornado document to see if I can get it. Also this implementation is rewrite from rxpy example.
stream.on_completed()get called when I pass thecompletedmessage with websocket. And when I called it, it throw the error and stacktrace as below. I believe it's because there is nothing in the rttm file.
Uncaught exception GET /ws (::1)
HTTPServerRequest(protocol='http', host='localhost:8888', method='GET', uri='/ws', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
File "C:\path\to\conda\env\lib\site-packages\tornado\websocket.py", line 635, in _run_callback
result = callback(*args, **kwargs)
File "ws.py", line 60, in on_message
if message == "complete":
File "C:\path\to\conda\env\lib\site-packages\rx\subject\subject.py", line 89, in on_completed
super().on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\observer.py", line 56, in on_completed
self._on_completed_core()
File "C:\path\to\conda\env\lib\site-packages\rx\subject\subject.py", line 97, in _on_completed_core
observer.on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\autodetachobserver.py", line 44, in on_completed
self._on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\autodetachobserver.py", line 44, in on_completed
self._on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\autodetachobserver.py", line 44, in on_completed
self._on_completed()
[Previous line repeated 11 more times]
File "C:\path\to\conda\env\lib\site-packages\rx\core\operators\do.py", line 70, in _on_completed
observer.on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\autodetachobserver.py", line 44, in on_completed
self._on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\operators\do.py", line 70, in _on_completed
observer.on_completed()
File "C:\path\to\conda\env\lib\site-packages\rx\core\observer\autodetachobserver.py", line 44, in on_completed
self._on_completed()
File "C:\path\to\conda\env\lib\site-packages\diart\sinks.py", line 44, in on_completed
self.patch_rttm()
File "C:\path\to\conda\env\lib\site-packages\diart\sinks.py", line 27, in patch_rttm
annotation = list(load_rttm(self.path).values())[0]
IndexError: list index out of range
WebSocket closed
Ok so the progress bar is showing and it's being updated (if I understand correctly). Have you inspected the chunks you're giving to the pipeline? Try using this:
self.observable = pipeline.from_audio_source(self.source).pipe(
dops.progress(f"Streaming {self.source.uri}", total=self.source.length, leave=True),
ops.starmap(lambda annotation, chunk: ann),
ops.do_action(utils.visualize_annotation(config.duration)),
).subscribe(rttm_writer)
and this:
self.observable = pipeline.from_audio_source(self.source).pipe(
dops.progress(f"Streaming {self.source.uri}", total=self.source.length, leave=True),
ops.starmap(lambda annotation, chunk: chunk),
ops.do_action(utils.visualize_feature(config.duration)),
).subscribe(rttm_writer)
visualize_feature and visualize_annotation will plot each chunk and annotation (respectively) before forwarding them to RTTMWriter. This way you can check why the RTTM is empty.
My guess is that the input chunks do not contain speech, either because there's no speech or because of some error. Could you verify this?
@ckliao-nccu would it help your use case if diart provided a WebsocketAudioSource that handles the websocket server loop? or would that be incompatible with what you're trying to achieve?
visualize_annotation doesn't plot anything.
visualize_feature generate this
Both of them will freeze the progress bar.
I think you are right about my input chunks having some problem. I will try another way to record them.
would it help your use case if diart provided a
WebsocketAudioSourcethat handles the websocket server loop?
Of course it will help!! I have already try writing a WebsocketAudioSource as you mentioned in #67. But couldn't get it work, so I'm trying to build this implementation.
visualize_annotation doesn't plot anything.
I just realized there was a typo in the code snippet I sent, it should be "annotation" and not "ann".
Of course it will help!! I have already try writing a WebsocketAudioSource as you mentioned in https://github.com/juanmc2005/StreamingSpeakerDiarization/issues/67. But couldn't get it work, so I'm trying to build this implementation.
I'm thinking of adding this to the next release, would you mind opening a PR with your implementation? Even if it's incomplete that would save me quite some time, then we can mark it as a draft and I can modify the branch or make my own.
I just realized there was a typo in the code snippet I sent, it should be "annotation" and not "ann".
I noticed that and fixed it. And I found out there is somtthing wrong with my mic last time I tried. After I fixed it, visualize_feature throw this error:
...
File "/path/to/conda/env/lib/python3.8/site-packages/rx/core/observer/autodetachobserver.py", line 26, in on_next
self._on_next(value)
File "/path/to/conda/env/lib/python3.8/site-packages/diart/sinks.py", line 91, in on_next
value[0].write_rttm(file)
AttributeError: 'numpy.ndarray' object has no attribute 'write_rttm'
visualize_annotation throw this error:
...
File "/path/to/conda/env/lib/python3.8/site-packages/diart/sinks.py", line 91, in on_next
value[0].write_rttm(file)
File "/path/to/conda/env/lib/python3.8/site-packages/pyannote/core/annotation.py", line 790, in __getitem__
return self._tracks[key[0]][key[1]]
TypeError: 'int' object is not subscriptable
I tried to stich the chunks as a file with the code below. And the output test.wav file can be play as normal, just missing some headers.
class WSHandler(WebSocketHandler):
def open(self):
self.f = open('test.wav', 'ab')
def on_message(self, message):
self.f.write(message)
def on_close(self):
self.f.close()
My new question is that is there any requirement that data chunks need to meet for diart to process?
I'm passing numpy array to on_next look like this:
[[ 1.37131149e-20 -1.64474758e-31 -8.04027892e+11 ... 2.40064952e-32 6.64617913e-28 1.06596241e-36]]
And the bytes data received from websocket look like this:
b'@\xb4b\x0c\x8e ... \xb5\x036\xa3'
Yes the errors are normal. It's because we've removed the chunk or annotation with the starmap operator and RTTMWriter is expecting both.
My new question is that is there any requirement that data chunks need to meet for diart to process?
Yes, you should make sure that the data is a numpy array with shape (1, samples).
If you receive bytes from the websocket you need to turn them into a numpy array and then reshape it:
chunk = np.frombuffer(message, dtype="float").reshape(1, -1)
@ckliao-nccu I just added an experimental WebSocketAudioSource in the feat/ws branch. You can take a look at it here.
Would you mind experimenting with it and telling me if you find any problems? I ran some tests locally with a very simple client and it seems to work as expected. Keep in mind that the server is expecting data in float32 encoded in base64 as a UTF-8 string, so make sure to:
message = base64.b64encode(chunk.astype(np.float32).tobytes()).decode("utf-8")
@juanmc2005 yes it work like a charm!
But I have to edit the uri = f"ws://{name}:{port} in WebSocketAudioSource because the // will cause path error when using RTTMWriter.
Happy to hear it works! Let me know if you run into any troubles with it. I'll make sure to include it in the next release.
But I have to edit the uri = f"ws://{name}:{port} in WebSocketAudioSource because the // will cause path error when using RTTMWriter.
Would you mind changing the uri and opening a pull request?
Closing this because websockets are now implemented in #77 and recently merged to develop. TCP/UDP compatibility (already mentioned in #67) has been pushed to the next release (v0.6)