atomic-agents icon indicating copy to clipboard operation
atomic-agents copied to clipboard

Add Tool.arun to MCP Factory

Open maximvlah opened this issue 7 months ago • 0 comments

It would be great to define tool.arun (binds run_tool_async) so that it could be called from, for example, the fastapi handler. Locally I did this and it works for my case. Or am I missing sth?


class MCPToolFactory:

    def _create_tool_classes(self, tool_definitions: List[MCPToolDefinition]) -> List[Type[BaseTool]]:
        """
        Create tool classes from definitions.

        Args:
            tool_definitions: List of tool definitions

        Returns:
            List of dynamically generated BaseTool subclasses
        """
        generated_tools = []

        for definition in tool_definitions:
            try:
                tool_name = definition.name
                tool_description = definition.description or f"Dynamically generated tool for MCP tool: {tool_name}"
                input_schema_dict = definition.input_schema

                # Create input schema
                InputSchema = self.schema_transformer.create_model_from_schema(
                    input_schema_dict,
                    f"{tool_name}InputSchema",
                    tool_name,
                    f"Input schema for {tool_name}",
                )

                # Create output schema
                OutputSchema = type(
                    f"{tool_name}OutputSchema", (MCPToolOutputSchema,), {"__doc__": f"Output schema for {tool_name}"}
                )

                # Add async 
                def run_tool_sync(self, params: InputSchema) -> OutputSchema:  # type: ignore
                    bound_tool_name = self.mcp_tool_name
                    bound_mcp_endpoint = self.mcp_endpoint  # May be None when using external session
                    bound_use_stdio = self.use_stdio
                    persistent_session: Optional[ClientSession] = getattr(self, "_client_session", None)
                    loop: Optional[asyncio.AbstractEventLoop] = getattr(self, "_event_loop", None)
                    bound_working_directory = getattr(self, "working_directory", None)

                    # Get arguments, excluding tool_name
                    arguments = params.model_dump(exclude={"tool_name"}, exclude_none=True)

                    async def _connect_and_call():
                        stack = AsyncExitStack()
                        try:
                            if bound_use_stdio:
                                # Split the command string into the command and its arguments
                                command_parts = shlex.split(bound_mcp_endpoint)
                                if not command_parts:
                                    raise ValueError("STDIO command string cannot be empty.")
                                command = command_parts[0]
                                args = command_parts[1:]
                                logger.debug(f"Executing tool '{bound_tool_name}' via STDIO: command='{command}', args={args}")
                                server_params = StdioServerParameters(
                                    command=command, args=args, env=None, cwd=bound_working_directory
                                )
                                stdio_transport = await stack.enter_async_context(stdio_client(server_params))
                                read_stream, write_stream = stdio_transport
                            else:
                                sse_endpoint = f"{bound_mcp_endpoint}/sse"
                                logger.debug(f"Executing tool '{bound_tool_name}' via SSE: endpoint={sse_endpoint}")
                                sse_transport = await stack.enter_async_context(sse_client(sse_endpoint))
                                read_stream, write_stream = sse_transport

                            session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
                            await session.initialize()

                            # Ensure arguments is a dict, even if empty
                            call_args = arguments if isinstance(arguments, dict) else {}
                            tool_result = await session.call_tool(name=bound_tool_name, arguments=call_args)
                            return tool_result
                        finally:
                            await stack.aclose()

                    async def _call_with_persistent_session():
                        # Ensure arguments is a dict, even if empty
                        call_args = arguments if isinstance(arguments, dict) else {}
                        return await persistent_session.call_tool(name=bound_tool_name, arguments=call_args)

                    try:
                        if persistent_session is not None:
                            # Use the always‑on session/loop supplied at construction time.
                            tool_result = cast(asyncio.AbstractEventLoop, loop).run_until_complete(
                                _call_with_persistent_session()
                            )
                        else:
                            # Legacy behaviour – open a fresh connection per invocation.
                            tool_result = loop.run(_connect_and_call())

                        # Process the result
                        if isinstance(tool_result, BaseModel) and hasattr(tool_result, "content"):
                            actual_result_content = tool_result.content
                        elif isinstance(tool_result, dict) and "content" in tool_result:
                            actual_result_content = tool_result["content"]
                        else:
                            actual_result_content = tool_result

                        return OutputSchema(result=actual_result_content)

                    except Exception as e:
                        logger.error(f"Error executing MCP tool '{bound_tool_name}': {e}", exc_info=True)
                        raise RuntimeError(f"Failed to execute MCP tool '{bound_tool_name}': {e}") from e

                # Create async run method
                async def run_tool_async(self, params: InputSchema) -> OutputSchema:  # type: ignore
                    bound_tool_name = self.mcp_tool_name
                    bound_mcp_endpoint = self.mcp_endpoint
                    bound_use_stdio = self.use_stdio
                    persistent_session: Optional[ClientSession] = getattr(self, "_client_session", None)
                    bound_working_directory = getattr(self, "working_directory", None)

                    arguments = params.model_dump(exclude={"tool_name"}, exclude_none=True)

                    async def _connect_and_call():
                        stack = AsyncExitStack()
                        try:
                            if bound_use_stdio:
                                command_parts = shlex.split(bound_mcp_endpoint)
                                if not command_parts:
                                    raise ValueError("STDIO command string cannot be empty.")
                                command = command_parts[0]
                                args = command_parts[1:]
                                logger.debug(f"Executing tool '{bound_tool_name}' via STDIO: command='{command}', args={args}")
                                server_params = StdioServerParameters(command=command, args=args, env=None, cwd=bound_working_directory)
                                stdio_transport = await stack.enter_async_context(stdio_client(server_params))
                                read_stream, write_stream = stdio_transport
                            else:
                                sse_endpoint = f"{bound_mcp_endpoint}/sse"
                                logger.debug(f"Executing tool '{bound_tool_name}' via SSE: endpoint={sse_endpoint}")
                                sse_transport = await stack.enter_async_context(sse_client(sse_endpoint))
                                read_stream, write_stream = sse_transport

                            session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
                            await session.initialize()
                            call_args = arguments if isinstance(arguments, dict) else {}
                            return await session.call_tool(name=bound_tool_name, arguments=call_args)
                        finally:
                            await stack.aclose()

                    async def _call_with_persistent_session():
                        call_args = arguments if isinstance(arguments, dict) else {}
                        return await persistent_session.call_tool(name=bound_tool_name, arguments=call_args)

                    try:
                        if persistent_session is not None:
                            tool_result = await _call_with_persistent_session()
                        else:
                            tool_result = await _connect_and_call()

                        if isinstance(tool_result, BaseModel) and hasattr(tool_result, "content"):
                            actual_result_content = tool_result.content
                        elif isinstance(tool_result, dict) and "content" in tool_result:
                            actual_result_content = tool_result["content"]
                        else:
                            actual_result_content = tool_result

                        return OutputSchema(result=actual_result_content)

                    except Exception as e:
                        logger.error(f"Error executing MCP tool '{bound_tool_name}': {e}", exc_info=True)
                        raise RuntimeError(f"Failed to execute MCP tool '{bound_tool_name}': {e}") from e
                
                # Create the tool class
                tool_class = type(
                    tool_name,
                    (BaseTool,),
                    {
                        "input_schema": InputSchema,
                        "output_schema": OutputSchema,
                        "run": run_tool_sync,
                        "arun": run_tool_async,
                        "__doc__": tool_description,
                        "mcp_tool_name": tool_name,
                        "mcp_endpoint": self.mcp_endpoint,
                        "use_stdio": self.use_stdio,
                        "_client_session": self.client_session,
                        "_event_loop": self.event_loop,
                        "working_directory": self.working_directory,
                    },
                )

                generated_tools.append(tool_class)

            except Exception as e:
                logger.error(f"Error generating class for tool '{definition.name}': {e}", exc_info=True)
                continue

        return generated_tools

maximvlah avatar May 15 '25 17:05 maximvlah