fix(calculator): update math calculator prompt, enhance error handling and logging

pull/819/head
DP37 3 months ago committed by ascender1729
parent ea66e78154
commit a612352abd

@ -0,0 +1,83 @@
from swarms import Agent
from swarms.tools.mcp_integration import MCPServerSseParams
from loguru import logger
# Comprehensive math prompt that encourages proper JSON formatting
MATH_AGENT_PROMPT = """
You are a helpful math calculator assistant.
Your role is to understand natural language math requests and perform calculations.
When asked to perform calculations:
1. Determine the operation (add, multiply, or divide)
2. Extract the numbers from the request
3. Use the appropriate math operation tool
FORMAT YOUR TOOL CALLS AS JSON with this format:
{"tool_name": "add", "a": <first_number>, "b": <second_number>}
or
{"tool_name": "multiply", "a": <first_number>, "b": <second_number>}
or
{"tool_name": "divide", "a": <first_number>, "b": <second_number>}
Always respond with a tool call in JSON format first, followed by a brief explanation.
"""
def initialize_math_system():
"""Initialize the math agent with MCP server configuration."""
# Configure the MCP server connection
math_server = MCPServerSseParams(
url="http://0.0.0.0:8000",
headers={"Content-Type": "application/json"},
timeout=5.0,
sse_read_timeout=30.0
)
# Create the agent with the MCP server configuration
math_agent = Agent(
agent_name="Math Assistant",
agent_description="Friendly math calculator",
system_prompt=MATH_AGENT_PROMPT,
max_loops=1,
mcp_servers=[math_server], # Pass MCP server config as a list
model_name="gpt-3.5-turbo",
verbose=True # Enable verbose mode to see more details
)
return math_agent
def main():
try:
logger.info("Initializing math system...")
math_agent = initialize_math_system()
print("\nMath Calculator Ready!")
print("Ask me any math question!")
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
print("Type 'exit' to quit\n")
while True:
try:
query = input("What would you like to calculate? ").strip()
if not query:
continue
if query.lower() == 'exit':
break
logger.info(f"Processing query: {query}")
result = math_agent.run(query)
print(f"\nResult: {result}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
logger.error(f"Error processing query: {e}")
print(f"Sorry, there was an error: {str(e)}")
except Exception as e:
logger.error(f"System initialization error: {e}")
print(f"Failed to start the math system: {str(e)}")
if __name__ == "__main__":
main()

@ -1,53 +1,83 @@
from swarms import Agent
from swarms.tools.mcp_integration import MCPServerSseParams
from loguru import logger
from swarms import Agent # Comprehensive math prompt that encourages proper JSON formatting
from swarms.tools.mcp_integration import MCPServerSseParams MATH_AGENT_PROMPT = """
from swarms.prompts.agent_prompts import MATH_AGENT_PROMPT You are a helpful math calculator assistant.
from loguru import logger
Your role is to understand natural language math requests and perform calculations.
def initialize_math_system(): When asked to perform calculations:
"""Initialize the math agent with MCP server configuration."""
math_server = MCPServerSseParams( 1. Determine the operation (add, multiply, or divide)
url="http://0.0.0.0:8000", 2. Extract the numbers from the request
headers={"Content-Type": "application/json"}, 3. Use the appropriate math operation tool
timeout=5.0,
sse_read_timeout=30.0 FORMAT YOUR TOOL CALLS AS JSON with this format:
) {"tool_name": "add", "a": <first_number>, "b": <second_number>}
or
math_agent = Agent( {"tool_name": "multiply", "a": <first_number>, "b": <second_number>}
agent_name="Math Assistant", or
agent_description="Friendly math calculator", {"tool_name": "divide", "a": <first_number>, "b": <second_number>}
system_prompt=MATH_AGENT_PROMPT,
max_loops=1, Always respond with a tool call in JSON format first, followed by a brief explanation.
mcp_servers=[math_server], """
model_name="gpt-3.5-turbo"
) def initialize_math_system():
"""Initialize the math agent with MCP server configuration."""
return math_agent # Configure the MCP server connection
math_server = MCPServerSseParams(
def main(): url="http://0.0.0.0:8000",
math_agent = initialize_math_system() headers={"Content-Type": "application/json"},
timeout=5.0,
print("\nMath Calculator Ready!") sse_read_timeout=30.0
print("Ask me any math question!") )
print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
print("Type 'exit' to quit\n") # Create the agent with the MCP server configuration
math_agent = Agent(
while True: agent_name="Math Assistant",
try: agent_description="Friendly math calculator",
query = input("What would you like to calculate? ").strip() system_prompt=MATH_AGENT_PROMPT,
if not query: max_loops=1,
continue mcp_servers=[math_server], # Pass MCP server config as a list
if query.lower() == 'exit': model_name="gpt-3.5-turbo",
break verbose=True # Enable verbose mode to see more details
)
result = math_agent.run(query)
print(f"\nResult: {result}\n") return math_agent
except KeyboardInterrupt: def main():
print("\nGoodbye!") try:
break logger.info("Initializing math system...")
except Exception as e: math_agent = initialize_math_system()
logger.error(f"Error: {e}")
print("\nMath Calculator Ready!")
if __name__ == "__main__": print("Ask me any math question!")
main() print("Examples: 'what is 5 plus 3?' or 'can you multiply 4 and 6?'")
print("Type 'exit' to quit\n")
while True:
try:
query = input("What would you like to calculate? ").strip()
if not query:
continue
if query.lower() == 'exit':
break
logger.info(f"Processing query: {query}")
result = math_agent.run(query)
print(f"\nResult: {result}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
logger.error(f"Error processing query: {e}")
print(f"Sorry, there was an error: {str(e)}")
except Exception as e:
logger.error(f"System initialization error: {e}")
print(f"Failed to start the math system: {str(e)}")
if __name__ == "__main__":
main()

@ -1,38 +1,79 @@
from fastmcp import FastMCP from fastmcp import FastMCP
from loguru import logger from loguru import logger
import time
# Create the MCP server
mcp = FastMCP(host="0.0.0.0",
port=8000,
transport="sse",
require_session_id=False)
mcp = FastMCP(
host="0.0.0.0",
port=8000,
transport="sse",
require_session_id=False
)
# Define tools with proper type hints and docstrings
@mcp.tool() @mcp.tool()
def add(a: int, b: int) -> str: def add(a: int, b: int) -> str:
"""Add two numbers.""" """Add two numbers.
Args:
a (int): First number
b (int): Second number
Returns:
str: A message containing the sum
"""
logger.info(f"Adding {a} and {b}")
result = a + b result = a + b
return f"The sum of {a} and {b} is {result}" return f"The sum of {a} and {b} is {result}"
@mcp.tool() @mcp.tool()
def multiply(a: int, b: int) -> str: def multiply(a: int, b: int) -> str:
"""Multiply two numbers.""" """Multiply two numbers.
Args:
a (int): First number
b (int): Second number
Returns:
str: A message containing the product
"""
logger.info(f"Multiplying {a} and {b}")
result = a * b result = a * b
return f"The product of {a} and {b} is {result}" return f"The product of {a} and {b} is {result}"
@mcp.tool() @mcp.tool()
def divide(a: int, b: int) -> str: def divide(a: int, b: int) -> str:
"""Divide two numbers.""" """Divide two numbers.
Args:
a (int): Numerator
b (int): Denominator
Returns:
str: A message containing the division result or an error message
"""
logger.info(f"Dividing {a} by {b}")
if b == 0: if b == 0:
logger.warning("Division by zero attempted")
return "Cannot divide by zero" return "Cannot divide by zero"
result = a / b result = a / b
return f"{a} divided by {b} is {result}" return f"{a} divided by {b} is {result}"
if __name__ == "__main__": if __name__ == "__main__":
try: try:
logger.info("Starting math server on http://0.0.0.0:8000") logger.info("Starting math server on http://0.0.0.0:8000")
print("Math MCP Server is running. Press Ctrl+C to stop.")
# Add a small delay to ensure logging is complete before the server starts
time.sleep(0.5)
# Run the MCP server
mcp.run() mcp.run()
except KeyboardInterrupt:
logger.info("Server shutdown requested")
print("\nShutting down server...")
except Exception as e: except Exception as e:
logger.error(f"Server error: {e}") logger.error(f"Server error: {e}")
raise raise

@ -1,14 +1,13 @@
# Agent prompts for MCP testing and interactions # Agent prompts for MCP testing and interactions
MATH_AGENT_PROMPT = '''You are a helpful math calculator assistant. # Keeping the original format that already has JSON formatting
MATH_AGENT_PROMPT = """You are a helpful math calculator assistant.
Your role is to understand natural language math requests and perform calculations. Your role is to understand natural language math requests and perform calculations.
When asked to perform calculations: When asked to perform calculations:
1. Determine the operation (add, multiply, or divide) 1. Determine the operation (add, multiply, or divide)
2. Extract the numbers from the request 2. Extract the numbers from the request
3. Use the appropriate math operation tool 3. Use the appropriate math operation tool
Format your tool calls as JSON with the tool_name and parameters.
Respond conversationally but be concise.
Example: Example:
User: "what is 5 plus 3?" User: "what is 5 plus 3?"
@ -17,7 +16,8 @@ You: Using the add operation for 5 and 3
User: "multiply 4 times 6" User: "multiply 4 times 6"
You: Using multiply for 4 and 6 You: Using multiply for 4 and 6
{"tool_name": "multiply", "a": 4, "b": 6}''' {"tool_name": "multiply", "a": 4, "b": 6}
"""
FINANCE_AGENT_PROMPT = """You are a financial analysis agent with access to stock market data services. FINANCE_AGENT_PROMPT = """You are a financial analysis agent with access to stock market data services.
Key responsibilities: Key responsibilities:
@ -28,42 +28,40 @@ Key responsibilities:
Use the available MCP tools to fetch real market data rather than making assumptions.""" Use the available MCP tools to fetch real market data rather than making assumptions."""
def generate_agent_role_prompt(agent): def generate_agent_role_prompt(agent):
"""Generates the agent role prompt. """Generates the agent role prompt.
Args: agent (str): The type of the agent. Args: agent (str): The type of the agent.
Returns: str: The agent role prompt. Returns: str: The agent role prompt.
""" """
prompts = { prompts = {
"Finance Agent": ( "Finance Agent":
"You are a seasoned finance analyst AI assistant. Your" ("You are a seasoned finance analyst AI assistant. Your"
" primary goal is to compose comprehensive, astute," " primary goal is to compose comprehensive, astute,"
" impartial, and methodically arranged financial reports" " impartial, and methodically arranged financial reports"
" based on provided data and trends." " based on provided data and trends."),
), "Travel Agent":
"Travel Agent": ( ("You are a world-travelled AI tour guide assistant. Your"
"You are a world-travelled AI tour guide assistant. Your" " main purpose is to draft engaging, insightful,"
" main purpose is to draft engaging, insightful," " unbiased, and well-structured travel reports on given"
" unbiased, and well-structured travel reports on given" " locations, including history, attractions, and cultural"
" locations, including history, attractions, and cultural" " insights."),
" insights." "Academic Research Agent":
), ("You are an AI academic research assistant. Your primary"
"Academic Research Agent": ( " responsibility is to create thorough, academically"
"You are an AI academic research assistant. Your primary" " rigorous, unbiased, and systematically organized"
" responsibility is to create thorough, academically" " reports on a given research topic, following the"
" rigorous, unbiased, and systematically organized" " standards of scholarly work."),
" reports on a given research topic, following the" "Default Agent":
" standards of scholarly work." ("You are an AI critical thinker research assistant. Your"
), " sole purpose is to write well written, critically"
"Default Agent": ( " acclaimed, objective and structured reports on given"
"You are an AI critical thinker research assistant. Your" " text."),
" sole purpose is to write well written, critically"
" acclaimed, objective and structured reports on given"
" text."
),
} }
return prompts.get(agent, "No such agent") return prompts.get(agent, "No such agent")
def generate_report_prompt(question, research_summary): def generate_report_prompt(question, research_summary):
"""Generates the report prompt for the given question and research summary. """Generates the report prompt for the given question and research summary.
Args: question (str): The question to generate the report prompt for Args: question (str): The question to generate the report prompt for
@ -71,16 +69,15 @@ def generate_report_prompt(question, research_summary):
Returns: str: The report prompt for the given question and research summary Returns: str: The report prompt for the given question and research summary
""" """
return ( return (f'"""{research_summary}""" Using the above information,'
f'"""{research_summary}""" Using the above information,' f' answer the following question or topic: "{question}" in a'
f' answer the following question or topic: "{question}" in a' " detailed report -- The report should focus on the answer"
" detailed report -- The report should focus on the answer" " to the question, should be well structured, informative,"
" to the question, should be well structured, informative," " in depth, with facts and numbers if available, a minimum"
" in depth, with facts and numbers if available, a minimum" " of 1,200 words and with markdown syntax and apa format."
" of 1,200 words and with markdown syntax and apa format." " Write all source urls at the end of the report in apa"
" Write all source urls at the end of the report in apa" " format")
" format"
)
def generate_search_queries_prompt(question): def generate_search_queries_prompt(question):
"""Generates the search queries prompt for the given question. """Generates the search queries prompt for the given question.
@ -88,12 +85,11 @@ def generate_search_queries_prompt(question):
Returns: str: The search queries prompt for the given question Returns: str: The search queries prompt for the given question
""" """
return ( return ("Write 4 google search queries to search online that form an"
"Write 4 google search queries to search online that form an" f' objective opinion from the following: "{question}"You must'
f' objective opinion from the following: "{question}"You must' " respond with a list of strings in the following format:"
" respond with a list of strings in the following format:" ' ["query 1", "query 2", "query 3", "query 4"]')
' ["query 1", "query 2", "query 3", "query 4"]'
)
def generate_resource_report_prompt(question, research_summary): def generate_resource_report_prompt(question, research_summary):
"""Generates the resource report prompt for the given question and research summary. """Generates the resource report prompt for the given question and research summary.
@ -105,19 +101,18 @@ def generate_resource_report_prompt(question, research_summary):
Returns: Returns:
str: The resource report prompt for the given question and research summary. str: The resource report prompt for the given question and research summary.
""" """
return ( return (f'"""{research_summary}""" Based on the above information,'
f'"""{research_summary}""" Based on the above information,' " generate a bibliography recommendation report for the"
" generate a bibliography recommendation report for the" f' following question or topic: "{question}". The report'
f' following question or topic: "{question}". The report' " should provide a detailed analysis of each recommended"
" should provide a detailed analysis of each recommended" " resource, explaining how each source can contribute to"
" resource, explaining how each source can contribute to" " finding answers to the research question. Focus on the"
" finding answers to the research question. Focus on the" " relevance, reliability, and significance of each source."
" relevance, reliability, and significance of each source." " Ensure that the report is well-structured, informative,"
" Ensure that the report is well-structured, informative," " in-depth, and follows Markdown syntax. Include relevant"
" in-depth, and follows Markdown syntax. Include relevant" " facts, figures, and numbers whenever available. The report"
" facts, figures, and numbers whenever available. The report" " should have a minimum length of 1,200 words.")
" should have a minimum length of 1,200 words."
)
def generate_outline_report_prompt(question, research_summary): def generate_outline_report_prompt(question, research_summary):
"""Generates the outline report prompt for the given question and research summary. """Generates the outline report prompt for the given question and research summary.
@ -126,17 +121,16 @@ def generate_outline_report_prompt(question, research_summary):
Returns: str: The outline report prompt for the given question and research summary Returns: str: The outline report prompt for the given question and research summary
""" """
return ( return (f'"""{research_summary}""" Using the above information,'
f'"""{research_summary}""" Using the above information,' " generate an outline for a research report in Markdown"
" generate an outline for a research report in Markdown" f' syntax for the following question or topic: "{question}".'
f' syntax for the following question or topic: "{question}".' " The outline should provide a well-structured framework for"
" The outline should provide a well-structured framework for" " the research report, including the main sections,"
" the research report, including the main sections," " subsections, and key points to be covered. The research"
" subsections, and key points to be covered. The research" " report should be detailed, informative, in-depth, and a"
" report should be detailed, informative, in-depth, and a" " minimum of 1,200 words. Use appropriate Markdown syntax to"
" minimum of 1,200 words. Use appropriate Markdown syntax to" " format the outline and ensure readability.")
" format the outline and ensure readability."
)
def generate_concepts_prompt(question, research_summary): def generate_concepts_prompt(question, research_summary):
"""Generates the concepts prompt for the given question. """Generates the concepts prompt for the given question.
@ -145,15 +139,14 @@ def generate_concepts_prompt(question, research_summary):
Returns: str: The concepts prompt for the given question Returns: str: The concepts prompt for the given question
""" """
return ( return (f'"""{research_summary}""" Using the above information,'
f'"""{research_summary}""" Using the above information,' " generate a list of 5 main concepts to learn for a research"
" generate a list of 5 main concepts to learn for a research" f' report on the following question or topic: "{question}".'
f' report on the following question or topic: "{question}".' " The outline should provide a well-structured frameworkYou"
" The outline should provide a well-structured frameworkYou" " must respond with a list of strings in the following"
" must respond with a list of strings in the following" ' format: ["concepts 1", "concepts 2", "concepts 3",'
' format: ["concepts 1", "concepts 2", "concepts 3",' ' "concepts 4, concepts 5"]')
' "concepts 4, concepts 5"]'
)
def generate_lesson_prompt(concept): def generate_lesson_prompt(concept):
""" """
@ -164,16 +157,15 @@ def generate_lesson_prompt(concept):
str: The lesson prompt for the given concept. str: The lesson prompt for the given concept.
""" """
prompt = ( prompt = (f"generate a comprehensive lesson about {concept} in Markdown"
f"generate a comprehensive lesson about {concept} in Markdown" f" syntax. This should include the definitionof {concept},"
f" syntax. This should include the definitionof {concept}," " its historical background and development, its"
" its historical background and development, its" " applications or uses in differentfields, and notable"
" applications or uses in differentfields, and notable" f" events or facts related to {concept}.")
f" events or facts related to {concept}."
)
return prompt return prompt
def get_report_by_type(report_type): def get_report_by_type(report_type):
report_type_mapping = { report_type_mapping = {
"research_report": generate_report_prompt, "research_report": generate_report_prompt,

@ -2647,18 +2647,7 @@ class Agent:
else: else:
return str(response) return str(response)
async def mcp_execution_flow(self, tool_call):
"""Execute MCP tool call flow"""
try:
result = await execute_mcp_tool(
url=self.mcp_servers[0]["url"],
parameters=tool_call,
output_type="str",
)
return result
except Exception as e:
logger.error(f"Error executing tool call: {e}")
return f"Error executing tool call: {e}"
def sentiment_and_evaluator(self, response: str): def sentiment_and_evaluator(self, response: str):
if self.evaluator: if self.evaluator:
@ -2689,3 +2678,136 @@ class Agent:
role="Output Cleaner", role="Output Cleaner",
content=response, content=response,
) )
async def amcp_execution_flow(self, response: str) -> str:
"""Async implementation of MCP execution flow.
Args:
response (str): The response from the LLM containing tool calls or natural language.
Returns:
str: The result of executing the tool calls with preserved formatting.
"""
try:
# Try to parse as JSON first
try:
tool_calls = json.loads(response)
is_json = True
logger.debug(f"Successfully parsed response as JSON: {tool_calls}")
except json.JSONDecodeError:
# If not JSON, treat as natural language
tool_calls = [response]
is_json = False
logger.debug(f"Could not parse response as JSON, treating as natural language")
# Execute tool calls against MCP servers
results = []
errors = []
# Handle both single tool call and array of tool calls
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
logger.debug(f"Executing {len(tool_calls)} tool calls against {len(self.mcp_servers)} MCP servers")
for tool_call in tool_calls:
try:
# Import here to avoid circular imports
from swarms.tools.mcp_integration import abatch_mcp_flow
logger.debug(f"Executing tool call: {tool_call}")
# Execute the tool call against all MCP servers
result = await abatch_mcp_flow(self.mcp_servers, tool_call)
if result:
logger.debug(f"Got result from MCP servers: {result}")
results.extend(result)
# Add successful result to memory with context
self.short_memory.add(
role="assistant",
content=f"Tool execution result: {result}"
)
else:
error_msg = "No result from tool execution"
errors.append(error_msg)
logger.debug(error_msg)
self.short_memory.add(
role="error",
content=error_msg
)
except Exception as e:
error_msg = f"Error executing tool call: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
self.short_memory.add(
role="error",
content=error_msg
)
# Format the final response
if results:
if len(results) == 1:
# For single results, return as is to preserve formatting
return results[0]
else:
# For multiple results, combine with context
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(f"Result {i}: {result}")
return "\n".join(formatted_results)
elif errors:
if len(errors) == 1:
return errors[0]
else:
return "Multiple errors occurred:\n" + "\n".join(f"- {err}" for err in errors)
else:
return "No results or errors returned"
except Exception as e:
error_msg = f"Error in MCP execution flow: {str(e)}"
logger.error(error_msg)
self.short_memory.add(
role="error",
content=error_msg
)
return error_msg
def mcp_execution_flow(self, response: str) -> str:
"""Synchronous wrapper for MCP execution flow.
This method creates a new event loop if needed or uses the existing one
to run the async MCP execution flow.
Args:
response (str): The response from the LLM containing tool calls or natural language.
Returns:
str: The result of executing the tool calls with preserved formatting.
"""
try:
# Check if we're already in an event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop exists, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if loop.is_running():
# We're in an async context, use run_coroutine_threadsafe
logger.debug("Using run_coroutine_threadsafe to execute MCP flow")
future = asyncio.run_coroutine_threadsafe(
self.amcp_execution_flow(response), loop
)
return future.result(timeout=30) # Adding timeout to prevent hanging
else:
# We're not in an async context, use loop.run_until_complete
logger.debug("Using run_until_complete to execute MCP flow")
return loop.run_until_complete(self.amcp_execution_flow(response))
except Exception as e:
error_msg = f"Error in MCP execution flow wrapper: {str(e)}"
logger.error(error_msg)
return error_msg

@ -1,255 +1,320 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import asyncio import asyncio
from contextlib import AbstractAsyncContextManager, AsyncExitStack from contextlib import AbstractAsyncContextManager, AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Literal from typing import Any, Dict, List, Optional, Literal, Union
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from loguru import logger from loguru import logger
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.types import CallToolResult, JSONRPCMessage from mcp.types import CallToolResult, JSONRPCMessage
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
class MCPServer(abc.ABC): class MCPServer(abc.ABC):
"""Base class for Model Context Protocol servers.""" """Base class for Model Context Protocol servers."""
@abc.abstractmethod @abc.abstractmethod
async def connect(self) -> None: async def connect(self) -> None:
"""Establish connection to the MCP server.""" """Establish connection to the MCP server."""
pass pass
@property @property
@abc.abstractmethod @abc.abstractmethod
def name(self) -> str: def name(self) -> str:
"""Human-readable server name.""" """Human-readable server name."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Clean up resources and close connection.""" """Clean up resources and close connection."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def list_tools(self) -> List[MCPTool]: async def list_tools(self) -> List[MCPTool]:
"""List available MCP tools on the server.""" """List available MCP tools on the server."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def call_tool( async def call_tool(
self, tool_name: str, arguments: Dict[str, Any] | None self, tool_name: str, arguments: Dict[str, Any] | None
) -> CallToolResult: ) -> CallToolResult:
"""Invoke a tool by name with provided arguments.""" """Invoke a tool by name with provided arguments."""
pass pass
class _MCPServerWithClientSession(MCPServer, abc.ABC): class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Mixin providing ClientSession-based MCP communication.""" """Mixin providing ClientSession-based MCP communication."""
def __init__(self, cache_tools_list: bool = False): def __init__(self, cache_tools_list: bool = False):
self.session: Optional[ClientSession] = None self.session: Optional[ClientSession] = None
self.exit_stack: AsyncExitStack = AsyncExitStack() self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock = asyncio.Lock() self._cleanup_lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list self.cache_tools_list = cache_tools_list
self._cache_dirty = True self._cache_dirty = True
self._tools_list: Optional[List[MCPTool]] = None self._tools_list: Optional[List[MCPTool]] = None
@abc.abstractmethod @abc.abstractmethod
def create_streams( def create_streams(
self self
) -> AbstractAsyncContextManager[ ) -> AbstractAsyncContextManager[
tuple[ tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage], MemoryObjectSendStream[JSONRPCMessage],
] ]
]: ]:
"""Supply the read/write streams for the MCP transport.""" """Supply the read/write streams for the MCP transport."""
pass pass
async def __aenter__(self) -> MCPServer: async def __aenter__(self) -> MCPServer:
await self.connect() await self.connect()
return self # type: ignore return self # type: ignore
async def __aexit__(self, exc_type, exc_value, tb) -> None: async def __aexit__(self, exc_type, exc_value, tb) -> None:
await self.cleanup() await self.cleanup()
async def connect(self) -> None: async def connect(self) -> None:
"""Initialize transport and ClientSession.""" """Initialize transport and ClientSession."""
try: try:
transport = await self.exit_stack.enter_async_context( transport = await self.exit_stack.enter_async_context(
self.create_streams() self.create_streams()
) )
read, write = transport read, write = transport
session = await self.exit_stack.enter_async_context( session = await self.exit_stack.enter_async_context(
ClientSession(read, write) ClientSession(read, write)
) )
await session.initialize() await session.initialize()
self.session = session self.session = session
except Exception as e: except Exception as e:
logger.error(f"Error initializing MCP server: {e}") logger.error(f"Error initializing MCP server: {e}")
await self.cleanup() await self.cleanup()
raise raise
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Close session and transport.""" """Close session and transport."""
async with self._cleanup_lock: async with self._cleanup_lock:
try: try:
await self.exit_stack.aclose() await self.exit_stack.aclose()
except Exception as e: except Exception as e:
logger.error(f"Error during cleanup: {e}") logger.error(f"Error during cleanup: {e}")
finally: finally:
self.session = None self.session = None
async def list_tools(self) -> List[MCPTool]: async def list_tools(self) -> List[MCPTool]:
if not self.session: if not self.session:
raise RuntimeError("Server not connected. Call connect() first.") raise RuntimeError("Server not connected. Call connect() first.")
if self.cache_tools_list and not self._cache_dirty and self._tools_list: if self.cache_tools_list and not self._cache_dirty and self._tools_list:
return self._tools_list return self._tools_list
self._cache_dirty = False self._cache_dirty = False
self._tools_list = (await self.session.list_tools()).tools self._tools_list = (await self.session.list_tools()).tools
return self._tools_list # type: ignore return self._tools_list # type: ignore
async def call_tool( async def call_tool(
self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None self, tool_name: str | None = None, arguments: Dict[str, Any] | None = None
) -> CallToolResult: ) -> CallToolResult:
if not arguments: if not arguments:
raise ValueError("Arguments dict is required to call a tool") raise ValueError("Arguments dict is required to call a tool")
name = tool_name or arguments.get("tool_name") or arguments.get("name") name = tool_name or arguments.get("tool_name") or arguments.get("name")
if not name: if not name:
raise ValueError("Tool name missing in arguments") raise ValueError("Tool name missing in arguments")
if not self.session: if not self.session:
raise RuntimeError("Server not connected. Call connect() first.") raise RuntimeError("Server not connected. Call connect() first.")
return await self.session.call_tool(name, arguments) return await self.session.call_tool(name, arguments)
class MCPServerStdioParams(TypedDict): class MCPServerStdioParams(TypedDict):
"""Configuration for stdio transport.""" """Configuration for stdio transport."""
command: str command: str
args: NotRequired[List[str]] args: NotRequired[List[str]]
env: NotRequired[Dict[str, str]] env: NotRequired[Dict[str, str]]
cwd: NotRequired[str | Path] cwd: NotRequired[str | Path]
encoding: NotRequired[str] encoding: NotRequired[str]
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]] encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
class MCPServerStdio(_MCPServerWithClientSession): class MCPServerStdio(_MCPServerWithClientSession):
"""MCP server over stdio transport.""" """MCP server over stdio transport."""
def __init__( def __init__(
self, self,
params: MCPServerStdioParams, params: MCPServerStdioParams,
cache_tools_list: bool = False, cache_tools_list: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
): ):
super().__init__(cache_tools_list) super().__init__(cache_tools_list)
self.params = StdioServerParameters( self.params = StdioServerParameters(
command=params["command"], command=params["command"],
args=params.get("args", []), args=params.get("args", []),
env=params.get("env"), env=params.get("env"),
cwd=params.get("cwd"), cwd=params.get("cwd"),
encoding=params.get("encoding", "utf-8"), encoding=params.get("encoding", "utf-8"),
encoding_error_handler=params.get("encoding_error_handler", "strict"), encoding_error_handler=params.get("encoding_error_handler", "strict"),
) )
self._name = name or f"stdio:{self.params.command}" self._name = name or f"stdio:{self.params.command}"
def create_streams(self) -> AbstractAsyncContextManager[ def create_streams(self) -> AbstractAsyncContextManager[
tuple[ tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage], MemoryObjectSendStream[JSONRPCMessage],
] ]
]: ]:
return stdio_client(self.params) return stdio_client(self.params)
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
class MCPServerSseParams(TypedDict): class MCPServerSseParams(TypedDict):
"""Configuration for HTTP+SSE transport.""" """Configuration for HTTP+SSE transport."""
url: str url: str
headers: NotRequired[Dict[str, str]] headers: NotRequired[Dict[str, str]]
timeout: NotRequired[float] timeout: NotRequired[float]
sse_read_timeout: NotRequired[float] sse_read_timeout: NotRequired[float]
class MCPServerSse(_MCPServerWithClientSession): class MCPServerSse(_MCPServerWithClientSession):
"""MCP server over HTTP with SSE transport.""" """MCP server over HTTP with SSE transport."""
def __init__( def __init__(
self, self,
params: MCPServerSseParams, params: MCPServerSseParams,
cache_tools_list: bool = False, cache_tools_list: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
): ):
super().__init__(cache_tools_list) super().__init__(cache_tools_list)
self.params = params self.params = params
self._name = name or f"sse:{params['url']}" self._name = name or f"sse:{params['url']}"
def create_streams(self) -> AbstractAsyncContextManager[ def create_streams(self) -> AbstractAsyncContextManager[
tuple[ tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage], MemoryObjectSendStream[JSONRPCMessage],
] ]
]: ]:
return sse_client( return sse_client(
url=self.params["url"], url=self.params["url"],
headers=self.params.get("headers"), headers=self.params.get("headers"),
timeout=self.params.get("timeout", 5), timeout=self.params.get("timeout", 5),
sse_read_timeout=self.params.get("sse_read_timeout", 300), sse_read_timeout=self.params.get("sse_read_timeout", 300),
) )
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
async def call_tool_fast( async def call_tool_fast(
server: MCPServerSse, payload: Dict[str, Any] | str server: MCPServerSse, payload: Dict[str, Any] | str
) -> Any: ) -> Any:
try: """Async function to call a tool on a server with proper cleanup."""
await server.connect() try:
result = await server.call_tool(arguments=payload if isinstance(payload, dict) else None) await server.connect()
return result arguments = payload if isinstance(payload, dict) else None
finally: result = await server.call_tool(arguments=arguments)
await server.cleanup() return result
finally:
await server.cleanup()
async def mcp_flow_get_tool_schema(
params: MCPServerSseParams,
) -> Any: async def mcp_flow_get_tool_schema(
async with MCPServerSse(params) as server: params: MCPServerSseParams,
tools = await server.list_tools() ) -> Any:
return tools """Async function to get tool schema from MCP server."""
async with MCPServerSse(params) as server:
tools = await server.list_tools()
async def mcp_flow( return tools
params: MCPServerSseParams,
function_call: Dict[str, Any] | str,
) -> Any: async def mcp_flow(
async with MCPServerSse(params) as server: params: MCPServerSseParams,
return await call_tool_fast(server, function_call) function_call: Dict[str, Any] | str,
) -> Any:
"""Async function to call a tool with given parameters."""
async def _call_one_server( async with MCPServerSse(params) as server:
params: MCPServerSseParams, payload: Dict[str, Any] | str return await call_tool_fast(server, function_call)
) -> Any:
server = MCPServerSse(params)
try: async def _call_one_server(
await server.connect() params: MCPServerSseParams, payload: Dict[str, Any] | str
return await server.call_tool(arguments=payload if isinstance(payload, dict) else None) ) -> Any:
finally: """Helper function to call a single MCP server."""
await server.cleanup() server = MCPServerSse(params)
try:
await server.connect()
def batch_mcp_flow( arguments = payload if isinstance(payload, dict) else None
params: List[MCPServerSseParams], payload: Dict[str, Any] | str return await server.call_tool(arguments=arguments)
) -> List[Any]: finally:
return asyncio.run( await server.cleanup()
asyncio.gather(*[_call_one_server(p, payload) for p in params])
)
async def abatch_mcp_flow(
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
) -> List[Any]:
"""Async function to execute a batch of MCP calls concurrently.
Args:
params (List[MCPServerSseParams]): List of MCP server configurations
payload (Dict[str, Any] | str): The payload to send to each server
Returns:
List[Any]: Results from all MCP servers
"""
if not params:
logger.warning("No MCP servers provided for batch operation")
return []
try:
return await asyncio.gather(*[_call_one_server(p, payload) for p in params])
except Exception as e:
logger.error(f"Error in abatch_mcp_flow: {e}")
# Return partial results if any were successful
return [f"Error in batch operation: {str(e)}"]
def batch_mcp_flow(
params: List[MCPServerSseParams], payload: Dict[str, Any] | str
) -> List[Any]:
"""Sync wrapper for batch MCP operations.
This creates a new event loop if needed to run the async batch operation.
ONLY use this when not already in an async context.
Args:
params (List[MCPServerSseParams]): List of MCP server configurations
payload (Dict[str, Any] | str): The payload to send to each server
Returns:
List[Any]: Results from all MCP servers
"""
if not params:
logger.warning("No MCP servers provided for batch operation")
return []
try:
# Check if we're already in an event loop
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop exists, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if loop.is_running():
# We're already in an async context, can't use asyncio.run
# Use a future to bridge sync-async gap
future = asyncio.run_coroutine_threadsafe(
abatch_mcp_flow(params, payload), loop
)
return future.result(timeout=30) # Add timeout to prevent hanging
else:
# We're not in an async context, safe to use loop.run_until_complete
return loop.run_until_complete(abatch_mcp_flow(params, payload))
except Exception as e:
logger.error(f"Error in batch_mcp_flow: {e}")
return [f"Error in batch operation: {str(e)}"]
Loading…
Cancel
Save