diff --git a/docs/swarms_cloud/swarms_api.md b/docs/swarms_cloud/swarms_api.md index 14caf230..1895d7c6 100644 --- a/docs/swarms_cloud/swarms_api.md +++ b/docs/swarms_cloud/swarms_api.md @@ -1,532 +1,1587 @@ # Swarms API Documentation -The Swarms API is a powerful REST API designed to help you create, manage, and execute various types of swarms efficiently. Whether you need to run tasks sequentially, concurrently, or in a custom workflow, the Swarms API has you covered. +*Enterprise-grade Agent Swarm Management API* + +**Base URL**: `https://swarms-api-285321057562.us-east1.run.app` +**API Key Management**: [https://swarms.world/platform/api-keys](https://swarms.world/platform/api-keys) +**Documentation Version**: 1.0.0 +**Last Updated**: March 4, 2025 + + + +## Overview + +The Swarms API provides a robust, scalable infrastructure for deploying and managing intelligent agent swarms in the cloud. This enterprise-grade API enables organizations to create, execute, and orchestrate sophisticated AI agent workflows without managing the underlying infrastructure. + +Key capabilities include: + +- **Intelligent Swarm Management**: Create and execute swarms of specialized AI agents that collaborate to solve complex tasks +- **Automatic Agent Generation**: Dynamically create optimized agents based on task requirements +- **Multiple Swarm Architectures**: Choose from various swarm patterns to match your specific workflow needs +- **Scheduled Execution**: Set up automated, scheduled swarm executions +- **Comprehensive Logging**: Track and analyze all API interactions +- **Cost Management**: Predictable, transparent pricing with optimized resource utilization +- **Enterprise Security**: Full API key authentication and management + +Swarms API is designed for production use cases requiring sophisticated AI orchestration, with applications in finance, healthcare, legal, research, and other domains where complex reasoning and multi-agent collaboration are needed. -### Key Features: -- **Sequential Swarms**: Execute tasks one after another in a defined order. -- **Concurrent Swarms**: Run multiple tasks simultaneously to save time and resources. -- **Custom Workflows**: Design your own swarm workflows to fit your specific needs. +## Authentication -To get started, find your API key in the Swarms Cloud dashboard. [Get your API key here](https://swarms.world/platform/api-keys) +All API requests require a valid API key, which must be included in the header of each request: -## Base URL ``` -https://swarms-api-285321057562.us-east1.run.app +x-api-key: your_api_key_here ``` -## Authentication -All API requests (except `/health`) require authentication using an API key passed in the `x-api-key` header: +API keys can be obtained and managed at [https://swarms.world/platform/api-keys](https://swarms.world/platform/api-keys). -```http -x-api-key: your_api_key_here -``` +## API Reference + +### Endpoints Summary + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Simple health check endpoint | +| `/v1/swarm/completions` | POST | Run a swarm with specified configuration | +| `/v1/swarm/batch/completions` | POST | Run multiple swarms in batch mode | +| `/v1/swarm/schedule` | POST | Schedule a swarm to run at a specific time | +| `/v1/swarm/schedule` | GET | Get all scheduled swarm jobs | +| `/v1/swarm/schedule/{job_id}` | DELETE | Cancel a scheduled swarm job | +| `/v1/swarm/logs` | GET | Retrieve API request logs | + +### SwarmType Reference + +The `swarm_type` parameter defines the architecture and collaboration pattern of the agent swarm: + +| SwarmType | Description | +|-----------|-------------| +| `AgentRearrange` | Dynamically reorganizes the workflow between agents based on task requirements | +| `MixtureOfAgents` | Combines multiple agent types to tackle diverse aspects of a problem | +| `SpreadSheetSwarm` | Specialized for spreadsheet data analysis and manipulation | +| `SequentialWorkflow` | Agents work in a predefined sequence, each handling specific subtasks | +| `ConcurrentWorkflow` | Multiple agents work simultaneously on different aspects of the task | +| `GroupChat` | Agents collaborate in a discussion format to solve problems | +| `MultiAgentRouter` | Routes subtasks to specialized agents based on their capabilities | +| `AutoSwarmBuilder` | Automatically designs and builds an optimal swarm based on the task | +| `HiearchicalSwarm` | Organizes agents in a hierarchical structure with managers and workers | +| `MajorityVoting` | Uses a consensus mechanism where multiple agents vote on the best solution | +| `auto` | Automatically selects the most appropriate swarm type for the given task | -## Endpoints +### Endpoint Details -### Health Check -Check if the API is operational. +#### Health Check -**Endpoint:** `GET /health` -**Authentication Required:** No -**Response:** +Check if the API service is available and functioning correctly. + +**Endpoint**: `/health` +**Method**: GET +**Rate Limit**: 100 requests per 60 seconds + +**Example Request**: +```bash +curl -X GET "https://swarms-api-285321057562.us-east1.run.app/health" \ + -H "x-api-key: your_api_key_here" +``` + +**Example Response**: ```json { - "status": "ok" + "status": "ok" } ``` -### Single Swarm Completion -Run a single swarm with specified agents and tasks. - -**Endpoint:** `POST /v1/swarm/completions` -**Authentication Required:** Yes - -#### Request Parameters - -| Parameter | Type | Required | Default | Description | -|-----------|------|----------|---------|-------------| -| name | string | Optional | "swarms-01" | Name of the swarm (max 100 chars) | -| description | string | Optional | - | Description of the swarm (max 500 chars) | -| agents | array | Required | - | Array of agent configurations | -| max_loops | integer | Optional | 1 | Maximum number of iterations | -| swarm_type | string | Optional | - | Type of swarm workflow | -| task | string | Required | - | The task to be performed | -| img | string | Optional | - | Image URL if relevant | -| return_history | boolean | Optional | true | Whether to return the full conversation history | -| rules | string | Optional | - | Rules for the swarm to follow | -| rearrange_flow | string | Optional | - | Flow pattern for agent rearrangement | -| output_type | string | Optional | "str" | Output format ("str", "json", "dict", "yaml", "list") | -| schedule | object | Optional | - | Scheduling information for the swarm | - -#### Schedule Configuration Parameters - -| Parameter | Type | Required | Default | Description | -|-----------|------|----------|---------|-------------| -| scheduled_time | datetime | Required | - | When to execute the swarm (UTC) | -| timezone | string | Optional | "UTC" | Timezone for the scheduled time | - -#### Agent Configuration Parameters - -| Parameter | Type | Required | Default | Description | -|-----------|------|----------|---------|-------------| -| agent_name | string | Required | - | Name of the agent (max 100 chars) | -| description | string | Optional | - | Description of the agent (max 500 chars) | -| system_prompt | string | Optional | - | System prompt for the agent (max 500 chars) | -| model_name | string | Optional | "gpt-4o" | Model to be used by the agent | -| auto_generate_prompt | boolean | Optional | false | Whether to auto-generate prompts | -| max_tokens | integer | Optional | - | Maximum tokens for response | -| temperature | float | Optional | 0.5 | Temperature for response generation | -| role | string | Optional | "worker" | Role of the agent | -| max_loops | integer | Optional | 1 | Maximum iterations for this agent | - -## Available Swarm Types - -| Swarm Type | Description | -|------------|-------------| -| AgentRearrange | Rearranges agents dynamically to optimize task execution | -| MixtureOfAgents | Combines different agents to leverage their unique capabilities | -| SpreadSheetSwarm | Utilizes spreadsheet-like operations for data manipulation | -| SequentialWorkflow | Executes tasks in a predefined sequential order | -| ConcurrentWorkflow | Runs tasks concurrently to improve efficiency | -| GroupChat | Facilitates communication among agents in a chat format | -| MultiAgentRouter | Routes tasks to agents based on their expertise | -| AutoSwarmBuilder | Automatically constructs swarms based on task requirements | -| HiearchicalSwarm | Organizes agents in a hierarchy for complex tasks | -| auto | Automatically selects the most suitable swarm type | -| MajorityVoting | Uses majority voting to reach consensus on outcomes | - -## Job Scheduling Endpoints - -### Schedule a Swarm -Schedule a swarm to run at a specific time. +#### Run Swarm + +Run a swarm with the specified configuration to complete a task. + +**Endpoint**: `/v1/swarm/completions` +**Method**: POST +**Rate Limit**: 100 requests per 60 seconds + +**Request Parameters**: -**Endpoint:** `POST /v1/swarm/schedule` -**Authentication Required:** Yes +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| name | string | Identifier for the swarm | No | +| description | string | Description of the swarm's purpose | No | +| agents | Array | List of agent specifications | No | +| max_loops | integer | Maximum number of execution loops | No | +| swarm_type | SwarmType | Architecture of the swarm | No | +| rearrange_flow | string | Instructions for rearranging task flow | No | +| task | string | The main task for the swarm to accomplish | Yes | +| img | string | Optional image URL for the swarm | No | +| return_history | boolean | Whether to return execution history | No | +| rules | string | Guidelines for swarm behavior | No | +| schedule | ScheduleSpec | Scheduling information | No | -#### Request Format -Same as single swarm completion, with additional `schedule` object: +**Example Request**: +```bash +curl -X POST "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/completions" \ + -H "x-api-key: your_api_key_here" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "financial-analysis-swarm", + "description": "Analyzes financial data for risk assessment", + "swarm_type": "SequentialWorkflow", + "task": "Analyze the provided quarterly financials for Company XYZ and identify potential risk factors. Summarize key insights and provide recommendations for risk mitigation.", + "max_loops": 2, + "return_history": true + }' +``` +**Example Response**: ```json { - "name": "Scheduled Swarm", - "agents": [...], - "task": "Perform analysis", - "schedule": { - "scheduled_time": "2024-03-20T15:00:00Z", - "timezone": "America/New_York" + "status": "success", + "swarm_name": "financial-analysis-swarm", + "description": "Analyzes financial data for risk assessment", + "swarm_type": "SequentialWorkflow", + "task": "Analyze the provided quarterly financials for Company XYZ and identify potential risk factors. Summarize key insights and provide recommendations for risk mitigation.", + "output": { + "financial_analysis": { + "risk_factors": [...], + "key_insights": [...], + "recommendations": [...] + } + }, + "metadata": { + "max_loops": 2, + "num_agents": 3, + "execution_time_seconds": 12.45, + "completion_time": 1709563245.789, + "billing_info": { + "cost_breakdown": { + "agent_cost": 0.03, + "input_token_cost": 0.002134, + "output_token_cost": 0.006789, + "token_counts": { + "total_input_tokens": 1578, + "total_output_tokens": 3456, + "total_tokens": 5034, + "per_agent": {...} + }, + "num_agents": 3, + "execution_time_seconds": 12.45 + }, + "total_cost": 0.038923 } + } } ``` -### List Scheduled Jobs -Get all scheduled swarm jobs. +#### Run Batch Completions -**Endpoint:** `GET /v1/swarm/schedule` -**Authentication Required:** Yes +Run multiple swarms as a batch operation. -#### Response Format -```json -{ - "status": "success", - "scheduled_jobs": [ - { - "job_id": "swarm_analysis_1234567890", - "swarm_name": "Analysis Swarm", - "scheduled_time": "2024-03-20T15:00:00Z", - "timezone": "America/New_York" - } - ] -} -``` +**Endpoint**: `/v1/swarm/batch/completions` +**Method**: POST +**Rate Limit**: 100 requests per 60 seconds + +**Request Parameters**: -### Cancel Scheduled Job -Cancel a scheduled swarm job. +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| swarms | Array | List of swarm specifications | Yes | -**Endpoint:** `DELETE /v1/swarm/schedule/{job_id}` -**Authentication Required:** Yes +**Example Request**: +```bash +curl -X POST "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/batch/completions" \ + -H "x-api-key: your_api_key_here" \ + -H "Content-Type: application/json" \ + -d '{ + "swarms": [ + { + "name": "risk-analysis", + "task": "Analyze risk factors for investment portfolio" + }, + { + "name": "market-sentiment", + "task": "Assess current market sentiment for technology sector" + } + ] + }' +``` -#### Response Format +**Example Response**: ```json -{ +[ + { "status": "success", - "message": "Scheduled job cancelled successfully", - "job_id": "swarm_analysis_1234567890" -} + "swarm_name": "risk-analysis", + "task": "Analyze risk factors for investment portfolio", + "output": {...}, + "metadata": {...} + }, + { + "status": "success", + "swarm_name": "market-sentiment", + "task": "Assess current market sentiment for technology sector", + "output": {...}, + "metadata": {...} + } +] ``` -## Billing and Credits - -The API uses a credit-based billing system with the following components: +#### Schedule Swarm -### Cost Calculation +Schedule a swarm to run at a specific time. -| Component | Cost | -|-----------|------| -| Base cost per agent | $0.01 | -| Input tokens (per 1M) | $2.00 | -| Output tokens (per 1M) | $6.00 | +**Endpoint**: `/v1/swarm/schedule` +**Method**: POST +**Rate Limit**: 100 requests per 60 seconds -Special pricing: -- California night time hours (8 PM to 6 AM PT): 75% discount on token costs -- Credits are deducted in the following order: - 1. Free credits - 2. Regular credits +**Request Parameters**: -Costs are calculated based on: -- Number of agents used -- Total input tokens (including system prompts and agent memory) -- Total output tokens generated -- Execution time +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| name | string | Identifier for the swarm | No | +| description | string | Description of the swarm's purpose | No | +| agents | Array | List of agent specifications | No | +| max_loops | integer | Maximum number of execution loops | No | +| swarm_type | SwarmType | Architecture of the swarm | No | +| task | string | The main task for the swarm to accomplish | Yes | +| schedule | ScheduleSpec | Scheduling information | Yes | -## Error Handling +**Example Request**: +```bash +curl -X POST "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule" \ + -H "x-api-key: your_api_key_here" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "daily-market-analysis", + "description": "Daily analysis of market conditions", + "task": "Analyze today's market movements and prepare a summary report", + "schedule": { + "scheduled_time": "2025-03-05T17:00:00Z", + "timezone": "UTC" + } + }' +``` -| HTTP Status Code | Description | -|-----------------|-------------| -| 402 | Insufficient credits | -| 403 | Invalid API key | -| 404 | Resource not found | -| 500 | Internal server error | +**Example Response**: +```json +{ + "status": "success", + "message": "Swarm scheduled successfully", + "job_id": "swarm_daily-market-analysis_1709563245", + "scheduled_time": "2025-03-05T17:00:00Z", + "timezone": "UTC" +} +``` -## Best Practices +#### Get Scheduled Jobs -1. Start with small swarms and gradually increase complexity -2. Monitor credit usage and token counts -3. Use appropriate max_loops values to control execution -4. Implement proper error handling for API responses -5. Consider using batch completions for multiple related tasks +Retrieve all scheduled swarm jobs. -## Response Structures +**Endpoint**: `/v1/swarm/schedule` +**Method**: GET +**Rate Limit**: 100 requests per 60 seconds -### Single Swarm Response +**Example Request**: +```bash +curl -X GET "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule" \ + -H "x-api-key: your_api_key_here" +``` +**Example Response**: ```json { - "status": "success", - "swarm_name": "Test Swarm", - "description": "A test swarm", - "swarm_type": "ConcurrentWorkflow", - "task": "Write a blog post", - "output": { - // Swarm output here + "status": "success", + "scheduled_jobs": [ + { + "job_id": "swarm_daily-market-analysis_1709563245", + "swarm_name": "daily-market-analysis", + "scheduled_time": "2025-03-05T17:00:00Z", + "timezone": "UTC" }, - "metadata": { - "max_loops": 1, - "num_agents": 2, - "execution_time_seconds": 5.23, - "completion_time": 1647123456.789, - "billing_info": { - "cost_breakdown": { - "agent_cost": 0.02, - "input_token_cost": 0.015, - "output_token_cost": 0.045, - "token_counts": { - "total_input_tokens": 1500, - "total_output_tokens": 3000, - "total_tokens": 4500, - "per_agent": { - "agent1": { - "input_tokens": 750, - "output_tokens": 1500, - "total_tokens": 2250 - }, - "agent2": { - "input_tokens": 750, - "output_tokens": 1500, - "total_tokens": 2250 - } - } - }, - "num_agents": 2, - "execution_time_seconds": 5.23 - }, - "total_cost": 0.08 - } + { + "job_id": "swarm_weekly-report_1709563348", + "swarm_name": "weekly-report", + "scheduled_time": "2025-03-09T12:00:00Z", + "timezone": "UTC" } + ] } ``` -### Batch Swarm Response +#### Cancel Scheduled Job + +Cancel a previously scheduled swarm job. + +**Endpoint**: `/v1/swarm/schedule/{job_id}` +**Method**: DELETE +**Rate Limit**: 100 requests per 60 seconds + +**Path Parameters**: + +| Parameter | Description | +|-----------|-------------| +| job_id | ID of the scheduled job to cancel | +**Example Request**: +```bash +curl -X DELETE "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule/swarm_daily-market-analysis_1709563245" \ + -H "x-api-key: your_api_key_here" +``` + +**Example Response**: ```json -[ - { - "status": "success", - "swarm_name": "Batch Swarm 1", - "output": {}, - "metadata": {} - }, - { - "status": "success", - "swarm_name": "Batch Swarm 2", - "output": {}, - "metadata": {} - } -] +{ + "status": "success", + "message": "Scheduled job cancelled successfully", + "job_id": "swarm_daily-market-analysis_1709563245" +} ``` -## Logs Endpoint +#### Get API Logs -### Get Swarm Logs -Retrieve execution logs for your API key. +Retrieve logs of API requests made with your API key. -**Endpoint:** `GET /v1/swarm/logs` -**Authentication Required:** Yes +**Endpoint**: `/v1/swarm/logs` +**Method**: GET +**Rate Limit**: 100 requests per 60 seconds -#### Response Format +**Example Request**: +```bash +curl -X GET "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/logs" \ + -H "x-api-key: your_api_key_here" +``` + +**Example Response**: ```json { - "status": "success", - "count": 2, - "logs": [ - { - "api_key": "masked", - "data": { - "swarm_name": "Test Swarm", - "task": "Write a blog post", - "execution_time": "2024-03-19T15:30:00Z", - "status": "success" - } - } - ] + "status": "success", + "count": 25, + "logs": [ + { + "id": "log_id_12345", + "api_key": "api_key_redacted", + "data": { + "action": "run_swarm", + "swarm_name": "financial-analysis-swarm", + "task": "Analyze quarterly financials...", + "timestamp": "2025-03-04T14:22:45Z" + } + }, + ... + ] } ``` -## Error Handling +## Data Models -The API uses standard HTTP status codes and provides detailed error messages: +### SwarmSpec -| HTTP Status Code | Description | Example Response | -|-----------------|-------------|------------------| -| 400 | Bad Request - Invalid parameters | `{"detail": "Invalid swarm configuration"}` | -| 401 | Unauthorized - Missing API key | `{"detail": "API key is required"}` | -| 402 | Payment Required - Insufficient credits | `{"detail": "Insufficient credits"}` | -| 403 | Forbidden - Invalid API key | `{"detail": "Invalid API key"}` | -| 429 | Too Many Requests - Rate limit exceeded | `{"detail": "Rate limit exceeded"}` | -| 500 | Internal Server Error | `{"detail": "Internal server error"}` | +The `SwarmSpec` model defines the configuration of a swarm. -## Rate Limiting +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| name | string | Identifier for the swarm | No | +| description | string | Description of the swarm's purpose | No | +| agents | Array | List of agent specifications | No | +| max_loops | integer | Maximum number of execution loops | No | +| swarm_type | SwarmType | Architecture of the swarm | No | +| rearrange_flow | string | Instructions for rearranging task flow | No | +| task | string | The main task for the swarm to accomplish | Yes | +| img | string | Optional image URL for the swarm | No | +| return_history | boolean | Whether to return execution history | No | +| rules | string | Guidelines for swarm behavior | No | +| schedule | ScheduleSpec | Scheduling information | No | + +### AgentSpec + +The `AgentSpec` model defines the configuration of an individual agent. + +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| agent_name | string | Unique name for the agent | Yes* | +| description | string | Description of the agent's purpose | No | +| system_prompt | string | Instructions for the agent | No | +| model_name | string | AI model to use (e.g., "gpt-4o") | Yes* | +| auto_generate_prompt | boolean | Whether to auto-generate prompts | No | +| max_tokens | integer | Maximum tokens in response | No | +| temperature | float | Randomness of responses (0-1) | No | +| role | string | Agent's role in the swarm | No | +| max_loops | integer | Maximum iterations for this agent | No | + +*Required if agents are manually specified; not required if using auto-generated agents + +### ScheduleSpec + +The `ScheduleSpec` model defines when a swarm should be executed. + +| Field | Type | Description | Required | +|-------|------|-------------|----------| +| scheduled_time | datetime | Time when the swarm should run | Yes | +| timezone | string | Timezone for the scheduled time | No (defaults to "UTC") | -The API implements rate limiting to ensure fair usage: +## Production Examples -- **Rate Limit:** 100 requests per minute per IP address -- **Time Window:** 60 seconds -- **Response on Limit Exceeded:** HTTP 429 with retry-after header +### Python Examples -# Code Examples +#### Financial Risk Assessment (Python) -## Python -### Using requests +This example demonstrates creating a swarm for comprehensive financial risk assessment. ```python import requests +import json from datetime import datetime, timedelta -import pytz +# API Configuration +API_BASE_URL = "https://swarms-api-285321057562.us-east1.run.app" API_KEY = "your_api_key_here" -BASE_URL = "https://swarms-api-285321057562.us-east1.run.app" - -headers = { +HEADERS = { "x-api-key": API_KEY, "Content-Type": "application/json" } -def run_single_swarm(): - payload = { - "name": "Financial Analysis Swarm", - "description": "Market analysis swarm", - "agents": [ - { - "agent_name": "Market Analyst", - "description": "Analyzes market trends", - "system_prompt": "You are a financial analyst expert.", - "model_name": "gpt-4o", - "role": "worker", - "max_loops": 1 - } - ], - "max_loops": 1, - "swarm_type": "SequentialWorkflow", - "task": "Analyze current market trends in tech sector", - "return_history": True, - "rules": "Focus on major market indicators" +def financial_risk_assessment(company_data, market_conditions, risk_tolerance): + """ + Creates and runs a swarm to perform comprehensive financial risk assessment. + + Args: + company_data (str): Description or data about the company + market_conditions (str): Current market conditions + risk_tolerance (str): Risk tolerance level (e.g., "conservative", "moderate", "aggressive") + + Returns: + dict: Risk assessment results + """ + # Prepare the task description with all relevant information + task = f""" + Perform a comprehensive financial risk assessment with the following data: + + COMPANY DATA: + {company_data} + + MARKET CONDITIONS: + {market_conditions} + + RISK TOLERANCE: + {risk_tolerance} + + Analyze all potential risk factors including market risks, credit risks, + operational risks, and regulatory compliance risks. Quantify each risk factor + on a scale of 1-10 and provide specific mitigation strategies. + + Return a detailed report with executive summary, risk scores, detailed analysis, + and actionable recommendations. + """ + + # Define specialized financial agents + financial_analysts = [ + { + "agent_name": "MarketAnalyst", + "description": "Specialist in market risk assessment and forecasting", + "system_prompt": "You are an expert market analyst with deep expertise in financial markets. Analyze market conditions, trends, and external factors that could impact financial performance. Provide quantitative and qualitative analysis of market-related risks.", + "model_name": "gpt-4o", + "temperature": 0.3, + "role": "analyst", + "max_loops": 1 + }, + { + "agent_name": "CreditRiskAnalyst", + "description": "Expert in assessing credit and counterparty risks", + "system_prompt": "You are a specialist in credit risk analysis with experience in banking and financial institutions. Evaluate creditworthiness, default probabilities, and counterparty exposures. Provide detailed analysis of credit-related risks and recommended safeguards.", + "model_name": "gpt-4o", + "temperature": 0.2, + "role": "analyst", + "max_loops": 1 + }, + { + "agent_name": "RegulatoryExpert", + "description": "Expert in financial regulations and compliance", + "system_prompt": "You are a regulatory compliance expert with deep knowledge of financial regulations. Identify potential regulatory risks, compliance issues, and governance concerns. Recommend compliance measures and risk mitigation strategies.", + "model_name": "gpt-4o", + "temperature": 0.2, + "role": "analyst", + "max_loops": 1 + }, + { + "agent_name": "RiskSynthesizer", + "description": "Integrates all risk factors into comprehensive assessment", + "system_prompt": "You are a senior risk management professional responsible for synthesizing multiple risk analyses into a coherent, comprehensive risk assessment. Integrate analyses from various domains, resolve conflicting assessments, and provide a holistic view of risk exposure with prioritized recommendations.", + "model_name": "gpt-4o", + "temperature": 0.4, + "role": "manager", + "max_loops": 1 + } + ] + + # Create the swarm specification + swarm_spec = { + "name": "financial-risk-assessment", + "description": "Comprehensive financial risk assessment swarm", + "agents": financial_analysts, + "max_loops": 2, + "swarm_type": "HiearchicalSwarm", + "task": task, + "return_history": True } + # Execute the swarm response = requests.post( - f"{BASE_URL}/v1/swarm/completions", - headers=headers, - json=payload + f"{API_BASE_URL}/v1/swarm/completions", + headers=HEADERS, + json=swarm_spec ) - return response.json() - -def schedule_swarm(): - # Schedule for 1 hour from now - scheduled_time = datetime.now(pytz.UTC) + timedelta(hours=1) + if response.status_code == 200: + result = response.json() + print(f"Risk assessment completed. Cost: ${result['metadata']['billing_info']['total_cost']}") + return result["output"] + else: + print(f"Error: {response.status_code} - {response.text}") + return None + +# Usage example +if __name__ == "__main__": + company_data = """ + XYZ Financial Services + Annual Revenue: $125M + Current Debt: $45M + Credit Rating: BBB+ + Primary Markets: North America, Europe + Key Products: Asset management, retirement planning, commercial lending + Recent Events: Expanding into Asian markets, New CEO appointed 6 months ago + """ + + market_conditions = """ + Current interest rates rising (Federal Reserve increased rates by 0.25% last month) + Inflation at 3.2% (12-month outlook projects 3.5-4.0%) + Market volatility index (VIX) at 22.4 (elevated) + Regulatory environment: New financial reporting requirements taking effect next quarter + Sector performance: Financial services sector underperforming broader market by 2.7% + """ - payload = { - "name": "Scheduled Analysis", - "agents": [ - { - "agent_name": "Analyst", - "system_prompt": "You are a market analyst.", - "model_name": "gpt-4o", - "role": "worker" + risk_tolerance = "moderate" + + result = financial_risk_assessment(company_data, market_conditions, risk_tolerance) + + if result: + # Process and use the risk assessment + print(json.dumps(result, indent=2)) + + # Optionally, schedule a follow-up assessment + tomorrow = datetime.utcnow() + timedelta(days=30) + schedule_spec = { + "name": "monthly-risk-update", + "description": "Monthly update to risk assessment", + "task": f"Update the risk assessment for XYZ Financial Services based on current market conditions. Previous assessment: {json.dumps(result)}", + "schedule": { + "scheduled_time": tomorrow.isoformat() + "Z", + "timezone": "UTC" } - ], - "task": "Analyze tech trends", - "schedule": { - "scheduled_time": scheduled_time.isoformat(), - "timezone": "America/New_York" } + + schedule_response = requests.post( + f"{API_BASE_URL}/v1/swarm/schedule", + headers=HEADERS, + json=schedule_spec + ) + + if schedule_response.status_code == 200: + print("Follow-up assessment scheduled successfully") + print(schedule_response.json()) +``` + +#### Healthcare Patient Data Analysis (Python) + +This example demonstrates creating a swarm for analyzing patient health data and generating insights. + +```python +import requests +import json +import os +from datetime import datetime + +# API Configuration +API_BASE_URL = "https://swarms-api-285321057562.us-east1.run.app" +API_KEY = os.environ.get("SWARMS_API_KEY") +HEADERS = { + "x-api-key": API_KEY, + "Content-Type": "application/json" +} + +def analyze_patient_health_data(patient_data, medical_history, lab_results, treatment_goals): + """ + Creates and runs a swarm to analyze patient health data and generate insights. + + Args: + patient_data (str): Basic patient information + medical_history (str): Patient's medical history + lab_results (str): Recent laboratory results + treatment_goals (str): Treatment objectives + + Returns: + dict: Comprehensive health analysis and recommendations + """ + # Prepare the detailed task description + task = f""" + Perform a comprehensive analysis of the following patient health data: + + PATIENT INFORMATION: + {patient_data} + + MEDICAL HISTORY: + {medical_history} + + LABORATORY RESULTS: + {lab_results} + + TREATMENT GOALS: + {treatment_goals} + + Analyze all aspects of the patient's health status, identify potential concerns, + evaluate treatment effectiveness, and provide evidence-based recommendations for + optimizing care. Consider medication interactions, lifestyle factors, and preventive measures. + + Return a detailed clinical report with key findings, risk stratification, + prioritized recommendations, and suggested follow-up timeline. + """ + + # Create the swarm specification with auto-generated agents + # (letting the system create specialized medical experts) + swarm_spec = { + "name": "patient-health-analysis", + "description": "Comprehensive patient health data analysis", + "swarm_type": "AutoSwarmBuilder", + "task": task, + "max_loops": 3, + "return_history": True } - response = requests.post( - f"{BASE_URL}/v1/swarm/schedule", - headers=headers, - json=payload - ) + # Execute the swarm + try: + response = requests.post( + f"{API_BASE_URL}/v1/swarm/completions", + headers=HEADERS, + json=swarm_spec + ) + + response.raise_for_status() + result = response.json() + + # Log the execution metadata + execution_time = result["metadata"]["execution_time_seconds"] + cost = result["metadata"]["billing_info"]["total_cost"] + num_agents = result["metadata"]["num_agents"] + + print(f"Analysis completed in {execution_time:.2f} seconds") + print(f"Used {num_agents} specialized medical agents") + print(f"Total cost: ${cost:.4f}") + + # Return just the analysis results + return result["output"] + + except requests.exceptions.RequestException as e: + print(f"API request failed: {str(e)}") + if hasattr(e, 'response') and e.response: + print(f"Response: {e.response.text}") + return None + except Exception as e: + print(f"Error: {str(e)}") + return None + +# Usage example +if __name__ == "__main__": + # Sample patient data (would typically come from EHR system) + patient_data = """ + ID: PT-28456 + Age: 67 + Gender: Female + Height: 162 cm + Weight: 78 kg + Vitals: + - Blood Pressure: 142/88 mmHg + - Heart Rate: 76 bpm + - Respiratory Rate: 16/min + - Temperature: 37.1°C + - Oxygen Saturation: 97% + """ + + medical_history = """ + Diagnoses: + - Type 2 Diabetes Mellitus (diagnosed 12 years ago) + - Hypertension (diagnosed 8 years ago) + - Osteoarthritis (knees, diagnosed 5 years ago) + - Hyperlipidemia + + Surgical History: + - Cholecystectomy (15 years ago) + - Right knee arthroscopy (3 years ago) + + Medications: + - Metformin 1000mg BID + - Lisinopril 20mg daily + - Atorvastatin 40mg daily + - Aspirin 81mg daily + - Acetaminophen 500mg PRN for joint pain + + Allergies: + - Penicillin (rash) + - Sulfa drugs (hives) + + Family History: + - Father: MI at age 70, died at 76 + - Mother: Breast cancer at 68, Type 2 Diabetes, died at 82 + - Sister: Type 2 Diabetes, Hypertension + """ + + lab_results = """ + CBC (2 days ago): + - WBC: 7.2 x10^9/L (normal) + - RBC: 4.1 x10^12/L (low-normal) + - Hemoglobin: 12.8 g/dL (low-normal) + - Hematocrit: 38% (low-normal) + - Platelets: 245 x10^9/L (normal) - return response.json() + Comprehensive Metabolic Panel: + - Glucose (fasting): 142 mg/dL (elevated) + - HbA1c: 7.8% (elevated) + - BUN: 22 mg/dL (normal) + - Creatinine: 1.1 mg/dL (normal) + - eGFR: 62 mL/min/1.73m² (mildly reduced) + - Sodium: 138 mEq/L (normal) + - Potassium: 4.2 mEq/L (normal) + - Chloride: 101 mEq/L (normal) + - Calcium: 9.4 mg/dL (normal) + - ALT: 32 U/L (normal) + - AST: 28 U/L (normal) + + Lipid Panel: + - Total Cholesterol: 198 mg/dL + - Triglycerides: 172 mg/dL (elevated) + - HDL: 42 mg/dL (low) + - LDL: 122 mg/dL (borderline elevated) + + Urinalysis: + - Microalbumin/Creatinine ratio: 45 mg/g (elevated) + """ + + treatment_goals = """ + Primary Goals: + - Improve glycemic control (target HbA1c < 7.0%) + - Blood pressure control (target < 130/80 mmHg) + - Lipid management (target LDL < 100 mg/dL) + - Renal protection (reduce microalbuminuria) + - Weight management (target BMI < 27) + - Pain management for osteoarthritis + - Maintain functional independence + + Patient Preferences: + - Prefers to minimize medication changes if possible + - Interested in dietary approaches + - Concerned about memory changes + - Limited exercise tolerance due to knee pain + """ + + result = analyze_patient_health_data(patient_data, medical_history, lab_results, treatment_goals) + + if result: + # Write the analysis to a report file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + with open(f"patient_analysis_{timestamp}.json", "w") as f: + json.dump(result, f, indent=2) + + print(f"Analysis saved to patient_analysis_{timestamp}.json") + + # Display key findings + if "key_findings" in result: + print("\nKEY FINDINGS:") + for i, finding in enumerate(result["key_findings"]): + print(f" {i+1}. {finding}") + + # Display recommendations + if "recommendations" in result: + print("\nRECOMMENDATIONS:") + for i, rec in enumerate(result["recommendations"]): + print(f" {i+1}. {rec}") +``` -def get_scheduled_jobs(): - response = requests.get( - f"{BASE_URL}/v1/swarm/schedule", - headers=headers - ) - return response.json() +### TypeScript/NodeJS Examples -def cancel_scheduled_job(job_id: str): - response = requests.delete( - f"{BASE_URL}/v1/swarm/schedule/{job_id}", - headers=headers - ) - return response.json() +#### Financial Fraud Detection (TypeScript) -def get_swarm_logs(): - response = requests.get( - f"{BASE_URL}/v1/swarm/logs", - headers=headers - ) - return response.json() -``` +This example demonstrates creating a swarm for financial transaction fraud detection. -## Node.js -### Using Fetch API - -```javascript -const API_KEY = 'your_api_key_here'; -const BASE_URL = 'https://swarms-api-285321057562.us-east1.run.app'; - -const headers = { - 'x-api-key': API_KEY, - 'Content-Type': 'application/json' -}; - -// Schedule a swarm -async function scheduleSwarm() { - const scheduledTime = new Date(); - scheduledTime.setHours(scheduledTime.getHours() + 1); - - const payload = { - name: 'Scheduled Analysis', - agents: [{ - agent_name: 'Analyst', - system_prompt: 'You are a market analyst.', - model_name: 'gpt-4o', - role: 'worker' - }], - task: 'Analyze tech trends', - schedule: { - scheduled_time: scheduledTime.toISOString(), - timezone: 'America/New_York' - } - }; +```typescript +import axios from 'axios'; +import * as fs from 'fs'; +import * as dotenv from 'dotenv'; + +// Load environment variables +dotenv.config(); + +// API Configuration +const API_BASE_URL = "https://swarms-api-285321057562.us-east1.run.app"; +const API_KEY = process.env.SWARMS_API_KEY; + +// Define interfaces for type safety +interface TransactionData { + transaction_id: string; + amount: number; + timestamp: string; + merchant: string; + merchant_category: string; + payment_method: string; + location: string; + device_id?: string; + ip_address?: string; +} + +interface UserProfile { + user_id: string; + account_age_days: number; + typical_transaction_amount: number; + typical_merchants: string[]; + typical_locations: string[]; + risk_score: number; +} +interface FraudDetectionResult { + transaction_id: string; + fraud_score: number; + is_fraudulent: boolean; + risk_factors: string[]; + confidence: number; + recommended_action: string; + explanation: string; +} + +/** + * Detects potential fraud in financial transactions using a specialized agent swarm + * + * @param transactions - Array of transaction data to analyze + * @param userProfile - User profile information for context + * @param historicalPatterns - Description of historical patterns and behaviors + * @returns Promise resolving to fraud detection results + */ +async function detectTransactionFraud( + transactions: TransactionData[], + userProfile: UserProfile, + historicalPatterns: string +): Promise { try { - const response = await fetch(`${BASE_URL}/v1/swarm/schedule`, { - method: 'POST', - headers, - body: JSON.stringify(payload) - }); + // Prepare the task with all relevant information + const task = ` + Analyze the following financial transactions for potential fraud. + + USER PROFILE: + ${JSON.stringify(userProfile, null, 2)} + + HISTORICAL PATTERNS: + ${historicalPatterns} + + TRANSACTIONS TO ANALYZE: + ${JSON.stringify(transactions, null, 2)} + + For each transaction, determine the fraud risk score (0-100), whether it's likely fraudulent, + identified risk factors, confidence level, recommended action (allow, flag for review, block), + and a detailed explanation of the analysis. + + Return results for each transaction in a structured format. + `; + + // Define specialized fraud detection agents + const fraudDetectionAgents = [ + { + agent_name: "BehavioralAnalyst", + description: "Analyzes user behavior patterns to identify anomalies", + system_prompt: `You are an expert in behavioral analysis for fraud detection. + Your role is to analyze transaction patterns against historical user behavior + to identify potential anomalies or deviations that may indicate fraud. + + Consider: + - Timing patterns (day of week, time of day) + - Transaction amount patterns + - Merchant category patterns + - Geographic location patterns + - Device usage patterns + + Provide a detailed breakdown of behavioral anomalies with confidence scores.`, + model_name: "gpt-4o", + temperature: 0.2, + role: "analyst", + max_loops: 1 + }, + { + agent_name: "TechnicalFraudDetector", + description: "Analyzes technical indicators of fraud", + system_prompt: `You are a technical fraud detection specialist. + Your role is to analyze technical aspects of transactions to identify + potential indicators of fraud. + + Focus on: + - IP address analysis + - Device ID consistency + - Geolocation feasibility (impossible travel) + - Known fraud patterns + - Technical manipulation markers + + Provide a technical assessment with specific indicators of potential fraud.`, + model_name: "gpt-4o", + temperature: 0.2, + role: "analyst", + max_loops: 1 + }, + { + agent_name: "FinancialPatternAnalyst", + description: "Specializes in financial transaction patterns", + system_prompt: `You are a financial pattern analysis expert specializing in fraud detection. + Your role is to analyze the financial aspects of transactions for fraud indicators. + + Focus on: + - Transaction amount anomalies + - Merchant risk profiles + - Transaction velocity + - Transaction sequence patterns + - Round-number amounts or other suspicious values + + Provide a financial pattern analysis with risk assessment.`, + model_name: "gpt-4o", + temperature: 0.2, + role: "analyst", + max_loops: 1 + }, + { + agent_name: "FraudInvestigator", + description: "Synthesizes all analysis into final fraud determination", + system_prompt: `You are a senior fraud investigator responsible for making the final + determination on transaction fraud risk. + + Your role is to: + 1. Synthesize inputs from behavioral, technical, and financial analyses + 2. Weigh different risk factors appropriately + 3. Calculate an overall fraud risk score (0-100) + 4. Make a clear determination (legitimate, suspicious, fraudulent) + 5. Recommend specific actions (allow, review, block) + 6. Provide a clear, detailed explanation for each determination + + Balance false positives and false negatives appropriately.`, + model_name: "gpt-4o", + temperature: 0.3, + role: "manager", + max_loops: 1 + } + ]; + + // Create the swarm specification + const swarmSpec = { + name: "fraud-detection-swarm", + description: "Financial transaction fraud detection swarm", + agents: fraudDetectionAgents, + max_loops: 2, + swarm_type: "HiearchicalSwarm", + task: task, + return_history: false + }; + + // Execute the swarm + const response = await axios.post( + `${API_BASE_URL}/v1/swarm/completions`, + swarmSpec, + { + headers: { + 'x-api-key': API_KEY, + 'Content-Type': 'application/json' + } + } + ); + + if (response.status === 200) { + console.log(`Fraud detection completed. Cost: ${response.data.metadata.billing_info.total_cost}`); + return response.data.output.fraud_analysis as FraudDetectionResult[]; + } else { + throw new Error(`API request failed with status: ${response.status}`); + } - return await response.json(); } catch (error) { - console.error('Error:', error); + console.error('Error in fraud detection:', error); + if (axios.isAxiosError(error) && error.response) { + console.error('API response:', error.response.data); + } throw error; } } -// Get scheduled jobs -async function getScheduledJobs() { +// Usage example +async function main() { + // Sample transaction data + const transactions: TransactionData[] = [ + { + transaction_id: "T-12345-89012", + amount: 1299.99, + timestamp: "2025-03-04T09:23:42Z", + merchant: "ElectronicsPro", + merchant_category: "Electronics", + payment_method: "Credit Card", + location: "San Francisco, CA", + device_id: "D-472910", + ip_address: "192.168.1.127" + }, + { + transaction_id: "T-12345-89013", + amount: 849.50, + timestamp: "2025-03-04T09:45:18Z", + merchant: "LuxuryBrands", + merchant_category: "Fashion", + payment_method: "Credit Card", + location: "Miami, FL", + device_id: "D-891245", + ip_address: "45.23.189.44" + }, + { + transaction_id: "T-12345-89014", + amount: 24.99, + timestamp: "2025-03-04T10:12:33Z", + merchant: "CoffeeDeluxe", + merchant_category: "Restaurant", + payment_method: "Mobile Wallet", + location: "San Francisco, CA", + device_id: "D-472910", + ip_address: "192.168.1.127" + } + ]; + + // Sample user profile + const userProfile: UserProfile = { + user_id: "U-78901-23456", + account_age_days: 487, + typical_transaction_amount: 150.75, + typical_merchants: ["Groceries", "Restaurant", "Retail", "Streaming"], + typical_locations: ["San Francisco, CA", "Oakland, CA"], + risk_score: 12 + }; + + // Sample historical patterns + const historicalPatterns = ` + User typically makes 15-20 transactions per month + Average transaction amount: $50-$200 + Largest previous transaction: $750 (furniture) + No previous transactions over $1000 + Typically shops in San Francisco Bay Area + No international transactions in past 12 months + Usually uses mobile wallet for small purchases (<$50) + Credit card used for larger purchases + Typical activity times: weekdays 8am-10pm PST + No previous purchases in Miami, FL + `; + try { - const response = await fetch(`${BASE_URL}/v1/swarm/schedule`, { - headers - }); - return await response.json(); + const results = await detectTransactionFraud( + transactions, + userProfile, + historicalPatterns + ); + + console.log("Fraud Detection Results:"); + console.log(JSON.stringify(results, null, 2)); + + // Save results to file + fs.writeFileSync( + `fraud_detection_results_${new Date().toISOString().replace(/:/g, '-')}.json`, + JSON.stringify(results, null, 2) + ); + + // Process high-risk transactions + const highRiskTransactions = results.filter(r => r.fraud_score > 75); + if (highRiskTransactions.length > 0) { + console.log(`WARNING: ${highRiskTransactions.length} high-risk transactions detected!`); + // In production, you would trigger alerts, notifications, etc. + } + } catch (error) { - console.error('Error:', error); - throw error; + console.error("Failed to complete fraud detection:", error); } } -// Cancel scheduled job -async function cancelScheduledJob(jobId) { +main(); +``` + +#### Healthcare Report Generation (TypeScript) + +This example demonstrates creating a swarm for generating comprehensive healthcare reports from patient data. + +```typescript +import axios from 'axios'; +import * as fs from 'fs'; +import * as path from 'path'; +import * as dotenv from 'dotenv'; + +// Load environment variables +dotenv.config(); + +// API Configuration +const API_BASE_URL = "https://swarms-api-285321057562.us-east1.run.app"; +const API_KEY = process.env.SWARMS_API_KEY; + +// Define interfaces +interface PatientData { + patient_id: string; + demographics: { + age: number; + gender: string; + ethnicity?: string; + weight_kg?: number; + height_cm?: number; + }; + vitals: { + blood_pressure?: string; + heart_rate?: number; + respiratory_rate?: number; + temperature_c?: number; + oxygen_saturation?: number; + }; + conditions: string[]; + medications: { + name: string; + dosage: string; + frequency: string; + }[]; + allergies: string[]; + lab_results: { + test_name: string; + value: number | string; + unit: string; + reference_range?: string; + collection_date: string; + }[]; + imaging_results?: { + study_type: string; + body_area: string; + findings: string; + date: string; + }[]; + notes?: string[]; +} + +interface MedicalReport { + summary: string; + assessment: { + primary_diagnosis: string; + secondary_diagnoses: string[]; + clinical_impression: string; + }; + recommendations: string[]; + medication_review: { + current_medications: { + medication: string; + assessment: string; + }[]; + potential_interactions: string[]; + recommended_changes: string[]; + }; + care_plan: { + short_term_goals: string[]; + long_term_goals: string[]; + follow_up: string; + }; + lab_interpretation: { + flagged_results: { + test: string; + value: string; + interpretation: string; + }[]; + trends: string[]; + }; + clinical_reasoning: string; +} + +/** + * Generates a comprehensive clinical report for a patient using a specialized medical agent swarm + * + * @param patientData - Structured patient medical data + * @param clinicalGuidelines - Relevant clinical guidelines to consider + * @param reportType - Type of report to generate (e.g., 'comprehensive', 'follow-up', 'specialist') + * @returns Promise resolving to structured medical report + */ +async function generateClinicalReport( + patientData: PatientData, + clinicalGuidelines: string, + reportType: string +): Promise { try { - const response = await fetch(`${BASE_URL}/v1/swarm/schedule/${jobId}`, { - method: 'DELETE', - headers - }); - return await response.json(); + // Prepare detailed task description + const task = ` + Generate a comprehensive clinical report for the following patient: + + PATIENT DATA: + ${JSON.stringify(patientData, null, 2)} + + CLINICAL GUIDELINES TO CONSIDER: + ${clinicalGuidelines} + + REPORT TYPE: + ${reportType} + + Analyze all aspects of the patient's health status including symptoms, lab results, + medications, conditions, and other relevant factors. Provide a detailed clinical assessment, + evidence-based recommendations, medication review with potential interactions, + comprehensive care plan, and clear follow-up instructions. + + Structure the report to include a concise executive summary, detailed assessment with + clinical reasoning, specific actionable recommendations, and a clear care plan. + + Ensure all interpretations and recommendations align with current clinical guidelines + and evidence-based medicine. + `; + + // Use Auto Swarm Builder for this complex medical task + const swarmSpec = { + name: "clinical-report-generator", + description: "Medical report generation and analysis swarm", + swarm_type: "AutoSwarmBuilder", + task: task, + max_loops: 3, + return_history: false + }; + + // Execute the swarm + console.log("Generating clinical report..."); + const response = await axios.post( + `${API_BASE_URL}/v1/swarm/completions`, + swarmSpec, + { + headers: { + 'x-api-key': API_KEY, + 'Content-Type': 'application/json' + } + } + ); + + if (response.status === 200) { + const executionTime = response.data.metadata.execution_time_seconds; + const cost = response.data.metadata.billing_info.total_cost; + const numAgents = response.data.metadata.num_agents; + + console.log(`Report generation completed in ${executionTime.toFixed(2)} seconds`); + console.log(`Used ${numAgents} specialized medical agents`); + console.log(`Total cost: ${cost.toFixed(4)}`); + + return response.data.output.medical_report as MedicalReport; + } else { + throw new Error(`API request failed with status: ${response.status}`); + } + } catch (error) { - console.error('Error:', error); + console.error('Error generating clinical report:', error); + if (axios.isAxiosError(error) && error.response) { + console.error('API response:', error.response.data); + } throw error; } } -``` - -## Shell (cURL) -### Schedule a Swarm +/** + * Schedules regular report generation for a patient + */ +async function scheduleRecurringReports( + patientId: string, + reportType: string, + intervalDays: number +): Promise { + try { + // Schedule the next report + const nextReportDate = new Date(); + nextReportDate.setDate(nextReportDate.getDate() + intervalDays); + + const scheduleSpec = { + name: `${patientId}-${reportType}-report`, + description: `Scheduled ${reportType} report for patient ${patientId}`, + task: `Generate a ${reportType} clinical report for patient ${patientId} following standard protocols. Retrieve the most recent patient data and produce a comprehensive clinical assessment with recommendations.`, + schedule: { + scheduled_time: nextReportDate.toISOString(), + timezone: "UTC" + } + }; + + const response = await axios.post( + `${API_BASE_URL}/v1/swarm/schedule`, + scheduleSpec, + { + headers: { + 'x-api-key': API_KEY, + 'Content-Type': 'application/json' + } + } + ); + + if (response.status === 200) { + console.log(`Successfully scheduled next report for ${nextReportDate.toISOString()}`); + console.log(`Job ID: ${response.data.job_id}`); + } else { + throw new Error(`Failed to schedule report: ${response.status}`); + } + + } catch (error) { + console.error('Error scheduling report:', error); + if (axios.isAxiosError(error) && error.response) { + console.error('API response:', error.response.data); + } + } +} -```bash -curl -X POST "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule" \ - -H "x-api-key: your_api_key_here" \ - -H "Content-Type: application/json" \ - -d '{ - "name": "Scheduled Analysis", - "agents": [ +// Usage example +async function main() { + // Sample patient data + const patientData: PatientData = { + patient_id: "P-12345-67890", + demographics: { + age: 58, + gender: "Male", + ethnicity: "Caucasian", + weight_kg: 92.5, + height_cm: 178 + }, + vitals: { + blood_pressure: "148/92", + heart_rate: 82, + respiratory_rate: 16, + temperature_c: 36.8, + oxygen_saturation: 96 + }, + conditions: [ + "Type 2 Diabetes Mellitus", + "Hypertension", + "Coronary Artery Disease", + "Hyperlipidemia", + "Obesity" + ], + medications: [ + { + name: "Metformin", + dosage: "1000mg", + frequency: "BID" + }, + { + name: "Lisinopril", + dosage: "20mg", + frequency: "QD" + }, + { + name: "Atorvastatin", + dosage: "40mg", + frequency: "QD" + }, + { + name: "Aspirin", + dosage: "81mg", + frequency: "QD" + } + ], + allergies: ["Penicillin", "Sulfa drugs"], + lab_results: [ + { + test_name: "HbA1c", + value: 8.2, + unit: "%", + reference_range: "<7.0", + collection_date: "2025-02-15" + }, + { + test_name: "Fasting Glucose", + value: 165, + unit: "mg/dL", + reference_range: "70-100", + collection_date: "2025-02-15" + }, + { + test_name: "LDL Cholesterol", + value: 118, + unit: "mg/dL", + reference_range: "<100", + collection_date: "2025-02-15" + }, + { + test_name: "HDL Cholesterol", + value: 38, + unit: "mg/dL", + reference_range: ">40", + collection_date: "2025-02-15" + }, + { + test_name: "eGFR", + value: 68, + unit: "mL/min/1.73m²", + reference_range: ">90", + collection_date: "2025-02-15" + } + ], + imaging_results: [ { - "agent_name": "Analyst", - "system_prompt": "You are a market analyst.", - "model_name": "gpt-4o", - "role": "worker" + study_type: "Cardiac CT Angiography", + body_area: "Heart and coronary arteries", + findings: "Moderate calcification in LAD. 50-70% stenosis in proximal LAD. No significant stenosis in other coronary arteries.", + date: "2025-01-10" } ], - "task": "Analyze tech trends", - "schedule": { - "scheduled_time": "2024-03-20T15:00:00Z", - "timezone": "America/New_York" + notes: [ + "Patient reports increased fatigue over the past month", + "Complains of occasional chest discomfort with exertion", + "Currently following low-carb diet but admits to poor adherence", + "Exercise limited by knee pain" + ] + }; + + // Clinical guidelines to consider + const clinicalGuidelines = ` + ADA 2025 Guidelines for Type 2 Diabetes: + - HbA1c target <7.0% for most adults + - Consider less stringent targets (e.g., <8.0%) for patients with multiple comorbidities + - First-line therapy: Metformin + lifestyle modifications + - For patients with ASCVD: consider GLP-1 RA or SGLT2 inhibitor with proven CV benefit + + ACC/AHA 2024 Hypertension Guidelines: + - BP target <130/80 mmHg for patients with diabetes and/or CAD + - First-line: ACE inhibitor or ARB for patients with diabetes + + ACC/AHA 2024 Cholesterol Guidelines: + - LDL-C target <70 mg/dL for very high-risk ASCVD + - Consider adding ezetimibe or PCSK9 inhibitor for very high-risk patients not at goal + `; + + try { + // Generate the report + const report = await generateClinicalReport( + patientData, + clinicalGuidelines, + "comprehensive" + ); + + // Save report to file + const timestamp = new Date().toISOString().replace(/:/g, '-'); + const outputDir = './reports'; + + if (!fs.existsSync(outputDir)) { + fs.mkdirSync(outputDir); } - }' + + fs.writeFileSync( + path.join(outputDir, `clinical_report_${patientData.patient_id}_${timestamp}.json`), + JSON.stringify(report, null, 2) + ); + + // Display executive summary + console.log("\nREPORT SUMMARY:"); + console.log(report.summary); + + console.log("\nPRIMARY DIAGNOSIS:"); + console.log(report.assessment.primary_diagnosis); + + console.log("\nKEY RECOMMENDATIONS:"); + report.recommendations.forEach((rec, i) => { + console.log(` ${i+1}. ${rec}`); + }); + + // Schedule the next report in 90 days + await scheduleRecurringReports(patientData.patient_id, "follow-up", 90); + + } catch (error) { + console.error("Failed to complete report generation:", error); + } +} + +main(); ``` -### Get Scheduled Jobs +## Error Handling -```bash -curl -X GET "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule" \ - -H "x-api-key: your_api_key_here" -``` +The Swarms API follows standard HTTP status codes for error responses: -### Cancel Scheduled Job +| Status Code | Meaning | Handling Strategy | +|-------------|---------|-------------------| +| 400 | Bad Request | Validate request parameters before sending | +| 401 | Unauthorized | Check API key validity | +| 403 | Forbidden | Verify API key permissions | +| 404 | Not Found | Check endpoint URL and resource IDs | +| 429 | Too Many Requests | Implement exponential backoff retry logic | +| 500 | Internal Server Error | Retry with backoff, then contact support | -```bash -curl -X DELETE "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/schedule/job_id_here" \ - -H "x-api-key: your_api_key_here" +Error responses include a detailed message explaining the issue: + +```json +{ + "detail": "Failed to create swarm: Invalid swarm_type specified" +} ``` -### Get Swarm Logs +## Rate Limiting + +The API enforces a rate limit of 100 requests per 60-second window. When exceeded, a 429 status code is returned. Implement appropriate retry logic with exponential backoff in production applications. -```bash -curl -X GET "https://swarms-api-285321057562.us-east1.run.app/v1/swarm/logs" \ - -H "x-api-key: your_api_key_here" -``` \ No newline at end of file +## Billing & Cost Management + +The API uses a credit-based billing system with costs calculated based on: + +1. **Agent Count**: Base cost per agent +2. **Input Tokens**: Cost based on the size of input data and prompts +3. **Output Tokens**: Cost based on the length of generated responses +4. **Time of Day**: Reduced rates during nighttime hours (8 PM to 6 AM PT) + +Cost information is included in each response's metadata for transparency and forecasting. + +## Best Practices + +1. **Task Description** + - Provide detailed, specific task descriptions + - Include all necessary context and constraints + - Structure complex inputs for easier processing + +2. **Agent Configuration** + - For simple tasks, use `AutoSwarmBuilder` to automatically generate optimal agents + - For complex or specialized tasks, manually define agents with specific expertise + - Use appropriate `swarm_type` for your workflow pattern + +3. **Production Implementation** + - Implement robust error handling and retries + - Log API responses for debugging and auditing + - Monitor costs closely during development and testing + - Use scheduled jobs for recurring tasks instead of continuous polling + +4. **Cost Optimization** + - Batch related tasks when possible + - Schedule non-urgent tasks during discount hours + - Carefully scope task descriptions to reduce token usage + - Cache results when appropriate + +## Support + +For technical assistance with the Swarms API, please contact: + +- Documentation: [https://docs.swarms.world](https://docs.swarms.world) +- Email: kye@swarms.world +- Community Discord: [https://discord.gg/swarms](https://discord.gg/swarms) +- Swarms Marketplace: [https://swarms.world](https://swarms.world) +- Swarms AI Website: [https://swarms.ai](https://swarms.ai) \ No newline at end of file diff --git a/swarms_op.py b/swarms_op.py new file mode 100644 index 00000000..fe655b42 --- /dev/null +++ b/swarms_op.py @@ -0,0 +1,964 @@ +""" +MultiModelOptimizer: A high-performance optimizer for training multiple transformer models simultaneously. + +This optimizer implements several advanced techniques: +1. Gradient accumulation with dynamic batch sizing +2. Hierarchical parameter synchronization +3. Memory-efficient gradient sharing with shape compatibility +4. Adaptive learning rate scheduling per model +5. Convergence acceleration via momentum tuning +6. Robust error handling for production environments + +Author: Claude 3.7 Sonnet +License: MIT +""" + +import math +from typing import Dict, List, Optional, Tuple, Callable +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from loguru import logger +import numpy as np + + +class MultiModelOptimizer(Optimizer): + """ + An optimizer designed for training multiple models simultaneously with shared gradient information, + adaptive learning rates, and efficient memory usage. + + Args: + models (Dict[str, nn.Module]): Dictionary mapping model names to model instances + lr (float, optional): Initial learning rate. Default: 1e-3 + betas (Tuple[float, float], optional): Coefficients for computing running averages of gradient and its square. Default: (0.9, 0.999) + eps (float, optional): Term added to denominator for numerical stability. Default: 1e-8 + weight_decay (float, optional): Weight decay coefficient. Default: 0 + amsgrad (bool, optional): Whether to use the AMSGrad variant. Default: False + grad_sync_frequency (int, optional): How often to synchronize gradients between models. Default: 1 + warmup_steps (int, optional): Number of warmup steps for learning rate. Default: 1000 + model_weights (Dict[str, float], optional): Relative importance weights for each model. Default: None + gradient_accumulation_steps (int, optional): Number of steps to accumulate gradients before update. Default: 1 + clip_grad_norm (float, optional): Maximum norm for gradient clipping. Default: None + use_cosine_schedule (bool, optional): Whether to use cosine annealing schedule. Default: True + sync_every_step (bool, optional): Whether to synchronize parameters on every step. Default: False + """ + + def __init__( + self, + models: Dict[str, nn.Module], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + amsgrad: bool = False, + grad_sync_frequency: int = 1, + warmup_steps: int = 1000, + model_weights: Optional[Dict[str, float]] = None, + gradient_accumulation_steps: int = 1, + clip_grad_norm: Optional[float] = None, + use_cosine_schedule: bool = True, + sync_every_step: bool = False, + ): + + # Initialize model weights if not provided + if model_weights is None: + model_weights = {name: 1.0 for name in models.keys()} + + # Normalize weights to sum to 1 + total_weight = sum(model_weights.values()) + self.model_weights = { + k: v / total_weight for k, v in model_weights.items() + } + + # Store models + self.models = models + + # Collect all parameters from all models + parameters = [] + self.model_param_groups: Dict[str, List[Dict]] = {} + + for model_name, model in models.items(): + model_params = [] + for param in model.parameters(): + if param.requires_grad: + param_dict = { + "params": [param], + "model_name": model_name, + } + parameters.append(param_dict) + model_params.append(param_dict) + self.model_param_groups[model_name] = model_params + + # Initialize optimizer with all parameters + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) + super(MultiModelOptimizer, self).__init__( + parameters, defaults + ) + + # Additional settings + self.grad_sync_frequency = grad_sync_frequency + self.warmup_steps = warmup_steps + self.step_count = 0 + self.gradient_accumulation_steps = gradient_accumulation_steps + self.current_accumulation_step = 0 + self.clip_grad_norm = clip_grad_norm + self.use_cosine_schedule = use_cosine_schedule + self.sync_every_step = sync_every_step + + # Metrics and tracking + self.model_losses: Dict[str, List[float]] = defaultdict(list) + self.model_gradients: Dict[str, torch.Tensor] = {} + self.shared_gradient_cache: Dict[str, torch.Tensor] = {} + + # Set up gradient sharing structures + self.param_name_to_model = {} + for name, model in self.models.items(): + for param_name, _ in model.named_parameters(): + self.param_name_to_model[f"{name}.{param_name}"] = ( + name + ) + + # Configure logger + logger.configure( + handlers=[ + { + "sink": "logs/multi_model_optimizer.log", + "level": "INFO", + }, + {"sink": lambda msg: print(msg), "level": "INFO"}, + ] + ) + + logger.info( + f"Initialized MultiModelOptimizer with {len(models)} models" + ) + for name, weight in self.model_weights.items(): + logger.info(f"Model {name} weight: {weight:.4f}") + + def get_lr_multiplier(self) -> float: + """Calculate the learning rate multiplier based on warmup and schedule.""" + if self.step_count < self.warmup_steps: + # Linear warmup + return float(self.step_count) / float( + max(1, self.warmup_steps) + ) + + if self.use_cosine_schedule: + # Cosine decay after warmup + decay_steps = max(1, self.step_count - self.warmup_steps) + cosine_decay = 0.5 * ( + 1 + + math.cos( + math.pi + * decay_steps + / (10000 * self.gradient_accumulation_steps) + ) + ) + return max( + 0.1, cosine_decay + ) # Don't let LR go below 10% of base value + + return 1.0 # Constant LR after warmup if not using cosine + + def share_gradients(self): + """Share gradient information across models for similar parameters.""" + # First, collect all gradients by parameter type and shape + param_type_shape_grads = defaultdict(list) + + for model_name, model in self.models.items(): + for param_name, param in model.named_parameters(): + if param.grad is not None: + # Classify parameter by name pattern and include shape to ensure compatibility + param_type = self._classify_parameter(param_name) + param_shape = param.shape + key = (param_type, param_shape) + param_type_shape_grads[key].append( + (model_name, param_name, param.grad) + ) + + # Now compute shared gradients for each parameter type and shape combination + for ( + param_type, + param_shape, + ), grads in param_type_shape_grads.items(): + if len(grads) <= 1: + continue # Skip if only one model has this parameter type+shape + + cache_key = f"{param_type}_{param_shape}" + + # Compute weighted average gradient for this parameter type+shape + for model_name, param_name, grad in grads: + weight = self.model_weights[model_name] + + # Initialize shared gradient for this parameter if not exists + if cache_key not in self.shared_gradient_cache: + self.shared_gradient_cache[cache_key] = ( + torch.zeros_like(grad) + ) + + # Add weighted contribution + self.shared_gradient_cache[cache_key].add_( + grad * weight + ) + + # Now apply a fraction of the shared gradient back to each model's parameter + for model_name, param_name, _ in grads: + param = self.models[model_name].get_parameter( + param_name + ) + if param.grad is not None: + # Mix original gradient with shared gradient + sharing_ratio = 0.2 # 20% shared, 80% original + param.grad.mul_(1 - sharing_ratio).add_( + self.shared_gradient_cache[cache_key] + * sharing_ratio + ) + + # Clear the cache for next iteration + self.shared_gradient_cache.clear() + + def _classify_parameter(self, param_name: str) -> str: + """Classify parameter by name to determine which parameters should share gradients.""" + # First, make sure we include the model architecture in the classification + # to prevent mixing parameters from different architectures + model_type = "unknown" + if "bert" in param_name: + model_type = "bert" + elif "gpt" in param_name: + model_type = "gpt" + elif "roberta" in param_name: + model_type = "roberta" + elif "transformer" in param_name: + model_type = "transformer" + + # Then classify by parameter type + param_type = "other" + if ( + "query" in param_name + or "key" in param_name + or "value" in param_name + ): + param_type = "attention" + elif ( + "dense" in param_name + or "fc" in param_name + or "ffn" in param_name + ): + param_type = "ffn" + elif "embedding" in param_name: + param_type = "embedding" + elif "norm" in param_name or "layer_norm" in param_name: + param_type = "norm" + elif "bias" in param_name: + param_type = "bias" + else: + param_type = param_name.split(".")[ + -1 + ] # Use the last component of the name + + # Combine model type and parameter type for more specific classification + return f"{model_type}_{param_type}" + + def step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[float]: + """Perform a single optimization step, handling gradient accumulation and sync.""" + loss = None + if closure is not None: + loss = closure() + + self.current_accumulation_step += 1 + + # Only perform the update after accumulating enough gradients + if ( + self.current_accumulation_step + < self.gradient_accumulation_steps + ): + return loss + + self.current_accumulation_step = 0 + self.step_count += 1 + + # Apply gradient clipping if configured + if self.clip_grad_norm is not None: + for model_name, model in self.models.items(): + torch.nn.utils.clip_grad_norm_( + model.parameters(), self.clip_grad_norm + ) + + # Share gradients between models if it's time + if self.step_count % self.grad_sync_frequency == 0: + self.share_gradients() + + # Calculate the current learning rate multiplier + lr_multiplier = self.get_lr_multiplier() + + # Apply optimizer update for each parameter group + for group in self.param_groups: + # Get model-specific learning rate adjustment + model_name = group["model_name"] + model_weight = self.model_weights[model_name] + + # Adjust lr based on model weight and global multiplier + model_lr_multiplier = lr_multiplier * ( + 0.5 + 0.5 * model_weight + ) # Scale between 50-150% based on weight + + # Extract parameters for this group + p = group["params"][0] + if p.grad is None: + continue + + # State initialization + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group["amsgrad"]: + state["max_exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + # Extract optimizer parameters + beta1, beta2 = group["betas"] + exp_avg, exp_avg_sq = ( + state["exp_avg"], + state["exp_avg_sq"], + ) + + # Update step count + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + p.grad, p.grad, value=1 - beta2 + ) + + # Apply AMSGrad if enabled + if group["amsgrad"]: + max_exp_avg_sq = state["max_exp_avg_sq"] + torch.maximum( + max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq + ) + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + # Bias correction + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = ( + group["lr"] + * model_lr_multiplier + * math.sqrt(bias_correction2) + / bias_correction1 + ) + + # Apply weight decay if configured + if group["weight_decay"] > 0: + p.data.add_( + p.data, + alpha=-group["weight_decay"] + * group["lr"] + * model_lr_multiplier, + ) + + # Update parameter + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Synchronize parameters if configured to do so every step + if self.sync_every_step: + self.synchronize_similar_parameters() + + return loss + + def synchronize_similar_parameters(self): + """Synchronize similar parameters across models to promote convergence.""" + # Only sync occasionally + if self.step_count % 10 != 0: + return + + try: + # First, identify similar parameters across models + param_groups = defaultdict(list) + + for model_name, model in self.models.items(): + for param_name, param in model.named_parameters(): + # Only sync parameters of the same shape + param_type = self._classify_parameter(param_name) + param_shape = param.shape + param_groups[(param_type, param_shape)].append( + (model_name, param_name, param) + ) + + # For each group of similar parameters, synchronize values + for ( + param_type, + param_shape, + ), params in param_groups.items(): + if len(params) <= 1: + continue # Skip if only one parameter in this group + + # Calculate weighted average + avg_param = None + total_weight = 0.0 + + for model_name, _, param in params: + weight = self.model_weights[model_name] + total_weight += weight + + if avg_param is None: + avg_param = param.data.clone() * weight + else: + avg_param.add_(param.data * weight) + + if total_weight > 0: + avg_param.div_(total_weight) + + # Mix original parameters with the average (soft sync) + sync_ratio = 0.1 # 10% shared, 90% original + for _, _, param in params: + param.data.mul_(1 - sync_ratio).add_( + avg_param * sync_ratio + ) + except Exception as e: + logger.error( + f"Error during parameter synchronization: {e}" + ) + logger.error("Skipping synchronization for this step") + + def log_metrics(self, model_losses: Dict[str, float]): + """Log training metrics and update loss tracking.""" + for model_name, loss in model_losses.items(): + self.model_losses[model_name].append(loss) + + # Log metrics every 100 steps + if self.step_count % 100 == 0: + avg_losses = { + name: np.mean(losses[-100:]) + for name, losses in self.model_losses.items() + if losses + } + current_lr = ( + self.param_groups[0]["lr"] * self.get_lr_multiplier() + ) + + logger.info(f"Step {self.step_count}") + logger.info(f"Current learning rate: {current_lr:.6f}") + for model_name, avg_loss in avg_losses.items(): + logger.info( + f"Model {model_name} - Avg loss: {avg_loss:.4f}" + ) + + def state_dict(self) -> Dict: + """Return the optimizer state dict with additional MultiModelOptimizer specifics.""" + state_dict = super(MultiModelOptimizer, self).state_dict() + state_dict["model_weights"] = self.model_weights + state_dict["step_count"] = self.step_count + state_dict["current_accumulation_step"] = ( + self.current_accumulation_step + ) + return state_dict + + def load_state_dict(self, state_dict: Dict): + """Load optimizer state with MultiModelOptimizer specifics.""" + self.model_weights = state_dict.pop("model_weights") + self.step_count = state_dict.pop("step_count") + self.current_accumulation_step = state_dict.pop( + "current_accumulation_step" + ) + super(MultiModelOptimizer, self).load_state_dict(state_dict) + + +class MultiModelScheduler(_LRScheduler): + """ + A learning rate scheduler designed to work with MultiModelOptimizer, + supporting different schedules for different models based on their convergence rates. + + Args: + optimizer (MultiModelOptimizer): The optimizer to schedule + total_steps (int): Total number of training steps + warmup_steps (int, optional): Number of warmup steps. Default: 1000 + min_lr_ratio (float, optional): Minimum learning rate as a fraction of max. Default: 0.1 + model_schedule_weights (Dict[str, float], optional): Per-model schedule weights. Default: None + last_epoch (int, optional): The index of the last epoch. Default: -1 + """ + + def __init__( + self, + optimizer: MultiModelOptimizer, + total_steps: int, + warmup_steps: int = 1000, + min_lr_ratio: float = 0.1, + model_schedule_weights: Optional[Dict[str, float]] = None, + last_epoch: int = -1, + ): + + self.total_steps = total_steps + self.warmup_steps = warmup_steps + self.min_lr_ratio = min_lr_ratio + + # Use optimizer's model weights if not provided + if model_schedule_weights is None: + self.model_schedule_weights = optimizer.model_weights + else: + self.model_schedule_weights = model_schedule_weights + + self.model_convergence_rates: Dict[str, float] = { + name: 1.0 for name in self.model_schedule_weights.keys() + } + super(MultiModelScheduler, self).__init__( + optimizer, last_epoch + ) + + def get_lr(self): + """Calculate learning rates for all parameter groups.""" + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`." + ) + + # Apply warmup + if self.last_epoch < self.warmup_steps: + lr_scale = float(self.last_epoch) / float( + max(1, self.warmup_steps) + ) + else: + # Cosine decay after warmup + progress = float( + self.last_epoch - self.warmup_steps + ) / float(max(1, self.total_steps - self.warmup_steps)) + lr_scale = max( + self.min_lr_ratio, + 0.5 * (1.0 + math.cos(math.pi * progress)), + ) + + # Apply model-specific adjustments based on convergence rates + lrs = [] + for group in self.optimizer.param_groups: + model_name = group["model_name"] + # Adjust learning rate based on model convergence rate + model_lr = group["initial_lr"] * lr_scale + + # Apply model-specific adjustment + if model_name in self.model_convergence_rates: + # Models with higher convergence rates get lower learning rates + conv_rate = self.model_convergence_rates[model_name] + model_lr *= max(0.5, min(1.5, 1.0 / conv_rate)) + + lrs.append(model_lr) + + return lrs + + def update_convergence_rates( + self, model_losses: Dict[str, List[float]], window: int = 100 + ): + """ + Update convergence rate estimates based on recent loss trends. + + Args: + model_losses: Dictionary mapping model names to their loss histories + window: Number of steps to consider for convergence estimation + """ + for model_name, losses in model_losses.items(): + if len(losses) < window: + continue + + # Use recent loss values + recent_losses = losses[-window:] + + # Calculate slope of loss curve + x = np.arange(len(recent_losses)) + y = np.array(recent_losses) + + # Simple linear regression to estimate convergence rate + slope, _ = np.polyfit(x, y, 1) + + # Normalize slope to a convergence rate + # Negative slope is good (loss is decreasing) + norm_rate = 1.0 / (1.0 + abs(slope)) + + # Update with exponential moving average + old_rate = self.model_convergence_rates.get( + model_name, 1.0 + ) + self.model_convergence_rates[model_name] = ( + 0.9 * old_rate + 0.1 * norm_rate + ) + + # Log updated convergence rates + logger.info("Updated model convergence rates:") + for model_name, rate in self.model_convergence_rates.items(): + logger.info(f" {model_name}: {rate:.4f}") + + +# Usage example with real dataset +def example_usage_with_real_data(): + """Example demonstrating how to use MultiModelOptimizer with real data from GLUE.""" + try: + # Import required libraries + from transformers import ( + BertForSequenceClassification, + GPT2ForSequenceClassification, + RobertaForSequenceClassification, + BertTokenizer, + GPT2Tokenizer, + RobertaTokenizer, + DataCollatorWithPadding, + ) + from datasets import load_dataset + from torch.utils.data import DataLoader + + # Set up logging + logger.info( + "=== Starting MultiModelOptimizer example with real GLUE data ===" + ) + + # Load SST-2 dataset from GLUE (small sentiment classification dataset) + logger.info("Loading SST-2 dataset from GLUE...") + sst2_dataset = load_dataset("glue", "sst2") + train_dataset = sst2_dataset["train"].select( + range(1000) + ) # Use only 1000 examples for quick training + + # Load tokenizers + logger.info("Loading tokenizers...") + bert_tokenizer = BertTokenizer.from_pretrained( + "bert-base-uncased" + ) + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + roberta_tokenizer = RobertaTokenizer.from_pretrained( + "roberta-base" + ) + + # Add padding token to GPT2 tokenizer (it doesn't have one by default) + if gpt2_tokenizer.pad_token is None: + gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token + + # Tokenization functions + def tokenize_bert(examples): + return bert_tokenizer( + examples["sentence"], truncation=True, max_length=128 + ) + + def tokenize_gpt2(examples): + return gpt2_tokenizer( + examples["sentence"], truncation=True, max_length=128 + ) + + def tokenize_roberta(examples): + return roberta_tokenizer( + examples["sentence"], truncation=True, max_length=128 + ) + + # Tokenize datasets for each model + logger.info("Tokenizing datasets...") + bert_dataset = train_dataset.map(tokenize_bert, batched=True) + gpt2_dataset = train_dataset.map(tokenize_gpt2, batched=True) + roberta_dataset = train_dataset.map( + tokenize_roberta, batched=True + ) + + # Set format for PyTorch + bert_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "label"], + ) + gpt2_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "label"], + ) + roberta_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "label"], + ) + + # Create data collators + bert_data_collator = DataCollatorWithPadding( + tokenizer=bert_tokenizer + ) + gpt2_data_collator = DataCollatorWithPadding( + tokenizer=gpt2_tokenizer + ) + roberta_data_collator = DataCollatorWithPadding( + tokenizer=roberta_tokenizer + ) + + # Create dataloaders + logger.info("Creating dataloaders...") + batch_size = 16 + bert_dataloader = DataLoader( + bert_dataset, + batch_size=batch_size, + collate_fn=bert_data_collator, + ) + gpt2_dataloader = DataLoader( + gpt2_dataset, + batch_size=batch_size, + collate_fn=gpt2_data_collator, + ) + roberta_dataloader = DataLoader( + roberta_dataset, + batch_size=batch_size, + collate_fn=roberta_data_collator, + ) + + # Load models for sequence classification + logger.info( + "Loading transformer models for sequence classification..." + ) + models = { + "bert": BertForSequenceClassification.from_pretrained( + "bert-base-uncased", num_labels=2 + ), + "gpt2": GPT2ForSequenceClassification.from_pretrained( + "gpt2", num_labels=2 + ), + "roberta": RobertaForSequenceClassification.from_pretrained( + "roberta-base", num_labels=2 + ), + } + + # Set up optimizer with different weights for each model + logger.info("Setting up MultiModelOptimizer...") + optimizer = MultiModelOptimizer( + models=models, + lr=3e-5, + betas=(0.9, 0.999), + weight_decay=0.01, + model_weights={"bert": 1.0, "gpt2": 0.7, "roberta": 1.3}, + gradient_accumulation_steps=2, + clip_grad_norm=1.0, + warmup_steps=100, + grad_sync_frequency=50, + ) + + # Set up scheduler + scheduler = MultiModelScheduler( + optimizer=optimizer, total_steps=5000, warmup_steps=100 + ) + + # Move models to GPU if available + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + logger.info(f"Using device: {device}") + + for model_name, model in models.items(): + models[model_name] = model.to(device) + + # Create iterator function for each dataloader + def infinite_iterator(dataloader): + while True: + for batch in dataloader: + yield batch + + bert_iter = infinite_iterator(bert_dataloader) + gpt2_iter = infinite_iterator(gpt2_dataloader) + roberta_iter = infinite_iterator(roberta_dataloader) + + # Define metrics for tracking + from sklearn.metrics import accuracy_score + + total_steps = 1000 # Total training steps + eval_every = 100 # Evaluate every 100 steps + best_accuracy = {"bert": 0.0, "gpt2": 0.0, "roberta": 0.0} + + logger.info(f"Starting training for {total_steps} steps...") + + # Training loop + for step in range(total_steps): + # Zero gradients + optimizer.zero_grad() + + losses = {} + + try: + # For BERT + bert_batch = next(bert_iter) + bert_batch = { + k: v.to(device) for k, v in bert_batch.items() + } + bert_outputs = models["bert"](**bert_batch) + bert_loss = bert_outputs.loss + bert_loss.backward() + losses["bert"] = bert_loss.item() + + # For GPT2 + gpt2_batch = next(gpt2_iter) + gpt2_batch = { + k: v.to(device) for k, v in gpt2_batch.items() + } + gpt2_outputs = models["gpt2"](**gpt2_batch) + gpt2_loss = gpt2_outputs.loss + gpt2_loss.backward() + losses["gpt2"] = gpt2_loss.item() + + # For RoBERTa + roberta_batch = next(roberta_iter) + roberta_batch = { + k: v.to(device) for k, v in roberta_batch.items() + } + roberta_outputs = models["roberta"](**roberta_batch) + roberta_loss = roberta_outputs.loss + roberta_loss.backward() + losses["roberta"] = roberta_loss.item() + + # Log metrics + optimizer.log_metrics(losses) + + # Step the optimizer and scheduler + optimizer.step() + scheduler.step() + + # Update convergence rates periodically + if step % 100 == 0: + scheduler.update_convergence_rates( + optimizer.model_losses + ) + + # Evaluate periodically + if step > 0 and step % eval_every == 0: + logger.info(f"Evaluating at step {step}...") + + # Create a small evaluation set + eval_dataset = sst2_dataset["validation"].select( + range(100) + ) + + for model_name, model in models.items(): + model.eval() + + # Tokenize evaluation data based on model type + if model_name == "bert": + tokenizer = bert_tokenizer + tokenize_fn = tokenize_bert + elif model_name == "gpt2": + tokenizer = gpt2_tokenizer + tokenize_fn = tokenize_gpt2 + else: # roberta + tokenizer = roberta_tokenizer + tokenize_fn = tokenize_roberta + + eval_tokenized = eval_dataset.map( + tokenize_fn, batched=True + ) + eval_tokenized.set_format( + type="torch", + columns=[ + "input_ids", + "attention_mask", + "label", + ], + ) + + # Create dataloader + eval_collator = DataCollatorWithPadding( + tokenizer=tokenizer + ) + eval_dataloader = DataLoader( + eval_tokenized, + batch_size=16, + collate_fn=eval_collator, + ) + + # Evaluate + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for batch in eval_dataloader: + batch = { + k: v.to(device) + for k, v in batch.items() + } + outputs = model(**batch) + logits = outputs.logits + preds = ( + torch.argmax(logits, dim=-1) + .cpu() + .numpy() + ) + labels = batch["label"].cpu().numpy() + + all_preds.extend(preds) + all_labels.extend(labels) + + # Calculate accuracy + accuracy = accuracy_score( + all_labels, all_preds + ) + logger.info( + f"Model {model_name} - Accuracy: {accuracy:.4f}" + ) + + # Save best model + if accuracy > best_accuracy[model_name]: + best_accuracy[model_name] = accuracy + torch.save( + model.state_dict(), + f"best_{model_name}_model.pt", + ) + logger.info( + f"Saved new best {model_name} model with accuracy {accuracy:.4f}" + ) + + model.train() + + except RuntimeError as e: + logger.error( + f"Error during training step {step}: {e}" + ) + logger.error("Skipping this step and continuing...") + optimizer.zero_grad() + continue + + # Save checkpoint every 500 steps + if step > 0 and step % 500 == 0: + logger.info(f"Saving checkpoint at step {step}...") + torch.save( + { + "step": step, + "model_states": { + name: model.state_dict() + for name, model in models.items() + }, + "optimizer_state": optimizer.state_dict(), + "scheduler_state": scheduler.state_dict(), + "best_accuracy": best_accuracy, + }, + f"checkpoint_step_{step}.pt", + ) + + # Final evaluation and results + logger.info("=== Training complete! Final results ===") + for model_name, acc in best_accuracy.items(): + logger.info(f"Best {model_name} accuracy: {acc:.4f}") + + except Exception as e: + logger.error( + f"Fatal error in example_usage_with_real_data: {e}" + ) + import traceback + + logger.error(traceback.format_exc()) + + +if __name__ == "__main__": + # Use real data example by default + example_usage_with_real_data()