chainlit
chainlit copied to clipboard
Chat resume fails when using custom S3DataLayer
Describe the bug I have implemented a custom S3DataLayer to support chat history for multiple users. However, the "Resume Chat" feature is not working as expected. When I click "Resume Chat," nothing happens (I found that the on_chat_resume function is not invoked, according to the debugger). To illustrate the issue, I will attach a simple counter chat example. It responds every time with a count message indicating how many messages the user has sent so far.
To Reproduce
- Configure an AWS account with read/write access to the S3 service.
- Ensure your laptop is set up with AWS credentials.
- Create 2 users with 1 chat at least.
- Try to resume the chat in the next session or even within the same session.
- It will upload all messages, but you can't add new messages.
Expected behavior Possibility to write new messages in new chat.
Screenshots
Desktop (please complete the following information):
- OS: Windows
- Browser: Brave
- Version: 1.67.119
Additional context
S3 files structure:
{bucket_key}/users/{user_id}/threads/{thread_id}/steps/{step_id}.json
Main script
import chainlit as cl
from chainlit.types import ThreadDict
from s3_datalayer import S3DataLayer
from auth_utils import *
cl.data._data_layer = S3DataLayer(
bucket_name='YOUR_BUCKET_NAME',
bucket_key='chainlit',
user_thread_limit=100,
)
@cl.password_auth_callback
def auth_callback(username: str, password: str):
if (username, password) == ("admin", "admin"):
return cl.User(identifier="admin", metadata={"role": "admin", "provider": "credentials"})
elif (username, password) == ("user", "user"):
return cl.User(identifier="user", metadata={"role": "user", "provider": "credentials"})
else:
return None
@cl.on_message
async def on_message(message: cl.Message):
counter = cl.user_session.get("counter", 0)
counter += 1
cl.user_session.set("counter", counter)
app_user = cl.user_session.get("user")
await cl.Message(content=f"Dear {app_user.identifier}, you sent {counter} message(s)!").send()
@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
if "metadata" in thread:
await cl.Message(thread["metadata"], author="metadata", language="json").send()
if "tags" in thread:
await cl.Message(thread["tags"], author="tags", language="json").send()
s3_datalayer.py
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import boto3
from chainlit.context import context
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
from chainlit.step import StepDict
from chainlit.types import (
PageInfo,
PaginatedResponse,
Pagination,
ThreadDict,
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from parse_utils import parse_user_from_string
class S3DataLayer(BaseDataLayer):
def __init__(
self,
bucket_name: str,
bucket_key: str,
user_thread_limit: int = 10,
storage_provider: Optional[BaseStorageClient] = None,
):
region_name = os.environ.get("AWS_REGION", "us-east-1")
self.client = boto3.client("s3", region_name=region_name) # type: ignore
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.user_thread_limit = user_thread_limit
self.storage_provider = storage_provider
def _get_current_timestamp(self) -> str:
return datetime.now().isoformat() + "Z"
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
key = f"{self.bucket_key}/users/{identifier}/identity.json"
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=key)
user_data = json.loads(response["Body"].read())
return PersistedUser(
id=user_data["id"],
identifier=user_data["identifier"],
createdAt=user_data["createdAt"],
metadata=user_data["metadata"],
)
except self.client.exceptions.NoSuchKey:
return None
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
if str(user.identifier) == 'None':
raise ValueError('Creation of `None` user')
ts = self._get_current_timestamp()
metadata: Dict[Any, Any] = user.metadata
item = {
"id": user.identifier,
"identifier": user.identifier,
"metadata": metadata,
"createdAt": ts,
}
key = f"{self.bucket_key}/users/{user.identifier}/identity.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(item),
)
return PersistedUser(
id=user.identifier,
identifier=user.identifier,
createdAt=ts,
metadata=metadata,
)
@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
try:
if context.session.thread_id != step_dict['threadId']:
raise ValueError('The context session threadId is not equal to the threadId of step.')
key = f"{self.bucket_key}/users/{context.session.user.identifier}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(dict(step_dict)),
)
except:
prefix = f"{self.bucket_key}/users/"
paginator = self.client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if f"threads/{step_dict['threadId']}" in key:
user = parse_user_from_string(key)
if user is not None:
key = f"{self.bucket_key}/users/{user}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(dict(step_dict)),
)
@queue_until_user_message()
async def update_step(self, step_dict: "StepDict"):
try:
if context.session.thread_id != step_dict['threadId']:
raise ValueError('The context session threadId is not equal to the threadId of step.')
key = f"{self.bucket_key}/users/{context.session.user.identifier}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(dict(step_dict)),
)
except:
prefix = f"{self.bucket_key}/users/"
paginator = self.client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if f"threads/{step_dict['threadId']}/steps/{step_dict['id']}.json" in key:
user = parse_user_from_string(key)
if user is not None:
# here i need something like user, threadid, stepId
key = f"{self.bucket_key}/users/{user}/threads/{step_dict['threadId']}/steps/{step_dict['id']}.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(dict(step_dict)),
)
@queue_until_user_message()
async def delete_step(self, step_id: str):
thread_id = context.session.thread_id
user_id = context.session.user.identifier
key = f"{self.bucket_key}/users/{user_id}/threads/{thread_id}/steps/{step_id}.json"
self.client.delete_object(
Bucket=self.bucket_name,
Key=key)
async def get_thread_author(self, thread_id: str) -> str:
prefix = f"{self.bucket_key}/users/"
paginator = self.client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if f'threads/{thread_id}/thread.json' in key:
response = self.client.get_object(Bucket=self.bucket_name, Key=key)
thread_data = json.loads(response["Body"].read())
return thread_data["userId"]
async def delete_thread(self, thread_id: str):
prefix = f"{self.bucket_key}/users/"
paginator = self.client.get_paginator('list_objects_v2')
objects_to_delete = []
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if f'threads/{thread_id}' in key:
objects_to_delete.append({'Key': key})
# Perform the batch delete
delete_response = self.client.delete_objects(Bucket=self.bucket_name, Delete={'Objects': objects_to_delete})
deleted = delete_response.get('Deleted', [])
for obj in deleted:
print(f"Deleted {obj['Key']}")
async def list_threads(self, pagination: "Pagination", filters: "ThreadFilter") -> "PaginatedResponse[ThreadDict]":
paginated_response = PaginatedResponse(
data=[],
pageInfo=PageInfo(
hasNextPage=False, startCursor=pagination.cursor, endCursor=None
),
)
prefix = f"{self.bucket_key}/users/{filters.userId}/threads/"
paginator = self.client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if ('steps' not in key) & ('elements' not in key) & (key.endswith('thread.json')):
response = self.client.get_object(Bucket=self.bucket_name, Key=key)
thread_data = json.loads(response["Body"].read())
thread = ThreadDict(
id=thread_data["id"],
createdAt=thread_data["createdAt"],
name=thread_data["name"],
)
paginated_response.data.append(thread)
return paginated_response
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
# get user
try:
user = context.session.user.identifier
except:
prefix = f"{self.bucket_key}/users/"
paginator = self.client.get_paginator('list_objects_v2')
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for obj in page.get('Contents', []):
key = obj['Key']
if f'threads/{thread_id}/thread.json' in key:
user = parse_user_from_string(key)
# get thread
if user is not None:
prefix = f"{self.bucket_key}/users/{user}/threads/{thread_id}/"
response = self.client.get_object(Bucket=self.bucket_name, Key=prefix+"thread.json")
thread_dict = json.loads(response["Body"].read())
thread = ThreadDict(
id=thread_dict["id"],
createdAt=thread_dict["createdAt"],
name=thread_dict.get("name"),
userId=thread_dict.get("userId"),
userIdentifier=thread_dict.get("userIdentifier"),
tags=thread_dict.get("tags"),
metadata=thread_dict.get("metadata"),
steps=[],
elements=[]
)
# collect steps
steps = []
paginator = self.client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix + "steps/")
for page in pages:
for obj in page.get('Contents', []):
key = obj['Key']
response = self.client.get_object(Bucket=self.bucket_name, Key=key)
item = json.loads(response["Body"].read())
steps.append(item)
if len(steps) > 0:
steps.sort(key=lambda i: i["createdAt"])
thread["steps"] = steps
# collect elements
elements = []
paginator = self.client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix + "elements/")
for page in pages:
for obj in page.get('Contents', []):
key = obj['Key']
response = self.client.get_object(Bucket=self.bucket_name, Key=key)
item = json.loads(response["Body"].read())
elements.append(item)
thread["elements"] = elements
return thread
async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None):
if (user_id is not None) & (name is not None):
ts = self._get_current_timestamp()
item = {
"id": thread_id,
"createdAt": ts,
"name": name,
"userId": user_id,
"tags": tags,
"metadata": metadata,
}
key = f"{self.bucket_key}/users/{user_id}/threads/{thread_id}/thread.json"
self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=json.dumps(item),
)
async def delete_user_session(self, id: str) -> bool:
return True
async def build_debug_url(self) -> str:
return ""
It's interesting that the entire history loads well, but the on_chat_resume decorator never triggers. This is just astonishing. Perhaps the problem lies solely with it.
My best guess, just based on skimming through the code provided, is that Chainlit does not like your omission of the get_all_user_threads method. I believe this method is called on its own by Chainlit at times. I have noticed this function specifically getting called during the resume chat as well, which could lead to the issues you are describing.
Currently, your code isn't even using the user_thread_limit parameter you are providing the data layer; normally this parameter gets used in the get_all_user_threads method to determine how many chats to load for each user from their respective chat history. Would you be able to try and implement this function for your S3 Data Layer and see if that solves the issue?
Also, I'm also not sure if this is related, but I think the delete_user_session function should just return False instead of True, based on the SQLAlchemy data layer.
@datason were you able to find the cause of this? I am facing the same issue.
Hey did anyone find any workaround for this issue?
@gajanandC It was a while ago but my issue had to do with the userId/identifier fields, make sure the data layer always defines both for all the types even though they are Optional.
Specifically here it expects the identifier to be defined while elsewhere it uses user_id.
This issue is stale because it has been open for 14 days with no activity.
This issue was closed because it has been inactive for 7 days since being marked as stale.