text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Improve LLM's tool awareness

Open drbh opened this issue 1 year ago • 2 comments

Feature request

Currently the tools feature does not return the name of the chosen function. This is due to how tools/functions are implemented in TGI (by constrained generation vs fine tuned model).

It's been proposed that the internal structure is updated to force the generation to include names https://github.com/huggingface/text-generation-inference/pull/1650 and overall the function mechanism may be improved.

Opening this issue as a place for others thoughts/idea/uses on how to improve tools to be as useful as possible.

drbh avatar Mar 20 '24 16:03 drbh

I have solved the name issue with a const with tool name in the original json grammar, i'm not familiar with PR 's in github but will try it in the next days, otherwise i Can leave you the changes there are small and a little few. With that changes the output is 100% aligned with the open ai specs.

regards

puppetm4st3r avatar Mar 21 '24 04:03 puppetm4st3r

in the server.rs the extraction of the name and the tool parser looks like this:

let (tool_calls, output) = if tool_grammar.is_some() {
            // gen_text should be valid json
            let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| {
                (
                    StatusCode::UNPROCESSABLE_ENTITY,
                    Json(ErrorResponse {
                        error: e.to_string(),
                        error_type: "Input validation error".to_string(),
                    }),
                )
            })?;
        
            let tool_call_id = generate_random_id();
            let tool_call = ToolCall {
                id: tool_call_id,
                r#type: "function".to_string(),
                function: FunctionDefinitionResponse {
                    // Extract the function name from "function" -> "name", not as a constant
                    name: gen_text_value
                        .get("function") // Access the JSON value of "function"
                        .and_then(|f| f.get("name")) // Directly access "name" inside "function"
                        .and_then(|name| name.as_str()) // Ensure "name" is a string
                        .unwrap_or("default_function_name") // Provide a default name if none is found
                        .to_string(),
                    // Serialize the JSON object obtained from "function" to an escaped JSON string
                    arguments: gen_text_value
                        .get("function") // Access the JSON value of "function"
                        .map_or_else(
                            || Ok("{}".to_string()), // Use an empty JSON object if "function" does not exist
                            |f| {
                                // Remove the "name" key from properties before serialization
                                let mut f_cloned = f.clone();
                                if let Value::Object(ref mut props) = f_cloned {
                                    props.remove("name"); // Remove the "name" key
                                }
                                serde_json::to_string(&f_cloned) // Attempt to serialize the modified object to String
                                    .map_err(|e| { // Handle serialization error, if any
                                        (
                                            StatusCode::UNPROCESSABLE_ENTITY,
                                            Json(ErrorResponse {
                                                error: e.to_string(),
                                                error_type: "Input validation error".to_string(),
                                            }),
                                        )
                                    })
                            },
                        )?,
                },
            };                  
            (Some(vec![tool_call]), None)
        } else {
            (None, Some(generation.generated_text))
        };

and the input grammar generator looks like:

