From 9abe300548a4cc39d53f8b08388d3df60235c2cb Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Mon, 16 Dec 2024 22:04:10 -0800 Subject: [PATCH] [FEAT][Agent][save][load] [FIX][openai assistants] --- auto_test_eval.py | 348 ++++++++ docs/mkdocs.yml | 36 +- example.py | 19 +- github_summarizer_agent.py | 189 ++++ new_features_examples/async_agent.py | 44 + .../openai_assistant_wrapper.py | 14 + pyproject.toml | 2 +- swarms/agents/openai_assistant.py | 186 ++-- swarms/structs/__init__.py | 16 - swarms/structs/agent.py | 823 +++++++++++------- swarms/structs/agent_memory_manager.py | 419 +++++++++ swarms/structs/async_workflow.py | 38 +- swarms/structs/safe_loading.py | 258 ++++++ swarms/tools/base_tool.py | 12 +- swarms/utils/litellm_wrapper.py | 5 +- swarms/utils/loguru_logger.py | 2 +- test_agent_features.py | 598 +++++++++++++ 17 files changed, 2544 insertions(+), 465 deletions(-) create mode 100644 auto_test_eval.py create mode 100644 github_summarizer_agent.py create mode 100644 new_features_examples/async_agent.py create mode 100644 new_features_examples/openai_assistant_wrapper.py create mode 100644 swarms/structs/agent_memory_manager.py create mode 100644 swarms/structs/safe_loading.py create mode 100644 test_agent_features.py diff --git a/auto_test_eval.py b/auto_test_eval.py new file mode 100644 index 00000000..fd282013 --- /dev/null +++ b/auto_test_eval.py @@ -0,0 +1,348 @@ +import json +import os +import platform +import sys +import traceback +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import psutil +import requests +from loguru import logger +from swarm_models import OpenAIChat + +from swarms.structs.agent import Agent + + +@dataclass +class SwarmSystemInfo: + """System information for Swarms issue reports.""" + + os_name: str + os_version: str + python_version: str + cpu_usage: float + memory_usage: float + disk_usage: float + swarms_version: str # Added Swarms version tracking + cuda_available: bool # Added CUDA availability check + gpu_info: Optional[str] # Added GPU information + + +class SwarmsIssueReporter: + """ + Production-grade GitHub issue reporter specifically designed for the Swarms library. + Automatically creates detailed issues for the https://github.com/kyegomez/swarms repository. + + Features: + - Swarms-specific error categorization + - Automatic version and dependency tracking + - CUDA and GPU information collection + - Integration with Swarms logging system + - Detailed environment information + """ + + REPO_OWNER = "kyegomez" + REPO_NAME = "swarms" + ISSUE_CATEGORIES = { + "agent": ["agent", "automation"], + "memory": ["memory", "storage"], + "tool": ["tools", "integration"], + "llm": ["llm", "model"], + "performance": ["performance", "optimization"], + "compatibility": ["compatibility", "environment"], + } + + def __init__( + self, + github_token: str, + rate_limit: int = 10, + rate_period: int = 3600, + log_file: str = "swarms_issues.log", + enable_duplicate_check: bool = True, + ): + """ + Initialize the Swarms Issue Reporter. + + Args: + github_token (str): GitHub personal access token + rate_limit (int): Maximum number of issues to create per rate_period + rate_period (int): Time period for rate limiting in seconds + log_file (str): Path to log file + enable_duplicate_check (bool): Whether to check for duplicate issues + """ + self.github_token = github_token + self.rate_limit = rate_limit + self.rate_period = rate_period + self.enable_duplicate_check = enable_duplicate_check + self.github_token = os.getenv("GITHUB_API_KEY") + + # Initialize logging + log_path = os.path.join(os.getcwd(), "logs", log_file) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + logger.add( + log_path, + rotation="1 day", + retention="1 month", + compression="zip", + ) + + # Issue tracking + self.issues_created = [] + self.last_issue_time = datetime.now() + + # Validate GitHub token + # self._validate_github_credentials() + + def _get_swarms_version(self) -> str: + """Get the installed version of Swarms.""" + try: + import swarms + + return swarms.__version__ + except: + return "Unknown" + + def _get_gpu_info(self) -> Tuple[bool, Optional[str]]: + """Get GPU information and CUDA availability.""" + try: + import torch + + cuda_available = torch.cuda.is_available() + if cuda_available: + gpu_info = torch.cuda.get_device_name(0) + return cuda_available, gpu_info + return False, None + except: + return False, None + + def _get_system_info(self) -> SwarmSystemInfo: + """Collect system and Swarms-specific information.""" + cuda_available, gpu_info = self._get_gpu_info() + + return SwarmSystemInfo( + os_name=platform.system(), + os_version=platform.version(), + python_version=sys.version, + cpu_usage=psutil.cpu_percent(), + memory_usage=psutil.virtual_memory().percent, + disk_usage=psutil.disk_usage("/").percent, + swarms_version=self._get_swarms_version(), + cuda_available=cuda_available, + gpu_info=gpu_info, + ) + + def _categorize_error( + self, error: Exception, context: Dict + ) -> List[str]: + """Categorize the error and return appropriate labels.""" + error_str = str(error).lower() + type(error).__name__ + + labels = ["bug", "automated"] + + # Check error message and context for category keywords + for ( + category, + category_labels, + ) in self.ISSUE_CATEGORIES.items(): + if any( + keyword in error_str for keyword in category_labels + ): + labels.extend(category_labels) + break + + # Add severity label based on error type + if issubclass(type(error), (SystemError, MemoryError)): + labels.append("severity:critical") + elif issubclass(type(error), (ValueError, TypeError)): + labels.append("severity:medium") + else: + labels.append("severity:low") + + return list(set(labels)) # Remove duplicates + + def _format_swarms_issue_body( + self, + error: Exception, + system_info: SwarmSystemInfo, + context: Dict, + ) -> str: + """Format the issue body with Swarms-specific information.""" + return f""" + ## Swarms Error Report + - **Error Type**: {type(error).__name__} + - **Error Message**: {str(error)} + - **Swarms Version**: {system_info.swarms_version} + + ## Environment Information + - **OS**: {system_info.os_name} {system_info.os_version} + - **Python Version**: {system_info.python_version} + - **CUDA Available**: {system_info.cuda_available} + - **GPU**: {system_info.gpu_info or "N/A"} + - **CPU Usage**: {system_info.cpu_usage}% + - **Memory Usage**: {system_info.memory_usage}% + - **Disk Usage**: {system_info.disk_usage}% + + ## Stack Trace + {traceback.format_exc()} + + ## Context + {json.dumps(context, indent=2)} + + ## Dependencies + {self._get_dependencies_info()} + + ## Time of Occurrence + {datetime.now().isoformat()} + + --- + *This issue was automatically generated by SwarmsIssueReporter* + """ + + def _get_dependencies_info(self) -> str: + """Get information about installed dependencies.""" + try: + import pkg_resources + + deps = [] + for dist in pkg_resources.working_set: + deps.append(f"- {dist.key} {dist.version}") + return "\n".join(deps) + except: + return "Unable to fetch dependency information" + + # First, add this method to your SwarmsIssueReporter class + def _check_rate_limit(self) -> bool: + """Check if we're within rate limits.""" + now = datetime.now() + time_diff = (now - self.last_issue_time).total_seconds() + + if ( + len(self.issues_created) >= self.rate_limit + and time_diff < self.rate_period + ): + logger.warning("Rate limit exceeded for issue creation") + return False + + # Clean up old issues from tracking + self.issues_created = [ + time + for time in self.issues_created + if (now - time).total_seconds() < self.rate_period + ] + + return True + + def report_swarms_issue( + self, + error: Exception, + agent: Optional[Agent] = None, + context: Dict[str, Any] = None, + priority: str = "normal", + ) -> Optional[int]: + """ + Report a Swarms-specific issue to GitHub. + + Args: + error (Exception): The exception to report + agent (Optional[Agent]): The Swarms agent instance that encountered the error + context (Dict[str, Any]): Additional context about the error + priority (str): Issue priority ("low", "normal", "high", "critical") + + Returns: + Optional[int]: Issue number if created successfully + """ + try: + if not self._check_rate_limit(): + logger.warning( + "Skipping issue creation due to rate limit" + ) + return None + + # Collect system information + system_info = self._get_system_info() + + # Prepare context with agent information if available + full_context = context or {} + if agent: + full_context.update( + { + "agent_name": agent.agent_name, + "agent_description": agent.agent_description, + "max_loops": agent.max_loops, + "context_length": agent.context_length, + } + ) + + # Create issue title + title = f"[{type(error).__name__}] {str(error)[:100]}" + if agent: + title = f"[Agent: {agent.agent_name}] {title}" + + # Get appropriate labels + labels = self._categorize_error(error, full_context) + labels.append(f"priority:{priority}") + + # Create the issue + url = f"https://api.github.com/repos/{self.REPO_OWNER}/{self.REPO_NAME}/issues" + data = { + "title": title, + "body": self._format_swarms_issue_body( + error, system_info, full_context + ), + "labels": labels, + } + + response = requests.post( + url, + headers={ + "Authorization": f"token {self.github_token}" + }, + json=data, + ) + response.raise_for_status() + + issue_number = response.json()["number"] + logger.info( + f"Successfully created Swarms issue #{issue_number}" + ) + + return issue_number + + except Exception as e: + logger.error(f"Error creating Swarms issue: {str(e)}") + return None + + +# from swarms import Agent +# from swarm_models import OpenAIChat +# from swarms.utils.issue_reporter import SwarmsIssueReporter +# import os + +# Setup the reporter with your GitHub token +reporter = SwarmsIssueReporter( + github_token=os.getenv("GITHUB_API_KEY") +) + + +# Force an error to test the reporter +try: + # This will raise an error since the input isn't valid + # Create an agent that might have issues + model = OpenAIChat(model_name="gpt-4o") + agent = Agent(agent_name="Test-Agent", max_loops=1) + + result = agent.run(None) + + raise ValueError("test") +except Exception as e: + # Report the issue + issue_number = reporter.report_swarms_issue( + error=e, + agent=agent, + context={"task": "test_run"}, + priority="high", + ) + print(f"Created issue number: {issue_number}") diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 0f04373f..44029f92 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -192,15 +192,6 @@ nav: - Conversation: "swarms/structs/conversation.md" # - Task: "swarms/structs/task.md" - Full API Reference: "swarms/framework/reference.md" - - Contributing: - - Contributing: "swarms/contributing.md" - - Tests: "swarms/framework/test.md" - - Code Cleanliness: "swarms/framework/code_cleanliness.md" - - Philosophy: "swarms/concept/philosophy.md" - - Changelog: - - Swarms 5.6.8: "swarms/changelog/5_6_8.md" - - Swarms 5.8.1: "swarms/changelog/5_8_1.md" - - Swarms 5.9.2: "swarms/changelog/changelog_new.md" - Swarm Models: - Overview: "swarms/models/index.md" # - Models Available: "swarms/models/index.md" @@ -254,16 +245,25 @@ nav: # - Tools API: # - Overview: "swarms_platform/tools_api.md" # - Add Tools: "swarms_platform/fetch_tools.md" - - Guides: - - Unlocking Efficiency and Cost Savings in Healthcare; How Swarms of LLM Agents Can Revolutionize Medical Operations and Save Millions: "guides/healthcare_blog.md" - - Understanding Agent Evaluation Mechanisms: "guides/agent_evals.md" - - Agent Glossary: "swarms/glossary.md" - - The Ultimate Technical Guide to the Swarms CLI; A Step-by-Step Developers Guide: "swarms/cli/cli_guide.md" - - Prompting Guide: - - The Essence of Enterprise-Grade Prompting: "swarms/prompts/essence.md" - - An Analysis on Prompting Strategies: "swarms/prompts/overview.md" - - Managing Prompts in Production: "swarms/prompts/main.md" + # - Guides: + # - Unlocking Efficiency and Cost Savings in Healthcare; How Swarms of LLM Agents Can Revolutionize Medical Operations and Save Millions: "guides/healthcare_blog.md" + # - Understanding Agent Evaluation Mechanisms: "guides/agent_evals.md" + # - Agent Glossary: "swarms/glossary.md" + # - The Ultimate Technical Guide to the Swarms CLI; A Step-by-Step Developers Guide: "swarms/cli/cli_guide.md" + # - Prompting Guide: + # - The Essence of Enterprise-Grade Prompting: "swarms/prompts/essence.md" + # - An Analysis on Prompting Strategies: "swarms/prompts/overview.md" + # - Managing Prompts in Production: "swarms/prompts/main.md" - Community: + - Contributing: + - Contributing: "swarms/contributing.md" + - Tests: "swarms/framework/test.md" + - Code Cleanliness: "swarms/framework/code_cleanliness.md" + - Philosophy: "swarms/concept/philosophy.md" + - Changelog: + - Swarms 5.6.8: "swarms/changelog/5_6_8.md" + - Swarms 5.8.1: "swarms/changelog/5_8_1.md" + - Swarms 5.9.2: "swarms/changelog/changelog_new.md" - Bounty Program: "corporate/bounty_program.md" - Corporate: - Culture: "corporate/culture.md" diff --git a/example.py b/example.py index 76c23353..362c4f59 100644 --- a/example.py +++ b/example.py @@ -2,6 +2,10 @@ from swarms import Agent from swarms.prompts.finance_agent_sys_prompt import ( FINANCIAL_AGENT_SYS_PROMPT, ) +from swarm_models import OpenAIChat + +model = OpenAIChat(model_name="gpt-4o") + # Initialize the agent agent = Agent( @@ -9,20 +13,21 @@ agent = Agent( agent_description="Personal finance advisor agent", system_prompt=FINANCIAL_AGENT_SYS_PROMPT + "Output the token when you're done creating a portfolio of etfs, index, funds, and more for AI", - model_name="gpt-4o", # Use any model from litellm - max_loops="auto", + max_loops=1, + llm=model, dynamic_temperature_enabled=True, user_name="Kye", retry_attempts=3, - streaming_on=True, - context_length=16000, + # streaming_on=True, + context_length=8192, return_step_meta=False, output_type="str", # "json", "dict", "csv" OR "string" "yaml" and auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task - max_tokens=16000, # max output tokens - interactive=True, + max_tokens=4000, # max output tokens + # interactive=True, stopping_token="", - execute_tool=True, + saved_state_path="agent_00.json", + interactive=False, ) agent.run( diff --git a/github_summarizer_agent.py b/github_summarizer_agent.py new file mode 100644 index 00000000..c461c307 --- /dev/null +++ b/github_summarizer_agent.py @@ -0,0 +1,189 @@ +import requests +import datetime +from typing import List, Dict, Tuple +from loguru import logger +from swarms import Agent +from swarm_models import OpenAIChat + +# GitHub API Configurations +GITHUB_REPO = "kyegomez/swarms" # Swarms GitHub repository +GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/commits" + +# Initialize Loguru +logger.add( + "commit_summary.log", + rotation="1 MB", + level="INFO", + backtrace=True, + diagnose=True, +) + + +# Step 1: Fetch the latest commits from GitHub +def fetch_latest_commits( + repo_url: str, limit: int = 5 +) -> List[Dict[str, str]]: + """ + Fetch the latest commits from a public GitHub repository. + """ + logger.info( + f"Fetching the latest {limit} commits from {repo_url}" + ) + try: + params = {"per_page": limit} + response = requests.get(repo_url, params=params) + response.raise_for_status() + + commits = response.json() + commit_data = [] + + for commit in commits: + commit_data.append( + { + "sha": commit["sha"][:7], # Short commit hash + "author": commit["commit"]["author"]["name"], + "message": commit["commit"]["message"], + "date": commit["commit"]["author"]["date"], + } + ) + + logger.success("Successfully fetched commit data") + return commit_data + + except Exception as e: + logger.error(f"Error fetching commits: {e}") + raise + + +# Step 2: Format commits and fetch current time +def format_commits_with_time( + commits: List[Dict[str, str]] +) -> Tuple[str, str]: + """ + Format commit data into a readable string and return current time. + """ + current_time = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + logger.info(f"Formatting commits at {current_time}") + + commit_summary = "\n".join( + [ + f"- `{commit['sha']}` by {commit['author']} on {commit['date']}: {commit['message']}" + for commit in commits + ] + ) + + logger.success("Commits formatted successfully") + return current_time, commit_summary + + +# Step 3: Build a dynamic system prompt +def build_custom_system_prompt( + current_time: str, commit_summary: str +) -> str: + """ + Build a dynamic system prompt with the current time and commit summary. + """ + logger.info("Building the custom system prompt for the agent") + prompt = f""" +You are a software analyst tasked with summarizing the latest commits from the Swarms GitHub repository. + +The current time is **{current_time}**. + +Here are the latest commits: +{commit_summary} + +**Your task**: +1. Summarize the changes into a clear and concise table in **markdown format**. +2. Highlight the key improvements and fixes. +3. End your output with the token ``. + +Make sure the table includes the following columns: Commit SHA, Author, Date, and Commit Message. +""" + logger.success("System prompt created successfully") + return prompt + + +# Step 4: Initialize the Agent +def initialize_agent() -> Agent: + """ + Initialize the Swarms agent with OpenAI model. + """ + logger.info("Initializing the agent with GPT-4o") + model = OpenAIChat(model_name="gpt-4o") + + agent = Agent( + agent_name="Commit-Summarization-Agent", + agent_description="Fetch and summarize GitHub commits for Swarms repository.", + system_prompt="", # Will set dynamically + max_loops=1, + llm=model, + dynamic_temperature_enabled=True, + user_name="Kye", + retry_attempts=3, + context_length=8192, + return_step_meta=False, + output_type="str", + auto_generate_prompt=False, + max_tokens=4000, + stopping_token="", + interactive=False, + ) + logger.success("Agent initialized successfully") + return agent + + +# Step 5: Run the Agent with Data +def summarize_commits_with_agent(agent: Agent, prompt: str) -> str: + """ + Pass the system prompt to the agent and fetch the result. + """ + logger.info("Sending data to the agent for summarization") + try: + result = agent.run( + f"{prompt}", + all_cores=True, + ) + logger.success("Agent completed the summarization task") + return result + except Exception as e: + logger.error(f"Agent encountered an error: {e}") + raise + + +# Main Execution +if __name__ == "__main__": + try: + logger.info("Starting commit summarization process") + + # Fetch latest commits + latest_commits = fetch_latest_commits(GITHUB_API_URL, limit=5) + + # Format commits and get current time + current_time, commit_summary = format_commits_with_time( + latest_commits + ) + + # Build the custom system prompt + custom_system_prompt = build_custom_system_prompt( + current_time, commit_summary + ) + + # Initialize agent + agent = initialize_agent() + + # Set the dynamic system prompt + agent.system_prompt = custom_system_prompt + + # Run the agent and summarize commits + result = summarize_commits_with_agent( + agent, custom_system_prompt + ) + + # Print the result + print("### Commit Summary in Markdown:") + print(result) + + except Exception as e: + logger.critical(f"Process failed: {e}") diff --git a/new_features_examples/async_agent.py b/new_features_examples/async_agent.py new file mode 100644 index 00000000..5c23a8b8 --- /dev/null +++ b/new_features_examples/async_agent.py @@ -0,0 +1,44 @@ +from swarms import Agent +from swarms.prompts.finance_agent_sys_prompt import ( + FINANCIAL_AGENT_SYS_PROMPT, +) +from swarm_models import OpenAIChat + +model = OpenAIChat(model_name="gpt-4o") + + +# Initialize the agent +agent = Agent( + agent_name="Financial-Analysis-Agent", + agent_description="Personal finance advisor agent", + system_prompt=FINANCIAL_AGENT_SYS_PROMPT + + "Output the token when you're done creating a portfolio of etfs, index, funds, and more for AI", + max_loops=1, + llm=model, + dynamic_temperature_enabled=True, + user_name="Kye", + retry_attempts=3, + # streaming_on=True, + context_length=8192, + return_step_meta=False, + output_type="str", # "json", "dict", "csv" OR "string" "yaml" and + auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task + max_tokens=4000, # max output tokens + # interactive=True, + stopping_token="", + saved_state_path="agent_00.json", + interactive=False, +) + + +async def run_agent(): + await agent.arun( + "Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.", + all_cores=True, + ) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(run_agent()) diff --git a/new_features_examples/openai_assistant_wrapper.py b/new_features_examples/openai_assistant_wrapper.py new file mode 100644 index 00000000..2944ec11 --- /dev/null +++ b/new_features_examples/openai_assistant_wrapper.py @@ -0,0 +1,14 @@ +from swarms.prompts.finance_agent_sys_prompt import ( + FINANCIAL_AGENT_SYS_PROMPT, +) +from swarms.agents.openai_assistant import OpenAIAssistant + +agent = OpenAIAssistant( + name="test", instructions=FINANCIAL_AGENT_SYS_PROMPT +) + +print( + agent.run( + "Create a table of super high growth opportunities for AI. I have $40k to invest in ETFs, index funds, and more. Please create a table in markdown.", + ) +) diff --git a/pyproject.toml b/pyproject.toml index 90a05195..541bcbac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "6.6.9" +version = "6.7.0" description = "Swarms - TGSC" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/agents/openai_assistant.py b/swarms/agents/openai_assistant.py index acedf362..2a29e1bf 100644 --- a/swarms/agents/openai_assistant.py +++ b/swarms/agents/openai_assistant.py @@ -1,14 +1,47 @@ -from typing import Optional, List, Dict, Any, Callable +import json +import os +import subprocess +import sys import time -from openai import OpenAI +from typing import Any, Callable, Dict, List, Optional + +from loguru import logger + from swarms.structs.agent import Agent -import json + + +def check_openai_package(): + """Check if the OpenAI package is installed, and install it if not.""" + try: + import openai + + return openai + except ImportError: + logger.info( + "OpenAI package not found. Attempting to install..." + ) + + # Attempt to install the OpenAI package + try: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "openai"] + ) + logger.info("OpenAI package installed successfully.") + import openai # Re-import the package after installation + + return openai + except subprocess.CalledProcessError as e: + logger.error(f"Failed to install OpenAI package: {e}") + raise RuntimeError( + "OpenAI package installation failed." + ) from e + class OpenAIAssistant(Agent): """ OpenAI Assistant wrapper for the swarms framework. Integrates OpenAI's Assistants API with the swarms architecture. - + Example: >>> assistant = OpenAIAssistant( ... name="Math Tutor", @@ -22,6 +55,7 @@ class OpenAIAssistant(Agent): def __init__( self, name: str, + description: str = "Standard openai assistant wrapper", instructions: Optional[str] = None, model: str = "gpt-4o", tools: Optional[List[Dict[str, Any]]] = None, @@ -29,47 +63,63 @@ class OpenAIAssistant(Agent): metadata: Optional[Dict[str, Any]] = None, functions: Optional[List[Dict[str, Any]]] = None, *args, - **kwargs + **kwargs, ): """Initialize the OpenAI Assistant. Args: name: Name of the assistant instructions: System instructions for the assistant - model: Model to use (default: gpt-4-turbo-preview) + model: Model to use (default: gpt-4o) tools: List of tools to enable (code_interpreter, retrieval) file_ids: List of file IDs to attach metadata: Additional metadata functions: List of custom functions to make available """ + self.name = name + self.description = description + self.instructions = instructions + self.model = model + self.tools = tools + self.file_ids = file_ids + self.metadata = metadata + self.functions = functions + super().__init__(*args, **kwargs) - + # Initialize tools list with any provided functions self.tools = tools or [] if functions: for func in functions: - self.tools.append({ - "type": "function", - "function": func - }) + self.tools.append( + {"type": "function", "function": func} + ) # Create the OpenAI Assistant - self.client = OpenAI() + openai = check_openai_package() + self.client = openai.OpenAI( + api_key=os.getenv("OPENAI_API_KEY") + ) self.assistant = self.client.beta.assistants.create( name=name, instructions=instructions, model=model, tools=self.tools, - file_ids=file_ids or [], - metadata=metadata or {} + # file_ids=file_ids or [], + metadata=metadata or {}, ) - + # Store available functions self.available_functions: Dict[str, Callable] = {} - - def add_function(self, func: Callable, description: str, parameters: Dict[str, Any]) -> None: + + def add_function( + self, + func: Callable, + description: str, + parameters: Dict[str, Any], + ) -> None: """Add a function that the assistant can call. - + Args: func: The function to make available to the assistant description: Description of what the function does @@ -78,27 +128,23 @@ class OpenAIAssistant(Agent): func_dict = { "name": func.__name__, "description": description, - "parameters": parameters + "parameters": parameters, } - + # Add to tools list - self.tools.append({ - "type": "function", - "function": func_dict - }) - + self.tools.append({"type": "function", "function": func_dict}) + # Store function reference self.available_functions[func.__name__] = func - + # Update assistant with new tools self.assistant = self.client.beta.assistants.update( - assistant_id=self.assistant.id, - tools=self.tools + assistant_id=self.assistant.id, tools=self.tools ) def _handle_tool_calls(self, run, thread_id: str) -> None: """Handle any required tool calls during a run. - + This method processes any tool calls required by the assistant during execution. It extracts function calls, executes them with provided arguments, and submits the results back to the assistant. @@ -114,38 +160,46 @@ class OpenAIAssistant(Agent): Exception: If there are errors executing the tool calls """ while run.status == "requires_action": - tool_calls = run.required_action.submit_tool_outputs.tool_calls + tool_calls = ( + run.required_action.submit_tool_outputs.tool_calls + ) tool_outputs = [] - + for tool_call in tool_calls: if tool_call.type == "function": # Get function details function_name = tool_call.function.name - function_args = json.loads(tool_call.function.arguments) - + function_args = json.loads( + tool_call.function.arguments + ) + # Call function if available if function_name in self.available_functions: - function_response = self.available_functions[function_name](**function_args) - tool_outputs.append({ - "tool_call_id": tool_call.id, - "output": str(function_response) - }) - + function_response = self.available_functions[ + function_name + ](**function_args) + tool_outputs.append( + { + "tool_call_id": tool_call.id, + "output": str(function_response), + } + ) + # Submit outputs back to the run run = self.client.beta.threads.runs.submit_tool_outputs( thread_id=thread_id, run_id=run.id, - tool_outputs=tool_outputs + tool_outputs=tool_outputs, ) - + # Wait for processing run = self._wait_for_run(run) - + return run def _wait_for_run(self, run) -> Any: """Wait for a run to complete and handle any required actions. - + This method polls the OpenAI API to check the status of a run until it completes or fails. It handles intermediate states like required actions and implements exponential backoff. @@ -161,10 +215,9 @@ class OpenAIAssistant(Agent): """ while True: run = self.client.beta.threads.runs.retrieve( - thread_id=run.thread_id, - run_id=run.id + thread_id=run.thread_id, run_id=run.id ) - + if run.status == "completed": break elif run.status == "requires_action": @@ -172,15 +225,17 @@ class OpenAIAssistant(Agent): if run.status == "completed": break elif run.status in ["failed", "expired"]: - raise Exception(f"Run failed with status: {run.status}") - + raise Exception( + f"Run failed with status: {run.status}" + ) + time.sleep(3) # Wait 3 seconds before checking again - + return run def _ensure_thread(self): """Ensure a thread exists for the conversation. - + This method checks if there is an active thread for the current conversation. If no thread exists, it creates a new one. This maintains conversation context across multiple interactions. @@ -188,10 +243,11 @@ class OpenAIAssistant(Agent): Side Effects: Sets self.thread if it doesn't exist """ - if not self.thread: - self.thread = self.client.beta.threads.create() + self.thread = self.client.beta.threads.create() - def add_message(self, content: str, file_ids: Optional[List[str]] = None) -> None: + def add_message( + self, content: str, file_ids: Optional[List[str]] = None + ) -> None: """Add a message to the thread. This method adds a new user message to the conversation thread. It ensures @@ -211,20 +267,18 @@ class OpenAIAssistant(Agent): thread_id=self.thread.id, role="user", content=content, - file_ids=file_ids or [] + # file_ids=file_ids or [], ) def _get_response(self) -> str: """Get the latest assistant response from the thread.""" messages = self.client.beta.threads.messages.list( - thread_id=self.thread.id, - order="desc", - limit=1 + thread_id=self.thread.id, order="desc", limit=1 ) - + if not messages.data: return "" - + message = messages.data[0] if message.role == "assistant": return message.content[0].text.value @@ -235,25 +289,25 @@ class OpenAIAssistant(Agent): Args: task: The task or prompt to send to the assistant - + Returns: The assistant's response as a string """ self._ensure_thread() - + # Add the user message self.add_message(task) - + # Create and run the assistant run = self.client.beta.threads.runs.create( thread_id=self.thread.id, assistant_id=self.assistant.id, - instructions=self.instructions + instructions=self.instructions, ) - + # Wait for completion run = self._wait_for_run(run) - + # Only get and return the response if run completed successfully if run.status == "completed": return self._get_response() @@ -261,4 +315,4 @@ class OpenAIAssistant(Agent): def call(self, task: str, *args, **kwargs) -> str: """Alias for run() to maintain compatibility with different agent interfaces.""" - return self.run(task, *args, **kwargs) \ No newline at end of file + return self.run(task, *args, **kwargs) diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index 5b85864f..16a93512 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -73,22 +73,6 @@ from swarms.structs.utils import ( find_token_in_text, parse_tasks, ) -from swarms.structs.swarm_router import ( - SwarmRouter, - SwarmType, - swarm_router, -) -from swarms.structs.swarm_arange import SwarmRearrange -from swarms.structs.multi_agent_exec import ( - run_agents_concurrently, - run_agents_concurrently_async, - run_single_agent, - run_agents_concurrently_multiprocess, - run_agents_sequentially, - run_agents_with_different_tasks, - run_agent_with_timeout, - run_agents_with_resource_monitoring, -) from swarms.structs.async_workflow import AsyncWorkflow __all__ = [ diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index b9df9157..caedb951 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -1,14 +1,13 @@ -from datetime import datetime import asyncio import json import logging import os import random -import sys import threading import time import uuid from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from typing import ( Any, Callable, @@ -22,9 +21,14 @@ from typing import ( import toml import yaml + +# from swarms.utils.loguru_logger import initialize_logger +from loguru import logger from pydantic import BaseModel from swarm_models.tiktoken_wrapper import TikTokenizer + from swarms.agents.ape_agent import auto_generate_prompt +from swarms.artifacts.main_artifact import Artifact from swarms.prompts.agent_system_prompts import AGENT_SYSTEM_PROMPT_3 from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, @@ -38,22 +42,19 @@ from swarms.schemas.base_schemas import ( ) from swarms.structs.concat import concat_strings from swarms.structs.conversation import Conversation -from swarms.tools.base_tool import BaseTool -from swarms.tools.func_calling_utils import ( - prepare_output_for_output_model, +from swarms.structs.safe_loading import ( + SafeLoaderUtils, + SafeStateManager, ) +from swarms.tools.base_tool import BaseTool from swarms.tools.tool_parse_exec import parse_and_execute_json from swarms.utils.data_to_text import data_to_text from swarms.utils.file_processing import create_file_in_folder +from swarms.utils.formatter import formatter from swarms.utils.pdf_to_text import pdf_to_text -from swarms.artifacts.main_artifact import Artifact -from swarms.utils.loguru_logger import initialize_logger from swarms.utils.wrapper_clusterop import ( exec_callable_with_clusterops, ) -from swarms.utils.formatter import formatter - -logger = initialize_logger(log_folder="agents") # Utils @@ -136,7 +137,6 @@ class Agent: callback (Callable): The callback function metadata (Dict[str, Any]): The metadata callbacks (List[Callable]): The list of callback functions - logger_handler (Any): The logger handler search_algorithm (Callable): The search algorithm logs_to_filename (str): The filename for the logs evaluator (Callable): The evaluator function @@ -271,7 +271,6 @@ class Agent: callback: Optional[Callable] = None, metadata: Optional[Dict[str, Any]] = None, callbacks: Optional[List[Callable]] = None, - logger_handler: Optional[Any] = sys.stderr, search_algorithm: Optional[Callable] = None, logs_to_filename: Optional[str] = None, evaluator: Optional[Callable] = None, # Custom LLM or agent @@ -297,7 +296,6 @@ class Agent: algorithm_of_thoughts: bool = False, tree_of_thoughts: bool = False, tool_choice: str = "auto", - execute_tool: bool = False, rules: str = None, # type: ignore planning: Optional[str] = False, planning_prompt: Optional[str] = None, @@ -319,7 +317,7 @@ class Agent: use_cases: Optional[List[Dict[str, str]]] = None, step_pool: List[Step] = [], print_every_step: Optional[bool] = False, - time_created: Optional[float] = time.strftime( + time_created: Optional[str] = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime() ), agent_output: ManySteps = None, @@ -340,6 +338,7 @@ class Agent: all_gpus: bool = False, model_name: str = None, llm_args: dict = None, + load_state_path: str = None, *args, **kwargs, ): @@ -390,7 +389,6 @@ class Agent: self.callback = callback self.metadata = metadata self.callbacks = callbacks - self.logger_handler = logger_handler self.search_algorithm = search_algorithm self.logs_to_filename = logs_to_filename self.evaluator = evaluator @@ -414,7 +412,6 @@ class Agent: self.algorithm_of_thoughts = algorithm_of_thoughts self.tree_of_thoughts = tree_of_thoughts self.tool_choice = tool_choice - self.execute_tool = execute_tool self.planning = planning self.planning_prompt = planning_prompt self.custom_planning_prompt = custom_planning_prompt @@ -457,6 +454,7 @@ class Agent: self.all_gpus = all_gpus self.model_name = model_name self.llm_args = llm_args + self.load_state_path = load_state_path # Initialize the short term memory self.short_memory = Conversation( @@ -546,19 +544,6 @@ class Agent: tool.__name__: tool for tool in tools } - # Set the logger handler - if exists(logger_handler): - log_file_path = os.path.join( - self.workspace_dir, f"{self.agent_name}.log" - ) - logger.add( - log_file_path, - level="INFO", - colorize=True, - backtrace=True, - diagnose=True, - ) - # If the tool schema exists or a list of base models exists then convert the tool schema into an openai schema if exists(tool_schema) or exists(list_base_models): threading.Thread( @@ -593,20 +578,23 @@ class Agent: # Telemetry Processor to log agent data threading.Thread(target=self.log_agent_data).start() - threading.Thread(target=self.llm_handling()) + if self.llm is not None and self.model_name is not None: + self.llm = self.llm_handling() def llm_handling(self): + from swarms.utils.litellm_wrapper import LiteLLM - if self.llm is None: - from swarms.utils.litellm_wrapper import LiteLLM + if self.llm_args is not None: + llm = LiteLLM(model_name=self.model_name, **self.llm_args) - if self.llm_args is not None: - self.llm = LiteLLM( - model_name=self.model_name, **self.llm_args - ) + else: + llm = LiteLLM( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) - else: - self.llm = LiteLLM(model_name=self.model_name) + return llm def check_if_no_prompt_then_autogenerate(self, task: str = None): """ @@ -820,6 +808,9 @@ class Agent: # Print the user's request + if self.autosave: + self.save() + # Print the request if print_task is True: formatter.print_panel( @@ -904,13 +895,6 @@ class Agent: # Check and execute tools if self.tools is not None: self.parse_and_execute_tools(response) - # if tool_result: - # self.update_tool_usage( - # step_meta["step_id"], - # tool_result["tool"], - # tool_result["args"], - # tool_result["response"], - # ) # Add the response to the memory self.short_memory.add( @@ -944,6 +928,12 @@ class Agent: success = True # Mark as successful to exit the retry loop except Exception as e: + + self.log_agent_data() + + if self.autosave is True: + self.save() + logger.error( f"Attempt {attempt+1}: Error generating" f" response: {e}" @@ -951,6 +941,12 @@ class Agent: attempt += 1 if not success: + + self.log_agent_data() + + if self.autosave is True: + self.save() + logger.error( "Failed to generate a valid response after" " retry attempts." @@ -994,8 +990,10 @@ class Agent: time.sleep(self.loop_interval) if self.autosave is True: - logger.info("Autosaving agent state.") - self.save_state() + self.log_agent_data() + + if self.autosave is True: + self.save() # Apply the cleaner function to the response if self.output_cleaner is not None: @@ -1037,10 +1035,9 @@ class Agent: self.artifacts_file_extension, ) - try: - self.log_agent_data() - except Exception: - pass + self.log_agent_data() + if self.autosave is True: + self.save() # More flexible output types if ( @@ -1050,7 +1047,10 @@ class Agent: return concat_strings(all_responses) elif self.output_type == "list": return all_responses - elif self.output_type == "json": + elif ( + self.output_type == "json" + or self.return_step_meta is True + ): return self.agent_output.model_dump_json(indent=4) elif self.output_type == "csv": return self.dict_to_csv( @@ -1062,8 +1062,6 @@ class Agent: return yaml.safe_dump( self.agent_output.model_dump(), sort_keys=False ) - elif self.return_step_meta is True: - return self.agent_output.model_dump_json(indent=4) elif self.return_history is True: history = self.short_memory.get_str() @@ -1077,18 +1075,74 @@ class Agent: ) except Exception as error: - self.log_agent_data() - logger.info( - f"Error running agent: {error} optimize your input parameters" - ) - raise error + self._handle_run_error(error) except KeyboardInterrupt as error: - self.log_agent_data() - logger.info( - f"Error running agent: {error} optimize your input parameters" + self._handle_run_error(error) + + def _handle_run_error(self, error: any): + self.log_agent_data() + + if self.autosave is True: + self.save() + + logger.info( + f"Error detected running your agent {self.agent_name} \n Error {error} \n Optimize your input parameters and or add an issue on the swarms github and contact our team on discord for support ;) " + ) + raise error + + async def arun( + self, + task: Optional[str] = None, + img: Optional[str] = None, + is_last: bool = False, + device: str = "cpu", # gpu + device_id: int = 1, + all_cores: bool = True, + do_not_use_cluster_ops: bool = True, + all_gpus: bool = False, + *args, + **kwargs, + ) -> Any: + """ + Asynchronously runs the agent with the specified parameters. + + Args: + task (Optional[str]): The task to be performed. Defaults to None. + img (Optional[str]): The image to be processed. Defaults to None. + is_last (bool): Indicates if this is the last task. Defaults to False. + device (str): The device to use for execution. Defaults to "cpu". + device_id (int): The ID of the GPU to use if device is set to "gpu". Defaults to 1. + all_cores (bool): If True, uses all available CPU cores. Defaults to True. + do_not_use_cluster_ops (bool): If True, does not use cluster operations. Defaults to True. + all_gpus (bool): If True, uses all available GPUs. Defaults to False. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + Any: The result of the asynchronous operation. + + Raises: + Exception: If an error occurs during the asynchronous operation. + """ + try: + return await asyncio.to_thread( + self.run, + task=task, + img=img, + is_last=is_last, + device=device, + device_id=device_id, + all_cores=all_cores, + do_not_use_cluster_ops=do_not_use_cluster_ops, + all_gpus=all_gpus, + *args, + **kwargs, ) - raise error + except Exception as error: + await self._handle_run_error( + error + ) # Ensure this is also async if needed def __call__( self, @@ -1096,8 +1150,10 @@ class Agent: img: Optional[str] = None, is_last: bool = False, device: str = "cpu", # gpu - device_id: int = 0, + device_id: int = 1, all_cores: bool = True, + do_not_use_cluster_ops: bool = True, + all_gpus: bool = False, *args, **kwargs, ) -> Any: @@ -1112,33 +1168,19 @@ class Agent: all_cores (bool): If True, uses all available CPU cores. Defaults to True. """ try: - if task is not None: - return self.run( - task=task, - is_last=is_last, - device=device, - device_id=device_id, - all_cores=all_cores, - *args, - **kwargs, - ) - elif img is not None: - return self.run( - img=img, - is_last=is_last, - device=device, - device_id=device_id, - all_cores=all_cores, - *args, - **kwargs, - ) - else: - raise ValueError( - "Either 'task' or 'img' must be provided." - ) + return self.run( + task=task, + img=img, + is_last=is_last, + device=device, + device_id=device_id, + all_cores=all_cores, + do_not_use_cluster_ops=do_not_use_cluster_ops, + all_gpus=all_gpus * args, + **kwargs, + ) except Exception as error: - logger.error(f"Error calling agent: {error}") - raise error + self._handle_run_error(error) def dict_to_csv(self, data: dict) -> str: """ @@ -1165,33 +1207,31 @@ class Agent: return output.getvalue() def parse_and_execute_tools(self, response: str, *args, **kwargs): - # Try executing the tool - if self.execute_tool is not False: - try: - logger.info("Executing tool...") - - # try to Execute the tool and return a string - out = parse_and_execute_json( - self.tools, - response, - parse_md=True, - *args, - **kwargs, - ) + try: + logger.info("Executing tool...") + + # try to Execute the tool and return a string + out = parse_and_execute_json( + functions=self.tools, + json_string=response, + parse_md=True, + *args, + **kwargs, + ) - out = str(out) + out = str(out) - logger.info(f"Tool Output: {out}") + logger.info(f"Tool Output: {out}") - # Add the output to the memory - self.short_memory.add( - role="Tool Executor", - content=out, - ) + # Add the output to the memory + self.short_memory.add( + role="Tool Executor", + content=out, + ) - except Exception as error: - logger.error(f"Error executing tool: {error}") - raise error + except Exception as error: + logger.error(f"Error executing tool: {error}") + raise error def add_memory(self, message: str): """Add a memory to the agent @@ -1203,6 +1243,7 @@ class Agent: _type_: _description_ """ logger.info(f"Adding memory: {message}") + return self.short_memory.add( role=self.agent_name, content=message ) @@ -1261,7 +1302,9 @@ class Agent: try: logger.info(f"Running concurrent tasks: {tasks}") futures = [ - self.executor.submit(self.run, task, *args, **kwargs) + self.executor.submit( + self.run, task=task, *args, **kwargs + ) for task in tasks ] results = [future.result() for future in futures] @@ -1289,94 +1332,345 @@ class Agent: except Exception as error: logger.info(f"Error running bulk run: {error}", "red") - def save(self) -> None: - """Save the agent history to a file. - - Args: - file_path (_type_): _description_ - """ - file_path = ( - f"{self.saved_state_path}.json" - or f"{self.agent_name}.json" - or f"{self.saved_state_path}.json" - ) + async def arun_batched( + self, + tasks: List[str], + *args, + **kwargs, + ): + """Asynchronously runs a batch of tasks.""" try: - create_file_in_folder( - self.workspace_dir, - file_path, - self.to_json(), - ) - logger.info(f"Saved agent history to: {file_path}") + # Create a list of coroutines for each task + coroutines = [ + self.arun(task=task, *args, **kwargs) + for task in tasks + ] + # Use asyncio.gather to run them concurrently + results = await asyncio.gather(*coroutines) + return results except Exception as error: - logger.error(f"Error saving agent history: {error}") - raise error + logger.error(f"Error running batched tasks: {error}") + raise - def load(self, file_path: str) -> None: + def save(self, file_path: str = None) -> None: """ - Load the agent history from a file, excluding the LLM. + Save the agent state to a file using SafeStateManager with atomic writing + and backup functionality. Automatically handles complex objects and class instances. Args: - file_path (str): The path to the file containing the saved agent history. + file_path (str, optional): Custom path to save the state. + If None, uses configured paths. Raises: - FileNotFoundError: If the specified file path does not exist - json.JSONDecodeError: If the file contains invalid JSON - AttributeError: If there are issues setting agent attributes + OSError: If there are filesystem-related errors Exception: For other unexpected errors """ try: - file_path = ( - f"{self.saved_state_path}.json" - or f"{self.agent_name}.json" - or f"{self.saved_state_path}.json" + # Determine the save path + resolved_path = ( + file_path + or self.saved_state_path + or f"{self.agent_name}_state.json" + ) + + # Ensure path has .json extension + if not resolved_path.endswith(".json"): + resolved_path += ".json" + + # Create full path including workspace directory + full_path = os.path.join( + self.workspace_dir, resolved_path + ) + backup_path = full_path + ".backup" + temp_path = full_path + ".temp" + + # Ensure workspace directory exists + os.makedirs(os.path.dirname(full_path), exist_ok=True) + + # First save to temporary file using SafeStateManager + SafeStateManager.save_state(self, temp_path) + + # If current file exists, create backup + if os.path.exists(full_path): + try: + os.replace(full_path, backup_path) + except Exception as e: + logger.warning(f"Could not create backup: {e}") + + # Move temporary file to final location + os.replace(temp_path, full_path) + + # Clean up old backup if everything succeeded + if os.path.exists(backup_path): + try: + os.remove(backup_path) + except Exception as e: + logger.warning( + f"Could not remove backup file: {e}" + ) + + # Log saved state information if verbose + if self.verbose: + self._log_saved_state_info(full_path) + + logger.info( + f"Successfully saved agent state to: {full_path}" ) - if not os.path.exists(file_path): - raise FileNotFoundError( - f"File not found at path: {file_path}" + # Handle additional component saves + self._save_additional_components(full_path) + + except OSError as e: + logger.error( + f"Filesystem error while saving agent state: {e}" + ) + raise + except Exception as e: + logger.error(f"Unexpected error saving agent state: {e}") + raise + + def _save_additional_components(self, base_path: str) -> None: + """Save additional agent components like memory.""" + try: + # Save long term memory if it exists + if ( + hasattr(self, "long_term_memory") + and self.long_term_memory is not None + ): + memory_path = ( + f"{os.path.splitext(base_path)[0]}_memory.json" ) + try: + self.long_term_memory.save(memory_path) + logger.info( + f"Saved long-term memory to: {memory_path}" + ) + except Exception as e: + logger.warning( + f"Could not save long-term memory: {e}" + ) - with open(file_path, "r") as file: + # Save memory manager if it exists + if ( + hasattr(self, "memory_manager") + and self.memory_manager is not None + ): + manager_path = f"{os.path.splitext(base_path)[0]}_memory_manager.json" try: - data = json.load(file) - except json.JSONDecodeError as e: - logger.error( - f"Invalid JSON in file {file_path}: {str(e)}" + self.memory_manager.save_memory_snapshot( + manager_path + ) + logger.info( + f"Saved memory manager state to: {manager_path}" + ) + except Exception as e: + logger.warning( + f"Could not save memory manager: {e}" ) - raise - if not isinstance(data, dict): - raise ValueError( - f"Expected dict data but got {type(data)}" + except Exception as e: + logger.warning(f"Error saving additional components: {e}") + + def enable_autosave(self, interval: int = 300) -> None: + """ + Enable automatic saving of agent state using SafeStateManager at specified intervals. + + Args: + interval (int): Time between saves in seconds. Defaults to 300 (5 minutes). + """ + + def autosave_loop(): + while self.autosave: + try: + self.save() + if self.verbose: + logger.debug( + f"Autosaved agent state (interval: {interval}s)" + ) + except Exception as e: + logger.error(f"Autosave failed: {e}") + time.sleep(interval) + + self.autosave = True + self.autosave_thread = threading.Thread( + target=autosave_loop, + daemon=True, + name=f"{self.agent_name}_autosave", + ) + self.autosave_thread.start() + logger.info(f"Enabled autosave with {interval}s interval") + + def disable_autosave(self) -> None: + """Disable automatic saving of agent state.""" + if hasattr(self, "autosave"): + self.autosave = False + if hasattr(self, "autosave_thread"): + self.autosave_thread.join(timeout=1) + delattr(self, "autosave_thread") + logger.info("Disabled autosave") + + def cleanup(self) -> None: + """Cleanup method to be called on exit. Ensures final state is saved.""" + try: + if getattr(self, "autosave", False): + logger.info( + "Performing final autosave before exit..." ) + self.disable_autosave() + self.save() + except Exception as e: + logger.error(f"Error during cleanup: {e}") + + def load(self, file_path: str = None) -> None: + """ + Load agent state from a file using SafeStateManager. + Automatically preserves class instances and complex objects. - # Store current LLM - current_llm = self.llm + Args: + file_path (str, optional): Path to load state from. + If None, uses default path from agent config. - try: - for key, value in data.items(): - if key != "llm": - setattr(self, key, value) - except AttributeError as e: - logger.error( - f"Error setting agent attribute: {str(e)}" + Raises: + FileNotFoundError: If state file doesn't exist + Exception: If there's an error during loading + """ + try: + # Resolve load path conditionally with a check for self.load_state_path + resolved_path = ( + file_path + or self.load_state_path + or ( + f"{self.saved_state_path}.json" + if self.saved_state_path + else ( + f"{self.agent_name}.json" + if self.agent_name + else ( + f"{self.workspace_dir}/{self.agent_name}_state.json" + if self.workspace_dir and self.agent_name + else None + ) + ) ) - raise + ) - # Restore LLM - self.llm = current_llm + # Load state using SafeStateManager + SafeStateManager.load_state(self, resolved_path) - logger.info( - f"Successfully loaded agent history from: {file_path}" + # Reinitialize any necessary runtime components + self._reinitialize_after_load() + + if self.verbose: + self._log_loaded_state_info(resolved_path) + + except FileNotFoundError: + logger.error(f"State file not found: {resolved_path}") + raise + except Exception as e: + logger.error(f"Error loading agent state: {e}") + raise + + def _reinitialize_after_load(self) -> None: + """ + Reinitialize necessary components after loading state. + Called automatically after load() to ensure all components are properly set up. + """ + try: + # Reinitialize conversation if needed + if ( + not hasattr(self, "short_memory") + or self.short_memory is None + ): + self.short_memory = Conversation( + system_prompt=self.system_prompt, + time_enabled=True, + user=self.user_name, + rules=self.rules, + ) + + # Reinitialize executor if needed + if not hasattr(self, "executor") or self.executor is None: + self.executor = ThreadPoolExecutor( + max_workers=os.cpu_count() + ) + + # # Reinitialize tool structure if needed + # if hasattr(self, 'tools') and (self.tools or getattr(self, 'list_base_models', None)): + # self.tool_struct = BaseTool( + # tools=self.tools, + # base_models=getattr(self, 'list_base_models', None), + # tool_system_prompt=self.tool_system_prompt + # ) + + except Exception as e: + logger.error(f"Error reinitializing components: {e}") + raise + + def _log_saved_state_info(self, file_path: str) -> None: + """Log information about saved state for debugging""" + try: + state_dict = SafeLoaderUtils.create_state_dict(self) + preserved = SafeLoaderUtils.preserve_instances(self) + + logger.info(f"Saved agent state to: {file_path}") + logger.debug( + f"Saved {len(state_dict)} configuration values" + ) + logger.debug( + f"Preserved {len(preserved)} class instances" ) + if self.verbose: + logger.debug("Preserved instances:") + for name, instance in preserved.items(): + logger.debug( + f" - {name}: {type(instance).__name__}" + ) except Exception as e: - logger.error( - f"Unexpected error loading agent history: {str(e)}" + logger.error(f"Error logging state info: {e}") + + def _log_loaded_state_info(self, file_path: str) -> None: + """Log information about loaded state for debugging""" + try: + state_dict = SafeLoaderUtils.create_state_dict(self) + preserved = SafeLoaderUtils.preserve_instances(self) + + logger.info(f"Loaded agent state from: {file_path}") + logger.debug( + f"Loaded {len(state_dict)} configuration values" + ) + logger.debug( + f"Preserved {len(preserved)} class instances" ) - raise - return None + if self.verbose: + logger.debug("Current class instances:") + for name, instance in preserved.items(): + logger.debug( + f" - {name}: {type(instance).__name__}" + ) + except Exception as e: + logger.error(f"Error logging state info: {e}") + + def get_saveable_state(self) -> Dict[str, Any]: + """ + Get a dictionary of all saveable state values. + Useful for debugging or manual state inspection. + + Returns: + Dict[str, Any]: Dictionary of saveable values + """ + return SafeLoaderUtils.create_state_dict(self) + + def get_preserved_instances(self) -> Dict[str, Any]: + """ + Get a dictionary of all preserved class instances. + Useful for debugging or manual state inspection. + + Returns: + Dict[str, Any]: Dictionary of preserved instances + """ + return SafeLoaderUtils.preserve_instances(self) def graceful_shutdown(self): """Gracefully shutdown the system saving the state""" @@ -1470,24 +1764,6 @@ class Agent: def get_llm_parameters(self): return str(vars(self.llm)) - def save_state(self, *args, **kwargs) -> None: - """ - Saves the current state of the agent to a JSON file, including the llm parameters. - - Args: - file_path (str): The path to the JSON file where the state will be saved. - - Example: - >>> agent.save_state('saved_flow.json') - """ - try: - logger.info(f"Saving Agent {self.agent_name}") - self.save() - logger.info("Saved agent state") - except Exception as error: - logger.error(f"Error saving agent state: {error}") - raise error - def update_system_prompt(self, system_prompt: str): """Upddate the system message""" self.system_prompt = system_prompt @@ -1722,53 +1998,6 @@ class Agent: except Exception as e: print(f"Error occurred during sentiment analysis: {e}") - def count_and_shorten_context_window( - self, history: str, *args, **kwargs - ): - """ - Count the number of tokens in the context window and shorten it if it exceeds the limit. - - Args: - history (str): The history of the conversation. - - Returns: - str: The shortened context window. - """ - # Count the number of tokens in the context window - count = self.tokenizer.count_tokens(history) - - # Shorten the context window if it exceeds the limit, keeping the last n tokens, need to implement the indexing - if count > self.context_length: - history = history[-self.context_length :] - - return history - - def output_cleaner_and_output_type( - self, response: str, *args, **kwargs - ): - """ - Applies the output cleaner function to the response and prepares the output for the output model. - - Args: - response (str): The response to be processed. - - Returns: - str: The processed response. - """ - # Apply the cleaner function to the response - if self.output_cleaner is not None: - logger.info("Applying output cleaner to response.") - response = self.output_cleaner(response) - logger.info(f"Response after output cleaner: {response}") - - # Prepare the output for the output model - if self.output_type is not None: - # logger.info("Preparing output for output model.") - response = prepare_output_for_output_model(response) - print(f"Response after output model: {response}") - - return response - def stream_response( self, response: str, delay: float = 0.001 ) -> None: @@ -1800,37 +2029,6 @@ class Agent: except Exception as e: print(f"An error occurred during streaming: {e}") - def dynamic_context_window(self): - """ - dynamic_context_window essentially clears everything execep - the system prompt and leaves the rest of the contxt window - for RAG query tokens - - """ - # Count the number of tokens in the short term memory - logger.info("Dynamic context window shuffling enabled") - count = self.tokenizer.count_tokens( - self.short_memory.return_history_as_string() - ) - logger.info(f"Number of tokens in memory: {count}") - - # Dynamically allocating everything except the system prompt to be dynamic - # We need to query the short_memory dict, for the system prompt slot - # Then delete everything after that - - if count > self.context_length: - self.short_memory = self.short_memory[ - -self.context_length : - ] - logger.info( - f"Short term memory has been truncated to {self.context_length} tokens" - ) - else: - logger.info("Short term memory is within the limit") - - # Return the memory as a string or update the short term memory - # return memory - def check_available_tokens(self): # Log the amount of tokens left in the memory and in the task if self.tokenizer is not None: @@ -1856,58 +2054,6 @@ class Agent: return out - def truncate_string_by_tokens( - self, input_string: str, limit: int - ) -> str: - """ - Truncate a string if it exceeds a specified number of tokens using a given tokenizer. - - :param input_string: The input string to be tokenized and truncated. - :param tokenizer: The tokenizer function to be used for tokenizing the input string. - :param max_tokens: The maximum number of tokens allowed. - :return: The truncated string if it exceeds the maximum number of tokens; otherwise, the original string. - """ - # Tokenize the input string - tokens = self.tokenizer.count_tokens(input_string) - - # Check if the number of tokens exceeds the maximum limit - if len(tokens) > limit: - # Truncate the tokens to the maximum allowed tokens - truncated_tokens = tokens[: self.context_length] - # Join the truncated tokens back to a string - truncated_string = " ".join(truncated_tokens) - return truncated_string - else: - return input_string - - def tokens_operations(self, input_string: str) -> str: - """ - Perform various operations on tokens of an input string. - - :param input_string: The input string to be processed. - :return: The processed string. - """ - # Tokenize the input string - tokens = self.tokenizer.count_tokens(input_string) - - # Check if the number of tokens exceeds the maximum limit - if len(tokens) > self.context_length: - # Truncate the tokens to the maximum allowed tokens - truncated_tokens = tokens[: self.context_length] - # Join the truncated tokens back to a string - truncated_string = " ".join(truncated_tokens) - return truncated_string - else: - # Log the amount of tokens left in the memory and in the task - if self.tokenizer is not None: - tokens_used = self.tokenizer.count_tokens( - self.short_memory.return_history_as_string() - ) - logger.info( - f"Tokens available: {tokens_used - self.context_length}" - ) - return input_string - def parse_function_call_and_execute(self, response: str): """ Parses a function call from the given response and executes it. @@ -2269,7 +2415,6 @@ class Agent: The result of the method call on the `llm` object. """ - # Check if the llm has a __call__, or run, or any other method if hasattr(self.llm, "__call__"): return self.llm(task, *args, **kwargs) elif hasattr(self.llm, "run"): @@ -2306,9 +2451,8 @@ class Agent: device_id: Optional[int] = 0, all_cores: Optional[bool] = True, scheduled_run_date: Optional[datetime] = None, - do_not_use_cluster_ops: Optional[bool] = False, + do_not_use_cluster_ops: Optional[bool] = True, all_gpus: Optional[bool] = False, - generate_speech: Optional[bool] = False, *args, **kwargs, ) -> Any: @@ -2341,6 +2485,7 @@ class Agent: device_id = device_id or self.device_id all_cores = all_cores or self.all_cores all_gpus = all_gpus or self.all_gpus + do_not_use_cluster_ops = ( do_not_use_cluster_ops or self.do_not_use_cluster_ops ) @@ -2358,7 +2503,7 @@ class Agent: return self._run( task=task, img=img, - generate_speech=generate_speech * args, + *args, **kwargs, ) @@ -2371,17 +2516,15 @@ class Agent: func=self._run, task=task, img=img, - generate_speech=generate_speech, *args, **kwargs, ) except ValueError as e: - logger.error(f"Invalid device specified: {e}") - raise e + self._handle_run_error(e) + except Exception as e: - logger.error(f"An error occurred during execution: {e}") - raise e + self._handle_run_error(e) def handle_artifacts( self, text: str, file_output_path: str, file_extension: str @@ -2389,8 +2532,8 @@ class Agent: """Handle creating and saving artifacts with error handling.""" try: # Ensure file_extension starts with a dot - if not file_extension.startswith('.'): - file_extension = '.' + file_extension + if not file_extension.startswith("."): + file_extension = "." + file_extension # If file_output_path doesn't have an extension, treat it as a directory # and create a default filename based on timestamp @@ -2412,18 +2555,26 @@ class Agent: edit_count=0, ) - logger.info(f"Saving artifact with extension: {file_extension}") + logger.info( + f"Saving artifact with extension: {file_extension}" + ) artifact.save_as(file_extension) - logger.success(f"Successfully saved artifact to {full_path}") + logger.success( + f"Successfully saved artifact to {full_path}" + ) except ValueError as e: - logger.error(f"Invalid input values for artifact: {str(e)}") + logger.error( + f"Invalid input values for artifact: {str(e)}" + ) raise except IOError as e: logger.error(f"Error saving artifact to file: {str(e)}") raise except Exception as e: - logger.error(f"Unexpected error handling artifact: {str(e)}") + logger.error( + f"Unexpected error handling artifact: {str(e)}" + ) raise def showcase_config(self): diff --git a/swarms/structs/agent_memory_manager.py b/swarms/structs/agent_memory_manager.py new file mode 100644 index 00000000..0f506fc4 --- /dev/null +++ b/swarms/structs/agent_memory_manager.py @@ -0,0 +1,419 @@ +import json +import logging +import time +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional + +import yaml +from pydantic import BaseModel +from swarm_models.tiktoken_wrapper import TikTokenizer + +logger = logging.getLogger(__name__) + + +class MemoryMetadata(BaseModel): + """Metadata for memory entries""" + + timestamp: Optional[float] = time.time() + role: Optional[str] = None + agent_name: Optional[str] = None + session_id: Optional[str] = None + memory_type: Optional[str] = None # 'short_term' or 'long_term' + token_count: Optional[int] = None + message_id: Optional[str] = str(uuid.uuid4()) + + +class MemoryEntry(BaseModel): + """Single memory entry with content and metadata""" + + content: Optional[str] = None + metadata: Optional[MemoryMetadata] = None + + +class MemoryConfig(BaseModel): + """Configuration for memory manager""" + + max_short_term_tokens: Optional[int] = 4096 + max_entries: Optional[int] = None + system_messages_token_buffer: Optional[int] = 1000 + enable_long_term_memory: Optional[bool] = False + auto_archive: Optional[bool] = True + archive_threshold: Optional[float] = 0.8 # Archive when 80% full + + +class MemoryManager: + """ + Manages both short-term and long-term memory for an agent, handling token limits, + archival, and context retrieval. + + Args: + config (MemoryConfig): Configuration for memory management + tokenizer (Optional[Any]): Tokenizer to use for token counting + long_term_memory (Optional[Any]): Vector store or database for long-term storage + """ + + def __init__( + self, + config: MemoryConfig, + tokenizer: Optional[Any] = None, + long_term_memory: Optional[Any] = None, + ): + self.config = config + self.tokenizer = tokenizer or TikTokenizer() + self.long_term_memory = long_term_memory + + # Initialize memories + self.short_term_memory: List[MemoryEntry] = [] + self.system_messages: List[MemoryEntry] = [] + + # Memory statistics + self.total_tokens_processed: int = 0 + self.archived_entries_count: int = 0 + + def create_memory_entry( + self, + content: str, + role: str, + agent_name: str, + session_id: str, + memory_type: str = "short_term", + ) -> MemoryEntry: + """Create a new memory entry with metadata""" + metadata = MemoryMetadata( + timestamp=time.time(), + role=role, + agent_name=agent_name, + session_id=session_id, + memory_type=memory_type, + token_count=self.tokenizer.count_tokens(content), + ) + return MemoryEntry(content=content, metadata=metadata) + + def add_memory( + self, + content: str, + role: str, + agent_name: str, + session_id: str, + is_system: bool = False, + ) -> None: + """Add a new memory entry to appropriate storage""" + entry = self.create_memory_entry( + content=content, + role=role, + agent_name=agent_name, + session_id=session_id, + memory_type="system" if is_system else "short_term", + ) + + if is_system: + self.system_messages.append(entry) + else: + self.short_term_memory.append(entry) + + # Check if archiving is needed + if self.should_archive(): + self.archive_old_memories() + + self.total_tokens_processed += entry.metadata.token_count + + def get_current_token_count(self) -> int: + """Get total tokens in short-term memory""" + return sum( + entry.metadata.token_count + for entry in self.short_term_memory + ) + + def get_system_messages_token_count(self) -> int: + """Get total tokens in system messages""" + return sum( + entry.metadata.token_count + for entry in self.system_messages + ) + + def should_archive(self) -> bool: + """Check if archiving is needed based on configuration""" + if not self.config.auto_archive: + return False + + current_usage = ( + self.get_current_token_count() + / self.config.max_short_term_tokens + ) + return current_usage >= self.config.archive_threshold + + def archive_old_memories(self) -> None: + """Move older memories to long-term storage""" + if not self.long_term_memory: + logger.warning( + "No long-term memory storage configured for archiving" + ) + return + + while self.should_archive(): + # Get oldest non-system message + if not self.short_term_memory: + break + + oldest_entry = self.short_term_memory.pop(0) + + # Store in long-term memory + self.store_in_long_term_memory(oldest_entry) + self.archived_entries_count += 1 + + def store_in_long_term_memory(self, entry: MemoryEntry) -> None: + """Store a memory entry in long-term memory""" + if self.long_term_memory is None: + logger.warning( + "Attempted to store in non-existent long-term memory" + ) + return + + try: + self.long_term_memory.add(str(entry.model_dump())) + except Exception as e: + logger.error(f"Error storing in long-term memory: {e}") + # Re-add to short-term if storage fails + self.short_term_memory.insert(0, entry) + + def get_relevant_context( + self, query: str, max_tokens: Optional[int] = None + ) -> str: + """ + Get relevant context from both memory types + + Args: + query (str): Query to match against memories + max_tokens (Optional[int]): Maximum tokens to return + + Returns: + str: Combined relevant context + """ + contexts = [] + + # Add system messages first + for entry in self.system_messages: + contexts.append(entry.content) + + # Add short-term memory + for entry in reversed(self.short_term_memory): + contexts.append(entry.content) + + # Query long-term memory if available + if self.long_term_memory is not None: + long_term_context = self.long_term_memory.query(query) + if long_term_context: + contexts.append(str(long_term_context)) + + # Combine and truncate if needed + combined = "\n".join(contexts) + if max_tokens: + combined = self.truncate_to_token_limit( + combined, max_tokens + ) + + return combined + + def truncate_to_token_limit( + self, text: str, max_tokens: int + ) -> str: + """Truncate text to fit within token limit""" + current_tokens = self.tokenizer.count_tokens(text) + + if current_tokens <= max_tokens: + return text + + # Truncate by splitting into sentences and rebuilding + sentences = text.split(". ") + result = [] + current_count = 0 + + for sentence in sentences: + sentence_tokens = self.tokenizer.count_tokens(sentence) + if current_count + sentence_tokens <= max_tokens: + result.append(sentence) + current_count += sentence_tokens + else: + break + + return ". ".join(result) + + def clear_short_term_memory( + self, preserve_system: bool = True + ) -> None: + """Clear short-term memory with option to preserve system messages""" + if not preserve_system: + self.system_messages.clear() + self.short_term_memory.clear() + logger.info( + "Cleared short-term memory" + + " (preserved system messages)" + if preserve_system + else "" + ) + + def get_memory_stats(self) -> Dict[str, Any]: + """Get detailed memory statistics""" + return { + "short_term_messages": len(self.short_term_memory), + "system_messages": len(self.system_messages), + "current_tokens": self.get_current_token_count(), + "system_tokens": self.get_system_messages_token_count(), + "max_tokens": self.config.max_short_term_tokens, + "token_usage_percent": round( + ( + self.get_current_token_count() + / self.config.max_short_term_tokens + ) + * 100, + 2, + ), + "has_long_term_memory": self.long_term_memory is not None, + "archived_entries": self.archived_entries_count, + "total_tokens_processed": self.total_tokens_processed, + } + + def save_memory_snapshot(self, file_path: str) -> None: + """Save current memory state to file""" + try: + data = { + "timestamp": datetime.now().isoformat(), + "config": self.config.model_dump(), + "system_messages": [ + entry.model_dump() + for entry in self.system_messages + ], + "short_term_memory": [ + entry.model_dump() + for entry in self.short_term_memory + ], + "stats": self.get_memory_stats(), + } + + with open(file_path, "w") as f: + if file_path.endswith(".yaml"): + yaml.dump(data, f) + else: + json.dump(data, f, indent=2) + + logger.info(f"Saved memory snapshot to {file_path}") + + except Exception as e: + logger.error(f"Error saving memory snapshot: {e}") + raise + + def load_memory_snapshot(self, file_path: str) -> None: + """Load memory state from file""" + try: + with open(file_path, "r") as f: + if file_path.endswith(".yaml"): + data = yaml.safe_load(f) + else: + data = json.load(f) + + self.config = MemoryConfig(**data["config"]) + self.system_messages = [ + MemoryEntry(**entry) + for entry in data["system_messages"] + ] + self.short_term_memory = [ + MemoryEntry(**entry) + for entry in data["short_term_memory"] + ] + + logger.info(f"Loaded memory snapshot from {file_path}") + + except Exception as e: + logger.error(f"Error loading memory snapshot: {e}") + raise + + def search_memories( + self, query: str, memory_type: str = "all" + ) -> List[MemoryEntry]: + """ + Search through memories of specified type + + Args: + query (str): Search query + memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all") + + Returns: + List[MemoryEntry]: Matching memory entries + """ + results = [] + + if memory_type in ["short_term", "all"]: + results.extend( + [ + entry + for entry in self.short_term_memory + if query.lower() in entry.content.lower() + ] + ) + + if memory_type in ["system", "all"]: + results.extend( + [ + entry + for entry in self.system_messages + if query.lower() in entry.content.lower() + ] + ) + + if ( + memory_type in ["long_term", "all"] + and self.long_term_memory is not None + ): + long_term_results = self.long_term_memory.query(query) + if long_term_results: + # Convert long-term results to MemoryEntry format + for result in long_term_results: + content = str(result) + metadata = MemoryMetadata( + timestamp=time.time(), + role="long_term", + agent_name="system", + session_id="long_term", + memory_type="long_term", + token_count=self.tokenizer.count_tokens( + content + ), + ) + results.append( + MemoryEntry( + content=content, metadata=metadata + ) + ) + + return results + + def get_memory_by_timeframe( + self, start_time: float, end_time: float + ) -> List[MemoryEntry]: + """Get memories within a specific timeframe""" + return [ + entry + for entry in self.short_term_memory + if start_time <= entry.metadata.timestamp <= end_time + ] + + def export_memories( + self, file_path: str, format: str = "json" + ) -> None: + """Export memories to file in specified format""" + data = { + "system_messages": [ + entry.model_dump() for entry in self.system_messages + ], + "short_term_memory": [ + entry.model_dump() for entry in self.short_term_memory + ], + "stats": self.get_memory_stats(), + } + + with open(file_path, "w") as f: + if format == "yaml": + yaml.dump(data, f) + else: + json.dump(data, f, indent=2) diff --git a/swarms/structs/async_workflow.py b/swarms/structs/async_workflow.py index 02ebe4df..f0b8ac1e 100644 --- a/swarms/structs/async_workflow.py +++ b/swarms/structs/async_workflow.py @@ -1,9 +1,10 @@ import asyncio -from typing import Any, Callable, List, Optional +from typing import Any, List from swarms.structs.base_workflow import BaseWorkflow from swarms.structs.agent import Agent from swarms.utils.loguru_logger import logger + class AsyncWorkflow(BaseWorkflow): def __init__( self, @@ -13,7 +14,7 @@ class AsyncWorkflow(BaseWorkflow): dashboard: bool = False, autosave: bool = False, verbose: bool = False, - **kwargs + **kwargs, ): super().__init__(agents=agents, **kwargs) self.name = name @@ -26,17 +27,25 @@ class AsyncWorkflow(BaseWorkflow): self.results = [] self.loop = None - async def _execute_agent_task(self, agent: Agent, task: str) -> Any: + async def _execute_agent_task( + self, agent: Agent, task: str + ) -> Any: """Execute a single agent task asynchronously""" try: if self.verbose: - logger.info(f"Agent {agent.agent_name} processing task: {task}") + logger.info( + f"Agent {agent.agent_name} processing task: {task}" + ) result = await agent.arun(task) if self.verbose: - logger.info(f"Agent {agent.agent_name} completed task") + logger.info( + f"Agent {agent.agent_name} completed task" + ) return result except Exception as e: - logger.error(f"Error in agent {agent.agent_name}: {str(e)}") + logger.error( + f"Error in agent {agent.agent_name}: {str(e)}" + ) return str(e) async def run(self, task: str) -> List[Any]: @@ -46,17 +55,22 @@ class AsyncWorkflow(BaseWorkflow): try: # Create tasks for all agents - tasks = [self._execute_agent_task(agent, task) for agent in self.agents] - + tasks = [ + self._execute_agent_task(agent, task) + for agent in self.agents + ] + # Execute all tasks concurrently - self.results = await asyncio.gather(*tasks, return_exceptions=True) - + self.results = await asyncio.gather( + *tasks, return_exceptions=True + ) + if self.autosave: # TODO: Implement autosave logic here pass - + return self.results except Exception as e: logger.error(f"Error in workflow execution: {str(e)}") - raise \ No newline at end of file + raise diff --git a/swarms/structs/safe_loading.py b/swarms/structs/safe_loading.py new file mode 100644 index 00000000..ce026ce6 --- /dev/null +++ b/swarms/structs/safe_loading.py @@ -0,0 +1,258 @@ +import inspect +import json +import logging +import os +from datetime import datetime +from typing import Any, Dict, Set +from uuid import UUID + +logger = logging.getLogger(__name__) + + +class SafeLoaderUtils: + """ + Utility class for safely loading and saving object states while automatically + detecting and preserving class instances and complex objects. + """ + + @staticmethod + def is_class_instance(obj: Any) -> bool: + """ + Detect if an object is a class instance (excluding built-in types). + + Args: + obj: Object to check + + Returns: + bool: True if object is a class instance + """ + if obj is None: + return False + + # Get the type of the object + obj_type = type(obj) + + # Check if it's a class instance but not a built-in type + return ( + hasattr(obj, "__dict__") + and not isinstance(obj, type) + and obj_type.__module__ != "builtins" + ) + + @staticmethod + def is_safe_type(value: Any) -> bool: + """ + Check if a value is of a safe, serializable type. + + Args: + value: Value to check + + Returns: + bool: True if the value is safe to serialize + """ + # Basic safe types + safe_types = ( + type(None), + bool, + int, + float, + str, + datetime, + UUID, + ) + + if isinstance(value, safe_types): + return True + + # Check containers + if isinstance(value, (list, tuple)): + return all( + SafeLoaderUtils.is_safe_type(item) for item in value + ) + + if isinstance(value, dict): + return all( + isinstance(k, str) and SafeLoaderUtils.is_safe_type(v) + for k, v in value.items() + ) + + # Check for common serializable types + try: + json.dumps(value) + return True + except (TypeError, OverflowError, ValueError): + return False + + @staticmethod + def get_class_attributes(obj: Any) -> Set[str]: + """ + Get all attributes of a class, including inherited ones. + + Args: + obj: Object to inspect + + Returns: + Set[str]: Set of attribute names + """ + attributes = set() + + # Get all attributes from class and parent classes + for cls in inspect.getmro(type(obj)): + attributes.update(cls.__dict__.keys()) + + # Add instance attributes + attributes.update(obj.__dict__.keys()) + + return attributes + + @staticmethod + def create_state_dict(obj: Any) -> Dict[str, Any]: + """ + Create a dictionary of safe values from an object's state. + + Args: + obj: Object to create state dict from + + Returns: + Dict[str, Any]: Dictionary of safe values + """ + state_dict = {} + + for attr_name in SafeLoaderUtils.get_class_attributes(obj): + # Skip private attributes + if attr_name.startswith("_"): + continue + + try: + value = getattr(obj, attr_name, None) + if SafeLoaderUtils.is_safe_type(value): + state_dict[attr_name] = value + except Exception as e: + logger.debug(f"Skipped attribute {attr_name}: {e}") + + return state_dict + + @staticmethod + def preserve_instances(obj: Any) -> Dict[str, Any]: + """ + Automatically detect and preserve all class instances in an object. + + Args: + obj: Object to preserve instances from + + Returns: + Dict[str, Any]: Dictionary of preserved instances + """ + preserved = {} + + for attr_name in SafeLoaderUtils.get_class_attributes(obj): + if attr_name.startswith("_"): + continue + + try: + value = getattr(obj, attr_name, None) + if SafeLoaderUtils.is_class_instance(value): + preserved[attr_name] = value + except Exception as e: + logger.debug(f"Could not preserve {attr_name}: {e}") + + return preserved + + +class SafeStateManager: + """ + Manages saving and loading object states while automatically handling + class instances and complex objects. + """ + + @staticmethod + def save_state(obj: Any, file_path: str) -> None: + """ + Save an object's state to a file, automatically handling complex objects. + + Args: + obj: Object to save state from + file_path: Path to save state to + """ + try: + # Create state dict with only safe values + state_dict = SafeLoaderUtils.create_state_dict(obj) + + # Ensure directory exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # Save to file + with open(file_path, "w") as f: + json.dump(state_dict, f, indent=4, default=str) + + logger.info(f"Successfully saved state to: {file_path}") + + except Exception as e: + logger.error(f"Error saving state: {e}") + raise + + @staticmethod + def load_state(obj: Any, file_path: str) -> None: + """ + Load state into an object while preserving class instances. + + Args: + obj: Object to load state into + file_path: Path to load state from + """ + try: + # Verify file exists + if not os.path.exists(file_path): + raise FileNotFoundError( + f"State file not found: {file_path}" + ) + + # Preserve existing instances + preserved = SafeLoaderUtils.preserve_instances(obj) + + # Load state + with open(file_path, "r") as f: + state_dict = json.load(f) + + # Set safe values + for key, value in state_dict.items(): + if ( + not key.startswith("_") + and key not in preserved + and SafeLoaderUtils.is_safe_type(value) + ): + setattr(obj, key, value) + + # Restore preserved instances + for key, value in preserved.items(): + setattr(obj, key, value) + + logger.info( + f"Successfully loaded state from: {file_path}" + ) + + except Exception as e: + logger.error(f"Error loading state: {e}") + raise + + +# # Example decorator for easy integration +# def safe_state_methods(cls: Type) -> Type: +# """ +# Class decorator to add safe state loading/saving methods to a class. + +# Args: +# cls: Class to decorate + +# Returns: +# Type: Decorated class +# """ +# def save(self, file_path: str) -> None: +# SafeStateManager.save_state(self, file_path) + +# def load(self, file_path: str) -> None: +# SafeStateManager.load_state(self, file_path) + +# cls.save = save +# cls.load = load +# return cls diff --git a/swarms/tools/base_tool.py b/swarms/tools/base_tool.py index 09b3c506..04319db8 100644 --- a/swarms/tools/base_tool.py +++ b/swarms/tools/base_tool.py @@ -400,10 +400,8 @@ class BaseTool(BaseModel): logger.info( f"Converting tool: {name} into a OpenAI certified function calling schema. Add documentation and type hints." ) - tool_schema = ( - get_openai_function_schema_from_func( - tool, name=name, description=description - ) + tool_schema = get_openai_function_schema_from_func( + tool, name=name, description=description ) logger.info( @@ -420,10 +418,12 @@ class BaseTool(BaseModel): if tool_schemas: combined_schema = { "type": "function", - "functions": [schema["function"] for schema in tool_schemas] + "functions": [ + schema["function"] for schema in tool_schemas + ], } return json.dumps(combined_schema, indent=4) - + return None def check_func_if_have_docs(self, func: callable): diff --git a/swarms/utils/litellm_wrapper.py b/swarms/utils/litellm_wrapper.py index 9b2c4829..2dbdc97e 100644 --- a/swarms/utils/litellm_wrapper.py +++ b/swarms/utils/litellm_wrapper.py @@ -77,7 +77,7 @@ class LiteLLM: str: The content of the response from the model. """ try: - + messages = self._prepare_messages(task) response = completion( @@ -85,14 +85,15 @@ class LiteLLM: messages=messages, stream=self.stream, temperature=self.temperature, - # max_completion_tokens=self.max_tokens, max_tokens=self.max_tokens, *args, **kwargs, ) + content = response.choices[ 0 ].message.content # Accessing the content + return content except Exception as error: print(error) diff --git a/swarms/utils/loguru_logger.py b/swarms/utils/loguru_logger.py index 4b54fb59..af5c7239 100644 --- a/swarms/utils/loguru_logger.py +++ b/swarms/utils/loguru_logger.py @@ -34,4 +34,4 @@ def initialize_logger(log_folder: str = "logs"): retention="10 days", # compression="zip", ) - return logger \ No newline at end of file + return logger diff --git a/test_agent_features.py b/test_agent_features.py new file mode 100644 index 00000000..85d01d09 --- /dev/null +++ b/test_agent_features.py @@ -0,0 +1,598 @@ +import asyncio +from swarms import Agent +from swarm_models import OpenAIChat +import os +import time +import json +import yaml +import tempfile + + +def test_basic_agent_functionality(): + """Test basic agent initialization and simple task execution""" + print("\nTesting basic agent functionality...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent(agent_name="Test-Agent", llm=model, max_loops=1) + + response = agent.run("What is 2+2?") + assert response is not None, "Agent response should not be None" + + # Test agent properties + assert ( + agent.agent_name == "Test-Agent" + ), "Agent name not set correctly" + assert agent.max_loops == 1, "Max loops not set correctly" + assert agent.llm is not None, "LLM not initialized" + + print("✓ Basic agent functionality test passed") + + +def test_memory_management(): + """Test agent memory management functionality""" + print("\nTesting memory management...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Memory-Test-Agent", + llm=model, + max_loops=1, + context_length=8192, + ) + + # Test adding to memory + agent.add_memory("Test memory entry") + assert ( + "Test memory entry" + in agent.short_memory.return_history_as_string() + ) + + # Test memory query + agent.memory_query("Test query") + + # Test token counting + tokens = agent.check_available_tokens() + assert isinstance(tokens, int), "Token count should be an integer" + + print("✓ Memory management test passed") + + +def test_agent_output_formats(): + """Test all available output formats""" + print("\nTesting all output formats...") + + model = OpenAIChat(model_name="gpt-4o") + test_task = "Say hello!" + + output_types = { + "str": str, + "string": str, + "list": str, # JSON string containing list + "json": str, # JSON string + "dict": dict, + "yaml": str, + } + + for output_type, expected_type in output_types.items(): + agent = Agent( + agent_name=f"{output_type.capitalize()}-Output-Agent", + llm=model, + max_loops=1, + output_type=output_type, + ) + + response = agent.run(test_task) + assert ( + response is not None + ), f"{output_type} output should not be None" + + if output_type == "yaml": + # Verify YAML can be parsed + try: + yaml.safe_load(response) + print(f"✓ {output_type} output valid") + except yaml.YAMLError: + assert False, f"Invalid YAML output for {output_type}" + elif output_type in ["json", "list"]: + # Verify JSON can be parsed + try: + json.loads(response) + print(f"✓ {output_type} output valid") + except json.JSONDecodeError: + assert False, f"Invalid JSON output for {output_type}" + + print("✓ Output formats test passed") + + +def test_agent_state_management(): + """Test comprehensive state management functionality""" + print("\nTesting state management...") + + model = OpenAIChat(model_name="gpt-4o") + + # Create temporary directory for test files + with tempfile.TemporaryDirectory() as temp_dir: + state_path = os.path.join(temp_dir, "agent_state.json") + + # Create agent with initial state + agent1 = Agent( + agent_name="State-Test-Agent", + llm=model, + max_loops=1, + saved_state_path=state_path, + ) + + # Add some data to the agent + agent1.run("Remember this: Test message 1") + agent1.add_memory("Test message 2") + + # Save state + agent1.save() + assert os.path.exists(state_path), "State file not created" + + # Create new agent and load state + agent2 = Agent( + agent_name="State-Test-Agent", llm=model, max_loops=1 + ) + agent2.load(state_path) + + # Verify state loaded correctly + history2 = agent2.short_memory.return_history_as_string() + assert ( + "Test message 1" in history2 + ), "State not loaded correctly" + assert ( + "Test message 2" in history2 + ), "Memory not loaded correctly" + + # Test autosave functionality + agent3 = Agent( + agent_name="Autosave-Test-Agent", + llm=model, + max_loops=1, + saved_state_path=os.path.join( + temp_dir, "autosave_state.json" + ), + autosave=True, + ) + + agent3.run("Test autosave") + time.sleep(2) # Wait for autosave + assert os.path.exists( + os.path.join(temp_dir, "autosave_state.json") + ), "Autosave file not created" + + print("✓ State management test passed") + + +def test_agent_tools_and_execution(): + """Test agent tool handling and execution""" + print("\nTesting tools and execution...") + + def sample_tool(x: int, y: int) -> int: + """Sample tool that adds two numbers""" + return x + y + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Tools-Test-Agent", + llm=model, + max_loops=1, + tools=[sample_tool], + ) + + # Test adding tools + agent.add_tool(lambda x: x * 2) + assert len(agent.tools) == 2, "Tool not added correctly" + + # Test removing tools + agent.remove_tool(sample_tool) + assert len(agent.tools) == 1, "Tool not removed correctly" + + # Test tool execution + response = agent.run("Calculate 2 + 2 using the sample tool") + assert response is not None, "Tool execution failed" + + print("✓ Tools and execution test passed") + + +def test_agent_concurrent_execution(): + """Test agent concurrent execution capabilities""" + print("\nTesting concurrent execution...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Concurrent-Test-Agent", llm=model, max_loops=1 + ) + + # Test bulk run + tasks = [ + {"task": "Count to 3"}, + {"task": "Say hello"}, + {"task": "Tell a short joke"}, + ] + + responses = agent.bulk_run(tasks) + assert len(responses) == len(tasks), "Not all tasks completed" + assert all( + response is not None for response in responses + ), "Some tasks failed" + + # Test concurrent tasks + concurrent_responses = agent.run_concurrent_tasks( + ["Task 1", "Task 2", "Task 3"] + ) + assert ( + len(concurrent_responses) == 3 + ), "Not all concurrent tasks completed" + + print("✓ Concurrent execution test passed") + + +def test_agent_error_handling(): + """Test agent error handling and recovery""" + print("\nTesting error handling...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Error-Test-Agent", + llm=model, + max_loops=1, + retry_attempts=3, + retry_interval=1, + ) + + # Test invalid tool execution + try: + agent.parse_and_execute_tools("invalid_json") + print("✓ Invalid tool execution handled") + except Exception: + assert True, "Expected error caught" + + # Test recovery after error + response = agent.run("Continue after error") + assert response is not None, "Agent failed to recover after error" + + print("✓ Error handling test passed") + + +def test_agent_configuration(): + """Test agent configuration and parameters""" + print("\nTesting agent configuration...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Config-Test-Agent", + llm=model, + max_loops=1, + temperature=0.7, + max_tokens=4000, + context_length=8192, + ) + + # Test configuration methods + agent.update_system_prompt("New system prompt") + agent.update_max_loops(2) + agent.update_loop_interval(2) + + # Verify updates + assert agent.max_loops == 2, "Max loops not updated" + assert agent.loop_interval == 2, "Loop interval not updated" + + # Test configuration export + config_dict = agent.to_dict() + assert isinstance( + config_dict, dict + ), "Configuration export failed" + + # Test YAML export + yaml_config = agent.to_yaml() + assert isinstance(yaml_config, str), "YAML export failed" + + print("✓ Configuration test passed") + + +def test_agent_with_stopping_condition(): + """Test agent with custom stopping condition""" + print("\nTesting agent with stopping condition...") + + def custom_stopping_condition(response: str) -> bool: + return "STOP" in response.upper() + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Stopping-Condition-Agent", + llm=model, + max_loops=5, + stopping_condition=custom_stopping_condition, + ) + + response = agent.run("Count up until you see the word STOP") + assert response is not None, "Stopping condition test failed" + print("✓ Stopping condition test passed") + + +def test_agent_with_retry_mechanism(): + """Test agent retry mechanism""" + print("\nTesting agent retry mechanism...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Retry-Test-Agent", + llm=model, + max_loops=1, + retry_attempts=3, + retry_interval=1, + ) + + response = agent.run("Tell me a joke.") + assert response is not None, "Retry mechanism test failed" + print("✓ Retry mechanism test passed") + + +def test_bulk_and_filtered_operations(): + """Test bulk operations and response filtering""" + print("\nTesting bulk and filtered operations...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Bulk-Filter-Test-Agent", llm=model, max_loops=1 + ) + + # Test bulk run + bulk_tasks = [ + {"task": "What is 2+2?"}, + {"task": "Name a color"}, + {"task": "Count to 3"}, + ] + bulk_responses = agent.bulk_run(bulk_tasks) + assert len(bulk_responses) == len( + bulk_tasks + ), "Bulk run should return same number of responses as tasks" + + # Test response filtering + agent.add_response_filter("color") + filtered_response = agent.filtered_run( + "What is your favorite color?" + ) + assert ( + "[FILTERED]" in filtered_response + ), "Response filter not applied" + + print("✓ Bulk and filtered operations test passed") + + +async def test_async_operations(): + """Test asynchronous operations""" + print("\nTesting async operations...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Async-Test-Agent", llm=model, max_loops=1 + ) + + # Test single async run + response = await agent.arun("What is 1+1?") + assert response is not None, "Async run failed" + + # Test concurrent async runs + tasks = ["Task 1", "Task 2", "Task 3"] + responses = await asyncio.gather( + *[agent.arun(task) for task in tasks] + ) + assert len(responses) == len( + tasks + ), "Not all async tasks completed" + + print("✓ Async operations test passed") + + +def test_memory_and_state_persistence(): + """Test memory management and state persistence""" + print("\nTesting memory and state persistence...") + + with tempfile.TemporaryDirectory() as temp_dir: + state_path = os.path.join(temp_dir, "test_state.json") + + # Create agent with memory configuration + model = OpenAIChat(model_name="gpt-4o") + agent1 = Agent( + agent_name="Memory-State-Test-Agent", + llm=model, + max_loops=1, + saved_state_path=state_path, + context_length=8192, + autosave=True, + ) + + # Test memory operations + agent1.add_memory("Important fact: The sky is blue") + agent1.memory_query("What color is the sky?") + + # Save state + agent1.save() + + # Create new agent and load state + agent2 = Agent( + agent_name="Memory-State-Test-Agent", + llm=model, + max_loops=1, + ) + agent2.load(state_path) + + # Verify memory persistence + memory_content = ( + agent2.short_memory.return_history_as_string() + ) + assert ( + "sky is blue" in memory_content + ), "Memory not properly persisted" + + print("✓ Memory and state persistence test passed") + + +def test_sentiment_and_evaluation(): + """Test sentiment analysis and response evaluation""" + print("\nTesting sentiment analysis and evaluation...") + + def mock_sentiment_analyzer(text): + """Mock sentiment analyzer that returns a score between 0 and 1""" + return 0.7 if "positive" in text.lower() else 0.3 + + def mock_evaluator(response): + """Mock evaluator that checks response quality""" + return "GOOD" if len(response) > 10 else "BAD" + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Sentiment-Eval-Test-Agent", + llm=model, + max_loops=1, + sentiment_analyzer=mock_sentiment_analyzer, + sentiment_threshold=0.5, + evaluator=mock_evaluator, + ) + + # Test sentiment analysis + agent.run("Generate a positive message") + + # Test evaluation + agent.run("Generate a detailed response") + + print("✓ Sentiment and evaluation test passed") + + +def test_tool_management(): + """Test tool management functionality""" + print("\nTesting tool management...") + + def tool1(x: int) -> int: + """Sample tool 1""" + return x * 2 + + def tool2(x: int) -> int: + """Sample tool 2""" + return x + 2 + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Tool-Test-Agent", + llm=model, + max_loops=1, + tools=[tool1], + ) + + # Test adding tools + agent.add_tool(tool2) + assert len(agent.tools) == 2, "Tool not added correctly" + + # Test removing tools + agent.remove_tool(tool1) + assert len(agent.tools) == 1, "Tool not removed correctly" + + # Test adding multiple tools + agent.add_tools([tool1, tool2]) + assert len(agent.tools) == 3, "Multiple tools not added correctly" + + print("✓ Tool management test passed") + + +def test_system_prompt_and_configuration(): + """Test system prompt and configuration updates""" + print("\nTesting system prompt and configuration...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Config-Test-Agent", llm=model, max_loops=1 + ) + + # Test updating system prompt + new_prompt = "You are a helpful assistant." + agent.update_system_prompt(new_prompt) + assert ( + agent.system_prompt == new_prompt + ), "System prompt not updated" + + # Test configuration updates + agent.update_max_loops(5) + assert agent.max_loops == 5, "Max loops not updated" + + agent.update_loop_interval(2) + assert agent.loop_interval == 2, "Loop interval not updated" + + # Test configuration export + config_dict = agent.to_dict() + assert isinstance( + config_dict, dict + ), "Configuration export failed" + + print("✓ System prompt and configuration test passed") + + +def test_agent_with_dynamic_temperature(): + """Test agent with dynamic temperature""" + print("\nTesting agent with dynamic temperature...") + + model = OpenAIChat(model_name="gpt-4o") + agent = Agent( + agent_name="Dynamic-Temp-Agent", + llm=model, + max_loops=2, + dynamic_temperature_enabled=True, + ) + + response = agent.run("Generate a creative story.") + assert response is not None, "Dynamic temperature test failed" + print("✓ Dynamic temperature test passed") + + +def run_all_tests(): + """Run all test functions""" + print("Starting Extended Agent functional tests...\n") + + test_functions = [ + test_basic_agent_functionality, + test_memory_management, + test_agent_output_formats, + test_agent_state_management, + test_agent_tools_and_execution, + test_agent_concurrent_execution, + test_agent_error_handling, + test_agent_configuration, + test_agent_with_stopping_condition, + test_agent_with_retry_mechanism, + test_agent_with_dynamic_temperature, + test_bulk_and_filtered_operations, + test_memory_and_state_persistence, + test_sentiment_and_evaluation, + test_tool_management, + test_system_prompt_and_configuration, + ] + + # Run synchronous tests + total_tests = len(test_functions) + 1 # +1 for async test + passed_tests = 0 + + for test in test_functions: + try: + test() + passed_tests += 1 + except Exception as e: + print(f"✗ Test {test.__name__} failed: {str(e)}") + + # Run async test + try: + asyncio.run(test_async_operations()) + passed_tests += 1 + except Exception as e: + print(f"✗ Async operations test failed: {str(e)}") + + print("\nExtended Test Summary:") + print(f"Total Tests: {total_tests}") + print(f"Passed: {passed_tests}") + print(f"Failed: {total_tests - passed_tests}") + print(f"Success Rate: {(passed_tests/total_tests)*100:.2f}%") + + +if __name__ == "__main__": + run_all_tests()