agenta icon indicating copy to clipboard operation
agenta copied to clipboard

[AGE-147] [Bug] Nested tracing and token usage bug

Open mmabrouk opened this issue 1 year ago • 5 comments

The following code results in a wrong trace

import agenta
import agenta as ag
import litellm
import asyncio
from supported_models import get_all_supported_llm_models
litellm.drop_params = True

ag.init()
tracing = ag.llm_tracing()

prompts = {
    "summarization_prompt": """Summarize the main points from this text: {text}?""",
    "synthesis_prompt": "Create one paragraph synthesis from the following summaries from different texts: {concatenated_texts}",
}

# ChatGPT 3.5 models
GPT_FORMAT_RESPONSE = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]


ag.config.default(
    temperature_1=ag.FloatParam(default=1, minval=0.0, maxval=2.0),
    model_1=ag.GroupedMultipleChoiceParam(
        default="gpt-3.5-turbo", choices=get_all_supported_llm_models()),
    max_tokens_1=ag.IntParam(-1, -1, 4000),
    summarization_prompt=ag.TextParam(prompts["summarization_prompt"]),
    top_p_1=ag.FloatParam(1),
    frequence_penalty_1=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    presence_penalty_1=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    force_json_1=ag.BinaryParam(False),
    temperature_2=ag.FloatParam(default=1, minval=0.0, maxval=2.0),
    model_2=ag.GroupedMultipleChoiceParam(
        default="gpt-3.5-turbo", choices=get_all_supported_llm_models()),
    max_tokens_2=ag.IntParam(-1, -1, 4000),
    synthesis_prompt=ag.TextParam(prompts["synthesis_prompt"]),
    top_p_2=ag.FloatParam(1),
    frequence_penalty_2=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    presence_penalty_2=ag.FloatParam(default=0.0, minval=-2.0, maxval=2.0),
    force_json_2=ag.BinaryParam(False)
)


@ag.span(type="llm")
async def llm_call(completion_params: dict):
    response = await litellm.acompletion(**completion_params)
    tracing.set_span_attribute(
        "model_config", {"model": completion_params["model"],
                         "temperature": completion_params["temperature"]}
    )
    token_usage = response.usage.dict()
    return {
        "message": response.choices[0].message.content,
        "usage": token_usage,
        "cost": ag.calculate_token_usage(completion_params["model"], token_usage),
    }


@ag.entrypoint
async def generate(
    inputs: ag.DictInput = ag.DictInput(default_keys=["text1, text2"]),
):
    tasks = []
    for text in inputs.values():
        if not text:
            raise ValueError("Input text is empty")
        prompt = ag.config.summarization_prompt.format(text=text)
        max_tokens_1 = ag.config.max_tokens_1 if ag.config.max_tokens_1 != -1 else None
        response_format_1 = (
            {"type": "json_object"}
            if ag.config.force_json_1 and ag.config.model in GPT_FORMAT_RESPONSE
            else {"type": "text"}
        )

        completion_params = {
            "model": ag.config.model_1,
            "messages": [
                {"content": prompt, "role": "user"},
            ],
            "temperature": ag.config.temperature_1,
            "max_tokens": max_tokens_1,
            "top_p": ag.config.top_p_1,
            "response_format": response_format_1,
        }

        # Include frequency_penalty and presence_penalty only if supported
        if ag.config.model_1 in GPT_FORMAT_RESPONSE:
            completion_params["frequency_penalty"] = ag.config.frequence_penalty_1
            completion_params["presence_penalty"] = ag.config.presence_penalty_1

        task = asyncio.create_task(llm_call(completion_params))
        tasks.append(task)

    responses = await asyncio.gather(*tasks)
    print(responses)
    total_cost = sum(response['cost'] for response in responses)
    total_usage = {}
    all_messages = []

    for response in responses:
        all_messages.append(response['message'])
        for key, value in response['usage'].items():
            if key in total_usage:
                total_usage[key] += value
            else:
                total_usage[key] = value

    concatenated_messages = " === ".join(all_messages)
    synthesis_prompt = ag.config.synthesis_prompt.format(
        concatenated_texts=concatenated_messages)
    max_tokens_2 = ag.config.max_tokens_2 if ag.config.max_tokens_2 != -1 else None
    response_format_2 = (
        {"type": "json_object"}
        if ag.config.force_json_2 and ag.config.model_2 in GPT_FORMAT_RESPONSE
        else {"type": "text"}
    )

    synthesis_completion_params = {
        "model": ag.config.model_2,
        "messages": [
            {"content": synthesis_prompt, "role": "user"},
        ],
        "temperature": ag.config.temperature_2,
        "max_tokens": max_tokens_2,
        "top_p": ag.config.top_p_2,
        "response_format": response_format_2,
    }

    # Include frequency_penalty and presence_penalty only if supported
    if ag.config.model_2 in GPT_FORMAT_RESPONSE:
        synthesis_completion_params["frequency_penalty"] = ag.config.frequence_penalty_2
        synthesis_completion_params["presence_penalty"] = ag.config.presence_penalty_2

    synthesis_response = await llm_call(synthesis_completion_params)
    final_message = synthesis_response['message']
    final_cost = total_cost + synthesis_response['cost']
    final_usage = total_usage.copy()
    for key, value in synthesis_response['usage'].items():
        if key in final_usage:
            final_usage[key] += value
        else:
            final_usage[key] = value
    # Now you have total_cost, total_usage, and concatenated_messages
    # You can return or use these values as needed
    return synthesis_response

