Caching does not work or not return cache info when streaming is enabled
After investigating the issue in more detail, I found that caching does not work when streaming: true is set. However, when I disable streaming (streaming: false), caching works correctly ā and I receive the expected cache-related response headers. With streaming disabled (streaming: false):
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 1454,
With streaming enabled (streaming: true):
"cache_creation_input_tokens": null,
"cache_read_input_tokens": null,
As you can see, the cache usage fields are null when streaming is turned on, which suggests caching is not working in that mode, or at least the information is not returned.
Hi! Thanks for raising the issue -- do you have a way to reproduce?
Bumping this issue. I am facing the exact same thing. Tell me if you need some code. Here is a dump of my code.
import anthropic
import requests
from bs4 import BeautifulSoup
import time
from typing import Dict, Union
from typing import List
from datetime import datetime
from provider_models import AnthropicModel, CLAUDE_HAIKU_35, CLAUDE_SONNET_4
import uuid
from provider_models import (
price_anthropic_call_with_input_caching,
price_anthropic_call_no_caching,
)
import logging
from anthropic import Stream
logger = logging.getLogger(__name__)
logging.basicConfig(filename="myapp.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("httpx").setLevel(logging.WARNING)
online_article = """
# Prompt caching
Prompt caching is a powerful feature that optimizes your API usage by allowing resuming from specific prefixes in your prompts. This approach significantly reduces processing time and costs for repetitive tasks or prompts with consistent elements.
Here's an example of how to implement prompt caching with the Messages API using a `cache_control` block:
........ (article longer than 2048 tokens)
"""
class ConversationHistory:
"""
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "<book>" + book_content + "</book>",
"cache_control": {"type": "ephemeral"}
},
{
"type": "text",
"text": "What is the title of this book? Only output the title."
}
]
}
]"""
def __init__(self):
self.turns = []
def add_turn_assistant(
self,
content_raw: str,
msg_id: str,
timestamp: datetime,
add_cache_control: bool = False,
):
self.turns.append(
{
"role": "assistant",
"content": [{"type": "text", "text": content_raw}],
"id": msg_id, # Add an ID for tracking
"timestamp": timestamp,
} # Add a timestamp for tracking
)
if add_cache_control:
self.turns[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
def add_turn_user(
self,
content_raw,
msg_id: str,
timestamp: datetime,
add_cache_control: bool = False,
):
self.turns.append(
{
"role": "user",
"content": [{"type": "text", "text": content_raw}],
"id": msg_id, # Add an ID for tracking
"timestamp": timestamp,
} # Add a timestamp for tracking
)
if add_cache_control: # not very important when saving the data
self.turns[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
def get_turns(self, add_cache_control: bool = True):
message_list = []
for idx, turn in enumerate(self.turns):
if turn.get("content"):
# Create a clean copy without any cache control
content_clean = []
for content_item in turn.get("content"):
clean_item = {"type": content_item["type"], "text": content_item["text"]}
content_clean.append(clean_item)
message_turn = {
"content": content_clean,
"role": turn.get("role"),
}
message_list.append(message_turn)
else:
raise ValueError(f"weird turn received {turn}")
# Only add cache control to the LAST user message (most recent)
if add_cache_control and message_list:
for i in range(len(message_list) - 1, -1, -1):
if message_list[i]["role"] == "user":
message_list[i]["content"][0]["cache_control"] = {"type": "ephemeral"}
break
return message_list
def get_turns_before(self, timestamp: datetime, include: bool = True):
result = []
for turn in self.turns:
if turn["timestamp"] < timestamp:
result.append(
{
"role": turn["role"],
"content": turn["content"],
}
)
elif include and turn["timestamp"] == timestamp:
result.append(
{
"role": turn["role"],
"content": turn["content"],
}
)
return result
class ProfilerChunk:
def __init__(self, text: str, timestamp: datetime, msg_id: str):
self.text: str = text
self.timestamp: datetime = timestamp
self.msg_id: str = msg_id
self._token_count = None
self._baseline_tokens = None
def _get_baseline_tokens(self, model: AnthropicModel) -> int:
"""Get baseline token count for empty message"""
if self._baseline_tokens is not None:
return self._baseline_tokens
client = anthropic.Anthropic()
response = client.messages.count_tokens(
model=model.name, messages=[{"role": "user", "content": "yo"}]
)
self._baseline_tokens = response.input_tokens
return self._baseline_tokens
def count_tokens(self, model: AnthropicModel) -> int:
"""Count tokens using Anthropic's count_tokens method with baseline subtraction"""
if self._token_count is not None:
return self._token_count
client = anthropic.Anthropic()
response = client.messages.count_tokens(
model=model.name, messages=[{"role": "user", "content": self.text}]
)
baseline = self._get_baseline_tokens(model)
self._token_count = max(0, response.input_tokens - baseline + 1)
return self._token_count
@property
def num_tokens(self):
# Fallback to simple word count if no model provided
return len(self.text.split())
class ProfilerMessage:
def __init__(self, msg_id: str):
self.msg_id = msg_id
self.chunks: List[ProfilerChunk] = []
self.input_tokens: int = None
self.cache_read_input_tokens: int = None
self.cache_creation_input_tokens: int = None
self.output_tokens: int = None
self.message : str = None
def add_chunk(self, text: str, timestamp: datetime):
chunk = ProfilerChunk(text, timestamp, self.msg_id)
self.chunks.append(chunk)
def update_message(self, message_str:str, usage_data):
"""Update message-level metrics from message_delta usage data"""
self.message = message_str
if hasattr(usage_data, "input_tokens"):
self.input_tokens = usage_data.input_tokens
else:
raise ValueError(f'no input tokens in usagedata {usage_data}')
if (
hasattr(usage_data, "cache_read_input_tokens")
):
self.cache_read_input_tokens = usage_data.cache_read_input_tokens
else:
raise ValueError(
f'no cache read_input_tokens in {usage_data}')
if (
hasattr(usage_data, "cache_creation_input_tokens")
):
self.cache_creation_input_tokens = usage_data.cache_creation_input_tokens
if (
hasattr(usage_data, "output_tokens")
):
self.output_tokens = usage_data.output_tokens
class ProfilerStreaming:
def __init__(self, system_message: str, history: ConversationHistory):
self.start = None
self.messages: Dict[str, ProfilerMessage] = {}
self.system_message: str = system_message
self.history = history
def start_profiling(self):
self.start = datetime.now()
def add_chunk(self, text: str, timestamp: datetime, msg_id: str):
if msg_id not in self.messages:
self.messages[msg_id] = ProfilerMessage(msg_id)
if self.start is None:
self.start_profiling()
self.messages[msg_id].add_chunk(text, timestamp)
def update_message(self, msg_id: str, message_str: str, usage_data:dict):
"""Update message-level metrics from message_delta event"""
if msg_id not in self.messages:
self.messages[msg_id] = ProfilerMessage(msg_id)
self.messages[msg_id].update_message(message_str, usage_data)
def cumulative_tokens_received_per_second(self, msg_id: str, model):
tokens_received = list([0])
time_elapsed_list = list([self.start])
if msg_id in self.messages:
for chunk in self.messages[msg_id].chunks:
tokens_received += [tokens_received[-1] + chunk.count_tokens(model)]
time_elapsed_list += [chunk.timestamp]
return tokens_received, time_elapsed_list
def cumulative_cost_per_second(
self,
msg_id: str,
cache_used: bool,
model: AnthropicModel,
history: ConversationHistory,
):
# graph with cumulative cost as a function of time elapsed
prices_dpoints = list([0.0])
time_elapsed_list = list([self.start])
tokens_received = 0
if msg_id in self.messages:
for chunk in self.messages[msg_id].chunks:
tokens_received = tokens_received + chunk.num_tokens
# lets consider that the tokens inputted are all cache read tokens in the case of streaming using cache.
if cache_used:
prices_dpoints += [
price_anthropic_call_with_input_caching(
model=model,
output_tokens=tokens_received,
system_message_str=self.system_message,
messages=history.get_turns_before(
chunk.timestamp, include=True
),
)
]
else:
prices_dpoints += [
price_anthropic_call_no_caching(
model=model,
output_tokens=tokens_received,
system_message_str=self.system_message,
messages=history.get_turns_before(
chunk.timestamp, include=True
),
)
]
time_elapsed_list += [chunk.timestamp]
return prices_dpoints, time_elapsed_list
def display_stats(
self, message_id: str, model: AnthropicModel
):
"""Display comprehensive stats for a specific message ID"""
if message_id not in self.messages:
print(f"ā No data found for message ID: {message_id}")
return
message = self.messages[message_id]
chunks = message.chunks
if not chunks:
print(f"ā No chunks found for message ID: {message_id}")
return
print(f"\nš Streaming Stats for Message ID: {message_id}")
print("=" * 60)
# Basic stats
total_chunks = len(chunks)
total_text = "".join([chunk.text for chunk in chunks])
total_characters = len(total_text)
# Time stats
first_chunk_time = chunks[0].timestamp
last_chunk_time = chunks[-1].timestamp
total_duration = (last_chunk_time - first_chunk_time).total_seconds()
# Token stats
total_tokens = sum([chunk.count_tokens(model) for chunk in chunks])
print(f"š Content Stats:")
print(f" ⢠Total chunks: {total_chunks}")
print(f" ⢠Total characters: {total_characters}")
print(f" ⢠Total tokens: {total_tokens}")
print(f" ⢠Characters per chunk (avg): {total_characters/total_chunks:.1f}")
print(f" ⢠Tokens per chunk (avg): {total_tokens/total_chunks:.1f}")
# Message-level metrics from API
print(f"\nš¢ API Token Metrics:")
print(f" ⢠Input tokens: {message.input_tokens}")
print(f" ⢠Cache read tokens: {message.cache_read_input_tokens}")
print(f" ⢠Cache write tokens: {message.cache_creation_input_tokens}")
print(f" ⢠Output tokens: {message.output_tokens}")
print(f"\nā±ļø Timing Stats:")
print(f" ⢠First chunk: {first_chunk_time.strftime('%H:%M:%S.%f')[:-3]}")
print(f" ⢠Last chunk: {last_chunk_time.strftime('%H:%M:%S.%f')[:-3]}")
print(f" ⢠Total duration: {total_duration:.3f} seconds")
if total_duration > 0:
print(f" ⢠Chunks per second: {total_chunks/total_duration:.1f}")
print(f" ⢠Tokens per second: {total_tokens/total_duration:.1f}")
print(f" ⢠Characters per second: {total_characters/total_duration:.1f}")
# # Cost estimation
# if add_cache_control:
# cache_used = True
# total_cost = price_anthropic_call_with_input_caching(
# model=model,
# output_tokens=message.output_tokens or total_tokens,
# system_message_str=self.system_message,
# messages=self.history.get_turns_before(
# last_chunk_time, include=True
# ),
# )
# else:
# total_cost = price_anthropic_call_no_caching(
# model=model,
# output_tokens=message.output_tokens or total_tokens,
# system_message_str=self.system_message,
# messages=self.history.get_turns_before(
# last_chunk_time, include=True
# ),
# )
# print(f"\nš° Cost Stats:")
# print(f" ⢠Cache used: {'Yes' if cache_used else 'No'}")
# print(f" ⢠Estimated total cost: ${total_cost:.6f}")
# if total_tokens > 0:
# print(f" ⢠Cost per token: ${total_cost/total_tokens:.8f}")
# print(f" ⢠Cost per character: ${total_cost/total_characters:.8f}")
print(f"\nš Chunk Details (first 5 and last 5):")
display_chunks = chunks[:5] + (chunks[-5:] if len(chunks) > 10 else chunks[5:])
for i, chunk in enumerate(display_chunks):
elapsed = (chunk.timestamp - first_chunk_time).total_seconds()
chunk_tokens = chunk.count_tokens(model)
print(
f" [{i+1:2d}] {elapsed:6.3f}s | {chunk_tokens:2d} tokens | '{chunk.text[:50]}'{'...' if len(chunk.text) > 50 else ''}"
)
if i == 4 and len(chunks) > 10:
print(f" ... ({len(chunks)-10} chunks omitted)")
print("=" * 60)
def run_message_with_streaming(
system_prompt: str,
user_query: str,
model: AnthropicModel = CLAUDE_HAIKU_35,
add_cache_control: bool = True,
):
"""
- system :str. eg. the whole paul graham essay
- user_query :str. eg. the question to ask about the essay. We will ask for a deterministic answer, if possible, so something like "Repeat back to me please!"
"""
client = anthropic.Anthropic()
conversation_history = ConversationHistory()
system_message = f"<file_contents> {system_prompt} </file_contents>"
print("š Terminal Streaming Animation with Anthropic API")
print("=" * 50)
profiler_stream = ProfilerStreaming(
system_message=system_message, history=conversation_history
)
for i, question in enumerate([user_query, user_query]):
if i == 0:
add_cache_control = True
print(f"\nš Turn {i}:")
print(f"User: {question}")
print(f"Assistant: ", end="", flush=True)
timestamp_turn = datetime.now()
conversation_history.add_turn_user(
question,
uuid.uuid4().hex,
timestamp_turn,
)
start_time = time.time()
full_response = ""
system_message_dict: Dict[str, Union[str, Dict[str, str]]] = {
"type": "text",
"text": system_message,
"cache_control": {"type": "ephemeral"} if add_cache_control else {}
}
# messages_turn = conversation_history.get_turns(add_cache_control)
messages_turn = [
{
"role": "user",
"content": [
{
"type": "text",
"text": question,
"cache_control": {"type": "ephemeral"} if add_cache_control else {}
}
]
}
]
message_id = None
message_usage = None
with client.messages.create(
model=model.name,
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
max_tokens=300,
system=[system_message_dict],
messages=messages_turn,
stream=True,
) as stream:
for event in stream:
if event.type == "message_start":
message_id = event.message.id
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
text_chunk = event.delta.text
print(text_chunk, end="", flush=True)
full_response += text_chunk
profiler_stream.add_chunk(
text=text_chunk,
timestamp=datetime.now(),
msg_id=message_id,
)
elif event.type == "message_delta":
# Capture message-level metrics like cache read tokens
if hasattr(event, "usage"):
message_usage = event.usage
logger.info(event)
logger.info(event.usage)
else:
logger.info(event)
time.sleep(0.01) # Small delay for visual effect
profiler_stream.update_message(message_id,full_response, message_usage)
conversation_history.add_turn_assistant(
content_raw=full_response,
msg_id=message_id,
timestamp=datetime.now(),
)
end_time = time.time()
print(f"\nā±ļø Turn {i} completed in {end_time - start_time:.2f} seconds")
# Display stats for this message
profiler_stream.display_stats(message_id, model)
def run_message_without_streaming(
system_prompt: str,
user_query: str,
model: AnthropicModel = CLAUDE_HAIKU_35,
add_cache_control: bool = True,
):
"""
Non-streaming version that uses the create function to make API calls to Anthropic
- system_prompt: str. The system message content
- user_query: str. The question to ask
"""
client = anthropic.Anthropic()
conversation_history = ConversationHistory()
system_message = f"<file_contents> {system_prompt} </file_contents>"
print("š Non-Streaming API Call with Anthropic")
print("=" * 50)
for i, question in enumerate([user_query, user_query]):
print(f"\nTurn {i+1}:")
logger.info(f"Turn {i+1}:")
print(f"User: {question}")
# Add user input to conversation history
timestamp_turn = datetime.now()
conversation_history.add_turn_user(
question,
uuid.uuid4().hex,
timestamp_turn,
)
start_time = time.time()
if add_cache_control:
system_message_list = [{"type": "text", "text": system_message, "cache_control": {"type": "ephemeral"}}]
else:
system_message_list = [{"type": "text", "text": system_message}]
turns = conversation_history.get_turns(add_cache_control)
for turn in turns:
logger.info(f"Turn: {turn['role']} - {turn['content']}")
response = client.messages.create(
model=model.name,
extra_headers={
"anthropic-beta": "prompt-caching-2024-07-31"
},
max_tokens=300,
system=system_message_list,
messages=turns,
)
# Record the end time
end_time = time.time()
# Extract the assistant's reply
assistant_reply = response.content[0].text
print(f"Assistant: {assistant_reply}")
# Print token usage information
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
input_tokens_cache_read = getattr(response.usage, 'cache_read_input_tokens', '---')
input_tokens_cache_create = getattr(response.usage, 'cache_creation_input_tokens', '---')
print(f"User input tokens: {input_tokens}")
print(f"Output tokens: {output_tokens}")
print(f"Input tokens (cache read): {input_tokens_cache_read}")
print(f"Input tokens (cache write): {input_tokens_cache_create}")
# Calculate and print the elapsed time
elapsed_time = end_time - start_time
# Calculate the percentage of input prompt cached
total_input_tokens = input_tokens + (int(input_tokens_cache_read) if input_tokens_cache_read != '---' else 0)
percentage_cached = (int(input_tokens_cache_read) / total_input_tokens * 100 if input_tokens_cache_read != '---' and total_input_tokens > 0 else 0)
print(f"{percentage_cached:.1f}% of input prompt cached ({total_input_tokens} tokens)")
print(f"Time taken: {elapsed_time:.2f} seconds")
# Add assistant's reply to conversation history
conversation_history.add_turn_assistant(
content_raw=assistant_reply,
msg_id=response.id,
timestamp=datetime.now(),
add_cache_control=add_cache_control,
)
def display_stats(profiler_streaming: ProfilerStreaming):
print(f"")
def main():
# run_message_without_streaming(
# system_prompt=online_article,
# user_query="Repeat back to me the paragraph \"when to use the 1-hour cache\", ignoring the rest! You must not change any word. And you must directly provide the words as output, without any introductory statement to acknowledge this task",
# model=CLAUDE_SONNET_4,
# add_cache_control=True,
# )
run_message_with_streaming(
system_prompt=online_article,
user_query="Repeat back to me the paragraph \"when to use the 1-hour cache\", ignoring the rest! You must not change any word. And you must directly provide the words as output, without any introductory statement to acknowledge this task",
model=CLAUDE_SONNET_4,
add_cache_control=True,
)
if __name__ == "__main__":
main()
The prompt caching only works with the run_message_without_streaming code.