LiveTalking icon indicating copy to clipboard operation
LiveTalking copied to clipboard

Piper TTS streaming integration

Open karayakar opened this issue 4 months ago • 2 comments

Hello everyone any official piperTTS integration? I have integrated by my own. but because of thread loop it delays sometime.

Demo

class PiperTTS(BaseTTS):
    """
    WebSocket TTS istemcisi:
      - txt_to_audio(msg) çağrıldığında WS'e bağlanır (veya her istek için ayrı bağlanır),
      - metni gönderir,
      - gelen binary'leri generator ile stream_tts'e aktarır.
    """
    def __init__(self, opt, parent: BaseReal):
        super().__init__(opt, parent)

        # Ayarlar (opt içine şu alanları koyabilirsin)
        self.ws_uri = getattr(opt, "WS_URI", "wss://localhost:7022/ws")
        self.ws_insecure = getattr(opt, "WS_INSECURE", False)   # self-signed test için
        self.ws_send_json = getattr(opt, "WS_SEND_JSON", False) # True ise JSON ile gönder
        self.ws_text_key = getattr(opt, "WS_TEXT_KEY", "text")  # JSON modunda alan adı
        self.ws_sample_rate = getattr(opt, "WS_SAMPLE_RATE", self.sample_rate)  # sunucu SR
        self.ws_headers = getattr(opt, "WS_HEADERS", None)      # dict | None

        # asyncio loop'u ayrı thread'de
        self._loop = None
        self._loop_thread = None
        self._start_loop_thread()

    # -------- Loop yönetimi --------

    def _start_loop_thread(self):
        def _runner():
            self._loop = asyncio.new_event_loop()
            asyncio.set_event_loop(self._loop)
            self._stop_event = asyncio.Event()
            try:
                self._loop.run_until_complete(self._stop_event.wait())
            finally:
                pending = asyncio.all_tasks(loop=self._loop)
                for t in pending:
                    t.cancel()
                with contextlib.suppress(Exception):
                    self._loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
                self._loop.close()

        import contextlib
        self._loop_thread = Thread(target=_runner, name="PiperTTS-Loop", daemon=True)
        self._loop_thread.start()

    def _ensure_loop(self):
        if self._loop is None:
            raise RuntimeError("Async loop not started")

    def _submit_coro(self, coro):
        self._ensure_loop()
        return asyncio.run_coroutine_threadsafe(coro, self._loop)

    def close(self):
        # İstersen BaseTTS kapatma akışında çağır
        if self._loop and not getattr(self, "_loop_closed", False):
            self._submit_coro(self._signal_stop())
            self._loop_thread.join(timeout=2)
            self._loop_closed = True

    async def _signal_stop(self):
        self._stop_event.set()

    # -------- Public API (BaseTTS ) --------
    def _resample_factor_linear(x: np.ndarray, factor: float) -> np.ndarray:
        """
        Resample by 'factor' (factor>1 => higher pitch/shorter; factor<1 => lower pitch/longer)
        Pure NumPy linear interpolation.
        """
        if factor == 1.0 or x.size == 0:
            return x.astype(np.float32, copy=False)

        # new length after resampling
        new_len = max(1, int(round(x.shape[0] / factor)))
        # interpolation grid
        xp = np.linspace(0.0, 1.0, num=x.shape[0], endpoint=False, dtype=np.float64)
        xq = np.linspace(0.0, 1.0, num=new_len,   endpoint=False, dtype=np.float64)
        y  = np.interp(xq, xp, x.astype(np.float64, copy=False))
        return y.astype(np.float32, copy=False)

    def txt_to_audio(self, msg):
        """
        BaseTTS.process_tts içinde çağrılır.
        Mesaj metnini WS'e gönderir ve gelen byteları stream_tts ile işler.
        """
        text, textevent = msg
        audio_stream = self.ws_speech(text)  # sync generator
        self.stream_tts(audio_stream, msg)

    # -------- Streaming generator (sync) --------

    def ws_speech(self, text):
        """
        Sync generator: içerde asyncio coroutine'i planlar,
        baytları thread-safe kuyruğa aktarır ve yield eder.
        """
        out_q: thq.Queue = thq.Queue(maxsize=64)  # binary chunk kuyruğu
        #done = Thread.Event()

        # WS coroutine'ini başlat
        fut = self._submit_coro(self._ws_stream(text, out_q))

        # Kuyruktan okuyup yield et
        first = True
        try:
            while True:
                try:
                    item = out_q.get(timeout=30.0)  # 30 sn sessizlikte çıkmak istersen ayarla
                except thq.Empty:
                    # Uzun süre veri gelmediyse akışı kes
                    break

                if item is None:
                    # Bitiş işareti
                    break

                if isinstance(item, Exception):
                    # Hata olduysa yükselt
                    raise item

                if first:
                    # İlk chunk süresini ölçmek istersen burada loglayabilirsin
                    first = False

                yield item
        finally:
            # Güvenli kapatma (coroutine hala çalışıyorsa iptal ettir)
            if not fut.done():
                fut.cancel()

   

    async def _ws_stream(self, text: str, out_q: thq.Queue):
        """
        Tek bir istek için bağlan -> gönder -> binary al -> kuyrukla -> bitir.
        Eğer sürekli bağlantı istiyorsan bu fonksiyonu persistent hale getirebiliriz.
        """
        # SSL context
        ssl_ctx = None
        if self.ws_uri.startswith("wss://"):
            ssl_ctx = ssl.create_default_context()
            if self.ws_insecure:
                ssl_ctx.check_hostname = False
                ssl_ctx.verify_mode = ssl.CERT_NONE

        headers = self.ws_headers
        uri="wss://yourdomain/ws?" #<----- change to your websocket piperTTS stream
        ssl_ctx, sni = make_ssl_context(mode="insecure")
        try:
            async with websockets.connect(uri, ssl=ssl_ctx, server_hostname=sni, max_size=None) as ws:
                # Metni gönder
                if self.ws_send_json:
                    text_to_send={"type":"responsebin", "payload":{"msg":f"{text}","pitch":1.5}}
                    json_str = json.dumps(text_to_send)
                    await ws.send(json_str)
                else:
                    text_to_send={"type":"responsebin", "payload":{"msg":f"{text}","pitch":1.5}}
                    json_str = json.dumps(text_to_send)
                    await ws.send(json_str)

                # Sunucudan gelenleri oku
                async for msg in ws:
                    if isinstance(msg, (bytes, bytearray)):
                        # Binary: doğrudan kuyrukla
                        try:
                            out_q.put_nowait(bytes(msg))
                        except thq.Full:
                            # Tüketici yetişemiyorsa bekle
                            out_q.put(bytes(msg))
                    else:
                        # Text kontrol mesajı ise (ör: {"type":"end"}), yakala
                        try:
                            data = json.loads(msg)
                            t = data.get("type")
                            if t in ("end", "done", "close", "ok"):
                                break
                            # ihtiyaca göre başka tipleri işleyebilirsin
                        except Exception:
                            # Düz metin gelirse yoksay
                            pass

        except asyncio.CancelledError:
            # Üst taraf iptal etti
            pass
        except Exception as e:
            # Hata durumunda tüketiciye ilet
            try:
                out_q.put_nowait(e)
            except thq.Full:
                out_q.put(e)
        finally:
            # Bitiş sinyali
            try:
                out_q.put_nowait(None)
            except thq.Full:
                out_q.put(None)
            #done_evt.set()


