chainlit icon indicating copy to clipboard operation
chainlit copied to clipboard

Chat resume fails when using custom S3DataLayer

Open datason opened this issue 1 year ago • 3 comments
trafficstars

Describe the bug I have implemented a custom S3DataLayer to support chat history for multiple users. However, the "Resume Chat" feature is not working as expected. When I click "Resume Chat," nothing happens (I found that the on_chat_resume function is not invoked, according to the debugger). To illustrate the issue, I will attach a simple counter chat example. It responds every time with a count message indicating how many messages the user has sent so far.

To Reproduce

  1. Configure an AWS account with read/write access to the S3 service.
  2. Ensure your laptop is set up with AWS credentials.
  3. Create 2 users with 1 chat at least.
  4. Try to resume the chat in the next session or even within the same session.
  5. It will upload all messages, but you can't add new messages.

Expected behavior Possibility to write new messages in new chat.

Screenshots image

Desktop (please complete the following information):

  • OS: Windows
  • Browser: Brave
  • Version: 1.67.119

Additional context S3 files structure: image {bucket_key}/users/{user_id}/threads/{thread_id}/steps/{step_id}.json

Main script

import chainlit as cl
from chainlit.types import ThreadDict
from s3_datalayer import S3DataLayer
from auth_utils import *


cl.data._data_layer = S3DataLayer(
    bucket_name='YOUR_BUCKET_NAME',
    bucket_key='chainlit',
    user_thread_limit=100,
)


@cl.password_auth_callback
def auth_callback(username: str, password: str):
    if (username, password) == ("admin", "admin"):
        return cl.User(identifier="admin", metadata={"role": "admin", "provider": "credentials"})
    elif (username, password) == ("user", "user"):
        return cl.User(identifier="user", metadata={"role": "user", "provider": "credentials"})
    else:
        return None


@cl.on_message
async def on_message(message: cl.Message):
    counter = cl.user_session.get("counter", 0)
    counter += 1
    cl.user_session.set("counter", counter)

    app_user = cl.user_session.get("user")

    await cl.Message(content=f"Dear {app_user.identifier}, you sent {counter} message(s)!").send()


@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
    if "metadata" in thread:
        await cl.Message(thread["metadata"], author="metadata", language="json").send()
    if "tags" in thread:
        await cl.Message(thread["tags"], author="tags", language="json").send()

s3_datalayer.py

import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional

import boto3
from chainlit.context import context
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
from chainlit.step import StepDict
from chainlit.types import (
    PageInfo,
    PaginatedResponse,
    Pagination,
    ThreadDict,
    ThreadFilter,
)
from chainlit.user import PersistedUser, User
from parse_utils import parse_user_from_string


