pipecat icon indicating copy to clipboard operation
pipecat copied to clipboard

Cache TTS to Reduce Costs

Open matheusbento opened this issue 4 months ago • 5 comments

Problem Statement

I’m always frustrated when the same TTS output is regenerated multiple times for the same input text, resulting in unnecessary API calls and increased cost — especially when using providers like ElevenLabs/Cartesia that charge per character. This happens most notably in introductory phrases or common prompts that are repeated across calls or agents. There’s currently no built-in way to intercept and reuse already-synthesized TTS chunks across sessions or calls.

Proposed Solution

Implement a TTS caching layer that stores synthesized audio blobs (e.g., mp3, pcm, etc.) indexed by a deterministic hash of the prompt, voice ID, and synthesis parameters (e.g., speed, style, stability). Before synthesizing new audio, check the cache (e.g., Redis, file-based, or cloud store like S3). If a match exists, skip the TTS provider request and return the cached audio.

Alternative Solutions

•	I’ve experimented with manually storing TTS results in Redis or the filesystem using my own wrapper around the TTS service, but this creates fragmentation and duplication of effort.
•	An external proxy/caching service could be built in front of the TTS provider, but ideally this should be part of the core TTSService pipeline.

Additional Context

Use case: In my application, the bot frequently uses the same greetings, confirmations, and fallback messages. Even small savings on repeated TTS generation would drastically reduce usage cost at scale. I’m also using ElevenLabs with custom voices, which makes caching even more valuable.

"""
Cached Cartesia TTS Service

This module provides a cached version of CartesiaTTSService that integrates
Redis caching directly into the WebSocket TTS flow.
"""

import asyncio
import base64
import json
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
from loguru import logger