// First, generate `tools_str` and `tool_grammar` without depending on `tool_prompt`.
    let (tools_str, tool_grammar) = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
        // Determine the tools to use based on `tool_choice`.
        let tools_to_use = match tool_choice {
            ToolType::FunctionName(name) => {
                vec![req_tools
                    .iter()
                    .find(|tool| tool.function.name == *name)
                    .ok_or_else(|| {
                        (
                            StatusCode::UNPROCESSABLE_ENTITY,
                            Json(ErrorResponse {
                                error: "Tool choice not found in tool names".to_string(),
                                error_type: "Tool not found".to_string(),
                            }),
                        )
                    })?
                    .clone()]
            },
            ToolType::OneOf => req_tools.to_owned(),
        };

        // Map each tool to its function and parameters.
        let mut functions: HashMap<String, Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();
                
                // Clone the existing parameters, which are expected to be a JSON object
                let mut params = if let Value::Object(params) = &func.parameters {
                    params.clone()
                } else {
                    Map::new()
                };

                // Insert the function's description at the top level, outside of properties
                params.insert("description".to_string(), Value::String(func.description.clone().unwrap_or_default()));
                
                // Ensure 'properties' exists and is an object
                let properties = params.entry("properties".to_string()).or_insert_with(|| json!({})).as_object_mut().unwrap();

                // Insert the constant for the function name inside 'properties'
                properties.insert("name".to_string(), json!({
                    "type": "string",
                    "const": func.name.clone(),
                    "description": "The name of the function"
                }));

                // Check if 'required' exists, and it is an array. If not, create an empty array.
                let required = params.entry("required".to_string()).or_insert_with(|| json!([])).as_array_mut().unwrap();

                // Add 'name' to the 'required' array if it is not already present
                if !required.iter().any(|r| r == "name") {
                    required.push(json!("name"));
                }

                // Return the function name and its parameters (including the constant) as a JSON object
                (func.name.clone(), Value::Object(params))
            })
            .collect();

        // adds the error notification function for LLM feedback if required
        let mut text_response_properties = Map::new();
        text_response_properties.insert("error".to_string(), serde_json::json!({
            "type": "string",
            "description": "The error or issue to notify"
        }));
        text_response_properties.insert("name".to_string(), serde_json::json!({
            "type": "string",
            "description": "The name of the function",
            "const": "notify_error"
        }));
        let text_response_object = serde_json::json!({
            "description": "Useful to notify when a tool can not be called.",
            "properties": text_response_properties,
            "required": ["error", "name"],
            "type": "object"
        });
        functions.insert("notify_error".to_string(), text_response_object);

        // Collect function references from `tools_to_use`
        let mut function_refs: Vec<FunctionRef> = tools_to_use
        .iter()
        .map(|tool| FunctionRef {
            ref_path: format!("#/$functions/{}", tool.function.name.clone()),
        })
        .collect();

        // Manually add the reference to `text_response` !
        function_refs.push(FunctionRef {
            ref_path: "#/$functions/notify_error".to_string(),
        });

        // Now `function_refs` includes all selected functions plus `text_response`
        let tools = Tools {
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: function_refs, 
            },
            required: vec!["function".to_string()],
        };

        // Serialize the `tools` object to a string.
        let tools_str = serde_json::to_string(&tools).map_err(|e| {
            (
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: e.to_string(),
                    error_type: "Input validation error".to_string(),
                }),
            )
        })?;
        (tools_str, Some(GrammarType::Json(serde_json::json!(tools))))
    } else {
        (String::new(), None)
    };
    
    // Proceed only if tool_prompt is not None
    if let Some(tool_prompt) = &req.tool_prompt {
        // Find the last message with role 'user'
        if let Some(last_user_message) = req.messages.iter_mut().rev().find(|msg| msg.role == "user") {
            // Generate the additional content combining tool_prompt and tools_str
            let additional_content = format!("\n\n---------------------------\n{}{}\n---------------------------", tool_prompt, tools_str);
            
            // If the last user message has existing content, append the additional content to it
            if let Some(content) = &mut last_user_message.content {
                content.push_str(&additional_content); // Append to the existing content
            } else {
                // If, for some reason, the content is None, replace it with the additional_content
                last_user_message.content = Some(additional_content);
            }
        }
        // Note: If there is no message with role 'user', this block does nothing
    }

    // Apply the chat template to flatten the request into a single input.
    let inputs = match infer.apply_chat_template(req.messages) {
        Ok(inputs) => inputs,
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
                }),
            ));
        }
    };

Added a notify function by default for fallback cases, with that the llm can notify errores, lack of conext information to select a tool, or other kind of errors to avoid selectings wrong tools.

Also updated the tool prompt to: You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n

Also added a random generator for the call_id number:

fn generate_random_id() -> String {
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("Time went backwards")
        .as_millis();
    
    let seed = now as u64;
    let mut rng = StdRng::seed_from_u64(seed);
    let id: String = (0..6)
        .map(|_| rng.sample(rand::distributions::Alphanumeric))
        .map(char::from)
        .collect();
    
    format!("call_{}", id)
}

its my first time in rust, hope to do it acceptable

puppetm4st3r avatar Mar 21 '24 04:03 puppetm4st3r