anthropic-sdk-python icon indicating copy to clipboard operation
anthropic-sdk-python copied to clipboard

Caching does not work or not return cache info when streaming is enabled

Open gulbaki opened this issue 5 months ago • 3 comments

After investigating the issue in more detail, I found that caching does not work when streaming: true is set. However, when I disable streaming (streaming: false), caching works correctly — and I receive the expected cache-related response headers. With streaming disabled (streaming: false):

  "cache_creation_input_tokens": 0,
  "cache_read_input_tokens": 1454,

With streaming enabled (streaming: true):

"cache_creation_input_tokens": null,
 "cache_read_input_tokens": null,

As you can see, the cache usage fields are null when streaming is turned on, which suggests caching is not working in that mode, or at least the information is not returned.

gulbaki avatar Jul 05 '25 12:07 gulbaki

Hi! Thanks for raising the issue -- do you have a way to reproduce?

kevinc13 avatar Jul 23 '25 20:07 kevinc13

Bumping this issue. I am facing the exact same thing. Tell me if you need some code. Here is a dump of my code.

import anthropic
import requests
from bs4 import BeautifulSoup
import time
from typing import Dict, Union
from typing import List
from datetime import datetime
from provider_models import AnthropicModel, CLAUDE_HAIKU_35, CLAUDE_SONNET_4
import uuid
from provider_models import (
    price_anthropic_call_with_input_caching,
    price_anthropic_call_no_caching,
)
import logging
from anthropic import Stream
logger = logging.getLogger(__name__)
logging.basicConfig(filename="myapp.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("httpx").setLevel(logging.WARNING)
online_article = """
# Prompt caching

Prompt caching is a powerful feature that optimizes your API usage by allowing resuming from specific prefixes in your prompts. This approach significantly reduces processing time and costs for repetitive tasks or prompts with consistent elements.

Here's an example of how to implement prompt caching with the Messages API using a `cache_control` block:

........ (article longer than 2048 tokens) 
"""


class ConversationHistory:
    """
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "<book>" + book_content + "</book>",
                    "cache_control": {"type": "ephemeral"}
                },
                {
                    "type": "text",
                    "text": "What is the title of this book? Only output the title."
                }
            ]
        }
    ]"""

    def __init__(self):
        self.turns = []

    def add_turn_assistant(
        self,
        content_raw: str,
        msg_id: str,
        timestamp: datetime,
        add_cache_control: bool = False,
    ):
        self.turns.append(
            {
                "role": "assistant",
                "content": [{"type": "text", "text": content_raw}],
                "id": msg_id,  # Add an ID for tracking
                "timestamp": timestamp,
            }  # Add a timestamp for tracking
        )
        if add_cache_control:
            self.turns[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}

    def add_turn_user(
        self,
        content_raw,
        msg_id: str,
        timestamp: datetime,
        add_cache_control: bool = False,
    ):
        self.turns.append(
            {
                "role": "user",
                "content": [{"type": "text", "text": content_raw}],
                "id": msg_id,  # Add an ID for tracking
                "timestamp": timestamp,
            }  # Add a timestamp for tracking
        )
        if add_cache_control: # not very important when saving the data
            self.turns[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}

    def get_turns(self, add_cache_control: bool = True):
        message_list = []
        for idx, turn in enumerate(self.turns):
            if turn.get("content"):
                # Create a clean copy without any cache control
                content_clean = []
                for content_item in turn.get("content"):
                    clean_item = {"type": content_item["type"], "text": content_item["text"]}
                    content_clean.append(clean_item)
                
                message_turn = {
                    "content": content_clean,
                    "role": turn.get("role"),
                }
                message_list.append(message_turn)
            else:
                raise ValueError(f"weird turn received {turn}")
                
        # Only add cache control to the LAST user message (most recent)
        if add_cache_control and message_list:
            for i in range(len(message_list) - 1, -1, -1):
                if message_list[i]["role"] == "user":
                    message_list[i]["content"][0]["cache_control"] = {"type": "ephemeral"}
                    break
        return message_list

    def get_turns_before(self, timestamp: datetime, include: bool = True):
        result = []
        for turn in self.turns:
            if turn["timestamp"] < timestamp:
                result.append(
                    {
                        "role": turn["role"],
                        "content": turn["content"],
                    }
                )
            elif include and turn["timestamp"] == timestamp:
                result.append(
                    {
                        "role": turn["role"],
                        "content": turn["content"],
                    }
                )
        return result


class ProfilerChunk:
    def __init__(self, text: str, timestamp: datetime, msg_id: str):
        self.text: str = text
        self.timestamp: datetime = timestamp
        self.msg_id: str = msg_id
        self._token_count = None
        self._baseline_tokens = None

    def _get_baseline_tokens(self, model: AnthropicModel) -> int:
        """Get baseline token count for empty message"""
        if self._baseline_tokens is not None:
            return self._baseline_tokens

        client = anthropic.Anthropic()
        response = client.messages.count_tokens(
            model=model.name, messages=[{"role": "user", "content": "yo"}]
        )
        self._baseline_tokens = response.input_tokens
        return self._baseline_tokens

    def count_tokens(self, model: AnthropicModel) -> int:
        """Count tokens using Anthropic's count_tokens method with baseline subtraction"""
        if self._token_count is not None:
            return self._token_count

        client = anthropic.Anthropic()
        response = client.messages.count_tokens(
            model=model.name, messages=[{"role": "user", "content": self.text}]
        )

        baseline = self._get_baseline_tokens(model)
        self._token_count = max(0, response.input_tokens - baseline + 1)
        return self._token_count

    @property
    def num_tokens(self):
        # Fallback to simple word count if no model provided
        return len(self.text.split())


class ProfilerMessage:
    def __init__(self, msg_id: str):
        self.msg_id = msg_id
        self.chunks: List[ProfilerChunk] = []
        self.input_tokens: int = None
        self.cache_read_input_tokens: int = None
        self.cache_creation_input_tokens: int = None
        self.output_tokens: int = None
        self.message : str = None

    def add_chunk(self, text: str, timestamp: datetime):
        chunk = ProfilerChunk(text, timestamp, self.msg_id)
        self.chunks.append(chunk)

    def update_message(self, message_str:str, usage_data):
        """Update message-level metrics from message_delta usage data"""
        self.message = message_str
        if hasattr(usage_data, "input_tokens"):
            self.input_tokens = usage_data.input_tokens
        else:
            raise ValueError(f'no input tokens in usagedata {usage_data}')
        if (
            hasattr(usage_data, "cache_read_input_tokens")
        ):
            self.cache_read_input_tokens = usage_data.cache_read_input_tokens
        else:
            raise ValueError(
                f'no cache read_input_tokens in {usage_data}')
        if (
            hasattr(usage_data, "cache_creation_input_tokens")
        ):
            self.cache_creation_input_tokens = usage_data.cache_creation_input_tokens
        if (
            hasattr(usage_data, "output_tokens")
        ):
            self.output_tokens = usage_data.output_tokens


class ProfilerStreaming:
    def __init__(self, system_message: str, history: ConversationHistory):
        self.start = None
        self.messages: Dict[str, ProfilerMessage] = {}
        self.system_message: str = system_message
        self.history = history

    def start_profiling(self):
        self.start = datetime.now()

    def add_chunk(self, text: str, timestamp: datetime, msg_id: str):
        if msg_id not in self.messages:
            self.messages[msg_id] = ProfilerMessage(msg_id)
        if self.start is None:
            self.start_profiling()
        self.messages[msg_id].add_chunk(text, timestamp)

    def update_message(self, msg_id: str, message_str: str, usage_data:dict):
        """Update message-level metrics from message_delta event"""
        if msg_id not in self.messages:
            self.messages[msg_id] = ProfilerMessage(msg_id)
        self.messages[msg_id].update_message(message_str, usage_data)

    def cumulative_tokens_received_per_second(self, msg_id: str, model):
        tokens_received = list([0])
        time_elapsed_list = list([self.start])
        if msg_id in self.messages:
            for chunk in self.messages[msg_id].chunks:
                tokens_received += [tokens_received[-1] + chunk.count_tokens(model)]
                time_elapsed_list += [chunk.timestamp]
        return tokens_received, time_elapsed_list

    def cumulative_cost_per_second(
        self,
        msg_id: str,
        cache_used: bool,
        model: AnthropicModel,
        history: ConversationHistory,
    ):
        # graph with cumulative cost as a function of time elapsed
        prices_dpoints = list([0.0])
        time_elapsed_list = list([self.start])
        tokens_received = 0
        if msg_id in self.messages:
            for chunk in self.messages[msg_id].chunks:
                tokens_received = tokens_received + chunk.num_tokens
                # lets consider that the tokens inputted are all cache read tokens in the case of streaming using cache.
                if cache_used:
                    prices_dpoints += [
                        price_anthropic_call_with_input_caching(
                            model=model,
                            output_tokens=tokens_received,
                            system_message_str=self.system_message,
                            messages=history.get_turns_before(
                                chunk.timestamp, include=True
                            ),
                        )
                    ]
                else:
                    prices_dpoints += [
                        price_anthropic_call_no_caching(
                            model=model,
                            output_tokens=tokens_received,
                            system_message_str=self.system_message,
                            messages=history.get_turns_before(
                                chunk.timestamp, include=True
                            ),
                        )
                    ]
                time_elapsed_list += [chunk.timestamp]
        return prices_dpoints, time_elapsed_list

    def display_stats(
        self, message_id: str, model: AnthropicModel
    ):
        """Display comprehensive stats for a specific message ID"""
        if message_id not in self.messages:
            print(f"āŒ No data found for message ID: {message_id}")
            return

        message = self.messages[message_id]
        chunks = message.chunks
        if not chunks:
            print(f"āŒ No chunks found for message ID: {message_id}")
            return

        print(f"\nšŸ“Š Streaming Stats for Message ID: {message_id}")
        print("=" * 60)

        # Basic stats
        total_chunks = len(chunks)
        total_text = "".join([chunk.text for chunk in chunks])
        total_characters = len(total_text)

        # Time stats
        first_chunk_time = chunks[0].timestamp
        last_chunk_time = chunks[-1].timestamp
        total_duration = (last_chunk_time - first_chunk_time).total_seconds()

        # Token stats
        total_tokens = sum([chunk.count_tokens(model) for chunk in chunks])

        print(f"šŸ“ Content Stats:")
        print(f"   • Total chunks: {total_chunks}")
        print(f"   • Total characters: {total_characters}")
        print(f"   • Total tokens: {total_tokens}")
        print(f"   • Characters per chunk (avg): {total_characters/total_chunks:.1f}")
        print(f"   • Tokens per chunk (avg): {total_tokens/total_chunks:.1f}")

        # Message-level metrics from API
        print(f"\nšŸ”¢ API Token Metrics:")
        print(f"   • Input tokens: {message.input_tokens}")
        print(f"   • Cache read tokens: {message.cache_read_input_tokens}")
        print(f"   • Cache write tokens: {message.cache_creation_input_tokens}")
        print(f"   • Output tokens: {message.output_tokens}")

        print(f"\nā±ļø  Timing Stats:")
        print(f"   • First chunk: {first_chunk_time.strftime('%H:%M:%S.%f')[:-3]}")
        print(f"   • Last chunk: {last_chunk_time.strftime('%H:%M:%S.%f')[:-3]}")
        print(f"   • Total duration: {total_duration:.3f} seconds")
        if total_duration > 0:
            print(f"   • Chunks per second: {total_chunks/total_duration:.1f}")
            print(f"   • Tokens per second: {total_tokens/total_duration:.1f}")
            print(f"   • Characters per second: {total_characters/total_duration:.1f}")

        # # Cost estimation
        # if add_cache_control:
        #     cache_used = True
            # total_cost = price_anthropic_call_with_input_caching(
            #     model=model,
            #     output_tokens=message.output_tokens or total_tokens,
            #     system_message_str=self.system_message,
            #     messages=self.history.get_turns_before(
            #         last_chunk_time, include=True
            #     ),
            # )
        # else:
        #     total_cost = price_anthropic_call_no_caching(
        #         model=model,
        #         output_tokens=message.output_tokens or total_tokens,
        #         system_message_str=self.system_message,
        #         messages=self.history.get_turns_before(
        #             last_chunk_time, include=True
        #         ),
        #     )

            # print(f"\nšŸ’° Cost Stats:")
            # print(f"   • Cache used: {'Yes' if cache_used else 'No'}")
            # print(f"   • Estimated total cost: ${total_cost:.6f}")
            # if total_tokens > 0:
            #     print(f"   • Cost per token: ${total_cost/total_tokens:.8f}")
            #     print(f"   • Cost per character: ${total_cost/total_characters:.8f}")

        print(f"\nšŸ“ˆ Chunk Details (first 5 and last 5):")
        display_chunks = chunks[:5] + (chunks[-5:] if len(chunks) > 10 else chunks[5:])
        for i, chunk in enumerate(display_chunks):
            elapsed = (chunk.timestamp - first_chunk_time).total_seconds()
            chunk_tokens = chunk.count_tokens(model)
            print(
                f"   [{i+1:2d}] {elapsed:6.3f}s | {chunk_tokens:2d} tokens | '{chunk.text[:50]}'{'...' if len(chunk.text) > 50 else ''}"
            )
            if i == 4 and len(chunks) > 10:
                print(f"   ... ({len(chunks)-10} chunks omitted)")

        print("=" * 60)


def run_message_with_streaming(
    system_prompt: str,
    user_query: str,
    model: AnthropicModel = CLAUDE_HAIKU_35,
    add_cache_control: bool = True,
):
    """
    - system :str. eg. the whole paul graham essay
    - user_query :str. eg. the question to ask about the essay. We will ask for a deterministic answer, if possible, so something like "Repeat back to me please!"
    """
    client = anthropic.Anthropic()
    conversation_history = ConversationHistory()

    system_message = f"<file_contents> {system_prompt} </file_contents>"
    print("šŸš€ Terminal Streaming Animation with Anthropic API")
    print("=" * 50)
    profiler_stream = ProfilerStreaming(
        system_message=system_message, history=conversation_history
    )
    for i, question in enumerate([user_query, user_query]):
        if i == 0:
            add_cache_control = True
        print(f"\nšŸ“ Turn {i}:")
        print(f"User: {question}")
        print(f"Assistant: ", end="", flush=True)

        timestamp_turn = datetime.now()
        conversation_history.add_turn_user(
            question,
            uuid.uuid4().hex,
            timestamp_turn,
        )

        start_time = time.time()
        full_response = ""

        system_message_dict: Dict[str, Union[str, Dict[str, str]]] = {
            "type": "text",
            "text": system_message,
            "cache_control": {"type": "ephemeral"} if add_cache_control else {}
        }
        # messages_turn = conversation_history.get_turns(add_cache_control)
        messages_turn = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": question,
                        "cache_control": {"type": "ephemeral"} if add_cache_control else {}
                    }
                ]
            }
        ]

        message_id = None
        message_usage = None
        with client.messages.create(
            model=model.name,
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
            max_tokens=300,
            system=[system_message_dict],
            messages=messages_turn,
            stream=True,

        ) as stream:
            for event in stream:
                if event.type == "message_start":
                    message_id = event.message.id
                elif event.type == "content_block_delta":
                    if event.delta.type == "text_delta":
                        text_chunk = event.delta.text
                        print(text_chunk, end="", flush=True)
                        full_response += text_chunk

                        profiler_stream.add_chunk(
                            text=text_chunk,
                            timestamp=datetime.now(),
                            msg_id=message_id,
                        )
                elif event.type == "message_delta":
                    # Capture message-level metrics like cache read tokens
                    if hasattr(event, "usage"):
                        message_usage = event.usage
                    logger.info(event)
                    logger.info(event.usage)
                else:
                    logger.info(event)
                time.sleep(0.01)  # Small delay for visual effect

        profiler_stream.update_message(message_id,full_response, message_usage)
        conversation_history.add_turn_assistant(
            content_raw=full_response,
            msg_id=message_id,
            timestamp=datetime.now(),
        )
        end_time = time.time()
        print(f"\nā±ļø Turn {i} completed in {end_time - start_time:.2f} seconds")

        # Display stats for this message
        profiler_stream.display_stats(message_id, model)


