From 89fc8c76098ba016d80594e3b19f8bdff514a7b3 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 24 Mar 2025 11:00:36 -0700 Subject: [PATCH] cleanup --- docs/swarms_cloud/swarms_api.md | 721 ++------------------------------ swarms/tools/mcp_integration.py | 554 ++++++++++++++++++++++++ 2 files changed, 581 insertions(+), 694 deletions(-) create mode 100644 swarms/tools/mcp_integration.py diff --git a/docs/swarms_cloud/swarms_api.md b/docs/swarms_cloud/swarms_api.md index 35ab1dfc..f3a60678 100644 --- a/docs/swarms_cloud/swarms_api.md +++ b/docs/swarms_cloud/swarms_api.md @@ -916,692 +916,6 @@ if __name__ == "__main__": print(f" {i+1}. {rec}") ``` -### TypeScript/NodeJS Examples - -#### Financial Fraud Detection (TypeScript) - -This example demonstrates creating a swarm for financial transaction fraud detection. - -```typescript -import axios from 'axios'; -import * as fs from 'fs'; -import * as dotenv from 'dotenv'; - -// Load environment variables -dotenv.config(); - -// API Configuration -const API_BASE_URL = "https://api.swarms.world"; -const API_KEY = process.env.SWARMS_API_KEY; - -// Define interfaces for type safety -interface TransactionData { - transaction_id: string; - amount: number; - timestamp: string; - merchant: string; - merchant_category: string; - payment_method: string; - location: string; - device_id?: string; - ip_address?: string; -} - -interface UserProfile { - user_id: string; - account_age_days: number; - typical_transaction_amount: number; - typical_merchants: string[]; - typical_locations: string[]; - risk_score: number; -} - -interface FraudDetectionResult { - transaction_id: string; - fraud_score: number; - is_fraudulent: boolean; - risk_factors: string[]; - confidence: number; - recommended_action: string; - explanation: string; -} - -/** - * Detects potential fraud in financial transactions using a specialized agent swarm - * - * @param transactions - Array of transaction data to analyze - * @param userProfile - User profile information for context - * @param historicalPatterns - Description of historical patterns and behaviors - * @returns Promise resolving to fraud detection results - */ -async function detectTransactionFraud( - transactions: TransactionData[], - userProfile: UserProfile, - historicalPatterns: string -): Promise { - try { - // Prepare the task with all relevant information - const task = ` - Analyze the following financial transactions for potential fraud. - - USER PROFILE: - ${JSON.stringify(userProfile, null, 2)} - - HISTORICAL PATTERNS: - ${historicalPatterns} - - TRANSACTIONS TO ANALYZE: - ${JSON.stringify(transactions, null, 2)} - - For each transaction, determine the fraud risk score (0-100), whether it's likely fraudulent, - identified risk factors, confidence level, recommended action (allow, flag for review, block), - and a detailed explanation of the analysis. - - Return results for each transaction in a structured format. - `; - - // Define specialized fraud detection agents - const fraudDetectionAgents = [ - { - agent_name: "BehavioralAnalyst", - description: "Analyzes user behavior patterns to identify anomalies", - system_prompt: `You are an expert in behavioral analysis for fraud detection. - Your role is to analyze transaction patterns against historical user behavior - to identify potential anomalies or deviations that may indicate fraud. - - Consider: - - Timing patterns (day of week, time of day) - - Transaction amount patterns - - Merchant category patterns - - Geographic location patterns - - Device usage patterns - - Provide a detailed breakdown of behavioral anomalies with confidence scores.`, - model_name: "gpt-4o", - temperature: 0.2, - role: "analyst", - max_loops: 1 - }, - { - agent_name: "TechnicalFraudDetector", - description: "Analyzes technical indicators of fraud", - system_prompt: `You are a technical fraud detection specialist. - Your role is to analyze technical aspects of transactions to identify - potential indicators of fraud. - - Focus on: - - IP address analysis - - Device ID consistency - - Geolocation feasibility (impossible travel) - - Known fraud patterns - - Technical manipulation markers - - Provide a technical assessment with specific indicators of potential fraud.`, - model_name: "gpt-4o", - temperature: 0.2, - role: "analyst", - max_loops: 1 - }, - { - agent_name: "FinancialPatternAnalyst", - description: "Specializes in financial transaction patterns", - system_prompt: `You are a financial pattern analysis expert specializing in fraud detection. - Your role is to analyze the financial aspects of transactions for fraud indicators. - - Focus on: - - Transaction amount anomalies - - Merchant risk profiles - - Transaction velocity - - Transaction sequence patterns - - Round-number amounts or other suspicious values - - Provide a financial pattern analysis with risk assessment.`, - model_name: "gpt-4o", - temperature: 0.2, - role: "analyst", - max_loops: 1 - }, - { - agent_name: "FraudInvestigator", - description: "Synthesizes all analysis into final fraud determination", - system_prompt: `You are a senior fraud investigator responsible for making the final - determination on transaction fraud risk. - - Your role is to: - 1. Synthesize inputs from behavioral, technical, and financial analyses - 2. Weigh different risk factors appropriately - 3. Calculate an overall fraud risk score (0-100) - 4. Make a clear determination (legitimate, suspicious, fraudulent) - 5. Recommend specific actions (allow, review, block) - 6. Provide a clear, detailed explanation for each determination - - Balance false positives and false negatives appropriately.`, - model_name: "gpt-4o", - temperature: 0.3, - role: "manager", - max_loops: 1 - } - ]; - - // Create the swarm specification - const swarmSpec = { - name: "fraud-detection-swarm", - description: "Financial transaction fraud detection swarm", - agents: fraudDetectionAgents, - max_loops: 2, - swarm_type: "HiearchicalSwarm", - task: task, - return_history: false - }; - - // Execute the swarm - const response = await axios.post( - `${API_BASE_URL}/v1/swarm/completions`, - swarmSpec, - { - headers: { - 'x-api-key': API_KEY, - 'Content-Type': 'application/json' - } - } - ); - - if (response.status === 200) { - console.log(`Fraud detection completed. Cost: ${response.data.metadata.billing_info.total_cost}`); - return response.data.output.fraud_analysis as FraudDetectionResult[]; - } else { - throw new Error(`API request failed with status: ${response.status}`); - } - - } catch (error) { - console.error('Error in fraud detection:', error); - if (axios.isAxiosError(error) && error.response) { - console.error('API response:', error.response.data); - } - throw error; - } -} - -// Usage example -async function main() { - // Sample transaction data - const transactions: TransactionData[] = [ - { - transaction_id: "T-12345-89012", - amount: 1299.99, - timestamp: "2025-03-04T09:23:42Z", - merchant: "ElectronicsPro", - merchant_category: "Electronics", - payment_method: "Credit Card", - location: "San Francisco, CA", - device_id: "D-472910", - ip_address: "192.168.1.127" - }, - { - transaction_id: "T-12345-89013", - amount: 849.50, - timestamp: "2025-03-04T09:45:18Z", - merchant: "LuxuryBrands", - merchant_category: "Fashion", - payment_method: "Credit Card", - location: "Miami, FL", - device_id: "D-891245", - ip_address: "45.23.189.44" - }, - { - transaction_id: "T-12345-89014", - amount: 24.99, - timestamp: "2025-03-04T10:12:33Z", - merchant: "CoffeeDeluxe", - merchant_category: "Restaurant", - payment_method: "Mobile Wallet", - location: "San Francisco, CA", - device_id: "D-472910", - ip_address: "192.168.1.127" - } - ]; - - // Sample user profile - const userProfile: UserProfile = { - user_id: "U-78901-23456", - account_age_days: 487, - typical_transaction_amount: 150.75, - typical_merchants: ["Groceries", "Restaurant", "Retail", "Streaming"], - typical_locations: ["San Francisco, CA", "Oakland, CA"], - risk_score: 12 - }; - - // Sample historical patterns - const historicalPatterns = ` - User typically makes 15-20 transactions per month - Average transaction amount: $50-$200 - Largest previous transaction: $750 (furniture) - No previous transactions over $1000 - Typically shops in San Francisco Bay Area - No international transactions in past 12 months - Usually uses mobile wallet for small purchases (<$50) - Credit card used for larger purchases - Typical activity times: weekdays 8am-10pm PST - No previous purchases in Miami, FL - `; - - try { - const results = await detectTransactionFraud( - transactions, - userProfile, - historicalPatterns - ); - - console.log("Fraud Detection Results:"); - console.log(JSON.stringify(results, null, 2)); - - // Save results to file - fs.writeFileSync( - `fraud_detection_results_${new Date().toISOString().replace(/:/g, '-')}.json`, - JSON.stringify(results, null, 2) - ); - - // Process high-risk transactions - const highRiskTransactions = results.filter(r => r.fraud_score > 75); - if (highRiskTransactions.length > 0) { - console.log(`WARNING: ${highRiskTransactions.length} high-risk transactions detected!`); - // In production, you would trigger alerts, notifications, etc. - } - - } catch (error) { - console.error("Failed to complete fraud detection:", error); - } -} - -main(); -``` - -#### Healthcare Report Generation (TypeScript) - -This example demonstrates creating a swarm for generating comprehensive healthcare reports from patient data. - -```typescript -import axios from 'axios'; -import * as fs from 'fs'; -import * as path from 'path'; -import * as dotenv from 'dotenv'; - -// Load environment variables -dotenv.config(); - -// API Configuration -const API_BASE_URL = "https://api.swarms.world"; -const API_KEY = process.env.SWARMS_API_KEY; - -// Define interfaces -interface PatientData { - patient_id: string; - demographics: { - age: number; - gender: string; - ethnicity?: string; - weight_kg?: number; - height_cm?: number; - }; - vitals: { - blood_pressure?: string; - heart_rate?: number; - respiratory_rate?: number; - temperature_c?: number; - oxygen_saturation?: number; - }; - conditions: string[]; - medications: { - name: string; - dosage: string; - frequency: string; - }[]; - allergies: string[]; - lab_results: { - test_name: string; - value: number | string; - unit: string; - reference_range?: string; - collection_date: string; - }[]; - imaging_results?: { - study_type: string; - body_area: string; - findings: string; - date: string; - }[]; - notes?: string[]; -} - -interface MedicalReport { - summary: string; - assessment: { - primary_diagnosis: string; - secondary_diagnoses: string[]; - clinical_impression: string; - }; - recommendations: string[]; - medication_review: { - current_medications: { - medication: string; - assessment: string; - }[]; - potential_interactions: string[]; - recommended_changes: string[]; - }; - care_plan: { - short_term_goals: string[]; - long_term_goals: string[]; - follow_up: string; - }; - lab_interpretation: { - flagged_results: { - test: string; - value: string; - interpretation: string; - }[]; - trends: string[]; - }; - clinical_reasoning: string; -} - -/** - * Generates a comprehensive clinical report for a patient using a specialized medical agent swarm - * - * @param patientData - Structured patient medical data - * @param clinicalGuidelines - Relevant clinical guidelines to consider - * @param reportType - Type of report to generate (e.g., 'comprehensive', 'follow-up', 'specialist') - * @returns Promise resolving to structured medical report - */ -async function generateClinicalReport( - patientData: PatientData, - clinicalGuidelines: string, - reportType: string -): Promise { - try { - // Prepare detailed task description - const task = ` - Generate a comprehensive clinical report for the following patient: - - PATIENT DATA: - ${JSON.stringify(patientData, null, 2)} - - CLINICAL GUIDELINES TO CONSIDER: - ${clinicalGuidelines} - - REPORT TYPE: - ${reportType} - - Analyze all aspects of the patient's health status including symptoms, lab results, - medications, conditions, and other relevant factors. Provide a detailed clinical assessment, - evidence-based recommendations, medication review with potential interactions, - comprehensive care plan, and clear follow-up instructions. - - Structure the report to include a concise executive summary, detailed assessment with - clinical reasoning, specific actionable recommendations, and a clear care plan. - - Ensure all interpretations and recommendations align with current clinical guidelines - and evidence-based medicine. - `; - - // Use Auto Swarm Builder for this complex medical task - const swarmSpec = { - name: "clinical-report-generator", - description: "Medical report generation and analysis swarm", - swarm_type: "AutoSwarmBuilder", - task: task, - max_loops: 3, - return_history: false - }; - - // Execute the swarm - console.log("Generating clinical report..."); - const response = await axios.post( - `${API_BASE_URL}/v1/swarm/completions`, - swarmSpec, - { - headers: { - 'x-api-key': API_KEY, - 'Content-Type': 'application/json' - } - } - ); - - if (response.status === 200) { - const executionTime = response.data.metadata.execution_time_seconds; - const cost = response.data.metadata.billing_info.total_cost; - const numAgents = response.data.metadata.num_agents; - - console.log(`Report generation completed in ${executionTime.toFixed(2)} seconds`); - console.log(`Used ${numAgents} specialized medical agents`); - console.log(`Total cost: ${cost.toFixed(4)}`); - - return response.data.output.medical_report as MedicalReport; - } else { - throw new Error(`API request failed with status: ${response.status}`); - } - - } catch (error) { - console.error('Error generating clinical report:', error); - if (axios.isAxiosError(error) && error.response) { - console.error('API response:', error.response.data); - } - throw error; - } -} - -/** - * Schedules regular report generation for a patient - */ -async function scheduleRecurringReports( - patientId: string, - reportType: string, - intervalDays: number -): Promise { - try { - // Schedule the next report - const nextReportDate = new Date(); - nextReportDate.setDate(nextReportDate.getDate() + intervalDays); - - const scheduleSpec = { - name: `${patientId}-${reportType}-report`, - description: `Scheduled ${reportType} report for patient ${patientId}`, - task: `Generate a ${reportType} clinical report for patient ${patientId} following standard protocols. Retrieve the most recent patient data and produce a comprehensive clinical assessment with recommendations.`, - schedule: { - scheduled_time: nextReportDate.toISOString(), - timezone: "UTC" - } - }; - - const response = await axios.post( - `${API_BASE_URL}/v1/swarm/schedule`, - scheduleSpec, - { - headers: { - 'x-api-key': API_KEY, - 'Content-Type': 'application/json' - } - } - ); - - if (response.status === 200) { - console.log(`Successfully scheduled next report for ${nextReportDate.toISOString()}`); - console.log(`Job ID: ${response.data.job_id}`); - } else { - throw new Error(`Failed to schedule report: ${response.status}`); - } - - } catch (error) { - console.error('Error scheduling report:', error); - if (axios.isAxiosError(error) && error.response) { - console.error('API response:', error.response.data); - } - } -} - -// Usage example -async function main() { - // Sample patient data - const patientData: PatientData = { - patient_id: "P-12345-67890", - demographics: { - age: 58, - gender: "Male", - ethnicity: "Caucasian", - weight_kg: 92.5, - height_cm: 178 - }, - vitals: { - blood_pressure: "148/92", - heart_rate: 82, - respiratory_rate: 16, - temperature_c: 36.8, - oxygen_saturation: 96 - }, - conditions: [ - "Type 2 Diabetes Mellitus", - "Hypertension", - "Coronary Artery Disease", - "Hyperlipidemia", - "Obesity" - ], - medications: [ - { - name: "Metformin", - dosage: "1000mg", - frequency: "BID" - }, - { - name: "Lisinopril", - dosage: "20mg", - frequency: "QD" - }, - { - name: "Atorvastatin", - dosage: "40mg", - frequency: "QD" - }, - { - name: "Aspirin", - dosage: "81mg", - frequency: "QD" - } - ], - allergies: ["Penicillin", "Sulfa drugs"], - lab_results: [ - { - test_name: "HbA1c", - value: 8.2, - unit: "%", - reference_range: "<7.0", - collection_date: "2025-02-15" - }, - { - test_name: "Fasting Glucose", - value: 165, - unit: "mg/dL", - reference_range: "70-100", - collection_date: "2025-02-15" - }, - { - test_name: "LDL Cholesterol", - value: 118, - unit: "mg/dL", - reference_range: "<100", - collection_date: "2025-02-15" - }, - { - test_name: "HDL Cholesterol", - value: 38, - unit: "mg/dL", - reference_range: ">40", - collection_date: "2025-02-15" - }, - { - test_name: "eGFR", - value: 68, - unit: "mL/min/1.73m²", - reference_range: ">90", - collection_date: "2025-02-15" - } - ], - imaging_results: [ - { - study_type: "Cardiac CT Angiography", - body_area: "Heart and coronary arteries", - findings: "Moderate calcification in LAD. 50-70% stenosis in proximal LAD. No significant stenosis in other coronary arteries.", - date: "2025-01-10" - } - ], - notes: [ - "Patient reports increased fatigue over the past month", - "Complains of occasional chest discomfort with exertion", - "Currently following low-carb diet but admits to poor adherence", - "Exercise limited by knee pain" - ] - }; - - // Clinical guidelines to consider - const clinicalGuidelines = ` - ADA 2025 Guidelines for Type 2 Diabetes: - - HbA1c target <7.0% for most adults - - Consider less stringent targets (e.g., <8.0%) for patients with multiple comorbidities - - First-line therapy: Metformin + lifestyle modifications - - For patients with ASCVD: consider GLP-1 RA or SGLT2 inhibitor with proven CV benefit - - ACC/AHA 2024 Hypertension Guidelines: - - BP target <130/80 mmHg for patients with diabetes and/or CAD - - First-line: ACE inhibitor or ARB for patients with diabetes - - ACC/AHA 2024 Cholesterol Guidelines: - - LDL-C target <70 mg/dL for very high-risk ASCVD - - Consider adding ezetimibe or PCSK9 inhibitor for very high-risk patients not at goal - `; - - try { - // Generate the report - const report = await generateClinicalReport( - patientData, - clinicalGuidelines, - "comprehensive" - ); - - // Save report to file - const timestamp = new Date().toISOString().replace(/:/g, '-'); - const outputDir = './reports'; - - if (!fs.existsSync(outputDir)) { - fs.mkdirSync(outputDir); - } - - fs.writeFileSync( - path.join(outputDir, `clinical_report_${patientData.patient_id}_${timestamp}.json`), - JSON.stringify(report, null, 2) - ); - - // Display executive summary - console.log("\nREPORT SUMMARY:"); - console.log(report.summary); - - console.log("\nPRIMARY DIAGNOSIS:"); - console.log(report.assessment.primary_diagnosis); - - console.log("\nKEY RECOMMENDATIONS:"); - report.recommendations.forEach((rec, i) => { - console.log(` ${i+1}. ${rec}`); - }); - - // Schedule the next report in 90 days - await scheduleRecurringReports(patientData.patient_id, "follow-up", 90); - - } catch (error) { - console.error("Failed to complete report generation:", error); - } -} - -main(); -``` - ## Error Handling The Swarms API follows standard HTTP status codes for error responses: @@ -1632,8 +946,12 @@ The API enforces a rate limit of 100 requests per 60-second window. When exceede The API uses a credit-based billing system with costs calculated based on: 1. **Agent Count**: Base cost per agent + + 2. **Input Tokens**: Cost based on the size of input data and prompts + 3. **Output Tokens**: Cost based on the length of generated responses + 4. **Time of Day**: Reduced rates during nighttime hours (8 PM to 6 AM PT) Cost information is included in each response's metadata for transparency and forecasting. @@ -1641,26 +959,41 @@ Cost information is included in each response's metadata for transparency and fo ## Best Practices 1. **Task Description** + - Provide detailed, specific task descriptions + - Include all necessary context and constraints + - Structure complex inputs for easier processing 2. **Agent Configuration** + - For simple tasks, use `AutoSwarmBuilder` to automatically generate optimal agents + - For complex or specialized tasks, manually define agents with specific expertise + - Use appropriate `swarm_type` for your workflow pattern 3. **Production Implementation** - - Implement robust error handling and retries - - Log API responses for debugging and auditing - - Monitor costs closely during development and testing - - Use scheduled jobs for recurring tasks instead of continuous polling + + - Implement robust error handling and retries + + - Log API responses for debugging and auditing + + - Monitor costs closely during development and testing + + - Use scheduled jobs for recurring tasks instead of continuous polling 4. **Cost Optimization** - - Batch related tasks when possible - - Schedule non-urgent tasks during discount hours - - Carefully scope task descriptions to reduce token usage - - Cache results when appropriate + + - Batch related tasks when possible + + - Schedule non-urgent tasks during discount hours + + - Carefully scope task descriptions to reduce token usage + + - Cache results when appropriate + ## Support diff --git a/swarms/tools/mcp_integration.py b/swarms/tools/mcp_integration.py new file mode 100644 index 00000000..6c079890 --- /dev/null +++ b/swarms/tools/mcp_integration.py @@ -0,0 +1,554 @@ +from contextlib import AsyncExitStack +from types import TracebackType +from typing import ( + Any, + Callable, + Coroutine, + List, + Literal, + Optional, + TypedDict, + cast, +) + +from mcp import ClientSession, StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.types import ( + CallToolResult, + EmbeddedResource, + ImageContent, + PromptMessage, + TextContent, +) +from mcp.types import ( + Tool as MCPTool, +) + + +def convert_mcp_prompt_message_to_message( + message: PromptMessage, +) -> str: + """Convert an MCP prompt message to a string message. + + Args: + message: MCP prompt message to convert + + Returns: + a string message + """ + if message.content.type == "text": + if message.role == "user": + return str(message.content.text) + elif message.role == "assistant": + return str( + message.content.text + ) # Fixed attribute name from str to text + else: + raise ValueError( + f"Unsupported prompt message role: {message.role}" + ) + + raise ValueError( + f"Unsupported prompt message content type: {message.content.type}" + ) + + +async def load_mcp_prompt( + session: ClientSession, + name: str, + arguments: Optional[dict[str, Any]] = None, +) -> List[str]: + """Load MCP prompt and convert to messages.""" + response = await session.get_prompt(name, arguments) + + return [ + convert_mcp_prompt_message_to_message(message) + for message in response.messages + ] + + +DEFAULT_ENCODING = "utf-8" +DEFAULT_ENCODING_ERROR_HANDLER = "strict" + +DEFAULT_HTTP_TIMEOUT = 5 +DEFAULT_SSE_READ_TIMEOUT = 60 * 5 + + +class StdioConnection(TypedDict): + transport: Literal["stdio"] + + command: str + """The executable to run to start the server.""" + + args: list[str] + """Command line arguments to pass to the executable.""" + + env: dict[str, str] | None + """The environment to use when spawning the process.""" + + encoding: str + """The text encoding used when sending/receiving messages to the server.""" + + encoding_error_handler: Literal["strict", "ignore", "replace"] + """ + The text encoding error handler. + + See https://docs.python.org/3/library/codecs.html#codec-base-classes for + explanations of possible values + """ + + +class SSEConnection(TypedDict): + transport: Literal["sse"] + + url: str + """The URL of the SSE endpoint to connect to.""" + + headers: dict[str, Any] | None + """HTTP headers to send to the SSE endpoint""" + + timeout: float + """HTTP timeout""" + + sse_read_timeout: float + """SSE read timeout""" + + +NonTextContent = ImageContent | EmbeddedResource + + +def _convert_call_tool_result( + call_tool_result: CallToolResult, +) -> tuple[str | list[str], list[NonTextContent] | None]: + text_contents: list[TextContent] = [] + non_text_contents = [] + for content in call_tool_result.content: + if isinstance(content, TextContent): + text_contents.append(content) + else: + non_text_contents.append(content) + + tool_content: str | list[str] = [ + content.text for content in text_contents + ] + if len(text_contents) == 1: + tool_content = tool_content[0] + + if call_tool_result.isError: + raise ValueError("Error calling tool") + + return tool_content, non_text_contents or None + + +def convert_mcp_tool_to_function( + session: ClientSession, + tool: MCPTool, +) -> Callable[ + ..., + Coroutine[ + Any, Any, tuple[str | list[str], list[NonTextContent] | None] + ], +]: + """Convert an MCP tool to a callable function. + + NOTE: this tool can be executed only in a context of an active MCP client session. + + Args: + session: MCP client session + tool: MCP tool to convert + + Returns: + a callable function + """ + + async def call_tool( + **arguments: dict[str, Any], + ) -> tuple[str | list[str], list[NonTextContent] | None]: + """Execute the tool with the given arguments.""" + call_tool_result = await session.call_tool( + tool.name, arguments + ) + return _convert_call_tool_result(call_tool_result) + + # Add metadata as attributes to the function + call_tool.__name__ = tool.name + call_tool.__doc__ = tool.description or "" + call_tool.schema = tool.inputSchema + + return call_tool + + +async def load_mcp_tools(session: ClientSession) -> list[Callable]: + """Load all available MCP tools and convert them to callable functions.""" + tools = await session.list_tools() + return [ + convert_mcp_tool_to_function(session, tool) + for tool in tools.tools + ] + + +class MultiServerMCPClient: + """Client for connecting to multiple MCP servers and loading tools from them.""" + + def __init__( + self, + connections: dict[ + str, StdioConnection | SSEConnection + ] = None, + ) -> None: + """Initialize a MultiServerMCPClient with MCP servers connections. + + Args: + connections: A dictionary mapping server names to connection configurations. + Each configuration can be either a StdioConnection or SSEConnection. + If None, no initial connections are established. + + Example: + + ```python + async with MultiServerMCPClient( + { + "math": { + "command": "python", + # Make sure to update to the full absolute path to your math_server.py file + "args": ["/path/to/math_server.py"], + "transport": "stdio", + }, + "weather": { + # make sure you start your weather server on port 8000 + "url": "http://localhost:8000/sse", + "transport": "sse", + } + } + ) as client: + all_tools = client.get_tools() + ... + ``` + """ + self.connections = connections + self.exit_stack = AsyncExitStack() + self.sessions: dict[str, ClientSession] = {} + self.server_name_to_tools: dict[str, list[Callable]] = {} + + async def _initialize_session_and_load_tools( + self, server_name: str, session: ClientSession + ) -> None: + """Initialize a session and load tools from it. + + Args: + server_name: Name to identify this server connection + session: The ClientSession to initialize + """ + # Initialize the session + await session.initialize() + self.sessions[server_name] = session + + # Load tools from this server + server_tools = await load_mcp_tools(session) + self.server_name_to_tools[server_name] = server_tools + + async def connect_to_server( + self, + server_name: str, + *, + transport: Literal["stdio", "sse"] = "stdio", + **kwargs, + ) -> None: + """Connect to an MCP server using either stdio or SSE. + + This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse + based on the provided transport parameter. + + Args: + server_name: Name to identify this server connection + transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio" + **kwargs: Additional arguments to pass to the specific connection method + + Raises: + ValueError: If transport is not recognized + ValueError: If required parameters for the specified transport are missing + """ + if transport == "sse": + if "url" not in kwargs: + raise ValueError( + "'url' parameter is required for SSE connection" + ) + await self.connect_to_server_via_sse( + server_name, + url=kwargs["url"], + headers=kwargs.get("headers"), + timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT), + sse_read_timeout=kwargs.get( + "sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT + ), + ) + elif transport == "stdio": + if "command" not in kwargs: + raise ValueError( + "'command' parameter is required for stdio connection" + ) + if "args" not in kwargs: + raise ValueError( + "'args' parameter is required for stdio connection" + ) + await self.connect_to_server_via_stdio( + server_name, + command=kwargs["command"], + args=kwargs["args"], + env=kwargs.get("env"), + encoding=kwargs.get("encoding", DEFAULT_ENCODING), + encoding_error_handler=kwargs.get( + "encoding_error_handler", + DEFAULT_ENCODING_ERROR_HANDLER, + ), + ) + else: + raise ValueError( + f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'" + ) + + async def connect_to_server_via_stdio( + self, + server_name: str, + *, + command: str, + args: list[str], + env: dict[str, str] | None = None, + encoding: str = DEFAULT_ENCODING, + encoding_error_handler: Literal[ + "strict", "ignore", "replace" + ] = DEFAULT_ENCODING_ERROR_HANDLER, + ) -> None: + """Connect to a specific MCP server using stdio + + Args: + server_name: Name to identify this server connection + command: Command to execute + args: Arguments for the command + env: Environment variables for the command + encoding: Character encoding + encoding_error_handler: How to handle encoding errors + """ + server_params = StdioServerParameters( + command=command, + args=args, + env=env, + encoding=encoding, + encoding_error_handler=encoding_error_handler, + ) + + # Create and store the connection + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params) + ) + read, write = stdio_transport + session = cast( + ClientSession, + await self.exit_stack.enter_async_context( + ClientSession(read, write) + ), + ) + + await self._initialize_session_and_load_tools( + server_name, session + ) + + async def connect_to_server_via_sse( + self, + server_name: str, + *, + url: str, + headers: dict[str, Any] | None = None, + timeout: float = DEFAULT_HTTP_TIMEOUT, + sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT, + ) -> None: + """Connect to a specific MCP server using SSE + + Args: + server_name: Name to identify this server connection + url: URL of the SSE server + headers: HTTP headers to send to the SSE endpoint + timeout: HTTP timeout + sse_read_timeout: SSE read timeout + """ + # Create and store the connection + sse_transport = await self.exit_stack.enter_async_context( + sse_client(url, headers, timeout, sse_read_timeout) + ) + read, write = sse_transport + session = cast( + ClientSession, + await self.exit_stack.enter_async_context( + ClientSession(read, write) + ), + ) + + await self._initialize_session_and_load_tools( + server_name, session + ) + + def get_tools(self) -> list[Callable]: + """Get a list of all tools from all connected servers.""" + all_tools: list[Callable] = [] + for server_tools in self.server_name_to_tools.values(): + all_tools.extend(server_tools) + return all_tools + + async def get_prompt( + self, + server_name: str, + prompt_name: str, + arguments: Optional[dict[str, Any]] = None, + ) -> List[str]: + """Get a prompt from a given MCP server.""" + session = self.sessions[server_name] + return await load_mcp_prompt(session, prompt_name, arguments) + + async def __aenter__(self) -> "MultiServerMCPClient": + try: + connections = self.connections or {} + for server_name, connection in connections.items(): + connection_dict = connection.copy() + transport = connection_dict.pop("transport") + if transport == "stdio": + await self.connect_to_server_via_stdio( + server_name, **connection_dict + ) + elif transport == "sse": + await self.connect_to_server_via_sse( + server_name, **connection_dict + ) + else: + raise ValueError( + f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'" + ) + return self + except Exception: + await self.exit_stack.aclose() + raise + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.exit_stack.aclose() + + +#!/usr/bin/env python3 +import asyncio +import os +import json +from typing import List, Any, Callable + +# # Import our MCP client module +# from mcp_client import MultiServerMCPClient + + +async def main(): + """Test script for demonstrating MCP client usage.""" + print("Starting MCP Client test...") + + # Create a connection to multiple MCP servers + # You'll need to update these paths to match your setup + async with MultiServerMCPClient( + { + "math": { + "transport": "stdio", + "command": "python", + "args": ["/path/to/math_server.py"], + "env": {"DEBUG": "1"}, + }, + "search": { + "transport": "sse", + "url": "http://localhost:8000/sse", + "headers": { + "Authorization": f"Bearer {os.environ.get('API_KEY', '')}" + }, + }, + } + ) as client: + # Get all available tools + tools = client.get_tools() + print(f"Found {len(tools)} tools across all servers") + + # Print tool information + for i, tool in enumerate(tools): + print(f"\nTool {i+1}: {tool.__name__}") + print(f" Description: {tool.__doc__}") + if hasattr(tool, "schema") and tool.schema: + print( + f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..." + ) + + # Example: Use a specific tool if available + calculator_tool = next( + (t for t in tools if t.__name__ == "calculator"), None + ) + if calculator_tool: + print("\n\nTesting calculator tool:") + try: + # Call the tool as an async function + result, artifacts = await calculator_tool( + expression="2 + 2 * 3" + ) + print(f" Calculator result: {result}") + if artifacts: + print( + f" With {len(artifacts)} additional artifacts" + ) + except Exception as e: + print(f" Error using calculator: {e}") + + # Example: Load a prompt from a server + try: + print("\n\nTesting prompt loading:") + prompt_messages = await client.get_prompt( + "math", + "calculation_introduction", + {"user_name": "Test User"}, + ) + print( + f" Loaded prompt with {len(prompt_messages)} messages:" + ) + for i, msg in enumerate(prompt_messages): + print(f" Message {i+1}: {msg[:50]}...") + except Exception as e: + print(f" Error loading prompt: {e}") + + +async def create_custom_tool(): + """Example of creating a custom tool function.""" + + # Define a tool function with metadata + async def add_numbers(a: float, b: float) -> tuple[str, None]: + """Add two numbers together.""" + result = a + b + return f"The sum of {a} and {b} is {result}", None + + # Add metadata to the function + add_numbers.__name__ = "add_numbers" + add_numbers.__doc__ = ( + "Add two numbers together and return the result." + ) + add_numbers.schema = { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + } + + # Use the tool + result, _ = await add_numbers(a=5, b=7) + print(f"\nCustom tool result: {result}") + + +if __name__ == "__main__": + # Run both examples + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + loop.run_until_complete(create_custom_tool())