atomic-agents
atomic-agents copied to clipboard
Add Tool.arun to MCP Factory
It would be great to define tool.arun (binds run_tool_async) so that it could be called from, for example, the fastapi handler. Locally I did this and it works for my case. Or am I missing sth?
class MCPToolFactory:
def _create_tool_classes(self, tool_definitions: List[MCPToolDefinition]) -> List[Type[BaseTool]]:
"""
Create tool classes from definitions.
Args:
tool_definitions: List of tool definitions
Returns:
List of dynamically generated BaseTool subclasses
"""
generated_tools = []
for definition in tool_definitions:
try:
tool_name = definition.name
tool_description = definition.description or f"Dynamically generated tool for MCP tool: {tool_name}"
input_schema_dict = definition.input_schema
# Create input schema
InputSchema = self.schema_transformer.create_model_from_schema(
input_schema_dict,
f"{tool_name}InputSchema",
tool_name,
f"Input schema for {tool_name}",
)
# Create output schema
OutputSchema = type(
f"{tool_name}OutputSchema", (MCPToolOutputSchema,), {"__doc__": f"Output schema for {tool_name}"}
)
# Add async
def run_tool_sync(self, params: InputSchema) -> OutputSchema: # type: ignore
bound_tool_name = self.mcp_tool_name
bound_mcp_endpoint = self.mcp_endpoint # May be None when using external session
bound_use_stdio = self.use_stdio
persistent_session: Optional[ClientSession] = getattr(self, "_client_session", None)
loop: Optional[asyncio.AbstractEventLoop] = getattr(self, "_event_loop", None)
bound_working_directory = getattr(self, "working_directory", None)
# Get arguments, excluding tool_name
arguments = params.model_dump(exclude={"tool_name"}, exclude_none=True)
async def _connect_and_call():
stack = AsyncExitStack()
try:
if bound_use_stdio:
# Split the command string into the command and its arguments
command_parts = shlex.split(bound_mcp_endpoint)
if not command_parts:
raise ValueError("STDIO command string cannot be empty.")
command = command_parts[0]
args = command_parts[1:]
logger.debug(f"Executing tool '{bound_tool_name}' via STDIO: command='{command}', args={args}")
server_params = StdioServerParameters(
command=command, args=args, env=None, cwd=bound_working_directory
)
stdio_transport = await stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport
else:
sse_endpoint = f"{bound_mcp_endpoint}/sse"
logger.debug(f"Executing tool '{bound_tool_name}' via SSE: endpoint={sse_endpoint}")
sse_transport = await stack.enter_async_context(sse_client(sse_endpoint))
read_stream, write_stream = sse_transport
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
await session.initialize()
# Ensure arguments is a dict, even if empty
call_args = arguments if isinstance(arguments, dict) else {}
tool_result = await session.call_tool(name=bound_tool_name, arguments=call_args)
return tool_result
finally:
await stack.aclose()
async def _call_with_persistent_session():
# Ensure arguments is a dict, even if empty
call_args = arguments if isinstance(arguments, dict) else {}
return await persistent_session.call_tool(name=bound_tool_name, arguments=call_args)
try:
if persistent_session is not None:
# Use the always‑on session/loop supplied at construction time.
tool_result = cast(asyncio.AbstractEventLoop, loop).run_until_complete(
_call_with_persistent_session()
)
else:
# Legacy behaviour – open a fresh connection per invocation.
tool_result = loop.run(_connect_and_call())
# Process the result
if isinstance(tool_result, BaseModel) and hasattr(tool_result, "content"):
actual_result_content = tool_result.content
elif isinstance(tool_result, dict) and "content" in tool_result:
actual_result_content = tool_result["content"]
else:
actual_result_content = tool_result
return OutputSchema(result=actual_result_content)
except Exception as e:
logger.error(f"Error executing MCP tool '{bound_tool_name}': {e}", exc_info=True)
raise RuntimeError(f"Failed to execute MCP tool '{bound_tool_name}': {e}") from e
# Create async run method
async def run_tool_async(self, params: InputSchema) -> OutputSchema: # type: ignore
bound_tool_name = self.mcp_tool_name
bound_mcp_endpoint = self.mcp_endpoint
bound_use_stdio = self.use_stdio
persistent_session: Optional[ClientSession] = getattr(self, "_client_session", None)
bound_working_directory = getattr(self, "working_directory", None)
arguments = params.model_dump(exclude={"tool_name"}, exclude_none=True)
async def _connect_and_call():
stack = AsyncExitStack()
try:
if bound_use_stdio:
command_parts = shlex.split(bound_mcp_endpoint)
if not command_parts:
raise ValueError("STDIO command string cannot be empty.")
command = command_parts[0]
args = command_parts[1:]
logger.debug(f"Executing tool '{bound_tool_name}' via STDIO: command='{command}', args={args}")
server_params = StdioServerParameters(command=command, args=args, env=None, cwd=bound_working_directory)
stdio_transport = await stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport
else:
sse_endpoint = f"{bound_mcp_endpoint}/sse"
logger.debug(f"Executing tool '{bound_tool_name}' via SSE: endpoint={sse_endpoint}")
sse_transport = await stack.enter_async_context(sse_client(sse_endpoint))
read_stream, write_stream = sse_transport
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
await session.initialize()
call_args = arguments if isinstance(arguments, dict) else {}
return await session.call_tool(name=bound_tool_name, arguments=call_args)
finally:
await stack.aclose()
async def _call_with_persistent_session():
call_args = arguments if isinstance(arguments, dict) else {}
return await persistent_session.call_tool(name=bound_tool_name, arguments=call_args)
try:
if persistent_session is not None:
tool_result = await _call_with_persistent_session()
else:
tool_result = await _connect_and_call()
if isinstance(tool_result, BaseModel) and hasattr(tool_result, "content"):
actual_result_content = tool_result.content
elif isinstance(tool_result, dict) and "content" in tool_result:
actual_result_content = tool_result["content"]
else:
actual_result_content = tool_result
return OutputSchema(result=actual_result_content)
except Exception as e:
logger.error(f"Error executing MCP tool '{bound_tool_name}': {e}", exc_info=True)
raise RuntimeError(f"Failed to execute MCP tool '{bound_tool_name}': {e}") from e
# Create the tool class
tool_class = type(
tool_name,
(BaseTool,),
{
"input_schema": InputSchema,
"output_schema": OutputSchema,
"run": run_tool_sync,
"arun": run_tool_async,
"__doc__": tool_description,
"mcp_tool_name": tool_name,
"mcp_endpoint": self.mcp_endpoint,
"use_stdio": self.use_stdio,
"_client_session": self.client_session,
"_event_loop": self.event_loop,
"working_directory": self.working_directory,
},
)
generated_tools.append(tool_class)
except Exception as e:
logger.error(f"Error generating class for tool '{definition.name}': {e}", exc_info=True)
continue
return generated_tools