dspy
dspy copied to clipboard
[Bug] Track Usage not working w/ FastAPI+Bedrock+dspy.Streamify
What happened?
When using dspy.settings.configure(track_usage=True) using a AWS Bedrock model in a FastAPI application and using streaming of the answer, I get the following error:
Exception in thread Thread-1 (producer):
+ Exception Group Traceback (most recent call last):
| File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
| self.run()
| File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/threading.py", line 1012, in run
| self._target(*self._args, **self._kwargs)
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/streaming/streamify.py", line 243, in producer
| context.run(asyncio.run, runner())
| File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 194, in run
| return runner.run(main)
| ^^^^^^^^^^^^^^^^
| File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/asyncio/runners.py", line 118, in run
| return self._loop.run_until_complete(task)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/streaming/streamify.py", line 237, in runner
| async for item in async_generator:
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/streaming/streamify.py", line 177, in async_streamer
| async with create_task_group() as tg, send_stream, receive_stream:
| ^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 772, in __aexit__
| raise BaseExceptionGroup(
| ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
+-+---------------- 1 ----------------
| Traceback (most recent call last):
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/streaming/streamify.py", line 171, in generator
| prediction = await program(*args, **kwargs)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/asyncify.py", line 63, in async_program
| return await call_async(*args, **kwargs)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/asyncer/_main.py", line 382, in wrapper
| return await run_sync(
| ^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/asyncer/_compat.py", line 24, in run_sync
| return await anyio.to_thread.run_sync(
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/anyio/to_thread.py", line 56, in run_sync
| return await get_async_backend().run_sync_in_worker_thread(
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 2470, in run_sync_in_worker_thread
| return await future
| ^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 967, in run
| result = context.run(func, *args)
| ^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/asyncify.py", line 57, in wrapped_program
| return program(*a, **kw)
| ^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/callback.py", line 343, in sync_wrapper
| raise exception
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/callback.py", line 339, in sync_wrapper
| results = fn(instance, *args, **kwargs)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/primitives/program.py", line 29, in __call__
| output.set_lm_usage(usage_tracker.get_total_tokens())
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/usage_tracker.py", line 59, in get_total_tokens
| total_usage = self._merge_usage_entries(total_usage, usage_entry)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| File "/Users/cmeyer/code/dspy_poc/.venv/lib/python3.12/site-packages/dspy/utils/usage_tracker.py", line 45, in _merge_usage_entries
| result[k] += v if v else 0
| TypeError: unsupported operand type(s) for +=: 'dict' and 'int'
+------------------------------------
Steps to reproduce
Here is my reproductible error:
from dataclasses import asdict
import dspy
import ujson
from dspy.streaming import StatusMessageProvider, StreamListener, StreamResponse, StatusMessage
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, HTTPException
import dspy
class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return f"`Pythia is using tool {instance.name}...`"
def tool_end_status_message(self, outputs):
return ""
def lm_start_status_message(self, instance, inputs):
return "`Pythia is thinking...`"
def evaluate_math(expression: str):
return dspy.PythonInterpreter({}).execute(expression)
def search_wikipedia(query: str):
results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=3)
return [x['text'] for x in results]
def streaming_response(streamer):
"""
Custom streaming response function that handles StatusMessage objects.
This implements the same functionality as dspy_streaming_response but adds handling for StatusMessage objects.
"""
from dspy.primitives.prediction import Prediction
import litellm
for value in streamer:
if isinstance(value, StatusMessage):
data = {"status": value.message}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, Prediction):
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, litellm.ModelResponseStream):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, StreamResponse):
data = {"chunk": {k: v for k, v in asdict(value).items()}}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
yield value
else:
data = {"unknown": str(value)}
yield f"data: {ujson.dumps(data)}\n\n"
yield "data: [DONE]\n\n"
app = FastAPI(
title="DSPy Program API",
description="A simple API serving a DSPy Chain of Thought program",
version="1.0.0"
)
# Configure DSPy with LiteLLM
dspy.configure(
lm=dspy.LM(
provider="litellm",
model="bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
aws_region_name="us-east-1",
cache=False,
)
)
dspy.settings.configure(track_usage=True)
react = dspy.ReAct("question -> answer: str", tools=[evaluate_math, search_wikipedia])
dspy_streamer = dspy.streamify(
react,
status_message_provider=MyStatusMessageProvider(),
stream_listeners=[
StreamListener(signature_field_name="reasoning"),
StreamListener(signature_field_name="answer"),
],
async_streaming=False,
)
@app.post("/predict")
def predict():
stream = dspy_streamer(question="What is 9362158 divided by the year of birth of David Gregory of Kinnairdy castle?")
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
I might do some more testing with other provider or without streaming
Result of more testing cases: ❌ FastAPI + Bedrock + dspy.streamify ✅ FastAPI + OpenAI + dspy.streamify ✅ FastAPI + Bedrock (without streaming) ✅ Bedrock + dspy.streamify (simple script without API)
DSPy version
2.6.24
MORE INFO:
This issue is NOT happening when we don't use dspy.streamify.
This is working:
@app.post("/predict")
def predict():
# stream = dspy_streamer(question="What is 9362158 divided by the year of birth of David Gregory of Kinnairdy castle?")
# return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
result = react(question="What is 9362158 divided by the year of birth of David Gregory of Kinnairdy castle?")
return {
"status": "success",
"data": result.toDict(),
"usage": result.get_lm_usage()
}
MORE INFO:
The issue is still happening when converting FastAPI endpoint to an async endpoint with streaming. So it's not related to FastAPI sync/async
dspy_streamer = dspy.asyncify(dspy_streamer)
@app.post("/predict")
async def predict():
stream = await dspy_streamer(question="What is 9362158 divided by the year of birth of David Gregory of Kinnairdy castle?")
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
Results in the same error.
MORE INFO:
Thie issue is NOT happening when using OpenAI provider by replace the LM by:
dspy.configure(
lm=dspy.LM('openai/gpt-4o-mini', api_key="xxxxx",
cache=False,
)
)
To sum-up: ❌ FastAPI + Bedrock + dspy.streamify ✅ FastAPI + OpenAI + dspy.streamify ✅ FastAPI + Bedrock (without streaming) ✅ Bedrock + dspy.streamify (simple script without API)
AI Generated Monkey-Patch that solved my issue:
from typing import Any
def fixed_merge_usage_entries(self, usage_entry1, usage_entry2) -> dict[str, dict[str, Any]]:
if usage_entry1 is None or len(usage_entry1) == 0:
return dict(usage_entry2)
if usage_entry2 is None or len(usage_entry2) == 0:
return dict(usage_entry1)
result = dict(usage_entry2)
for k, v in usage_entry1.items():
if k in result:
if isinstance(v, dict):
if isinstance(result[k], dict):
result[k] = self._merge_usage_entries(result[k], v)
else:
# If v is a dict but result[k] is not, replace result[k] with v
result[k] = dict(v)
else:
if isinstance(result[k], dict):
# If v is not a dict but result[k] is, keep result[k] as is
continue
else:
# Both are scalar values, add them
result[k] = result[k] or 0
v_value = v if v is not None else 0
result[k] += v_value
else:
# Key doesn't exist in result, just copy it over
result[k] = v
return result
from dspy.utils.usage_tracker import UsageTracker
UsageTracker._merge_usage_entries = types.MethodType(fixed_merge_usage_entries, UsageTracker)
@lambda-science Thanks for the detailed issue report! Looks like Bedrock can have None as the value, while most other providers sets a default value.
@lambda-science Is this fully working for you now? I have a DSPy application served by FastAPI and I get full usage when using sync non-streaming modules, but if I use dspy.streamify() on my app within an async context I only get completion_tokens but no prompt_tokens, and therefor total_tokens is equal to completion_tokens.
I've walked through the LiteLLM code in the debugger and I can see that all usage fields are being set, but at some point in the chunk iterator it seems that they are getting stripped. The LiteLLM code is way too dense and indirect for me to follow exactly where that is happening though. The interesting thing is that while walking through LiteLLM code all of the usage fields looked right, total_tokens was the sum of prompt and completion tokens, which tells me that it's getting re-computed at some point.
This is happening with both Nova and Claude models in Bedrock.
@lambda-science Is this fully working for you now? I have a DSPy application served by FastAPI and I get full usage when using sync non-streaming modules, but if I use dspy.streamify() on my app within an async context I only get completion_tokens but no prompt_tokens, and therefor total_tokens is equal to completion_tokens.
I've walked through the LiteLLM code in the debugger and I can see that all usage fields are being set, but at some point in the chunk iterator it seems that they are getting stripped. The LiteLLM code is way too dense and indirect for me to follow exactly where that is happening though. The interesting thing is that while walking through LiteLLM code all of the usage fields looked right, total_tokens was the sum of prompt and completion tokens, which tells me that it's getting re-computed at some point.
This is happening with both Nova and Claude models in Bedrock.
I didn't remove my monkey_patch yet with the new version, I'll try later on
I tried adding your monkey patch and I still don't see the missing tokens. I'm interested to see if this works for you. If it does then there may have been a regression when promoting to v3b.
I've notices that the mentioned fields are missing by the time we call any of the UsageTracker methods like _merge_usage_entries
@lambda-science Is this fully working for you now? I have a DSPy application served by FastAPI and I get full usage when using sync non-streaming modules, but if I use dspy.streamify() on my app within an async context I only get completion_tokens but no prompt_tokens, and therefor total_tokens is equal to completion_tokens.
I've walked through the LiteLLM code in the debugger and I can see that all usage fields are being set, but at some point in the chunk iterator it seems that they are getting stripped. The LiteLLM code is way too dense and indirect for me to follow exactly where that is happening though. The interesting thing is that while walking through LiteLLM code all of the usage fields looked right, total_tokens was the sum of prompt and completion tokens, which tells me that it's getting re-computed at some point.
This is happening with both Nova and Claude models in Bedrock.
Lol you're right i just checked and I get 0 prompt token now but it was working with my monkey patch and v2.7
@chenmoneygithub , I am using version 3.0.3 of dspy and have the same issue when using a model from Azure AI Studio with streamify.