karayakar avatar Aug 25 '25 08:08 karayakar

异步事件循环在独立线程运行 你通过 _start_loop_thread 在新线程中跑 asyncio loop,并用 asyncio.run_coroutine_threadsafe 提交协程。这种方式理论上可以并发,但会有跨线程上下文切换(线程切换、队列通信),如果消息量大或每次都新建连接,容易出现延迟。

WebSocket 每次请求新建连接 在 _ws_stream 里,每次 TTS 请求都新建 WebSocket 连接(async with websockets.connect(...) as ws:)。频繁连接与断开会导致显著的延迟,特别是在高并发或网络波动时。官方/高效实现通常会复用连接或引入连接池。

队列通信阻塞 out_q.put_nowait() 和 out_q.put() 用于跨线程传输音频 chunk。如果消费端处理较慢,队列写入会阻塞,导致延迟——比如队列满时,put 会阻塞主线程。

heyyyyou avatar Aug 27 '25 02:08 heyyyyou

The asynchronous event loop runs in a separate thread. You run the asyncio loop in a new thread using _start_loop_thread and submit the coroutine using asyncio.run_coroutine_threadsafe . While this approach theoretically allows for concurrency, it involves cross-thread context switching (thread switching and queue communication). This can lead to latency if the message volume is large or a new connection is created each time.

Each WebSocket request creates a new connection. In _ws_stream , each TTS request creates a new WebSocket connection (async with websockets.connect(...) as ws:). Frequent connections and disconnections can cause significant latency, especially under high concurrency or network fluctuations. Official/efficient implementations typically reuse connections or introduce connection pooling.

