langserve icon indicating copy to clipboard operation
langserve copied to clipboard

Running background tasks

Open Sara2823 opened this issue 1 year ago • 2 comments

How to add a FastAPI BackgroundTask to a LangServe endpoint?

We need to save chat history to a database after each request, and not block the user input until the messages are saved.

This is how we add the endpoint to the app:

add_routes(
    app,
    agent,
    path="/agent",
    per_req_config_modifier=per_request_config_modifier,
    config_keys=["configurable", "metadata"]
)

Thanks

Sara2823 avatar Jun 27 '24 09:06 Sara2823

Hi @Sara2823 you can use the APIHandler directly if you need fine grained control over the endpoint definition (e.g., to expose background tasks).

eyurtsev avatar Jun 27 '24 13:06 eyurtsev

Thanks @eyurtsev This code works but I had to copy the whole streaming function to the route to put the insertion function(insert_message) in the middle of the process. Is there a better solution I'm missing?

api_handler = APIHandler(config_agent,path ="/agent", per_req_config_modifier=per_request_config_modifier)
@app.post("/agent/stream_log")
async def stream_endpoint(request: Request, background_tasks: BackgroundTasks):
    # Get user input and configuration
    try:
        config, input_ = await api_handler._get_config_and_input(
            request,
            "",
            endpoint="stream_log",
            server_config=None,
        )
    except BaseException:
        raise
    try:
        body = await request.json()
        with _with_validation_error_translation():
            stream_log_request = StreamLogParameters(**body)
    except json.JSONDecodeError:
        raise RequestValidationError(errors=["Invalid JSON body"])
    except RequestValidationError:
        raise

            # Streaming loop
    async def _stream():
        try:
            async for chunk in api_handler._runnable.astream_log(
                input_,
                config=config,
                diff=True,
                include_names=stream_log_request.include_names,
                include_types=stream_log_request.include_types,
                include_tags=stream_log_request.include_tags,
                exclude_names=stream_log_request.exclude_names,
                exclude_types=stream_log_request.exclude_types,
                exclude_tags=stream_log_request.exclude_tags,
            ):
                if not isinstance(chunk, RunLogPatch):
                    raise AssertionError(
                        f"Expected a RunLog instance got {type(chunk)}"
                    )
                if (
                    api_handler._names_in_stream_allow_list is None
                    or api_handler._runnable.config.get("run_name")
                    in api_handler._names_in_stream_allow_list
                ):
                    data = {
                        "ops": chunk.ops,
                    }
                    yield {

                        "data": api_handler._serializer.dumps(data).decode("utf-8"),
                        "event": "data",
                    }
            final_answer = chunk.ops[0]['value']['output'].return_values['output']
            background_tasks.add_task(insert_message, config['configurable']['user_id'],"conversational",  input_['input'], final_answer)
            
            
        except BaseException:
            yield {
                "event": "error",
                "data": json.dumps(
                    {"status_code": 500, "message": "Internal Server Error"}
                ),
            }
            raise
    return EventSourceResponse(_stream())

Sara2823 avatar Jul 07 '24 12:07 Sara2823