from pipecat.frames.frames import (
    Frame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.utils.tracing.service_decorators import traced_tts

from .cartesia_cache import CartesiaCacheService, CachedAudioChunk
from .usage_tracker import get_usage_tracker
from ..models.usage import ComponentType, UsageUnit


class CachedCartesiaTTSService(CartesiaTTSService):
    """
    Cartesia TTS Service with integrated Redis caching.
    
    This service extends the standard CartesiaTTSService to add caching
    of WebSocket responses based on text and voice parameters.
    
    Cache hits return audio instantly (~1ms) while cache misses use the
    normal WebSocket flow and store results for future use.
    """
    
    def __init__(self, agent_cache_key: str = None, enable_cache: bool = True, call_id: str = None, **kwargs):
        """
        Initialize the cached Cartesia TTS service.
        
        Args:
            agent_cache_key: The agent's cache key for cache isolation
            enable_cache: Whether to enable caching (default: True)
            call_id: Call ID for usage tracking
            **kwargs: All other arguments passed to CartesiaTTSService
        """
        super().__init__(**kwargs)
        
        self.enable_cache = enable_cache
        self.cache_service = None
        self.call_id = call_id
        self.usage_tracker = get_usage_tracker()
        
        # Cache collection state for WebSocket messages
        self._current_text = None
        self._current_cache_data = None
        
        if self.enable_cache:
            self.cache_service = CartesiaCacheService(agent_cache_key)
            logger.info(f"✅ Initialized cached Cartesia TTS service for agent: {agent_cache_key}")
        else:
            logger.info("🔧 Cartesia TTS caching disabled")
    
    def set_call_id(self, call_id: str):
        """Set the call ID for usage tracking."""
        self.call_id = call_id
    
    def _get_current_voice_settings(self) -> Dict[str, Any]:
        """Get current voice settings for cache key generation."""
        return {
            "speed": self._settings.get("speed"),
            "emotion": self._settings.get("emotion", []),
            "output_format": self._settings.get("output_format", {}),
        }
    
    async def close_cache(self):
        """Close the cache service connection."""
        if self.cache_service:
            await self.cache_service.close()
    
    async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
        """
        Generate speech from text with caching support.
        
        First checks cache for existing response. If found, returns cached audio
        and word timestamps instantly. If not found, calls parent WebSocket 
        implementation and caches the result.
        
        Args:
            text: The text to synthesize into speech.

        Yields:
            Frame: Audio frames containing the synthesized speech.
        """
        logger.debug(f"CachedCartesiaTTS: Processing text [{text}]")
        
        # Check cache first if enabled
        if self.enable_cache and self.cache_service:
            voice_settings = self._get_current_voice_settings()
            language = self._settings.get("language")
            
            cached_response = await self.cache_service.get_cached_response(
                text=text,
                voice_id=self._voice_id,
                model=self.model_name,
                voice_settings=voice_settings,
                language=language
            )
            
            if cached_response:
                # Cache hit - delegate to untraced method (no timing metrics)
                async for frame in self._run_cached_tts(cached_response):
                    yield frame
                return
        
        # Cache miss - delegate to traced WebSocket method
        async for frame in self._run_websocket_tts(text):
            yield frame
    
    async def _run_cached_tts(self, cached_response) -> AsyncGenerator[Frame, None]:
        """
        Return cached TTS response without timing metrics.
        
        Args:
            cached_response: The cached response to return
            
        Yields:
            Frame: Audio frames from cache
        """
        logger.info(f"🎯 Cartesia CACHE HIT: Using cached response for text: '{self._current_text or 'unknown'}' ({len(cached_response.audio_chunks)} chunks)")
        
        # Track cache hit usage for transparency
        if self.call_id and self.usage_tracker and self._current_text:
            try:
                characters = len(self._current_text)
                await self.usage_tracker.track_cached_component_usage(
                    call_id=self.call_id,
                    component_type=ComponentType.TTS,
                    provider="cartesia",
                    quantity=characters,
                    unit=UsageUnit.CHARACTERS,
                    model=self._model,
                    metadata={
                        "voice_id": self._voice_id,
                        "cache_hit": True,
                        "cached_duration": getattr(cached_response, 'total_duration', 0.0)
                    }
                )
                logger.debug(f"✅ Tracked cached TTS usage: {characters} characters from cache")
            except Exception as e:
                logger.warning(f"Failed to track cached TTS usage: {e}")
        
        # Yield TTS frames with cached data
        yield TTSStartedFrame()
        
        # Set up word timestamps if available
        if cached_response.word_timestamps:
            self.start_word_timestamps()
            # Add word timestamps using the parent class method
            await self.add_word_timestamps(cached_response.word_timestamps)
        
        # Yield cached audio chunks
        for chunk in cached_response.audio_chunks:
            yield TTSAudioRawFrame(
                audio=chunk.audio,
                sample_rate=chunk.sample_rate,
                num_channels=chunk.num_channels
            )
        
        yield TTSStoppedFrame()
    
    @traced_tts
    async def _run_websocket_tts(self, text: str) -> AsyncGenerator[Frame, None]:
        """
        Run WebSocket TTS with caching collection (WITH timing metrics).
        
        Args:
            text: The text to synthesize
            
        Yields:
            Frame: Audio frames from WebSocket
        """
        logger.info(f"❌ Cartesia CACHE MISS: Generating new TTS for text: '{text[:50]}...'")
        
        # Set up cache collection if enabled
        if self.enable_cache and self.cache_service:
            self._current_text = text
            self._current_cache_data = {
                'audio_chunks': [],
                'word_timestamps': [],
                'voice_settings': self._get_current_voice_settings(),
                'language': self._settings.get("language")
            }
            logger.debug(f"🔄 Set up cache collection for text: '{text[:50]}...'")
        
        # Use parent implementation (which will trigger our overridden _process_messages)
        async for frame in super().run_tts(text):
            yield frame
    
    async def _process_messages(self):
        """
        Override parent's message processing to collect cache data.
        
        This method intercepts WebSocket messages to collect audio chunks
        and word timestamps for caching while maintaining normal flow.
        """
        async for message in self._get_websocket():
            msg = json.loads(message)
            if not msg or not self.audio_context_available(msg["context_id"]):
                continue
                
            if msg["type"] == "done":
                await self.stop_ttfb_metrics()
                await self.add_word_timestamps([("TTSStoppedFrame", 0), ("Reset", 0)])
                await self.remove_audio_context(msg["context_id"])
                
                # Store in cache if we have collected data
                await self._store_cache_if_ready()
                
            elif msg["type"] == "timestamps":
                # Process the timestamps based on language before adding them
                processed_timestamps = self._process_word_timestamps_for_language(
                    msg["word_timestamps"]["words"], msg["word_timestamps"]["start"]
                )
                await self.add_word_timestamps(processed_timestamps)
                
                # Collect for caching
                if self._current_cache_data is not None:
                    self._current_cache_data['word_timestamps'].extend(processed_timestamps)
                    logger.debug(f"🔄 Collected {len(processed_timestamps)} word timestamps for caching")
                
            elif msg["type"] == "chunk":
                await self.stop_ttfb_metrics()
                self.start_word_timestamps()
                
                # Decode audio data
                audio_data = base64.b64decode(msg["data"])
                
                frame = TTSAudioRawFrame(
                    audio=audio_data,
                    sample_rate=self.sample_rate,
                    num_channels=1,
                )
                await self.append_to_audio_context(msg["context_id"], frame)
                
                # Collect for caching
                if self._current_cache_data is not None:
                    self._current_cache_data['audio_chunks'].append(CachedAudioChunk(
                        audio=audio_data,
                        sample_rate=self.sample_rate,
                        num_channels=1
                    ))
                    logger.debug(f"🔄 Collected audio chunk for caching ({len(audio_data)} bytes)")
                
            elif msg["type"] == "error":
                logger.error(f"{self} error: {msg}")
                await self.push_frame(TTSStoppedFrame())
                await self.stop_all_metrics()
                await self.push_error(ErrorFrame(f"{self} error: {msg['error']}"))
                self._context_id = None
                
                # Clear cache collection on error
                self._current_cache_data = None
                self._current_text = None
                
            else:
                logger.error(f"{self} error, unknown message type: {msg}")
    
    async def _store_cache_if_ready(self):
        """Store collected data in cache if ready."""
        if (self.enable_cache and 
            self.cache_service and 
            self._current_cache_data and 
            self._current_text and
            self._current_cache_data['audio_chunks']):
            
            try:
                success = await self.cache_service.store_response(
                    text=self._current_text,
                    voice_id=self._voice_id,
                    model=self.model_name,
                    voice_settings=self._current_cache_data['voice_settings'],
                    audio_chunks=self._current_cache_data['audio_chunks'],
                    word_timestamps=self._current_cache_data['word_timestamps'],
                    language=self._current_cache_data['language']
                )
                
                if success:
                    logger.debug(f"✅ Successfully cached Cartesia response for text: '{self._current_text[:50]}...'")
                else:
                    logger.warning(f"⚠️ Failed to cache Cartesia response for text: '{self._current_text[:50]}...'")
                    
            except Exception as e:
                logger.error(f"❌ Error caching Cartesia response: {e}")
            
            finally:
                # Clear collection state
                self._current_cache_data = None
                self._current_text = None```
    
    async def get_cache_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        if self.cache_service:
            return await self.cache_service.get_cache_stats()
        return {"enabled": False, "reason": "Cache service not initialized"}
    
    async def clear_cache(self, pattern: str = None):
        """Clear cache entries."""
        if self.cache_service:
            await self.cache_service.clear_cache(pattern)
        else:
            logger.warning("🚫 Cannot clear cache - cache service not initialized")

Would you be willing to help implement this feature?

  • [x] Yes, I'd like to contribute
  • [ ] No, I'm just suggesting

matheusbento avatar Sep 10 '25 15:09 matheusbento

Hi @matheusbento are you actively working on this issue?

mahimairaja avatar Sep 18 '25 22:09 mahimairaja

This is nice

steinathan avatar Sep 29 '25 23:09 steinathan

Hey folks, I worked but still having issue.

matheusbento avatar Oct 01 '25 12:10 matheusbento

would be immensely helpful if this was implemented, i am having same issue and am trying to work around it

500lbbicepcurl avatar Oct 04 '25 21:10 500lbbicepcurl

It would be nice if this could be implemented even for LLM to cache tokens.

enterux avatar Nov 08 '25 06:11 enterux

@matheusbento - I am not sure how this would work when LLM is streaming text and we are adding those texts to the existing context of ElevenLabs. For instance, Lets say LLM streams two full sentences.

Sentence 1 - I am an LLM. Sentence 2 - I am pretty smart.

We will create an audio context and send the sentence 1 for generation. We will reuse the same audio context to send sentence 2 for generation.

How would we cache sentence 1 separately, since we don't know if ElevenLabs provides any guarantees around generation of different sentences in the same context.

May be I am missing something, would be great to take your inputs on this. Thanks!

a6kme avatar Dec 17 '25 07:12 a6kme