Update sap.py

pull/590/head
kirill670 4 months ago committed by GitHub
parent edc293cb6f
commit 8feebf0fb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,6 +5,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import chromadb import chromadb
import psutil
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -20,307 +21,121 @@ load_dotenv()
# Initialize ChromaDB client # Initialize ChromaDB client
chroma_client = chromadb.Client() chroma_client = chromadb.Client()
# Create a ChromaDB collection to store tasks, responses, and all swarm activity # Collection for swarm activity (tasks, responses, messages)
swarm_collection = chroma_client.create_collection( swarm_activity = chroma_client.create_collection(name="swarm_activity")
name="swarm_activity"
)
# Collection for agent capabilities
agent_capabilities = chroma_client.create_collection(name="agent_capabilities")
class InteractionLog(BaseModel):
"""
Pydantic model to log all interactions between agents, tasks, and responses.
"""
interaction_id: str = Field( class Message(BaseModel):
default_factory=lambda: str(uuid.uuid4()), message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
description="Unique ID for the interaction.",
)
agent_name: str agent_name: str
task: str message_type: str # e.g., "task", "request", "response"
content: Any
timestamp: datetime = Field(default_factory=datetime.utcnow) timestamp: datetime = Field(default_factory=datetime.utcnow)
response: Optional[Dict[str, Any]] = None conversation_id: Optional[str] = None
status: str = Field(
description="The status of the interaction, e.g., 'completed', 'failed'."
)
neighbors: Optional[List[str]] = (
None # Names of neighboring agents involved
)
conversation_id: Optional[str] = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID for the conversation history.",
)
class AgentHealthStatus(BaseModel):
"""
Pydantic model to log and monitor agent health.
"""
class AgentHealth(BaseModel):
agent_name: str agent_name: str
timestamp: datetime = Field(default_factory=datetime.utcnow) timestamp: datetime = Field(default_factory=datetime.utcnow)
status: str = Field( status: str = "available" # available, busy, failed
default="available", active_tasks: int = 0
description="Agent health status, e.g., 'available', 'busy', 'failed'.", system_load: float = 0.0 # Placeholder for actual system load
)
active_tasks: int = Field(
0,
description="Number of active tasks assigned to this agent.",
)
load: float = Field(
0.0,
description="Current load on the agent (CPU or memory usage).",
)
class Swarm: class Swarm:
""" def __init__(self, agents: List[Agent], chroma_client: chromadb.Client, api_key: str):
A scalable swarm architecture where agents can communicate by posting and querying all activities to ChromaDB.
Every input task, response, and action by the agents is logged to the vector database for persistent tracking.
Attributes:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): An instance of the ChromaDB client for agent-to-agent communication.
api_key (str): The OpenAI API key.
health_statuses (Dict[str, AgentHealthStatus]): A dictionary to monitor agent health statuses.
"""
def __init__(
self,
agents: List[Agent],
chroma_client: chromadb.Client,
api_key: str,
) -> None:
"""
Initializes the swarm with agents and a ChromaDB client for vector storage and communication.
Args:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): The ChromaDB client for handling vector embeddings.
api_key (str): The OpenAI API key.
"""
self.agents = agents self.agents = agents
self.chroma_client = chroma_client self.chroma_client = chroma_client
self.api_key = api_key self.api_key = api_key
self.health_statuses: Dict[str, AgentHealthStatus] = { self.health: Dict[str, AgentHealth] = {}
agent.agent_name: AgentHealthStatus( self.register_agents()
agent_name=agent.agent_name
)
for agent in agents
}
logger.info(f"Swarm initialized with {len(agents)} agents.") logger.info(f"Swarm initialized with {len(agents)} agents.")
def _log_to_db( def register_agents(self):
self, data: Dict[str, Any], description: str for agent in self.agents:
) -> None: self.health[agent.agent_name] = AgentHealth(agent_name=agent.agent_name)
""" agent_capabilities.add(
Logs a dictionary of data into the ChromaDB collection as a new entry. documents=[agent.system_prompt],
ids=[agent.agent_name],
Args: metadatas=[{"agent_name": agent.agent_name}],
data (Dict[str, Any]): The data to log in the database (task, response, etc.).
description (str): Description of the action (e.g., 'task', 'response').
"""
logger.info(f"Logging {description} to the database: {data}")
swarm_collection.add(
documents=[str(data)],
ids=[str(uuid.uuid4())], # Unique ID for each entry
metadatas=[
{
"description": description,
"timestamp": datetime.utcnow().isoformat(),
}
],
)
logger.info(
f"{description.capitalize()} logged successfully."
)
async def _find_most_relevant_agent(
self, task: str
) -> Optional[Agent]:
"""
Finds the agent whose system prompt is most relevant to the given task by querying ChromaDB.
If no relevant agents are found, return None and log a message.
Args:
task (str): The task for which to find the most relevant agent.
Returns:
Optional[Agent]: The most relevant agent for the task, or None if no relevant agent is found.
"""
logger.info(
f"Searching for the most relevant agent for the task: {task}"
)
# Query ChromaDB collection for nearest neighbor to the task
result = swarm_collection.query(
query_texts=[task], n_results=4
)
# Check if the query result contains any data
if not result["ids"] or not result["ids"][0]:
logger.error(
"No relevant agents found for the given task."
) )
return None # No agent found, return None
# Extract the agent ID from the result and find the corresponding agent
agent_id = result["ids"][0][0]
most_relevant_agent = next(
(
agent
for agent in self.agents
if agent.agent_name == agent_id
),
None,
)
if most_relevant_agent: async def find_best_agent(self, task: str) -> Optional[Agent]:
logger.info( results = agent_capabilities.query(query_texts=[task], n_results=1)
f"Most relevant agent for task '{task}' is {most_relevant_agent.agent_name}." if results["ids"] and results["ids"][0]:
agent_name = results["ids"][0][0]
agent = next((a for a in self.agents if a.agent_name == agent_name), None)
if agent:
logger.info(f"Best agent for task '{task}' is {agent.agent_name}.")
return agent
logger.warning(f"No suitable agent found for task: {task}")
return None
def post_message(self, message: Message):
swarm_activity.add(
documents=[message.model_dump_json()],
ids=[message.message_id],
metadatas=[message.model_dump()], # Store metadata for querying
)
def query_messages(self, query: str, message_type: Optional[str] = None, n_results: int = 5) -> List[Message]:
filter_query = {}
if message_type:
filter_query = {"message_type": message_type}
results = swarm_activity.query(query_texts=[query], n_results=n_results, where=filter_query)
messages = []
if results["documents"]:
for doc in results['documents'][0]: # Because the query returns a list of lists of documents
try:
messages.append(Message.model_parse_raw(doc))
except Exception as e:
logger.error(f"Error parsing message document: {e}")
return messages
async def run_task(self, task: str, conversation_id: Optional[str] = None) -> Optional[Any]:
agent = await self.find_best_agent(task)
if not agent:
return None
self.update_agent_health(agent, "busy")
self.post_message(Message(agent_name="Swarm", message_type="task", content=task, conversation_id=conversation_id))
try:
result = await agent.run(task)
self.post_message(
Message(
agent_name=agent.agent_name,
message_type="response",
content=result,
conversation_id=conversation_id,
)
) )
else: self.update_agent_health(agent, "available")
logger.error("No matching agent found in the agent list.") return result
except Exception as e:
return most_relevant_agent self.update_agent_health(agent, "failed")
logger.error(f"Agent {agent.agent_name} failed to execute task: {e}")
def _monitor_health(self, agent: Agent) -> None: return None
"""
Monitors the health status of agents and logs it to the database.
Args: def update_agent_health(self, agent: Agent, status: str):
agent (Agent): The agent whose health is being monitored. health = self.health.get(agent.agent_name)
""" if health:
current_status = self.health_statuses[agent.agent_name] health.status = status
current_status.active_tasks += ( health.active_tasks = agent.active_tasks # Assuming agent tracks active tasks
1 # Example increment for active tasks health.system_load = psutil.cpu_percent()
) health.timestamp = datetime.utcnow()
current_status.status = ( self.post_message(Message(agent_name="Swarm", message_type="health_update", content=health.model_dump()))
"busy" if current_status.active_tasks > 0 else "available"
)
current_status.load = 0.5 # Placeholder for real load data
logger.info(
f"Agent {agent.agent_name} is currently {current_status.status} with load {current_status.load}."
)
# Log health status to the database def run(self, task: str, conversation_id: Optional[str] = None) -> Any:
self._log_to_db(current_status.dict(), "health status") return asyncio.run(self.run_task(task, conversation_id))
def post_message(self, agent: Agent, message: str) -> None:
"""
Posts a message from an agent to the shared database.
Args:
agent (Agent): The agent posting the message.
message (str): The message to be posted.
"""
logger.info(
f"Agent {agent.agent_name} posting message: {message}"
)
message_data = {
"agent_name": agent.agent_name,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
self._log_to_db(message_data, "message")
def query_messages(
self, query: str, n_results: int = 5
) -> List[Dict[str, Any]]:
"""
Queries the database for relevant messages.
Args:
query (str): The query message or task for which to retrieve related messages.
n_results (int, optional): The number of relevant messages to retrieve. Defaults to 5.
Returns:
List[Dict[str, Any]]: A list of relevant messages and their metadata.
"""
logger.info(f"Querying the database for query: {query}")
results = swarm_collection.query(
query_texts=[query], n_results=n_results
)
logger.info(
f"Found {len(results['documents'])} relevant messages."
)
return results
async def run_async(self, task: str) -> None:
"""
Main entry point to find the most relevant agent, submit the task, and allow agents to
query the database to understand the task's history. Logs every task and response.
Args:
task (str): The task to be completed.
"""
# Query past messages to understand task history
past_messages = self.query_messages(task)
logger.info(
f"Past messages related to task '{task}': {past_messages}"
)
# Find the most relevant agent
agent = await self._find_most_relevant_agent(task)
if agent is None:
logger.error(
f"No relevant agent found for task: {task}. Task submission aborted."
)
return # Exit the function if no relevant agent is found
# Submit the task to the agent if found
await self._submit_task_to_agent(agent, task)
async def _submit_task_to_agent(
self, agent: Agent, task: str
) -> Dict[str, Any]:
"""
Submits a task to the specified agent and logs the result asynchronously.
Args:
agent (Agent): The agent to which the task will be submitted.
task (str): The task to be solved.
Returns:
Dict[str, Any]: The result of the task from the agent.
"""
if agent is None:
logger.error("No agent provided for task submission.")
return
logger.info(
f"Submitting task '{task}' to agent {agent.agent_name}."
)
interaction_log = InteractionLog(
agent_name=agent.agent_name, task=task, status="started"
)
# Log the task as a message to the shared database
self._log_to_db(
{"task": task, "agent_name": agent.agent_name}, "task"
)
result = await agent.run(task)
interaction_log.response = result
interaction_log.status = "completed"
interaction_log.timestamp = datetime.utcnow()
logger.info(
f"Task completed by agent {agent.agent_name}. Logged interaction: {interaction_log.dict()}"
)
# Log the result as a message to the shared database
self._log_to_db(
{"response": result, "agent_name": agent.agent_name},
"response",
)
return result
def run(self, task: str, *args, **kwargs):
return asyncio.run(self.run_async(task))
# Initialize the OpenAI model and agents # Initialize the OpenAI model and agents
@ -357,4 +172,10 @@ swarm = Swarm(
# Execute tasks asynchronously # Execute tasks asynchronously
task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?" task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"
print(swarm.run(task)) conversation_id = str(uuid.uuid4()) # Create a conversation ID
print(swarm.run(task, conversation_id))
# Example of querying messages related to the conversation:
past_messages = swarm.query_messages(query="", message_type="response", n_results=10)
print(f"Past messages in conversation {conversation_id}: {past_messages}")

Loading…
Cancel
Save