text-generation-inference
text-generation-inference copied to clipboard
Improve LLM's tool awareness
Feature request
Currently the tools feature does not return the name of the chosen function. This is due to how tools/functions are implemented in TGI (by constrained generation vs fine tuned model).
It's been proposed that the internal structure is updated to force the generation to include names https://github.com/huggingface/text-generation-inference/pull/1650 and overall the function mechanism may be improved.
Opening this issue as a place for others thoughts/idea/uses on how to improve tools to be as useful as possible.
I have solved the name issue with a const with tool name in the original json grammar, i'm not familiar with PR 's in github but will try it in the next days, otherwise i Can leave you the changes there are small and a little few. With that changes the output is 100% aligned with the open ai specs.
regards
in the server.rs the extraction of the name and the tool parser looks like this:
let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json
let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_call_id = generate_random_id();
let tool_call = ToolCall {
id: tool_call_id,
r#type: "function".to_string(),
function: FunctionDefinitionResponse {
// Extract the function name from "function" -> "name", not as a constant
name: gen_text_value
.get("function") // Access the JSON value of "function"
.and_then(|f| f.get("name")) // Directly access "name" inside "function"
.and_then(|name| name.as_str()) // Ensure "name" is a string
.unwrap_or("default_function_name") // Provide a default name if none is found
.to_string(),
// Serialize the JSON object obtained from "function" to an escaped JSON string
arguments: gen_text_value
.get("function") // Access the JSON value of "function"
.map_or_else(
|| Ok("{}".to_string()), // Use an empty JSON object if "function" does not exist
|f| {
// Remove the "name" key from properties before serialization
let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("name"); // Remove the "name" key
}
serde_json::to_string(&f_cloned) // Attempt to serialize the modified object to String
.map_err(|e| { // Handle serialization error, if any
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})
},
)?,
},
};
(Some(vec![tool_call]), None)
} else {
(None, Some(generation.generated_text))
};
and the input grammar generator looks like:
// First, generate `tools_str` and `tool_grammar` without depending on `tool_prompt`.
let (tools_str, tool_grammar) = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
// Determine the tools to use based on `tool_choice`.
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
},
ToolType::OneOf => req_tools.to_owned(),
};
// Map each tool to its function and parameters.
let mut functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.parameters {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert("description".to_string(), Value::String(func.description.clone().unwrap_or_default()));
// Ensure 'properties' exists and is an object
let properties = params.entry("properties".to_string()).or_insert_with(|| json!({})).as_object_mut().unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert("name".to_string(), json!({
"type": "string",
"const": func.name.clone(),
"description": "The name of the function"
}));
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params.entry("required".to_string()).or_insert_with(|| json!([])).as_array_mut().unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "name") {
required.push(json!("name"));
}
// Return the function name and its parameters (including the constant) as a JSON object
(func.name.clone(), Value::Object(params))
})
.collect();
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert("error".to_string(), serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}));
text_response_properties.insert("name".to_string(), serde_json::json!({
"type": "string",
"description": "The name of the function",
"const": "notify_error"
}));
let text_response_object = serde_json::json!({
"description": "Useful to notify when a tool can not be called.",
"properties": text_response_properties,
"required": ["error", "name"],
"type": "object"
});
functions.insert("notify_error".to_string(), text_response_object);
// Collect function references from `tools_to_use`
let mut function_refs: Vec<FunctionRef> = tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect();
// Manually add the reference to `text_response` !
function_refs.push(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
});
// Now `function_refs` includes all selected functions plus `text_response`
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: function_refs,
},
required: vec!["function".to_string()],
};
// Serialize the `tools` object to a string.
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
(tools_str, Some(GrammarType::Json(serde_json::json!(tools))))
} else {
(String::new(), None)
};
// Proceed only if tool_prompt is not None
if let Some(tool_prompt) = &req.tool_prompt {
// Find the last message with role 'user'
if let Some(last_user_message) = req.messages.iter_mut().rev().find(|msg| msg.role == "user") {
// Generate the additional content combining tool_prompt and tools_str
let additional_content = format!("\n\n---------------------------\n{}{}\n---------------------------", tool_prompt, tools_str);
// If the last user message has existing content, append the additional content to it
if let Some(content) = &mut last_user_message.content {
content.push_str(&additional_content); // Append to the existing content
} else {
// If, for some reason, the content is None, replace it with the additional_content
last_user_message.content = Some(additional_content);
}
}
// Note: If there is no message with role 'user', this block does nothing
}
// Apply the chat template to flatten the request into a single input.
let inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
Added a notify function by default for fallback cases, with that the llm can notify errores, lack of conext information to select a tool, or other kind of errors to avoid selectings wrong tools.
Also updated the tool prompt to:
You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n
Also added a random generator for the call_id number:
fn generate_random_id() -> String {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis();
let seed = now as u64;
let mut rng = StdRng::seed_from_u64(seed);
let id: String = (0..6)
.map(|_| rng.sample(rand::distributions::Alphanumeric))
.map(char::from)
.collect();
format!("call_{}", id)
}
its my first time in rust, hope to do it acceptable