spring-ai icon indicating copy to clipboard operation
spring-ai copied to clipboard

[OpenAiChatModel] Inability to Pass `tool_calls` or `tool_call_id` and Discrepancy in Message Roles

Open RikJux opened this issue 1 year ago • 1 comments

Description

There are two main issues with the current implementation of the createRequest method (see additional context below) in the OpenAiChatModel class:

  1. Inability to Pass tool_calls or tool_call_id: The ChatCompletionMessage constructor called in the createRequest method does not allow passing tool_calls or tool_call_id. These are always initialized as null.

    public ChatCompletionMessage(Object content, Role role) {
        this(content, role, null, null, null);
    }
    
  2. Discrepancy in Message Roles: There is a discrepancy between the MessageType enum in AbstractMessage and the Role enum in OpenAI messages.

    MessageType in AbstractMessage:

    public enum MessageType {
        USER("user"),
        ASSISTANT("assistant"),
        SYSTEM("system"),
        FUNCTION("function");
    }
    

    Role in OpenAI messages:

    public enum Role {
        @JsonProperty("system")
        SYSTEM,
        @JsonProperty("user")
        USER,
        @JsonProperty("assistant")
        ASSISTANT,
        @JsonProperty("tool")
        TOOL
    }
    

This results in in an java.lang.IllegalArgumentException

Steps to Reproduce

Passing a FunctionMessage in a prompt results in an error:

@Test
public void testPromptWithTool() throws JsonProcessingException {
    Prompt prompt = new Prompt(new FunctionMessage(""));
    openAiChatModel.call(prompt);
}

Error:

java.lang.IllegalArgumentException: No enum constant org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role.FUNCTION

    at java.base/java.lang.Enum.valueOf(Enum.java:293)
    at org.springframework.ai.openai.api.OpenAiApi$ChatCompletionMessage$Role.valueOf(OpenAiApi.java:488)
    at org.springframework.ai.openai.OpenAiChatModel.lambda$createRequest$9(OpenAiChatModel.java:267)
    at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
    at java.base/java.util.Collections$2.tryAdvance(Collections.java:5073)
    at java.base/java.util.Collections$2.forEachRemaining(Collections.java:5081)
    at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
    at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
    at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:575)
    at java.base/java.util.stream.AbstractPipeline.evaluateToArrayNode(AbstractPipeline.java:260)
    at java.base/java.util.stream.ReferencePipeline.toArray(ReferencePipeline.java:616)
    at java.base/java.util.stream.ReferencePipeline.toArray(ReferencePipeline.java:622)
    at java.base/java.util.stream.ReferencePipeline.toList(ReferencePipeline.java:627)
    at org.springframework.ai.openai.OpenAiChatModel.createRequest(OpenAiChatModel.java:268)
    at org.springframework.ai.openai.OpenAiChatModel.call(OpenAiChatModel.java:140)
    at it.ai.foundation.ServiceChatModelTest.testPromptWithTool(ServiceChatModelTest.java:143)
    at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)

Expected Behavior

The current implementation of the createRequest method does not cover all possible prompts that a user could define. The expected behavior should be as follows:

  1. Support for Tool Calls:

    • The createRequest method should allow passing tool_calls and tool_call_id to the ChatCompletionMessage constructor. This will enable the proper handling of tool calls within the chat prompt.
  2. Consistent Message Roles:

    • The MessageType enum in AbstractMessage should be aligned with the Role enum in OpenAI messages. Alternatively, a mapping should be provided to ensure that all message types are correctly translated and no errors occur due to missing enum constants.

Additional Context

Here is the current implementation of the createRequest method:

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
    Set<String> functionsForThisRequest = new HashSet<>();
    List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> {
        // Add text content.
        List<MediaContent> contents = new ArrayList<>(List.of(new MediaContent(m.getContent())));
        if (!CollectionUtils.isEmpty(m.getMedia())) {
            // Add media content.
            contents.addAll(m.getMedia()
                .stream()
                .map(media -> new MediaContent(
                    new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
                .toList());
        }
        return new ChatCompletionMessage(contents, ChatCompletionMessage.Role.valueOf(m.getMessageType().name()));
    }).toList();
    // rest of the code...
}

RikJux avatar Jun 17 '24 10:06 RikJux

Thanks for reporting this issue. The function calling APIs have been recently refactored and improved. The changes should have handled both points you reported.

  1. The ChatCompletionMessage in the OpenAiApi class accepts both toolCallId and toolCalls. See: https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java#L519
  2. The MessageType enum uses tool as the role name instead of function. See: https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java#L51

Can you confirm this issue is fixed?

ThomasVitale avatar Jul 31 '24 21:07 ThomasVitale

@RikJux please let us know if this is resolved. Will close in 7 days if no response. Thank you.

csterwa avatar Sep 04 '24 16:09 csterwa

Closing for now.

csterwa avatar Sep 11 '24 21:09 csterwa