class S3DataLayer(BaseDataLayer):
    def __init__(
        self,
        bucket_name: str,
        bucket_key: str,
        user_thread_limit: int = 10,
        storage_provider: Optional[BaseStorageClient] = None,
    ):
        region_name = os.environ.get("AWS_REGION", "us-east-1")
        self.client = boto3.client("s3", region_name=region_name)  # type: ignore
        self.bucket_name = bucket_name
        self.bucket_key = bucket_key
        self.user_thread_limit = user_thread_limit
        self.storage_provider = storage_provider


    def _get_current_timestamp(self) -> str:
        return datetime.now().isoformat() + "Z"

    async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
        key = f"{self.bucket_key}/users/{identifier}/identity.json"
        try:
            response = self.client.get_object(Bucket=self.bucket_name, Key=key)
            user_data = json.loads(response["Body"].read())
            return PersistedUser(
                id=user_data["id"],
                identifier=user_data["identifier"],
                createdAt=user_data["createdAt"],
                metadata=user_data["metadata"],
            )
        except self.client.exceptions.NoSuchKey:
            return None

    async def create_user(self, user: "User") -> Optional["PersistedUser"]:
        if str(user.identifier) == 'None':
            raise ValueError('Creation of `None` user')
        ts = self._get_current_timestamp()
        metadata: Dict[Any, Any] = user.metadata

        item = {
            "id": user.identifier,
            "identifier": user.identifier,
            "metadata": metadata,
            "createdAt": ts,
        }

        key = f"{self.bucket_key}/users/{user.identifier}/identity.json"
        self.client.put_object(
            Bucket=self.bucket_name,
            Key=key,
            Body=json.dumps(item),
        )

        return PersistedUser(
            id=user.identifier,
            identifier=user.identifier,
            createdAt=ts,
            metadata=metadata,
        )
        @queue_until_user_message()
    async def create_step(self, step_dict: "StepDict"):
        try:
            if context.session.thread_id != step_dict['threadId']:
                raise ValueError('The context session threadId is not equal to the threadId of step.')

            key = f"{self.bucket_key}/users/{context.session.user.identifier}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
            self.client.put_object(
                Bucket=self.bucket_name,
                Key=key,
                Body=json.dumps(dict(step_dict)),
            )

        except:
            prefix = f"{self.bucket_key}/users/"
            paginator = self.client.get_paginator('list_objects_v2')

            for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
                for obj in page.get('Contents', []):
                    key = obj['Key']

                    if f"threads/{step_dict['threadId']}" in key:
                        user = parse_user_from_string(key)
                        if user is not None:
                            key = f"{self.bucket_key}/users/{user}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
                            self.client.put_object(
                                Bucket=self.bucket_name,
                                Key=key,
                                Body=json.dumps(dict(step_dict)),
                            )


    @queue_until_user_message()
    async def update_step(self, step_dict: "StepDict"):
        try:
            if context.session.thread_id != step_dict['threadId']:
                raise ValueError('The context session threadId is not equal to the threadId of step.')

            key = f"{self.bucket_key}/users/{context.session.user.identifier}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
            self.client.put_object(
                Bucket=self.bucket_name,
                Key=key,
                Body=json.dumps(dict(step_dict)),
            )
        except:
            prefix = f"{self.bucket_key}/users/"
            paginator = self.client.get_paginator('list_objects_v2')

            for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
                for obj in page.get('Contents', []):
                    key = obj['Key']

                    if f"threads/{step_dict['threadId']}/steps/{step_dict['id']}.json" in key:
                        user = parse_user_from_string(key)
                        if user is not None:
                            # here i need something like user, threadid, stepId
                            key = f"{self.bucket_key}/users/{user}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
                            self.client.put_object(
                                Bucket=self.bucket_name,
                                Key=key,
                                Body=json.dumps(dict(step_dict)),
                            )


    @queue_until_user_message()
    async def delete_step(self, step_id: str):
        thread_id = context.session.thread_id
        user_id = context.session.user.identifier

        key = f"{self.bucket_key}/users/{user_id}/threads/{thread_id}/steps/{step_id}.json"
        self.client.delete_object(
            Bucket=self.bucket_name,
            Key=key)


    async def get_thread_author(self, thread_id: str) -> str:
        prefix = f"{self.bucket_key}/users/"
        paginator = self.client.get_paginator('list_objects_v2')

        for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
            for obj in page.get('Contents', []):
                key = obj['Key']

                if f'threads/{thread_id}/thread.json' in key:
                    response = self.client.get_object(Bucket=self.bucket_name, Key=key)
                    thread_data = json.loads(response["Body"].read())
                    return thread_data["userId"]


    async def delete_thread(self, thread_id: str):
        prefix = f"{self.bucket_key}/users/"
        paginator = self.client.get_paginator('list_objects_v2')

        objects_to_delete = []
        for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
            for obj in page.get('Contents', []):
                key = obj['Key']
                if f'threads/{thread_id}' in key:
                    objects_to_delete.append({'Key': key})

        # Perform the batch delete
        delete_response = self.client.delete_objects(Bucket=self.bucket_name, Delete={'Objects': objects_to_delete})
        deleted = delete_response.get('Deleted', [])
        for obj in deleted:
            print(f"Deleted {obj['Key']}")


    async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]":
        paginated_response = PaginatedResponse(
            data=[],
            pageInfo=PageInfo(
                hasNextPage=False, startCursor=pagination.cursor, endCursor=None
            ),
        )

        prefix = f"{self.bucket_key}/users/{filters.userId}/threads/"
        paginator = self.client.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
            for obj in page.get('Contents', []):
                key = obj['Key']
                if ('steps' not in key) & ('elements' not in key) & (key.endswith('thread.json')):
                    response = self.client.get_object(Bucket=self.bucket_name, Key=key)
                    thread_data = json.loads(response["Body"].read())
                    thread = ThreadDict(
                        id=thread_data["id"],
                        createdAt=thread_data["createdAt"],
                        name=thread_data["name"],
                    )
                    paginated_response.data.append(thread)

        return paginated_response


    async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
        # get user
        try:
            user = context.session.user.identifier
        except:
            prefix = f"{self.bucket_key}/users/"
            paginator = self.client.get_paginator('list_objects_v2')

            for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
                for obj in page.get('Contents', []):
                    key = obj['Key']

                    if f'threads/{thread_id}/thread.json' in key:
                        user = parse_user_from_string(key)
        # get thread
        if user is not None:
            prefix = f"{self.bucket_key}/users/{user}/threads/{thread_id}/"
            response = self.client.get_object(Bucket=self.bucket_name, Key=prefix+"thread.json")
            thread_dict = json.loads(response["Body"].read())

            thread = ThreadDict(
                id=thread_dict["id"],
                createdAt=thread_dict["createdAt"],
                name=thread_dict.get("name"),
                userId=thread_dict.get("userId"),
                userIdentifier=thread_dict.get("userIdentifier"),
                tags=thread_dict.get("tags"),
                metadata=thread_dict.get("metadata"),
                steps=[],
                elements=[]
            )

            # collect steps
            steps = []
            paginator = self.client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix + "steps/")

            for page in pages:
                for obj in page.get('Contents', []):
                    key = obj['Key']
                    response = self.client.get_object(Bucket=self.bucket_name, Key=key)
                    item = json.loads(response["Body"].read())
                    steps.append(item)
            if len(steps) > 0:
                steps.sort(key=lambda i: i["createdAt"])
                thread["steps"] = steps

            # collect elements
            elements = []
            paginator = self.client.get_paginator('list_objects_v2')
            pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix + "elements/")

            for page in pages:
                for obj in page.get('Contents', []):
                    key = obj['Key']
                    response = self.client.get_object(Bucket=self.bucket_name, Key=key)
                    item = json.loads(response["Body"].read())
                    elements.append(item)

            thread["elements"] = elements

            return thread


    async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
        if (user_id is not None) & (name is not None):

            ts = self._get_current_timestamp()
            item = {
                "id": thread_id,
                "createdAt": ts,
                "name": name,
                "userId": user_id,
                "tags": tags,
                "metadata": metadata,
            }

            key = f"{self.bucket_key}/users/{user_id}/threads/{thread_id}/thread.json"
            self.client.put_object(
                Bucket=self.bucket_name,
                Key=key,
                Body=json.dumps(item),
            )


    async def delete_user_session(self, id: str) -> bool:
        return True


    async def build_debug_url(self) -> str:
        return ""

