async-openai
async-openai copied to clipboard
The type of messages in deserialized CreateChatCompletionRequest are all SystemMessage
I want to deserialize request json to CreateChatCompletionRequest but i found the messages are all System.
code
use async_openai::types::{
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
.messages([
ChatCompletionRequestSystemMessageArgs::default()
.content("your are a calculator")
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content("what is the result of 1+1")
.build()?
.into(),
])
.build()?;
// serialize the request
let serialized = serde_json::to_string(&request)?;
println!("{}", serialized);
// deserialize the request
let deserialized: CreateChatCompletionRequest = serde_json::from_str(&serialized)?;
println!("{:?}", deserialized);
Ok(())
}
result
{"messages":[{"content":"your are a calculator","role":"system"},{"content":"what is the result of 1+1","role":"user"}],"model":""}
CreateChatCompletionRequest { messages: [System(ChatCompletionRequestSystemMessage { content: "your are a calculator", role: System, name: None }), System(ChatCompletionRequestSystemMessage { content: "what is the result of 1+1", role: User, name: None })], model: "", frequency_penalty: None, logit_bias: None, logprobs: None, top_logprobs: None, max_tokens: None, n: None, presence_penalty: None, response_format: None, seed: None, stop: None, stream: None, temperature: None, top_p: None, tools: None, tool_choice: None, user: None, function_call: None, functions: None }
I also have this issue. Using actix_web
Was banging my head on this for a bit, but just pushed a fix on my branch.
thanks coco.codes from the NAMTAO discord!
to solve the parent issue, of them always being System, we implement the macro #[serde(tag = "role", rename_all = "lowercase")] in ChatCompletionRequestMessage
This maps the role key to the appropriate enum under ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the child ChatCompletionRequestUserMessage spits out an error during deserialization because it no longer can see the role key.
I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the role is not actually used anywhere in the lib nor my codebase)
I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like
Was banging my head on this for a bit, but just pushed a fix on my branch.
thanks coco.codes from the NAMTAO discord!
to solve the parent issue, of them always being System, we implement the macro
#[serde(tag = "role", rename_all = "lowercase")]inChatCompletionRequestMessageThis maps the role key to the appropriate enum under
ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the childChatCompletionRequestUserMessagespits out an error during deserialization because it no longer can see therolekey.I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the
roleis not actually used anywhere in the lib nor my codebase)I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like
Thank you, I will have a try.
i wrote custom wrapper for ser and deser
use async_openai::types::{ChatCompletionRequestMessage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::ser::SerializeStruct;
use serde_json::Value;
#[derive(Debug)]
pub struct Message(ChatCompletionRequestMessage);
impl Message {
pub fn from_original(enum_val: ChatCompletionRequestMessage) -> Self {
Message(enum_val)
}
pub fn into_original(self) -> ChatCompletionRequestMessage {
self.0
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("Message", 2)?;
match &self.0 {
ChatCompletionRequestMessage::System(msg) => {
state.serialize_field("type", "system")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::User(msg) => {
state.serialize_field("type", "user")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Assistant(msg) => {
state.serialize_field("type", "assistant")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Tool(msg) => {
state.serialize_field("type", "tool")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Function(msg) => {
state.serialize_field("type", "function")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
}
state.end()
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value: Value = Deserialize::deserialize(deserializer)?;
let msg_type = value.get("type").and_then(Value::as_str).ok_or_else(|| {
serde::de::Error::custom("Missing or invalid `type` field")
})?;
match msg_type {
"system" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestSystemMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::System(msg)))
}
"user" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestUserMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::User(msg)))
}
"assistant" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestAssistantMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Assistant(msg)))
}
"tool" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestToolMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Tool(msg)))
}
"function" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestFunctionMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Function(msg)))
}
_ => Err(serde::de::Error::unknown_variant(msg_type, &["system", "user", "assistant", "tool", "function"])),
}
}
}
Instead of complex ser-de implementations, types have be udpated for proper serialization and deserialization in v0.23.0
Thank you @sontallive for contributing the test too - its included as part of tests in https://github.com/64bit/async-openai/blob/main/async-openai/tests/ser_de.rs