def run_message_without_streaming(
    system_prompt: str,
    user_query: str,
    model: AnthropicModel = CLAUDE_HAIKU_35,
    add_cache_control: bool = True,
):
    """
    Non-streaming version that uses the create function to make API calls to Anthropic
    - system_prompt: str. The system message content
    - user_query: str. The question to ask
    """
    client = anthropic.Anthropic()
    conversation_history = ConversationHistory()
    
    system_message = f"<file_contents> {system_prompt} </file_contents>"
    print("šŸš€ Non-Streaming API Call with Anthropic")
    print("=" * 50)
    
    for i, question in enumerate([user_query, user_query]):
        print(f"\nTurn {i+1}:")
        logger.info(f"Turn {i+1}:")
        print(f"User: {question}")
        
        # Add user input to conversation history
        timestamp_turn = datetime.now()
        conversation_history.add_turn_user(
            question,
            uuid.uuid4().hex,
            timestamp_turn,
        )

        start_time = time.time()

        if add_cache_control:
            system_message_list = [{"type": "text", "text": system_message, "cache_control": {"type": "ephemeral"}}]
        else:
            system_message_list = [{"type": "text", "text": system_message}]

        turns = conversation_history.get_turns(add_cache_control)
        for turn in turns:
            logger.info(f"Turn: {turn['role']} - {turn['content']}")
        response = client.messages.create(
            model=model.name,
            extra_headers={
                "anthropic-beta": "prompt-caching-2024-07-31"
            },
            max_tokens=300,
            system=system_message_list,
            messages=turns,
        )

        # Record the end time
        end_time = time.time()

        # Extract the assistant's reply
        assistant_reply = response.content[0].text
        print(f"Assistant: {assistant_reply}")

        # Print token usage information
        input_tokens = response.usage.input_tokens
        output_tokens = response.usage.output_tokens
        input_tokens_cache_read = getattr(response.usage, 'cache_read_input_tokens', '---')
        input_tokens_cache_create = getattr(response.usage, 'cache_creation_input_tokens', '---')
        print(f"User input tokens: {input_tokens}")
        print(f"Output tokens: {output_tokens}")
        print(f"Input tokens (cache read): {input_tokens_cache_read}")
        print(f"Input tokens (cache write): {input_tokens_cache_create}")

        # Calculate and print the elapsed time
        elapsed_time = end_time - start_time

        # Calculate the percentage of input prompt cached
        total_input_tokens = input_tokens + (int(input_tokens_cache_read) if input_tokens_cache_read != '---' else 0)
        percentage_cached = (int(input_tokens_cache_read) / total_input_tokens * 100 if input_tokens_cache_read != '---' and total_input_tokens > 0 else 0)

        print(f"{percentage_cached:.1f}% of input prompt cached ({total_input_tokens} tokens)")
        print(f"Time taken: {elapsed_time:.2f} seconds")

        # Add assistant's reply to conversation history
        conversation_history.add_turn_assistant(
            content_raw=assistant_reply,
            msg_id=response.id,
            timestamp=datetime.now(),
            add_cache_control=add_cache_control,
        )


def display_stats(profiler_streaming: ProfilerStreaming):
    print(f"")


def main():
    # run_message_without_streaming(
    #     system_prompt=online_article,
    #     user_query="Repeat back to me the paragraph \"when to use the 1-hour cache\", ignoring the rest! You must not change any word. And you must directly provide the words as output, without any introductory statement to acknowledge this task",
    #     model=CLAUDE_SONNET_4,
    #     add_cache_control=True,
    # )
    run_message_with_streaming(
        system_prompt=online_article,
        user_query="Repeat back to me the paragraph \"when to use the 1-hour cache\", ignoring the rest! You must not change any word. And you must directly provide the words as output, without any introductory statement to acknowledge this task",
        model=CLAUDE_SONNET_4,  
        add_cache_control=True,
    )


if __name__ == "__main__":
    main()

tanguyRenaudieDatadog avatar Jul 31 '25 17:07 tanguyRenaudieDatadog

The prompt caching only works with the run_message_without_streaming code.

tanguyRenaudieDatadog avatar Jul 31 '25 17:07 tanguyRenaudieDatadog