pull/800/head
Kye Gomez 3 weeks ago
parent 9cb2500e58
commit 89fc8c7609

@ -916,692 +916,6 @@ if __name__ == "__main__":
print(f" {i+1}. {rec}")
```
### TypeScript/NodeJS Examples
#### Financial Fraud Detection (TypeScript)
This example demonstrates creating a swarm for financial transaction fraud detection.
```typescript
import axios from 'axios';
import * as fs from 'fs';
import * as dotenv from 'dotenv';
// Load environment variables
dotenv.config();
// API Configuration
const API_BASE_URL = "https://api.swarms.world";
const API_KEY = process.env.SWARMS_API_KEY;
// Define interfaces for type safety
interface TransactionData {
transaction_id: string;
amount: number;
timestamp: string;
merchant: string;
merchant_category: string;
payment_method: string;
location: string;
device_id?: string;
ip_address?: string;
}
interface UserProfile {
user_id: string;
account_age_days: number;
typical_transaction_amount: number;
typical_merchants: string[];
typical_locations: string[];
risk_score: number;
}
interface FraudDetectionResult {
transaction_id: string;
fraud_score: number;
is_fraudulent: boolean;
risk_factors: string[];
confidence: number;
recommended_action: string;
explanation: string;
}
/**
* Detects potential fraud in financial transactions using a specialized agent swarm
*
* @param transactions - Array of transaction data to analyze
* @param userProfile - User profile information for context
* @param historicalPatterns - Description of historical patterns and behaviors
* @returns Promise resolving to fraud detection results
*/
async function detectTransactionFraud(
transactions: TransactionData[],
userProfile: UserProfile,
historicalPatterns: string
): Promise<FraudDetectionResult[]> {
try {
// Prepare the task with all relevant information
const task = `
Analyze the following financial transactions for potential fraud.
USER PROFILE:
${JSON.stringify(userProfile, null, 2)}
HISTORICAL PATTERNS:
${historicalPatterns}
TRANSACTIONS TO ANALYZE:
${JSON.stringify(transactions, null, 2)}
For each transaction, determine the fraud risk score (0-100), whether it's likely fraudulent,
identified risk factors, confidence level, recommended action (allow, flag for review, block),
and a detailed explanation of the analysis.
Return results for each transaction in a structured format.
`;
// Define specialized fraud detection agents
const fraudDetectionAgents = [
{
agent_name: "BehavioralAnalyst",
description: "Analyzes user behavior patterns to identify anomalies",
system_prompt: `You are an expert in behavioral analysis for fraud detection.
Your role is to analyze transaction patterns against historical user behavior
to identify potential anomalies or deviations that may indicate fraud.
Consider:
- Timing patterns (day of week, time of day)
- Transaction amount patterns
- Merchant category patterns
- Geographic location patterns
- Device usage patterns
Provide a detailed breakdown of behavioral anomalies with confidence scores.`,
model_name: "gpt-4o",
temperature: 0.2,
role: "analyst",
max_loops: 1
},
{
agent_name: "TechnicalFraudDetector",
description: "Analyzes technical indicators of fraud",
system_prompt: `You are a technical fraud detection specialist.
Your role is to analyze technical aspects of transactions to identify
potential indicators of fraud.
Focus on:
- IP address analysis
- Device ID consistency
- Geolocation feasibility (impossible travel)
- Known fraud patterns
- Technical manipulation markers
Provide a technical assessment with specific indicators of potential fraud.`,
model_name: "gpt-4o",
temperature: 0.2,
role: "analyst",
max_loops: 1
},
{
agent_name: "FinancialPatternAnalyst",
description: "Specializes in financial transaction patterns",
system_prompt: `You are a financial pattern analysis expert specializing in fraud detection.
Your role is to analyze the financial aspects of transactions for fraud indicators.
Focus on:
- Transaction amount anomalies
- Merchant risk profiles
- Transaction velocity
- Transaction sequence patterns
- Round-number amounts or other suspicious values
Provide a financial pattern analysis with risk assessment.`,
model_name: "gpt-4o",
temperature: 0.2,
role: "analyst",
max_loops: 1
},
{
agent_name: "FraudInvestigator",
description: "Synthesizes all analysis into final fraud determination",
system_prompt: `You are a senior fraud investigator responsible for making the final
determination on transaction fraud risk.
Your role is to:
1. Synthesize inputs from behavioral, technical, and financial analyses
2. Weigh different risk factors appropriately
3. Calculate an overall fraud risk score (0-100)
4. Make a clear determination (legitimate, suspicious, fraudulent)
5. Recommend specific actions (allow, review, block)
6. Provide a clear, detailed explanation for each determination
Balance false positives and false negatives appropriately.`,
model_name: "gpt-4o",
temperature: 0.3,
role: "manager",
max_loops: 1
}
];
// Create the swarm specification
const swarmSpec = {
name: "fraud-detection-swarm",
description: "Financial transaction fraud detection swarm",
agents: fraudDetectionAgents,
max_loops: 2,
swarm_type: "HiearchicalSwarm",
task: task,
return_history: false
};
// Execute the swarm
const response = await axios.post(
`${API_BASE_URL}/v1/swarm/completions`,
swarmSpec,
{
headers: {
'x-api-key': API_KEY,
'Content-Type': 'application/json'
}
}
);
if (response.status === 200) {
console.log(`Fraud detection completed. Cost: ${response.data.metadata.billing_info.total_cost}`);
return response.data.output.fraud_analysis as FraudDetectionResult[];
} else {
throw new Error(`API request failed with status: ${response.status}`);
}
} catch (error) {
console.error('Error in fraud detection:', error);
if (axios.isAxiosError(error) && error.response) {
console.error('API response:', error.response.data);
}
throw error;
}
}
// Usage example
async function main() {
// Sample transaction data
const transactions: TransactionData[] = [
{
transaction_id: "T-12345-89012",
amount: 1299.99,
timestamp: "2025-03-04T09:23:42Z",
merchant: "ElectronicsPro",
merchant_category: "Electronics",
payment_method: "Credit Card",
location: "San Francisco, CA",
device_id: "D-472910",
ip_address: "192.168.1.127"
},
{
transaction_id: "T-12345-89013",
amount: 849.50,
timestamp: "2025-03-04T09:45:18Z",
merchant: "LuxuryBrands",
merchant_category: "Fashion",
payment_method: "Credit Card",
location: "Miami, FL",
device_id: "D-891245",
ip_address: "45.23.189.44"
},
{
transaction_id: "T-12345-89014",
amount: 24.99,
timestamp: "2025-03-04T10:12:33Z",
merchant: "CoffeeDeluxe",
merchant_category: "Restaurant",
payment_method: "Mobile Wallet",
location: "San Francisco, CA",
device_id: "D-472910",
ip_address: "192.168.1.127"
}
];
// Sample user profile
const userProfile: UserProfile = {
user_id: "U-78901-23456",
account_age_days: 487,
typical_transaction_amount: 150.75,
typical_merchants: ["Groceries", "Restaurant", "Retail", "Streaming"],
typical_locations: ["San Francisco, CA", "Oakland, CA"],
risk_score: 12
};
// Sample historical patterns
const historicalPatterns = `
User typically makes 15-20 transactions per month
Average transaction amount: $50-$200
Largest previous transaction: $750 (furniture)
No previous transactions over $1000
Typically shops in San Francisco Bay Area
No international transactions in past 12 months
Usually uses mobile wallet for small purchases (<$50)
Credit card used for larger purchases
Typical activity times: weekdays 8am-10pm PST
No previous purchases in Miami, FL
`;
try {
const results = await detectTransactionFraud(
transactions,
userProfile,
historicalPatterns
);
console.log("Fraud Detection Results:");
console.log(JSON.stringify(results, null, 2));
// Save results to file
fs.writeFileSync(
`fraud_detection_results_${new Date().toISOString().replace(/:/g, '-')}.json`,
JSON.stringify(results, null, 2)
);
// Process high-risk transactions
const highRiskTransactions = results.filter(r => r.fraud_score > 75);
if (highRiskTransactions.length > 0) {
console.log(`WARNING: ${highRiskTransactions.length} high-risk transactions detected!`);
// In production, you would trigger alerts, notifications, etc.
}
} catch (error) {
console.error("Failed to complete fraud detection:", error);
}
}
main();
```
#### Healthcare Report Generation (TypeScript)
This example demonstrates creating a swarm for generating comprehensive healthcare reports from patient data.
```typescript
import axios from 'axios';
import * as fs from 'fs';
import * as path from 'path';
import * as dotenv from 'dotenv';
// Load environment variables
dotenv.config();
// API Configuration
const API_BASE_URL = "https://api.swarms.world";
const API_KEY = process.env.SWARMS_API_KEY;
// Define interfaces
interface PatientData {
patient_id: string;
demographics: {
age: number;
gender: string;
ethnicity?: string;
weight_kg?: number;
height_cm?: number;
};
vitals: {
blood_pressure?: string;
heart_rate?: number;
respiratory_rate?: number;
temperature_c?: number;
oxygen_saturation?: number;
};
conditions: string[];
medications: {
name: string;
dosage: string;
frequency: string;
}[];
allergies: string[];
lab_results: {
test_name: string;
value: number | string;
unit: string;
reference_range?: string;
collection_date: string;
}[];
imaging_results?: {
study_type: string;
body_area: string;
findings: string;
date: string;
}[];
notes?: string[];
}
interface MedicalReport {
summary: string;
assessment: {
primary_diagnosis: string;
secondary_diagnoses: string[];
clinical_impression: string;
};
recommendations: string[];
medication_review: {
current_medications: {
medication: string;
assessment: string;
}[];
potential_interactions: string[];
recommended_changes: string[];
};
care_plan: {
short_term_goals: string[];
long_term_goals: string[];
follow_up: string;
};
lab_interpretation: {
flagged_results: {
test: string;
value: string;
interpretation: string;
}[];
trends: string[];
};
clinical_reasoning: string;
}
/**
* Generates a comprehensive clinical report for a patient using a specialized medical agent swarm
*
* @param patientData - Structured patient medical data
* @param clinicalGuidelines - Relevant clinical guidelines to consider
* @param reportType - Type of report to generate (e.g., 'comprehensive', 'follow-up', 'specialist')
* @returns Promise resolving to structured medical report
*/
async function generateClinicalReport(
patientData: PatientData,
clinicalGuidelines: string,
reportType: string
): Promise<MedicalReport> {
try {
// Prepare detailed task description
const task = `
Generate a comprehensive clinical report for the following patient:
PATIENT DATA:
${JSON.stringify(patientData, null, 2)}
CLINICAL GUIDELINES TO CONSIDER:
${clinicalGuidelines}
REPORT TYPE:
${reportType}
Analyze all aspects of the patient's health status including symptoms, lab results,
medications, conditions, and other relevant factors. Provide a detailed clinical assessment,
evidence-based recommendations, medication review with potential interactions,
comprehensive care plan, and clear follow-up instructions.
Structure the report to include a concise executive summary, detailed assessment with
clinical reasoning, specific actionable recommendations, and a clear care plan.
Ensure all interpretations and recommendations align with current clinical guidelines
and evidence-based medicine.
`;
// Use Auto Swarm Builder for this complex medical task
const swarmSpec = {
name: "clinical-report-generator",
description: "Medical report generation and analysis swarm",
swarm_type: "AutoSwarmBuilder",
task: task,
max_loops: 3,
return_history: false
};
// Execute the swarm
console.log("Generating clinical report...");
const response = await axios.post(
`${API_BASE_URL}/v1/swarm/completions`,
swarmSpec,
{
headers: {
'x-api-key': API_KEY,
'Content-Type': 'application/json'
}
}
);
if (response.status === 200) {
const executionTime = response.data.metadata.execution_time_seconds;
const cost = response.data.metadata.billing_info.total_cost;
const numAgents = response.data.metadata.num_agents;
console.log(`Report generation completed in ${executionTime.toFixed(2)} seconds`);
console.log(`Used ${numAgents} specialized medical agents`);
console.log(`Total cost: ${cost.toFixed(4)}`);
return response.data.output.medical_report as MedicalReport;
} else {
throw new Error(`API request failed with status: ${response.status}`);
}
} catch (error) {
console.error('Error generating clinical report:', error);
if (axios.isAxiosError(error) && error.response) {
console.error('API response:', error.response.data);
}
throw error;
}
}
/**
* Schedules regular report generation for a patient
*/
async function scheduleRecurringReports(
patientId: string,
reportType: string,
intervalDays: number
): Promise<void> {
try {
// Schedule the next report
const nextReportDate = new Date();
nextReportDate.setDate(nextReportDate.getDate() + intervalDays);
const scheduleSpec = {
name: `${patientId}-${reportType}-report`,
description: `Scheduled ${reportType} report for patient ${patientId}`,
task: `Generate a ${reportType} clinical report for patient ${patientId} following standard protocols. Retrieve the most recent patient data and produce a comprehensive clinical assessment with recommendations.`,
schedule: {
scheduled_time: nextReportDate.toISOString(),
timezone: "UTC"
}
};
const response = await axios.post(
`${API_BASE_URL}/v1/swarm/schedule`,
scheduleSpec,
{
headers: {
'x-api-key': API_KEY,
'Content-Type': 'application/json'
}
}
);
if (response.status === 200) {
console.log(`Successfully scheduled next report for ${nextReportDate.toISOString()}`);
console.log(`Job ID: ${response.data.job_id}`);
} else {
throw new Error(`Failed to schedule report: ${response.status}`);
}
} catch (error) {
console.error('Error scheduling report:', error);
if (axios.isAxiosError(error) && error.response) {
console.error('API response:', error.response.data);
}
}
}
// Usage example
async function main() {
// Sample patient data
const patientData: PatientData = {
patient_id: "P-12345-67890",
demographics: {
age: 58,
gender: "Male",
ethnicity: "Caucasian",
weight_kg: 92.5,
height_cm: 178
},
vitals: {
blood_pressure: "148/92",
heart_rate: 82,
respiratory_rate: 16,
temperature_c: 36.8,
oxygen_saturation: 96
},
conditions: [
"Type 2 Diabetes Mellitus",
"Hypertension",
"Coronary Artery Disease",
"Hyperlipidemia",
"Obesity"
],
medications: [
{
name: "Metformin",
dosage: "1000mg",
frequency: "BID"
},
{
name: "Lisinopril",
dosage: "20mg",
frequency: "QD"
},
{
name: "Atorvastatin",
dosage: "40mg",
frequency: "QD"
},
{
name: "Aspirin",
dosage: "81mg",
frequency: "QD"
}
],
allergies: ["Penicillin", "Sulfa drugs"],
lab_results: [
{
test_name: "HbA1c",
value: 8.2,
unit: "%",
reference_range: "<7.0",
collection_date: "2025-02-15"
},
{
test_name: "Fasting Glucose",
value: 165,
unit: "mg/dL",
reference_range: "70-100",
collection_date: "2025-02-15"
},
{
test_name: "LDL Cholesterol",
value: 118,
unit: "mg/dL",
reference_range: "<100",
collection_date: "2025-02-15"
},
{
test_name: "HDL Cholesterol",
value: 38,
unit: "mg/dL",
reference_range: ">40",
collection_date: "2025-02-15"
},
{
test_name: "eGFR",
value: 68,
unit: "mL/min/1.73m²",
reference_range: ">90",
collection_date: "2025-02-15"
}
],
imaging_results: [
{
study_type: "Cardiac CT Angiography",
body_area: "Heart and coronary arteries",
findings: "Moderate calcification in LAD. 50-70% stenosis in proximal LAD. No significant stenosis in other coronary arteries.",
date: "2025-01-10"
}
],
notes: [
"Patient reports increased fatigue over the past month",
"Complains of occasional chest discomfort with exertion",
"Currently following low-carb diet but admits to poor adherence",
"Exercise limited by knee pain"
]
};
// Clinical guidelines to consider
const clinicalGuidelines = `
ADA 2025 Guidelines for Type 2 Diabetes:
- HbA1c target <7.0% for most adults
- Consider less stringent targets (e.g., <8.0%) for patients with multiple comorbidities
- First-line therapy: Metformin + lifestyle modifications
- For patients with ASCVD: consider GLP-1 RA or SGLT2 inhibitor with proven CV benefit
ACC/AHA 2024 Hypertension Guidelines:
- BP target <130/80 mmHg for patients with diabetes and/or CAD
- First-line: ACE inhibitor or ARB for patients with diabetes
ACC/AHA 2024 Cholesterol Guidelines:
- LDL-C target <70 mg/dL for very high-risk ASCVD
- Consider adding ezetimibe or PCSK9 inhibitor for very high-risk patients not at goal
`;
try {
// Generate the report
const report = await generateClinicalReport(
patientData,
clinicalGuidelines,
"comprehensive"
);
// Save report to file
const timestamp = new Date().toISOString().replace(/:/g, '-');
const outputDir = './reports';
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir);
}
fs.writeFileSync(
path.join(outputDir, `clinical_report_${patientData.patient_id}_${timestamp}.json`),
JSON.stringify(report, null, 2)
);
// Display executive summary
console.log("\nREPORT SUMMARY:");
console.log(report.summary);
console.log("\nPRIMARY DIAGNOSIS:");
console.log(report.assessment.primary_diagnosis);
console.log("\nKEY RECOMMENDATIONS:");
report.recommendations.forEach((rec, i) => {
console.log(` ${i+1}. ${rec}`);
});
// Schedule the next report in 90 days
await scheduleRecurringReports(patientData.patient_id, "follow-up", 90);
} catch (error) {
console.error("Failed to complete report generation:", error);
}
}
main();
```
## Error Handling
The Swarms API follows standard HTTP status codes for error responses:
@ -1632,8 +946,12 @@ The API enforces a rate limit of 100 requests per 60-second window. When exceede
The API uses a credit-based billing system with costs calculated based on:
1. **Agent Count**: Base cost per agent
2. **Input Tokens**: Cost based on the size of input data and prompts
3. **Output Tokens**: Cost based on the length of generated responses
4. **Time of Day**: Reduced rates during nighttime hours (8 PM to 6 AM PT)
Cost information is included in each response's metadata for transparency and forecasting.
@ -1641,26 +959,41 @@ Cost information is included in each response's metadata for transparency and fo
## Best Practices
1. **Task Description**
- Provide detailed, specific task descriptions
- Include all necessary context and constraints
- Structure complex inputs for easier processing
2. **Agent Configuration**
- For simple tasks, use `AutoSwarmBuilder` to automatically generate optimal agents
- For complex or specialized tasks, manually define agents with specific expertise
- Use appropriate `swarm_type` for your workflow pattern
3. **Production Implementation**
- Implement robust error handling and retries
- Log API responses for debugging and auditing
- Monitor costs closely during development and testing
- Use scheduled jobs for recurring tasks instead of continuous polling
- Implement robust error handling and retries
- Log API responses for debugging and auditing
- Monitor costs closely during development and testing
- Use scheduled jobs for recurring tasks instead of continuous polling
4. **Cost Optimization**
- Batch related tasks when possible
- Schedule non-urgent tasks during discount hours
- Carefully scope task descriptions to reduce token usage
- Cache results when appropriate
- Batch related tasks when possible
- Schedule non-urgent tasks during discount hours
- Carefully scope task descriptions to reduce token usage
- Cache results when appropriate
## Support

@ -0,0 +1,554 @@
from contextlib import AsyncExitStack
from types import TracebackType
from typing import (
Any,
Callable,
Coroutine,
List,
Literal,
Optional,
TypedDict,
cast,
)
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
PromptMessage,
TextContent,
)
from mcp.types import (
Tool as MCPTool,
)
def convert_mcp_prompt_message_to_message(
message: PromptMessage,
) -> str:
"""Convert an MCP prompt message to a string message.
Args:
message: MCP prompt message to convert
Returns:
a string message
"""
if message.content.type == "text":
if message.role == "user":
return str(message.content.text)
elif message.role == "assistant":
return str(
message.content.text
) # Fixed attribute name from str to text
else:
raise ValueError(
f"Unsupported prompt message role: {message.role}"
)
raise ValueError(
f"Unsupported prompt message content type: {message.content.type}"
)
async def load_mcp_prompt(
session: ClientSession,
name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Load MCP prompt and convert to messages."""
response = await session.get_prompt(name, arguments)
return [
convert_mcp_prompt_message_to_message(message)
for message in response.messages
]
DEFAULT_ENCODING = "utf-8"
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
DEFAULT_HTTP_TIMEOUT = 5
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
class StdioConnection(TypedDict):
transport: Literal["stdio"]
command: str
"""The executable to run to start the server."""
args: list[str]
"""Command line arguments to pass to the executable."""
env: dict[str, str] | None
"""The environment to use when spawning the process."""
encoding: str
"""The text encoding used when sending/receiving messages to the server."""
encoding_error_handler: Literal["strict", "ignore", "replace"]
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
class SSEConnection(TypedDict):
transport: Literal["sse"]
url: str
"""The URL of the SSE endpoint to connect to."""
headers: dict[str, Any] | None
"""HTTP headers to send to the SSE endpoint"""
timeout: float
"""HTTP timeout"""
sse_read_timeout: float
"""SSE read timeout"""
NonTextContent = ImageContent | EmbeddedResource
def _convert_call_tool_result(
call_tool_result: CallToolResult,
) -> tuple[str | list[str], list[NonTextContent] | None]:
text_contents: list[TextContent] = []
non_text_contents = []
for content in call_tool_result.content:
if isinstance(content, TextContent):
text_contents.append(content)
else:
non_text_contents.append(content)
tool_content: str | list[str] = [
content.text for content in text_contents
]
if len(text_contents) == 1:
tool_content = tool_content[0]
if call_tool_result.isError:
raise ValueError("Error calling tool")
return tool_content, non_text_contents or None
def convert_mcp_tool_to_function(
session: ClientSession,
tool: MCPTool,
) -> Callable[
...,
Coroutine[
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
],
]:
"""Convert an MCP tool to a callable function.
NOTE: this tool can be executed only in a context of an active MCP client session.
Args:
session: MCP client session
tool: MCP tool to convert
Returns:
a callable function
"""
async def call_tool(
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
"""Execute the tool with the given arguments."""
call_tool_result = await session.call_tool(
tool.name, arguments
)
return _convert_call_tool_result(call_tool_result)
# Add metadata as attributes to the function
call_tool.__name__ = tool.name
call_tool.__doc__ = tool.description or ""
call_tool.schema = tool.inputSchema
return call_tool
async def load_mcp_tools(session: ClientSession) -> list[Callable]:
"""Load all available MCP tools and convert them to callable functions."""
tools = await session.list_tools()
return [
convert_mcp_tool_to_function(session, tool)
for tool in tools.tools
]
class MultiServerMCPClient:
"""Client for connecting to multiple MCP servers and loading tools from them."""
def __init__(
self,
connections: dict[
str, StdioConnection | SSEConnection
] = None,
) -> None:
"""Initialize a MultiServerMCPClient with MCP servers connections.
Args:
connections: A dictionary mapping server names to connection configurations.
Each configuration can be either a StdioConnection or SSEConnection.
If None, no initial connections are established.
Example:
```python
async with MultiServerMCPClient(
{
"math": {
"command": "python",
# Make sure to update to the full absolute path to your math_server.py file
"args": ["/path/to/math_server.py"],
"transport": "stdio",
},
"weather": {
# make sure you start your weather server on port 8000
"url": "http://localhost:8000/sse",
"transport": "sse",
}
}
) as client:
all_tools = client.get_tools()
...
```
"""
self.connections = connections
self.exit_stack = AsyncExitStack()
self.sessions: dict[str, ClientSession] = {}
self.server_name_to_tools: dict[str, list[Callable]] = {}
async def _initialize_session_and_load_tools(
self, server_name: str, session: ClientSession
) -> None:
"""Initialize a session and load tools from it.
Args:
server_name: Name to identify this server connection
session: The ClientSession to initialize
"""
# Initialize the session
await session.initialize()
self.sessions[server_name] = session
# Load tools from this server
server_tools = await load_mcp_tools(session)
self.server_name_to_tools[server_name] = server_tools
async def connect_to_server(
self,
server_name: str,
*,
transport: Literal["stdio", "sse"] = "stdio",
**kwargs,
) -> None:
"""Connect to an MCP server using either stdio or SSE.
This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
based on the provided transport parameter.
Args:
server_name: Name to identify this server connection
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
**kwargs: Additional arguments to pass to the specific connection method
Raises:
ValueError: If transport is not recognized
ValueError: If required parameters for the specified transport are missing
"""
if transport == "sse":
if "url" not in kwargs:
raise ValueError(
"'url' parameter is required for SSE connection"
)
await self.connect_to_server_via_sse(
server_name,
url=kwargs["url"],
headers=kwargs.get("headers"),
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
sse_read_timeout=kwargs.get(
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
),
)
elif transport == "stdio":
if "command" not in kwargs:
raise ValueError(
"'command' parameter is required for stdio connection"
)
if "args" not in kwargs:
raise ValueError(
"'args' parameter is required for stdio connection"
)
await self.connect_to_server_via_stdio(
server_name,
command=kwargs["command"],
args=kwargs["args"],
env=kwargs.get("env"),
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
encoding_error_handler=kwargs.get(
"encoding_error_handler",
DEFAULT_ENCODING_ERROR_HANDLER,
),
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
async def connect_to_server_via_stdio(
self,
server_name: str,
*,
command: str,
args: list[str],
env: dict[str, str] | None = None,
encoding: str = DEFAULT_ENCODING,
encoding_error_handler: Literal[
"strict", "ignore", "replace"
] = DEFAULT_ENCODING_ERROR_HANDLER,
) -> None:
"""Connect to a specific MCP server using stdio
Args:
server_name: Name to identify this server connection
command: Command to execute
args: Arguments for the command
env: Environment variables for the command
encoding: Character encoding
encoding_error_handler: How to handle encoding errors
"""
server_params = StdioServerParameters(
command=command,
args=args,
env=env,
encoding=encoding,
encoding_error_handler=encoding_error_handler,
)
# Create and store the connection
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
read, write = stdio_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
await self._initialize_session_and_load_tools(
server_name, session
)
async def connect_to_server_via_sse(
self,
server_name: str,
*,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = DEFAULT_HTTP_TIMEOUT,
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
) -> None:
"""Connect to a specific MCP server using SSE
Args:
server_name: Name to identify this server connection
url: URL of the SSE server
headers: HTTP headers to send to the SSE endpoint
timeout: HTTP timeout
sse_read_timeout: SSE read timeout
"""
# Create and store the connection
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout, sse_read_timeout)
)
read, write = sse_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(
ClientSession(read, write)
),
)
await self._initialize_session_and_load_tools(
server_name, session
)
def get_tools(self) -> list[Callable]:
"""Get a list of all tools from all connected servers."""
all_tools: list[Callable] = []
for server_tools in self.server_name_to_tools.values():
all_tools.extend(server_tools)
return all_tools
async def get_prompt(
self,
server_name: str,
prompt_name: str,
arguments: Optional[dict[str, Any]] = None,
) -> List[str]:
"""Get a prompt from a given MCP server."""
session = self.sessions[server_name]
return await load_mcp_prompt(session, prompt_name, arguments)
async def __aenter__(self) -> "MultiServerMCPClient":
try:
connections = self.connections or {}
for server_name, connection in connections.items():
connection_dict = connection.copy()
transport = connection_dict.pop("transport")
if transport == "stdio":
await self.connect_to_server_via_stdio(
server_name, **connection_dict
)
elif transport == "sse":
await self.connect_to_server_via_sse(
server_name, **connection_dict
)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
return self
except Exception:
await self.exit_stack.aclose()
raise
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.exit_stack.aclose()
#!/usr/bin/env python3
import asyncio
import os
import json
from typing import List, Any, Callable
# # Import our MCP client module
# from mcp_client import MultiServerMCPClient
async def main():
"""Test script for demonstrating MCP client usage."""
print("Starting MCP Client test...")
# Create a connection to multiple MCP servers
# You'll need to update these paths to match your setup
async with MultiServerMCPClient(
{
"math": {
"transport": "stdio",
"command": "python",
"args": ["/path/to/math_server.py"],
"env": {"DEBUG": "1"},
},
"search": {
"transport": "sse",
"url": "http://localhost:8000/sse",
"headers": {
"Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
},
},
}
) as client:
# Get all available tools
tools = client.get_tools()
print(f"Found {len(tools)} tools across all servers")
# Print tool information
for i, tool in enumerate(tools):
print(f"\nTool {i+1}: {tool.__name__}")
print(f" Description: {tool.__doc__}")
if hasattr(tool, "schema") and tool.schema:
print(
f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
)
# Example: Use a specific tool if available
calculator_tool = next(
(t for t in tools if t.__name__ == "calculator"), None
)
if calculator_tool:
print("\n\nTesting calculator tool:")
try:
# Call the tool as an async function
result, artifacts = await calculator_tool(
expression="2 + 2 * 3"
)
print(f" Calculator result: {result}")
if artifacts:
print(
f" With {len(artifacts)} additional artifacts"
)
except Exception as e:
print(f" Error using calculator: {e}")
# Example: Load a prompt from a server
try:
print("\n\nTesting prompt loading:")
prompt_messages = await client.get_prompt(
"math",
"calculation_introduction",
{"user_name": "Test User"},
)
print(
f" Loaded prompt with {len(prompt_messages)} messages:"
)
for i, msg in enumerate(prompt_messages):
print(f" Message {i+1}: {msg[:50]}...")
except Exception as e:
print(f" Error loading prompt: {e}")
async def create_custom_tool():
"""Example of creating a custom tool function."""
# Define a tool function with metadata
async def add_numbers(a: float, b: float) -> tuple[str, None]:
"""Add two numbers together."""
result = a + b
return f"The sum of {a} and {b} is {result}", None
# Add metadata to the function
add_numbers.__name__ = "add_numbers"
add_numbers.__doc__ = (
"Add two numbers together and return the result."
)
add_numbers.schema = {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
}
# Use the tool
result, _ = await add_numbers(a=5, b=7)
print(f"\nCustom tool result: {result}")
if __name__ == "__main__":
# Run both examples
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.run_until_complete(create_custom_tool())
Loading…
Cancel
Save