Screenshot 2024-04-26 at 16.04.34.png

Note that the spans are for some reason nested. That the Trace has a usage of 293 although no usage is returned..

In addition, no model config is shown for the second span.

Here is the output from the backend:

[{"id":"662bb0a15bdb87e28cb758fb","name":"llm_call","parent_span_id":"662bb0a15bdb87e28cb758fa","created_at":"2024-04-26T13:48:21.435000","variant":{"variant_id":"662badf78b5b8d76b714b7de","variant_name":"app.default","revision":null},"environment":"playground","spankind":"LLM","status":"OK","metadata":{"cost":0.000773,"latency":3.456,"usage":{"prompt_tokens":374,"completion_tokens":106,"total_tokens":480}},"trace_id":"662bb0a15bdb87e28cb758f9","user_id":"—","children":[{"id":"662bb0a15bdb87e28cb758fc","name":"llm_call","parent_span_id":"662bb0a15bdb87e28cb758fb","created_at":"2024-04-26T13:48:20.443000","variant":{"variant_id":"662badf78b5b8d76b714b7de","variant_name":"app.default","revision":null},"environment":"playground","spankind":"LLM","status":"OK","metadata":{"cost":0.000751,"latency":2.463,"usage":{"prompt_tokens":394,"completion_tokens":80,"total_tokens":474}},"trace_id":"662bb0a15bdb87e28cb758f9","user_id":"—","children":[{"id":"662bb0a55bdb87e28cb758fd","name":"llm_call","parent_span_id":"662bb0a15bdb87e28cb758fc","created_at":"2024-04-26T13:48:24.024000","variant":{"variant_id":"662badf78b5b8d76b714b7de","variant_name":"app.default","revision":null},"environment":"playground","spankind":"LLM","status":"OK","metadata":{"cost":0.000483,"latency":2.588,"usage":{"prompt_tokens":206,"completion_tokens":87,"total_tokens":293}},"trace_id":"662bb0a15bdb87e28cb758f9","user_id":"—","children":null}]}]}]

From SyncLinear.com | AGE-147

mmabrouk avatar Apr 26 '24 14:04 mmabrouk

That the Trace has a usage of 293 although no usage is returned..

I can't reproduce this, @mmabrouk. For me, I get the usage (total, completion and prompt tokens):

Image

aybruhm avatar Apr 29 '24 13:04 aybruhm

@aybruhm and that is wrong, no?

  1. The two last spans are nested while nothing is nested in the code
  2. Where does the total tokens, completion tokens and prompt tokens for the trace come from?? They seem to be the same one taken in the last span

mmabrouk avatar Apr 29 '24 15:04 mmabrouk

@aybruhm: The issue is a complex race condition

mmabrouk avatar Apr 30 '24 07:04 mmabrouk

@israelvictory87 this has been fixed in the litellm PR, no?

mmabrouk avatar May 12 '24 16:05 mmabrouk

@israelvictory87 this has been fixed in the litellm PR, no?

Correct.

aybruhm avatar May 12 '24 18:05 aybruhm