Queue communication blocks out_q.put_nowait() and out_q.put() for transferring audio chunks across threads. If the consumer is slow, queue writes will block, causing delays. For example, if the queue is full, put() will block the main thread.

Thank you, improving the code

import asyncio
import json
import ssl
import contextlib
from threading import Thread
import queue as thq
import socket
import time
import logging

import numpy as np
import resampy
import websockets


# Fix Unicode logging issues for Turkish characters
def setup_unicode_logging():
    """Configure logging to handle Turkish characters properly"""
    try:
        # Force UTF-8 encoding for console output
        import sys
        if hasattr(sys.stdout, 'reconfigure'):
            sys.stdout.reconfigure(encoding='utf-8')
        if hasattr(sys.stderr, 'reconfigure'):
            sys.stderr.reconfigure(encoding='utf-8')
            
        # Configure logging with UTF-8 encoding
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.StreamHandler(sys.stdout)
            ]
        )
        
        # Alternative: Use a custom formatter that handles encoding
        class SafeFormatter(logging.Formatter):
            def format(self, record):
                try:
                    return super().format(record)
                except UnicodeEncodeError:
                    record.msg = str(record.msg).encode('ascii', 'ignore').decode('ascii')
                    return super().format(record)
        
        # Apply safe formatter to all handlers
        for handler in logging.getLogger().handlers:
            handler.setFormatter(SafeFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
            
    except Exception as e:
        print(f"Logging setup warning: {e}")

# Call this at module level to fix Unicode issues
setup_unicode_logging()


class PiperTTS(BaseTTS):
    """
    WebSocket TTS istemcisi - Your EXACT original working code
    """

    def __init__(self, opt, parent: BaseReal):
        super().__init__(opt, parent)

        # ---- yapılandırma (opt üstünden) ----
        self.ws_uri = getattr(opt, "WS_URI", "wss://")
        self.ws_insecure = getattr(opt, "WS_INSECURE", False)
        self.ws_send_json = getattr(opt, "WS_SEND_JSON", False)
        self.ws_text_key = getattr(opt, "WS_TEXT_KEY", "text")
        self.ws_sample_rate = getattr(opt, "WS_SAMPLE_RATE", self.sample_rate)
        self.ws_headers = getattr(opt, "WS_HEADERS", None)

        # ---- OPTIMIZATION: Pre-resolve DNS and prepare SSL ----
        self._uri = "wss://..............................................."
        self._ssl_ctx = None
        self._sni = None
        self._host_ip = None
        
        # Pre-resolve DNS and SSL context in background
        self._prepare_connection_async()

        # ---- asyncio loop'u ayrı bir thread'de çalıştır ----
        self._loop = None
        self._loop_thread = None
        self._stop_event = None
        self._loop_closed = False
        self._start_loop_thread()

    def _prepare_connection_async(self):
        """Pre-resolve DNS and prepare SSL context for faster connections"""
        def _prepare():
            try:
                # DNS resolution
                host = "yapzek.ai"
                self._host_ip = socket.gethostbyname(host)
                print(f"Pre-resolved DNS: {host} -> {self._host_ip}")
                
                # SSL context
                self._ssl_ctx, self._sni = make_ssl_context(mode="insecure")
                print("SSL context prepared")
                
            except Exception as e:
                print(f"Connection prep failed: {e}")
                # Fallback to runtime resolution
                self._host_ip = None
        
        # Run in background thread to not block initialization
        import threading
        prep_thread = threading.Thread(target=_prepare, daemon=True)
        prep_thread.start()

        # ---- asyncio loop'u ayrı bir thread'de çalıştır ----
        self._loop = None
        self._loop_thread = None
        self._stop_event = None
        self._loop_closed = False
        self._start_loop_thread()

    # ----------------- Loop yönetimi -----------------

    def _start_loop_thread(self):
        def _runner():
            self._loop = asyncio.new_event_loop()
            asyncio.set_event_loop(self._loop)
            self._stop_event = asyncio.Event()
            try:
                self._loop.run_until_complete(self._stop_event.wait())
            finally:
                pending = asyncio.all_tasks(loop=self._loop)
                for t in pending:
                    t.cancel()
                with contextlib.suppress(Exception):
                    self._loop.run_until_complete(
                        asyncio.gather(*pending, return_exceptions=True)
                    )
                self._loop.close()

        self._loop_thread = Thread(target=_runner, name="PiperTTS-Loop", daemon=True)
        self._loop_thread.start()

    def _ensure_loop(self):
        if self._loop is None:
            raise RuntimeError("Async loop not started")

    def _submit_coro(self, coro):
        self._ensure_loop()
        return asyncio.run_coroutine_threadsafe(coro, self._loop)

    def close(self):
        if self._loop and not self._loop_closed:
            self._submit_coro(self._signal_stop())
            self._loop_thread.join(timeout=2)
            self._loop_closed = True

    async def _signal_stop(self):
        if self._stop_event and not self._stop_event.is_set():
            self._stop_event.set()

    # ----------------- BaseTTS entegrasyonu -----------------

    def txt_to_audio(self, msg):
        """
        BaseTTS.process_tts içinde çağrılır.
        Metni WS'e gönderir ve gelen byteları stream_tts ile işler.
        """
        start_time = time.time()
        text, textevent = msg
        
        # Handle different text types safely for logging
        try:
            if isinstance(text, dict):
                safe_text = str(text)[:30]
            elif isinstance(text, str):
                safe_text = text.encode('ascii', 'ignore').decode('ascii')[:30]
            else:
                safe_text = str(text)[:30]
        except:
            safe_text = "unknown_format"
            
        print(f"\n=== TTS REQUEST START === {safe_text}...")
        
        audio_stream = self.ws_speech(text)  # sync generator
        
        setup_time = time.time() - start_time
        print(f"Stream setup: {setup_time*1000:.0f}ms")
        
        self.stream_tts(audio_stream, msg)
        
        total_time = time.time() - start_time
        print(f"=== TTS REQUEST COMPLETE === {total_time*1000:.0f}ms total\n")

    # ----------------- Sync generator -----------------

    def ws_speech(self, text):
        """
        Isolated sync generator - each request is independent
        """
        method_start = time.time()
        request_id = int(time.time() * 1000) % 100000  # Unique request ID
        out_q: thq.Queue = thq.Queue(maxsize=32)

        print(f"  → ws_speech called [REQ-{request_id}]")
        
        # Submit coroutine with request isolation
        submit_start = time.time()
        fut = self._submit_coro(self._ws_stream_isolated(text, out_q, request_id))
        submit_time = time.time() - submit_start
        print(f"  → Coroutine submitted [REQ-{request_id}]: {submit_time*1000:.1f}ms")

        try:
            first_item_received = False
            item_count = 0
            timeout_count = 0
            max_timeouts = 100  # 100 * 100ms = 10 seconds max wait
            
            while True:
                try:
                    get_start = time.time()
                    item = out_q.get(timeout=0.1)  # 100ms timeout
                    get_time = time.time() - get_start
                    timeout_count = 0  # Reset on successful get
                    
                    if not first_item_received and item is not None:
                        first_elapsed = time.time() - method_start
                        print(f"  → First item received [REQ-{request_id}]: {first_elapsed*1000:.1f}ms (queue wait: {get_time*1000:.1f}ms)")
                        first_item_received = True
                    
                except thq.Empty:
                    timeout_count += 1
                    if timeout_count >= max_timeouts:
                        print(f"  → Request timeout [REQ-{request_id}] after {timeout_count * 0.1:.1f}s")
                        break
                    continue

                if item is None:
                    print(f"  → End signal [REQ-{request_id}], total items: {item_count}")
                    break

                if isinstance(item, Exception):
                    print(f"  → Exception [REQ-{request_id}]: {item}")
                    raise item

                item_count += 1
                yield item
                
        finally:
            # Force cancel the coroutine to free up the async loop
            if not fut.done():
                print(f"  → Cancelling coroutine [REQ-{request_id}]")
                fut.cancel()
                try:
                    fut.result(timeout=0.1)  # Wait briefly for cleanup
                except:
                    pass
            
            total_time = time.time() - method_start
            print(f"  → ws_speech finished [REQ-{request_id}]: {total_time*1000:.1f}ms")

    # ----------------- YOUR EXACT ORIGINAL WORKING CODE -----------------

    async def _ws_stream_isolated(self, text: str, out_q: thq.Queue, request_id: int):
        """
        Completely isolated WebSocket request - no interference between requests
        """
        stream_start = time.time()
        print(f"    → _ws_stream started [REQ-{request_id}]")
        
        # Each request gets fresh connection parameters for complete isolation
        uri = "wss://yapzek.ai/ws?role=client&user=d9f5f61a-37cc-4f8a-b958-6db1233d5f2a&room=d9f5f61a-37cc-4f8a-b958-6db1233d5f2a"
        
        # Create fresh SSL context for this request to avoid any shared state
        try:
            ssl_ctx, sni = make_ssl_context(mode="insecure")
        except Exception as e:
            print(f"    → SSL context creation failed [REQ-{request_id}]: {e}")
            try:
                out_q.put_nowait(e)
            except thq.Full:
                pass
            return
            
        try:
            connect_start = time.time()
            print(f"    → Starting WebSocket connection [REQ-{request_id}]...")
            
            # Fresh connection for each request - guaranteed isolation
            async with websockets.connect(
                uri, 
                ssl=ssl_ctx, 
                server_hostname=sni, 
                max_size=None,
                close_timeout=5,  # Faster cleanup
                ping_interval=None  # Disable ping for short requests
            ) as ws:
                connect_time = time.time() - connect_start
                print(f"    → Connected [REQ-{request_id}]: {connect_time*1000:.0f}ms")
                
                send_start = time.time()
                
                # Your EXACT message sending logic
                if self.ws_send_json:
                    text_to_send = text 
                    json_str = json.dumps(text_to_send)
                    await ws.send(json_str)
                else:
                    text_to_send = text 
                    json_str = json.dumps(text_to_send)
                    await ws.send(json_str)
                
                send_time = time.time() - send_start
                print(f"    → Message sent [REQ-{request_id}]: {send_time*1000:.1f}ms")
                
                # Streamlined response handling
                receive_start = time.time()
                response_count = 0
                first_response_time = None
                
                async for msg in ws:
                    if first_response_time is None:
                        first_response_time = time.time() - receive_start
                        print(f"    → First response [REQ-{request_id}]: {first_response_time*1000:.1f}ms")
                    
                    if isinstance(msg, (bytes, bytearray)):
                        response_count += 1
                        # Fast queue operation
                        try:
                            out_q.put_nowait(bytes(msg))
                        except thq.Full:
                            # If queue full, this request is taking too long, abort
                            print(f"    → Queue full, aborting [REQ-{request_id}]")
                            break
                    else:
                        # Check for end signal
                        try:
                            data = json.loads(msg)
                            msg_type = data.get("type")
                            if msg_type in ("end", "done", "close", "ok"):
                                print(f"    → End signal [REQ-{request_id}], {response_count} chunks")
                                break
                            elif msg_type == "error":
                                error_text = data.get("text", "unknown_error")
                                print(f"    → Server error [REQ-{request_id}]: {error_text}")
                                # Continue processing - don't break on server errors
                        except:
                            pass
                
                print(f"    → Response processing complete [REQ-{request_id}]")
                
        except asyncio.CancelledError:
            print(f"    → Request cancelled [REQ-{request_id}]")
            # Don't re-raise CancelledError - handle it gracefully
        except Exception as e:
            error_time = time.time() - stream_start
            print(f"    → Error [REQ-{request_id}] after {error_time*1000:.0f}ms: {e}")
            try:
                out_q.put_nowait(e)
            except thq.Full:
                pass
        finally:
            # Always send end signal
            try:
                out_q.put_nowait(None)
            except thq.Full:
                pass
            
            total_stream_time = time.time() - stream_start
            print(f"    → _ws_stream finished [REQ-{request_id}]: {total_stream_time*1000:.0f}ms")

    # ----------------- FishTTS ile uyumlu streamer -----------------

    def stream_tts(self, audio_stream, msg):
        """
        audio_stream: bytes generator (sunucu PCM16 raw döndürüyorsa direkt,
        WAV döndürüyorsa header-strip gerekebilir; protokole bağlı).
        """
        text, textevent = msg
        first_evt_sent = False

        for chunk in audio_stream:
            if not chunk:
                continue

            # Varsayım: gelen veri PCM16 little-endian
            buf = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767.0

            # Sunucu SR farklıysa yeniden örnekle
            if self.ws_sample_rate != self.sample_rate:
                buf = resampy.resample(x=buf, sr_orig=self.ws_sample_rate, sr_new=self.sample_rate)

            # 20ms framelere böl (self.chunk = 320 @16k)
            streamlen = buf.shape[0]
            idx = 0
            while streamlen >= self.chunk:
                eventpoint = None
                if not first_evt_sent:
                    eventpoint = {'status': 'start', 'text': text, 'msgenvent': textevent}
                    first_evt_sent = True
                self.parent.put_audio_frame(buf[idx:idx + self.chunk], eventpoint)
                streamlen -= self.chunk
                idx += self.chunk

        # akış bitti → end event + bir boş frame
        eventpoint = {'status': 'end', 'text': text, 'msgenvent': textevent}
        self.parent.put_audio_frame(np.zeros(self.chunk, np.float32), eventpoint)

karayakar avatar Aug 28 '25 00:08 karayakar