diff --git a/examples/structs/swarms/experimental/sap.py b/examples/structs/swarms/experimental/sap.py index 4fca469b..d8a3b188 100644 --- a/examples/structs/swarms/experimental/sap.py +++ b/examples/structs/swarms/experimental/sap.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional import chromadb +import psutil from dotenv import load_dotenv from loguru import logger from pydantic import BaseModel, Field @@ -20,307 +21,121 @@ load_dotenv() # Initialize ChromaDB client chroma_client = chromadb.Client() -# Create a ChromaDB collection to store tasks, responses, and all swarm activity -swarm_collection = chroma_client.create_collection( - name="swarm_activity" -) +# Collection for swarm activity (tasks, responses, messages) +swarm_activity = chroma_client.create_collection(name="swarm_activity") +# Collection for agent capabilities +agent_capabilities = chroma_client.create_collection(name="agent_capabilities") -class InteractionLog(BaseModel): - """ - Pydantic model to log all interactions between agents, tasks, and responses. - """ - interaction_id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique ID for the interaction.", - ) +class Message(BaseModel): + message_id: str = Field(default_factory=lambda: str(uuid.uuid4())) agent_name: str - task: str + message_type: str # e.g., "task", "request", "response" + content: Any timestamp: datetime = Field(default_factory=datetime.utcnow) - response: Optional[Dict[str, Any]] = None - status: str = Field( - description="The status of the interaction, e.g., 'completed', 'failed'." - ) - neighbors: Optional[List[str]] = ( - None # Names of neighboring agents involved - ) - conversation_id: Optional[str] = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique ID for the conversation history.", - ) - - -class AgentHealthStatus(BaseModel): - """ - Pydantic model to log and monitor agent health. - """ + conversation_id: Optional[str] = None + +class AgentHealth(BaseModel): agent_name: str timestamp: datetime = Field(default_factory=datetime.utcnow) - status: str = Field( - default="available", - description="Agent health status, e.g., 'available', 'busy', 'failed'.", - ) - active_tasks: int = Field( - 0, - description="Number of active tasks assigned to this agent.", - ) - load: float = Field( - 0.0, - description="Current load on the agent (CPU or memory usage).", - ) + status: str = "available" # available, busy, failed + active_tasks: int = 0 + system_load: float = 0.0 # Placeholder for actual system load class Swarm: - """ - A scalable swarm architecture where agents can communicate by posting and querying all activities to ChromaDB. - Every input task, response, and action by the agents is logged to the vector database for persistent tracking. - - Attributes: - agents (List[Agent]): A list of initialized agents. - chroma_client (chroma.Client): An instance of the ChromaDB client for agent-to-agent communication. - api_key (str): The OpenAI API key. - health_statuses (Dict[str, AgentHealthStatus]): A dictionary to monitor agent health statuses. - """ - - def __init__( - self, - agents: List[Agent], - chroma_client: chromadb.Client, - api_key: str, - ) -> None: - """ - Initializes the swarm with agents and a ChromaDB client for vector storage and communication. - - Args: - agents (List[Agent]): A list of initialized agents. - chroma_client (chroma.Client): The ChromaDB client for handling vector embeddings. - api_key (str): The OpenAI API key. - """ + def __init__(self, agents: List[Agent], chroma_client: chromadb.Client, api_key: str): self.agents = agents self.chroma_client = chroma_client self.api_key = api_key - self.health_statuses: Dict[str, AgentHealthStatus] = { - agent.agent_name: AgentHealthStatus( - agent_name=agent.agent_name - ) - for agent in agents - } + self.health: Dict[str, AgentHealth] = {} + self.register_agents() logger.info(f"Swarm initialized with {len(agents)} agents.") - def _log_to_db( - self, data: Dict[str, Any], description: str - ) -> None: - """ - Logs a dictionary of data into the ChromaDB collection as a new entry. - - Args: - data (Dict[str, Any]): The data to log in the database (task, response, etc.). - description (str): Description of the action (e.g., 'task', 'response'). - """ - logger.info(f"Logging {description} to the database: {data}") - swarm_collection.add( - documents=[str(data)], - ids=[str(uuid.uuid4())], # Unique ID for each entry - metadatas=[ - { - "description": description, - "timestamp": datetime.utcnow().isoformat(), - } - ], - ) - logger.info( - f"{description.capitalize()} logged successfully." - ) - - async def _find_most_relevant_agent( - self, task: str - ) -> Optional[Agent]: - """ - Finds the agent whose system prompt is most relevant to the given task by querying ChromaDB. - If no relevant agents are found, return None and log a message. - - Args: - task (str): The task for which to find the most relevant agent. - - Returns: - Optional[Agent]: The most relevant agent for the task, or None if no relevant agent is found. - """ - logger.info( - f"Searching for the most relevant agent for the task: {task}" - ) - - # Query ChromaDB collection for nearest neighbor to the task - result = swarm_collection.query( - query_texts=[task], n_results=4 - ) - - # Check if the query result contains any data - if not result["ids"] or not result["ids"][0]: - logger.error( - "No relevant agents found for the given task." + def register_agents(self): + for agent in self.agents: + self.health[agent.agent_name] = AgentHealth(agent_name=agent.agent_name) + agent_capabilities.add( + documents=[agent.system_prompt], + ids=[agent.agent_name], + metadatas=[{"agent_name": agent.agent_name}], ) - return None # No agent found, return None - - # Extract the agent ID from the result and find the corresponding agent - agent_id = result["ids"][0][0] - most_relevant_agent = next( - ( - agent - for agent in self.agents - if agent.agent_name == agent_id - ), - None, - ) - if most_relevant_agent: - logger.info( - f"Most relevant agent for task '{task}' is {most_relevant_agent.agent_name}." + async def find_best_agent(self, task: str) -> Optional[Agent]: + results = agent_capabilities.query(query_texts=[task], n_results=1) + if results["ids"] and results["ids"][0]: + agent_name = results["ids"][0][0] + agent = next((a for a in self.agents if a.agent_name == agent_name), None) + if agent: + logger.info(f"Best agent for task '{task}' is {agent.agent_name}.") + return agent + logger.warning(f"No suitable agent found for task: {task}") + return None + + def post_message(self, message: Message): + swarm_activity.add( + documents=[message.model_dump_json()], + ids=[message.message_id], + metadatas=[message.model_dump()], # Store metadata for querying + ) + + def query_messages(self, query: str, message_type: Optional[str] = None, n_results: int = 5) -> List[Message]: + filter_query = {} + if message_type: + filter_query = {"message_type": message_type} + results = swarm_activity.query(query_texts=[query], n_results=n_results, where=filter_query) + + messages = [] + if results["documents"]: + for doc in results['documents'][0]: # Because the query returns a list of lists of documents + try: + messages.append(Message.model_parse_raw(doc)) + except Exception as e: + logger.error(f"Error parsing message document: {e}") + + return messages + + + async def run_task(self, task: str, conversation_id: Optional[str] = None) -> Optional[Any]: + agent = await self.find_best_agent(task) + if not agent: + return None + + self.update_agent_health(agent, "busy") + + self.post_message(Message(agent_name="Swarm", message_type="task", content=task, conversation_id=conversation_id)) + + try: + result = await agent.run(task) + self.post_message( + Message( + agent_name=agent.agent_name, + message_type="response", + content=result, + conversation_id=conversation_id, + ) ) - else: - logger.error("No matching agent found in the agent list.") - - return most_relevant_agent - - def _monitor_health(self, agent: Agent) -> None: - """ - Monitors the health status of agents and logs it to the database. + self.update_agent_health(agent, "available") + return result + except Exception as e: + self.update_agent_health(agent, "failed") + logger.error(f"Agent {agent.agent_name} failed to execute task: {e}") + return None - Args: - agent (Agent): The agent whose health is being monitored. - """ - current_status = self.health_statuses[agent.agent_name] - current_status.active_tasks += ( - 1 # Example increment for active tasks - ) - current_status.status = ( - "busy" if current_status.active_tasks > 0 else "available" - ) - current_status.load = 0.5 # Placeholder for real load data - logger.info( - f"Agent {agent.agent_name} is currently {current_status.status} with load {current_status.load}." - ) + def update_agent_health(self, agent: Agent, status: str): + health = self.health.get(agent.agent_name) + if health: + health.status = status + health.active_tasks = agent.active_tasks # Assuming agent tracks active tasks + health.system_load = psutil.cpu_percent() + health.timestamp = datetime.utcnow() + self.post_message(Message(agent_name="Swarm", message_type="health_update", content=health.model_dump())) - # Log health status to the database - self._log_to_db(current_status.dict(), "health status") + def run(self, task: str, conversation_id: Optional[str] = None) -> Any: + return asyncio.run(self.run_task(task, conversation_id)) - def post_message(self, agent: Agent, message: str) -> None: - """ - Posts a message from an agent to the shared database. - - Args: - agent (Agent): The agent posting the message. - message (str): The message to be posted. - """ - logger.info( - f"Agent {agent.agent_name} posting message: {message}" - ) - message_data = { - "agent_name": agent.agent_name, - "message": message, - "timestamp": datetime.utcnow().isoformat(), - } - self._log_to_db(message_data, "message") - - def query_messages( - self, query: str, n_results: int = 5 - ) -> List[Dict[str, Any]]: - """ - Queries the database for relevant messages. - - Args: - query (str): The query message or task for which to retrieve related messages. - n_results (int, optional): The number of relevant messages to retrieve. Defaults to 5. - - Returns: - List[Dict[str, Any]]: A list of relevant messages and their metadata. - """ - logger.info(f"Querying the database for query: {query}") - results = swarm_collection.query( - query_texts=[query], n_results=n_results - ) - logger.info( - f"Found {len(results['documents'])} relevant messages." - ) - return results - - async def run_async(self, task: str) -> None: - """ - Main entry point to find the most relevant agent, submit the task, and allow agents to - query the database to understand the task's history. Logs every task and response. - - Args: - task (str): The task to be completed. - """ - # Query past messages to understand task history - past_messages = self.query_messages(task) - logger.info( - f"Past messages related to task '{task}': {past_messages}" - ) - - # Find the most relevant agent - agent = await self._find_most_relevant_agent(task) - - if agent is None: - logger.error( - f"No relevant agent found for task: {task}. Task submission aborted." - ) - return # Exit the function if no relevant agent is found - - # Submit the task to the agent if found - await self._submit_task_to_agent(agent, task) - - async def _submit_task_to_agent( - self, agent: Agent, task: str - ) -> Dict[str, Any]: - """ - Submits a task to the specified agent and logs the result asynchronously. - - Args: - agent (Agent): The agent to which the task will be submitted. - task (str): The task to be solved. - - Returns: - Dict[str, Any]: The result of the task from the agent. - """ - if agent is None: - logger.error("No agent provided for task submission.") - return - - logger.info( - f"Submitting task '{task}' to agent {agent.agent_name}." - ) - - interaction_log = InteractionLog( - agent_name=agent.agent_name, task=task, status="started" - ) - - # Log the task as a message to the shared database - self._log_to_db( - {"task": task, "agent_name": agent.agent_name}, "task" - ) - - result = await agent.run(task) - - interaction_log.response = result - interaction_log.status = "completed" - interaction_log.timestamp = datetime.utcnow() - - logger.info( - f"Task completed by agent {agent.agent_name}. Logged interaction: {interaction_log.dict()}" - ) - - # Log the result as a message to the shared database - self._log_to_db( - {"response": result, "agent_name": agent.agent_name}, - "response", - ) - - return result - - def run(self, task: str, *args, **kwargs): - return asyncio.run(self.run_async(task)) # Initialize the OpenAI model and agents @@ -357,4 +172,10 @@ swarm = Swarm( # Execute tasks asynchronously task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?" -print(swarm.run(task)) +conversation_id = str(uuid.uuid4()) # Create a conversation ID +print(swarm.run(task, conversation_id)) + + +# Example of querying messages related to the conversation: +past_messages = swarm.query_messages(query="", message_type="response", n_results=10) +print(f"Past messages in conversation {conversation_id}: {past_messages}")