datason avatar Jul 02 '24 09:07 datason

It's interesting that the entire history loads well, but the on_chat_resume decorator never triggers. This is just astonishing. Perhaps the problem lies solely with it.

datason avatar Jul 02 '24 09:07 datason

My best guess, just based on skimming through the code provided, is that Chainlit does not like your omission of the get_all_user_threads method. I believe this method is called on its own by Chainlit at times. I have noticed this function specifically getting called during the resume chat as well, which could lead to the issues you are describing.

Currently, your code isn't even using the user_thread_limit parameter you are providing the data layer; normally this parameter gets used in the get_all_user_threads method to determine how many chats to load for each user from their respective chat history. Would you be able to try and implement this function for your S3 Data Layer and see if that solves the issue?

Also, I'm also not sure if this is related, but I think the delete_user_session function should just return False instead of True, based on the SQLAlchemy data layer.

AidanShipperley avatar Jul 05 '24 06:07 AidanShipperley

@datason were you able to find the cause of this? I am facing the same issue.

Dylan-Harden3 avatar Jul 16 '24 17:07 Dylan-Harden3

Hey did anyone find any workaround for this issue?

gajanandC avatar Jan 21 '25 07:01 gajanandC

@gajanandC It was a while ago but my issue had to do with the userId/identifier fields, make sure the data layer always defines both for all the types even though they are Optional.

Specifically here it expects the identifier to be defined while elsewhere it uses user_id.

Dylan-Harden3 avatar Jan 21 '25 14:01 Dylan-Harden3

This issue is stale because it has been open for 14 days with no activity.

github-actions[bot] avatar Jul 17 '25 02:07 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jul 25 '25 02:07 github-actions[bot]