griptape icon indicating copy to clipboard operation
griptape copied to clipboard

Google Gemini doesn't work with `QueryTool`

Open shhlife opened this issue 10 months ago • 6 comments

When using GoogleDriversConfig and off_prompt=True I'm getting an error - Unknown field for Schema: anyOf

from griptape.configs import Defaults
from griptape.configs.drivers import GoogleDriversConfig
from griptape.structures import Agent
from griptape.tools import QueryTool, WebScraperTool

Defaults.drivers_config = GoogleDriversConfig()

agent = Agent(tools=[WebScraperTool(off_prompt=True), QueryTool()])

agent.run(
    "How does off-prompt work? https://docs.griptape.ai/stable/griptape-framework/structures/task-memory/ "
)

Here's the error:

[01/14/25 11:36:25] INFO     PromptTask 6d49e7b57c0e4fb88b1b379e2e63e8ae
                             Input: How does off-prompt work? https://docs.griptape.ai/stable/griptape-framework/structures/task-memory/  
                    ERROR    PromptTask 6d49e7b57c0e4fb88b1b379e2e63e8ae
                             Unknown field for Schema: anyOf
                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 36, in to_proto
                                 return self._descriptor(**value)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^
                             TypeError: Parameter to CopyFrom() must be instance of same class: expected <class 'Schema'> got <class      
                             'dict'>.

                             During handling of the above exception, another exception occurred:

                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 36, in to_proto
                                 return self._descriptor(**value)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^
                             ValueError: Protocol message Schema has no "anyOf" field.

                             During handling of the above exception, another exception occurred:

                             Traceback (most recent call last):
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\tasks\base_task.py",  
                             line 163, in run
                                 self.output = self.try_run()
                                               ^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\tasks\prompt_task.py",
                             line 205, in try_run
                                 result = self.prompt_driver.run(self.prompt_stack)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\decorators.py",
                             line 18, in decorator
                                 Observability.observe(
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\observability\observab
                             ility.py", line 36, in observe
                                 return driver.observe(call)
                                        ^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\observability\
                             no_op_observability_driver.py", line 16, in observe
                                 return call()
                                        ^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\observable.py",
                             line 19, in __call__
                                 return self.func(*self.args, **self.kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 81, in run
                                 for attempt in self.retrying():
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 443, in __iter__
                                 do = self.iter(retry_state=retry_state)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 376, in iter
                                 result = action(retry_state)
                                          ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\tenacity\__init__.py",  
                             line 398, in <lambda>
                                 self._add_action_func(lambda rs: rs.outcome.result())
                                                                  ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\AppData\Local\Programs\Python\Python311\Lib\concurrent\futures\_base.py", line 449, in
                             result
                                 return self.__get_result()
                                        ^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\AppData\Local\Programs\Python\Python311\Lib\concurrent\futures\_base.py", line 401, in
                             __get_result
                                 raise self._exception
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 85, in run
                                 result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)        
                                                                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^        
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\base_pr
                             ompt_driver.py", line 126, in __process_run
                                 return self.try_run(prompt_stack)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\decorators.py",
                             line 18, in decorator
                                 Observability.observe(
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\observability\observab
                             ility.py", line 36, in observe
                                 return driver.observe(call)
                                        ^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\observability\
                             no_op_observability_driver.py", line 16, in observe
                                 return call()
                                        ^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\common\observable.py",
                             line 19, in __call__
                                 return self.func(*self.args, **self.kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 79, in try_run
                                 params = self._base_params(prompt_stack)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 153, in _base_params
                                 "tools": self.__to_google_tools(prompt_stack.tools),
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\griptape\drivers\prompt\google_
                             prompt_driver.py", line 193, in __to_google_tools
                                 tool_declaration = types.FunctionDeclaration(
                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\google\generativeai\types\conte
                             nt_types.py", line 558, in __init__
                                 self._proto = protos.FunctionDeclaration(
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             728, in __init__
                                 pb_value = marshal.to_proto(pb_type, value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             235, in to_proto
                                 pb_value = self.get_rule(proto_type=proto_type).to_proto(value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 45, in to_proto
                                 return self._wrapper(value)._pb
                                        ^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             728, in __init__
                                 pb_value = marshal.to_proto(pb_type, value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             233, in to_proto
                                 return {k: self.to_proto(recursive_type, v) for k, v in value.items()}
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             233, in <dictcomp>
                                 return {k: self.to_proto(recursive_type, v) for k, v in value.items()}
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\marshal.py", line
                             235, in to_proto
                                 pb_value = self.get_rule(proto_type=proto_type).to_proto(value)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                               File
                             "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\marshal\rules\message.py"
                             , line 45, in to_proto
                                 return self._wrapper(value)._pb
                                        ^^^^^^^^^^^^^^^^^^^^
                               File "C:\Users\jason\Documents\GitHub\griptape-intro-demos\.venv\Lib\site-packages\proto\message.py", line 
                             724, in __init__
                                 raise ValueError(
                             ValueError: Unknown field for Schema: anyOf

shhlife avatar Jan 13 '25 22:01 shhlife

Can be simplified to:

from griptape.drivers import GooglePromptDriver
from griptape.structures import Agent
from griptape.tools import QueryTool

agent = Agent(
    prompt_driver=GooglePromptDriver(model="gemini-1.5-pro"), tools=[QueryTool()]
)

agent.run()

It's not an issue with off_prompt, it's an issue with QueryTool.

collindutter avatar Jan 13 '25 22:01 collindutter

It looks like Google Gemini does not support anyOf in its json schemas. This is far from ideal, but in the meantime you can "fix" the tool by removing the use of schema.Or.

from __future__ import annotations

from attrs import define
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import GooglePromptDriver
from griptape.structures import Agent
from griptape.tools import QueryTool
from griptape.utils.decorators import activity
from schema import Literal, Schema


@define(kw_only=True)
class GeminiQueryTool(QueryTool):
    @activity(
        config={
            "description": "Can be used to search through textual content.",
            "schema": Schema(
                {
                    Literal(
                        "query", description="A natural language search query"
                    ): str,
                    Literal("content"): Schema(
                        {
                            "memory_name": str,
                            "artifact_namespace": str,
                        }
                    ),
                }
            ),
        },
    )
    def query(self, params: dict) -> ListArtifact | ErrorArtifact:
        return super().query(params)


agent = Agent(
    prompt_driver=GooglePromptDriver(model="gemini-1.5-flash"),
    tools=[GeminiQueryTool()],
)

agent.run()

collindutter avatar Jan 13 '25 23:01 collindutter

yeeks. :)

I can fix that if I'm using it on my own, but would rather not try and fix it in comfyUI where our customer is hitting it. Is there another fix that we can use for the framework, or is this a biggie?

shhlife avatar Jan 13 '25 23:01 shhlife

The issue boils down to this tool using schema.Or which turns into anyOf when rendered as a json schema. Others are running into it here. All the solutions I can think of would be a breaking change on the framework to the QueryTool. Can you just include this patched version of the QueryTool in comfy?

collindutter avatar Jan 14 '25 00:01 collindutter

I'm giving it a try - but I'm getting this error:

Traceback (most recent call last):
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\tools\base_tool.py", line 136, in run
    output = self.try_run(activity, subtask, action, output)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\common\decorators.py", line 18, in decorator
    Observability.observe(
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\observability\observability.py", line 36, in observe
    return driver.observe(call)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\drivers\observability\no_op_observability_driver.py", line 16, in observe
    return call()
           ^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\common\observable.py", line 19, in __call__
    return self.func(*self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\tools\base_tool.py", line 158, in try_run
    activity_result = activity(deepcopy(value))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\.venv\Lib\site-packages\griptape\utils\decorators.py", line 31, in wrapper
    return func(self, **_build_kwargs(func, params))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jason\Documents\GitHub\ComfyUI\custom_nodes\ComfyUI-Griptape\nodes\patches\gemini_query_tool.py", line 31, in query
    return super().query(params)
           ^^^^^^^^^^^^^
TypeError: super(type, obj): obj must be an instance or subtype of type

Will keep poking around unless you know a quick fix :)

shhlife avatar Jan 14 '25 01:01 shhlife

resolved it by copying the query code from QueryTool and not relying on:

    def query(self, params: dict) -> ListArtifact | ErrorArtifact:
        return super().query(params)

shhlife avatar Jan 14 '25 01:01 shhlife