Update sap.py

pull/590/head
kirill670 3 months ago committed by GitHub
parent edc293cb6f
commit 8feebf0fb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,6 +5,7 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
import chromadb
import psutil
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field
@ -20,307 +21,121 @@ load_dotenv()
# Initialize ChromaDB client
chroma_client = chromadb.Client()
# Create a ChromaDB collection to store tasks, responses, and all swarm activity
swarm_collection = chroma_client.create_collection(
name="swarm_activity"
)
# Collection for swarm activity (tasks, responses, messages)
swarm_activity = chroma_client.create_collection(name="swarm_activity")
# Collection for agent capabilities
agent_capabilities = chroma_client.create_collection(name="agent_capabilities")
class InteractionLog(BaseModel):
"""
Pydantic model to log all interactions between agents, tasks, and responses.
"""
interaction_id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID for the interaction.",
)
class Message(BaseModel):
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
agent_name: str
task: str
message_type: str # e.g., "task", "request", "response"
content: Any
timestamp: datetime = Field(default_factory=datetime.utcnow)
response: Optional[Dict[str, Any]] = None
status: str = Field(
description="The status of the interaction, e.g., 'completed', 'failed'."
)
neighbors: Optional[List[str]] = (
None # Names of neighboring agents involved
)
conversation_id: Optional[str] = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID for the conversation history.",
)
class AgentHealthStatus(BaseModel):
"""
Pydantic model to log and monitor agent health.
"""
conversation_id: Optional[str] = None
class AgentHealth(BaseModel):
agent_name: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
status: str = Field(
default="available",
description="Agent health status, e.g., 'available', 'busy', 'failed'.",
)
active_tasks: int = Field(
0,
description="Number of active tasks assigned to this agent.",
)
load: float = Field(
0.0,
description="Current load on the agent (CPU or memory usage).",
)
status: str = "available" # available, busy, failed
active_tasks: int = 0
system_load: float = 0.0 # Placeholder for actual system load
class Swarm:
"""
A scalable swarm architecture where agents can communicate by posting and querying all activities to ChromaDB.
Every input task, response, and action by the agents is logged to the vector database for persistent tracking.
Attributes:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): An instance of the ChromaDB client for agent-to-agent communication.
api_key (str): The OpenAI API key.
health_statuses (Dict[str, AgentHealthStatus]): A dictionary to monitor agent health statuses.
"""
def __init__(
self,
agents: List[Agent],
chroma_client: chromadb.Client,
api_key: str,
) -> None:
"""
Initializes the swarm with agents and a ChromaDB client for vector storage and communication.
Args:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): The ChromaDB client for handling vector embeddings.
api_key (str): The OpenAI API key.
"""
def __init__(self, agents: List[Agent], chroma_client: chromadb.Client, api_key: str):
self.agents = agents
self.chroma_client = chroma_client
self.api_key = api_key
self.health_statuses: Dict[str, AgentHealthStatus] = {
agent.agent_name: AgentHealthStatus(
agent_name=agent.agent_name
)
for agent in agents
}
self.health: Dict[str, AgentHealth] = {}
self.register_agents()
logger.info(f"Swarm initialized with {len(agents)} agents.")
def _log_to_db(
self, data: Dict[str, Any], description: str
) -> None:
"""
Logs a dictionary of data into the ChromaDB collection as a new entry.
Args:
data (Dict[str, Any]): The data to log in the database (task, response, etc.).
description (str): Description of the action (e.g., 'task', 'response').
"""
logger.info(f"Logging {description} to the database: {data}")
swarm_collection.add(
documents=[str(data)],
ids=[str(uuid.uuid4())], # Unique ID for each entry
metadatas=[
{
"description": description,
"timestamp": datetime.utcnow().isoformat(),
}
],
)
logger.info(
f"{description.capitalize()} logged successfully."
)
async def _find_most_relevant_agent(
self, task: str
) -> Optional[Agent]:
"""
Finds the agent whose system prompt is most relevant to the given task by querying ChromaDB.
If no relevant agents are found, return None and log a message.
Args:
task (str): The task for which to find the most relevant agent.
Returns:
Optional[Agent]: The most relevant agent for the task, or None if no relevant agent is found.
"""
logger.info(
f"Searching for the most relevant agent for the task: {task}"
)
# Query ChromaDB collection for nearest neighbor to the task
result = swarm_collection.query(
query_texts=[task], n_results=4
)
# Check if the query result contains any data
if not result["ids"] or not result["ids"][0]:
logger.error(
"No relevant agents found for the given task."
def register_agents(self):
for agent in self.agents:
self.health[agent.agent_name] = AgentHealth(agent_name=agent.agent_name)
agent_capabilities.add(
documents=[agent.system_prompt],
ids=[agent.agent_name],
metadatas=[{"agent_name": agent.agent_name}],
)
return None # No agent found, return None
# Extract the agent ID from the result and find the corresponding agent
agent_id = result["ids"][0][0]
most_relevant_agent = next(
(
agent
for agent in self.agents
if agent.agent_name == agent_id
),
None,
)
if most_relevant_agent:
logger.info(
f"Most relevant agent for task '{task}' is {most_relevant_agent.agent_name}."
async def find_best_agent(self, task: str) -> Optional[Agent]:
results = agent_capabilities.query(query_texts=[task], n_results=1)
if results["ids"] and results["ids"][0]:
agent_name = results["ids"][0][0]
agent = next((a for a in self.agents if a.agent_name == agent_name), None)
if agent:
logger.info(f"Best agent for task '{task}' is {agent.agent_name}.")
return agent
logger.warning(f"No suitable agent found for task: {task}")
return None
def post_message(self, message: Message):
swarm_activity.add(
documents=[message.model_dump_json()],
ids=[message.message_id],
metadatas=[message.model_dump()], # Store metadata for querying
)
def query_messages(self, query: str, message_type: Optional[str] = None, n_results: int = 5) -> List[Message]:
filter_query = {}
if message_type:
filter_query = {"message_type": message_type}
results = swarm_activity.query(query_texts=[query], n_results=n_results, where=filter_query)
messages = []
if results["documents"]:
for doc in results['documents'][0]: # Because the query returns a list of lists of documents
try:
messages.append(Message.model_parse_raw(doc))
except Exception as e:
logger.error(f"Error parsing message document: {e}")
return messages
async def run_task(self, task: str, conversation_id: Optional[str] = None) -> Optional[Any]:
agent = await self.find_best_agent(task)
if not agent:
return None
self.update_agent_health(agent, "busy")
self.post_message(Message(agent_name="Swarm", message_type="task", content=task, conversation_id=conversation_id))
try:
result = await agent.run(task)
self.post_message(
Message(
agent_name=agent.agent_name,
message_type="response",
content=result,
conversation_id=conversation_id,
)
)
else:
logger.error("No matching agent found in the agent list.")
return most_relevant_agent
def _monitor_health(self, agent: Agent) -> None:
"""
Monitors the health status of agents and logs it to the database.
self.update_agent_health(agent, "available")
return result
except Exception as e:
self.update_agent_health(agent, "failed")
logger.error(f"Agent {agent.agent_name} failed to execute task: {e}")
return None
Args:
agent (Agent): The agent whose health is being monitored.
"""
current_status = self.health_statuses[agent.agent_name]
current_status.active_tasks += (
1 # Example increment for active tasks
)
current_status.status = (
"busy" if current_status.active_tasks > 0 else "available"
)
current_status.load = 0.5 # Placeholder for real load data
logger.info(
f"Agent {agent.agent_name} is currently {current_status.status} with load {current_status.load}."
)
def update_agent_health(self, agent: Agent, status: str):
health = self.health.get(agent.agent_name)
if health:
health.status = status
health.active_tasks = agent.active_tasks # Assuming agent tracks active tasks
health.system_load = psutil.cpu_percent()
health.timestamp = datetime.utcnow()
self.post_message(Message(agent_name="Swarm", message_type="health_update", content=health.model_dump()))
# Log health status to the database
self._log_to_db(current_status.dict(), "health status")
def run(self, task: str, conversation_id: Optional[str] = None) -> Any:
return asyncio.run(self.run_task(task, conversation_id))
def post_message(self, agent: Agent, message: str) -> None:
"""
Posts a message from an agent to the shared database.
Args:
agent (Agent): The agent posting the message.
message (str): The message to be posted.
"""
logger.info(
f"Agent {agent.agent_name} posting message: {message}"
)
message_data = {
"agent_name": agent.agent_name,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
self._log_to_db(message_data, "message")
def query_messages(
self, query: str, n_results: int = 5
) -> List[Dict[str, Any]]:
"""
Queries the database for relevant messages.
Args:
query (str): The query message or task for which to retrieve related messages.
n_results (int, optional): The number of relevant messages to retrieve. Defaults to 5.
Returns:
List[Dict[str, Any]]: A list of relevant messages and their metadata.
"""
logger.info(f"Querying the database for query: {query}")
results = swarm_collection.query(
query_texts=[query], n_results=n_results
)
logger.info(
f"Found {len(results['documents'])} relevant messages."
)
return results
async def run_async(self, task: str) -> None:
"""
Main entry point to find the most relevant agent, submit the task, and allow agents to
query the database to understand the task's history. Logs every task and response.
Args:
task (str): The task to be completed.
"""
# Query past messages to understand task history
past_messages = self.query_messages(task)
logger.info(
f"Past messages related to task '{task}': {past_messages}"
)
# Find the most relevant agent
agent = await self._find_most_relevant_agent(task)
if agent is None:
logger.error(
f"No relevant agent found for task: {task}. Task submission aborted."
)
return # Exit the function if no relevant agent is found
# Submit the task to the agent if found
await self._submit_task_to_agent(agent, task)
async def _submit_task_to_agent(
self, agent: Agent, task: str
) -> Dict[str, Any]:
"""
Submits a task to the specified agent and logs the result asynchronously.
Args:
agent (Agent): The agent to which the task will be submitted.
task (str): The task to be solved.
Returns:
Dict[str, Any]: The result of the task from the agent.
"""
if agent is None:
logger.error("No agent provided for task submission.")
return
logger.info(
f"Submitting task '{task}' to agent {agent.agent_name}."
)
interaction_log = InteractionLog(
agent_name=agent.agent_name, task=task, status="started"
)
# Log the task as a message to the shared database
self._log_to_db(
{"task": task, "agent_name": agent.agent_name}, "task"
)
result = await agent.run(task)
interaction_log.response = result
interaction_log.status = "completed"
interaction_log.timestamp = datetime.utcnow()
logger.info(
f"Task completed by agent {agent.agent_name}. Logged interaction: {interaction_log.dict()}"
)
# Log the result as a message to the shared database
self._log_to_db(
{"response": result, "agent_name": agent.agent_name},
"response",
)
return result
def run(self, task: str, *args, **kwargs):
return asyncio.run(self.run_async(task))
# Initialize the OpenAI model and agents
@ -357,4 +172,10 @@ swarm = Swarm(
# Execute tasks asynchronously
task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"
print(swarm.run(task))
conversation_id = str(uuid.uuid4()) # Create a conversation ID
print(swarm.run(task, conversation_id))
# Example of querying messages related to the conversation:
past_messages = swarm.query_messages(query="", message_type="response", n_results=10)
print(f"Past messages in conversation {conversation_id}: {past_messages}")

Loading…
Cancel
Save