|
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
import chromadb
|
|
|
|
|
import psutil
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
from loguru import logger
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
@ -20,307 +21,121 @@ load_dotenv()
|
|
|
|
|
# Initialize ChromaDB client
|
|
|
|
|
chroma_client = chromadb.Client()
|
|
|
|
|
|
|
|
|
|
# Create a ChromaDB collection to store tasks, responses, and all swarm activity
|
|
|
|
|
swarm_collection = chroma_client.create_collection(
|
|
|
|
|
name="swarm_activity"
|
|
|
|
|
)
|
|
|
|
|
# Collection for swarm activity (tasks, responses, messages)
|
|
|
|
|
swarm_activity = chroma_client.create_collection(name="swarm_activity")
|
|
|
|
|
|
|
|
|
|
# Collection for agent capabilities
|
|
|
|
|
agent_capabilities = chroma_client.create_collection(name="agent_capabilities")
|
|
|
|
|
|
|
|
|
|
class InteractionLog(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Pydantic model to log all interactions between agents, tasks, and responses.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
interaction_id: str = Field(
|
|
|
|
|
default_factory=lambda: str(uuid.uuid4()),
|
|
|
|
|
description="Unique ID for the interaction.",
|
|
|
|
|
)
|
|
|
|
|
class Message(BaseModel):
|
|
|
|
|
message_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
|
|
|
agent_name: str
|
|
|
|
|
task: str
|
|
|
|
|
message_type: str # e.g., "task", "request", "response"
|
|
|
|
|
content: Any
|
|
|
|
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
|
response: Optional[Dict[str, Any]] = None
|
|
|
|
|
status: str = Field(
|
|
|
|
|
description="The status of the interaction, e.g., 'completed', 'failed'."
|
|
|
|
|
)
|
|
|
|
|
neighbors: Optional[List[str]] = (
|
|
|
|
|
None # Names of neighboring agents involved
|
|
|
|
|
)
|
|
|
|
|
conversation_id: Optional[str] = Field(
|
|
|
|
|
default_factory=lambda: str(uuid.uuid4()),
|
|
|
|
|
description="Unique ID for the conversation history.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentHealthStatus(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Pydantic model to log and monitor agent health.
|
|
|
|
|
"""
|
|
|
|
|
conversation_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentHealth(BaseModel):
|
|
|
|
|
agent_name: str
|
|
|
|
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
|
status: str = Field(
|
|
|
|
|
default="available",
|
|
|
|
|
description="Agent health status, e.g., 'available', 'busy', 'failed'.",
|
|
|
|
|
)
|
|
|
|
|
active_tasks: int = Field(
|
|
|
|
|
0,
|
|
|
|
|
description="Number of active tasks assigned to this agent.",
|
|
|
|
|
)
|
|
|
|
|
load: float = Field(
|
|
|
|
|
0.0,
|
|
|
|
|
description="Current load on the agent (CPU or memory usage).",
|
|
|
|
|
)
|
|
|
|
|
status: str = "available" # available, busy, failed
|
|
|
|
|
active_tasks: int = 0
|
|
|
|
|
system_load: float = 0.0 # Placeholder for actual system load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Swarm:
|
|
|
|
|
"""
|
|
|
|
|
A scalable swarm architecture where agents can communicate by posting and querying all activities to ChromaDB.
|
|
|
|
|
Every input task, response, and action by the agents is logged to the vector database for persistent tracking.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
agents (List[Agent]): A list of initialized agents.
|
|
|
|
|
chroma_client (chroma.Client): An instance of the ChromaDB client for agent-to-agent communication.
|
|
|
|
|
api_key (str): The OpenAI API key.
|
|
|
|
|
health_statuses (Dict[str, AgentHealthStatus]): A dictionary to monitor agent health statuses.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
agents: List[Agent],
|
|
|
|
|
chroma_client: chromadb.Client,
|
|
|
|
|
api_key: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Initializes the swarm with agents and a ChromaDB client for vector storage and communication.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
agents (List[Agent]): A list of initialized agents.
|
|
|
|
|
chroma_client (chroma.Client): The ChromaDB client for handling vector embeddings.
|
|
|
|
|
api_key (str): The OpenAI API key.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, agents: List[Agent], chroma_client: chromadb.Client, api_key: str):
|
|
|
|
|
self.agents = agents
|
|
|
|
|
self.chroma_client = chroma_client
|
|
|
|
|
self.api_key = api_key
|
|
|
|
|
self.health_statuses: Dict[str, AgentHealthStatus] = {
|
|
|
|
|
agent.agent_name: AgentHealthStatus(
|
|
|
|
|
agent_name=agent.agent_name
|
|
|
|
|
)
|
|
|
|
|
for agent in agents
|
|
|
|
|
}
|
|
|
|
|
self.health: Dict[str, AgentHealth] = {}
|
|
|
|
|
self.register_agents()
|
|
|
|
|
logger.info(f"Swarm initialized with {len(agents)} agents.")
|
|
|
|
|
|
|
|
|
|
def _log_to_db(
|
|
|
|
|
self, data: Dict[str, Any], description: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Logs a dictionary of data into the ChromaDB collection as a new entry.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data (Dict[str, Any]): The data to log in the database (task, response, etc.).
|
|
|
|
|
description (str): Description of the action (e.g., 'task', 'response').
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Logging {description} to the database: {data}")
|
|
|
|
|
swarm_collection.add(
|
|
|
|
|
documents=[str(data)],
|
|
|
|
|
ids=[str(uuid.uuid4())], # Unique ID for each entry
|
|
|
|
|
metadatas=[
|
|
|
|
|
{
|
|
|
|
|
"description": description,
|
|
|
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"{description.capitalize()} logged successfully."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _find_most_relevant_agent(
|
|
|
|
|
self, task: str
|
|
|
|
|
) -> Optional[Agent]:
|
|
|
|
|
"""
|
|
|
|
|
Finds the agent whose system prompt is most relevant to the given task by querying ChromaDB.
|
|
|
|
|
If no relevant agents are found, return None and log a message.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task (str): The task for which to find the most relevant agent.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Agent]: The most relevant agent for the task, or None if no relevant agent is found.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Searching for the most relevant agent for the task: {task}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Query ChromaDB collection for nearest neighbor to the task
|
|
|
|
|
result = swarm_collection.query(
|
|
|
|
|
query_texts=[task], n_results=4
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check if the query result contains any data
|
|
|
|
|
if not result["ids"] or not result["ids"][0]:
|
|
|
|
|
logger.error(
|
|
|
|
|
"No relevant agents found for the given task."
|
|
|
|
|
def register_agents(self):
|
|
|
|
|
for agent in self.agents:
|
|
|
|
|
self.health[agent.agent_name] = AgentHealth(agent_name=agent.agent_name)
|
|
|
|
|
agent_capabilities.add(
|
|
|
|
|
documents=[agent.system_prompt],
|
|
|
|
|
ids=[agent.agent_name],
|
|
|
|
|
metadatas=[{"agent_name": agent.agent_name}],
|
|
|
|
|
)
|
|
|
|
|
return None # No agent found, return None
|
|
|
|
|
|
|
|
|
|
# Extract the agent ID from the result and find the corresponding agent
|
|
|
|
|
agent_id = result["ids"][0][0]
|
|
|
|
|
most_relevant_agent = next(
|
|
|
|
|
(
|
|
|
|
|
agent
|
|
|
|
|
for agent in self.agents
|
|
|
|
|
if agent.agent_name == agent_id
|
|
|
|
|
),
|
|
|
|
|
None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if most_relevant_agent:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Most relevant agent for task '{task}' is {most_relevant_agent.agent_name}."
|
|
|
|
|
async def find_best_agent(self, task: str) -> Optional[Agent]:
|
|
|
|
|
results = agent_capabilities.query(query_texts=[task], n_results=1)
|
|
|
|
|
if results["ids"] and results["ids"][0]:
|
|
|
|
|
agent_name = results["ids"][0][0]
|
|
|
|
|
agent = next((a for a in self.agents if a.agent_name == agent_name), None)
|
|
|
|
|
if agent:
|
|
|
|
|
logger.info(f"Best agent for task '{task}' is {agent.agent_name}.")
|
|
|
|
|
return agent
|
|
|
|
|
logger.warning(f"No suitable agent found for task: {task}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def post_message(self, message: Message):
|
|
|
|
|
swarm_activity.add(
|
|
|
|
|
documents=[message.model_dump_json()],
|
|
|
|
|
ids=[message.message_id],
|
|
|
|
|
metadatas=[message.model_dump()], # Store metadata for querying
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def query_messages(self, query: str, message_type: Optional[str] = None, n_results: int = 5) -> List[Message]:
|
|
|
|
|
filter_query = {}
|
|
|
|
|
if message_type:
|
|
|
|
|
filter_query = {"message_type": message_type}
|
|
|
|
|
results = swarm_activity.query(query_texts=[query], n_results=n_results, where=filter_query)
|
|
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
|
if results["documents"]:
|
|
|
|
|
for doc in results['documents'][0]: # Because the query returns a list of lists of documents
|
|
|
|
|
try:
|
|
|
|
|
messages.append(Message.model_parse_raw(doc))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error parsing message document: {e}")
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_task(self, task: str, conversation_id: Optional[str] = None) -> Optional[Any]:
|
|
|
|
|
agent = await self.find_best_agent(task)
|
|
|
|
|
if not agent:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
self.update_agent_health(agent, "busy")
|
|
|
|
|
|
|
|
|
|
self.post_message(Message(agent_name="Swarm", message_type="task", content=task, conversation_id=conversation_id))
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = await agent.run(task)
|
|
|
|
|
self.post_message(
|
|
|
|
|
Message(
|
|
|
|
|
agent_name=agent.agent_name,
|
|
|
|
|
message_type="response",
|
|
|
|
|
content=result,
|
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
logger.error("No matching agent found in the agent list.")
|
|
|
|
|
|
|
|
|
|
return most_relevant_agent
|
|
|
|
|
|
|
|
|
|
def _monitor_health(self, agent: Agent) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Monitors the health status of agents and logs it to the database.
|
|
|
|
|
self.update_agent_health(agent, "available")
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.update_agent_health(agent, "failed")
|
|
|
|
|
logger.error(f"Agent {agent.agent_name} failed to execute task: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
agent (Agent): The agent whose health is being monitored.
|
|
|
|
|
"""
|
|
|
|
|
current_status = self.health_statuses[agent.agent_name]
|
|
|
|
|
current_status.active_tasks += (
|
|
|
|
|
1 # Example increment for active tasks
|
|
|
|
|
)
|
|
|
|
|
current_status.status = (
|
|
|
|
|
"busy" if current_status.active_tasks > 0 else "available"
|
|
|
|
|
)
|
|
|
|
|
current_status.load = 0.5 # Placeholder for real load data
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Agent {agent.agent_name} is currently {current_status.status} with load {current_status.load}."
|
|
|
|
|
)
|
|
|
|
|
def update_agent_health(self, agent: Agent, status: str):
|
|
|
|
|
health = self.health.get(agent.agent_name)
|
|
|
|
|
if health:
|
|
|
|
|
health.status = status
|
|
|
|
|
health.active_tasks = agent.active_tasks # Assuming agent tracks active tasks
|
|
|
|
|
health.system_load = psutil.cpu_percent()
|
|
|
|
|
health.timestamp = datetime.utcnow()
|
|
|
|
|
self.post_message(Message(agent_name="Swarm", message_type="health_update", content=health.model_dump()))
|
|
|
|
|
|
|
|
|
|
# Log health status to the database
|
|
|
|
|
self._log_to_db(current_status.dict(), "health status")
|
|
|
|
|
def run(self, task: str, conversation_id: Optional[str] = None) -> Any:
|
|
|
|
|
return asyncio.run(self.run_task(task, conversation_id))
|
|
|
|
|
|
|
|
|
|
def post_message(self, agent: Agent, message: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Posts a message from an agent to the shared database.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
agent (Agent): The agent posting the message.
|
|
|
|
|
message (str): The message to be posted.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Agent {agent.agent_name} posting message: {message}"
|
|
|
|
|
)
|
|
|
|
|
message_data = {
|
|
|
|
|
"agent_name": agent.agent_name,
|
|
|
|
|
"message": message,
|
|
|
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
|
|
|
}
|
|
|
|
|
self._log_to_db(message_data, "message")
|
|
|
|
|
|
|
|
|
|
def query_messages(
|
|
|
|
|
self, query: str, n_results: int = 5
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Queries the database for relevant messages.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query (str): The query message or task for which to retrieve related messages.
|
|
|
|
|
n_results (int, optional): The number of relevant messages to retrieve. Defaults to 5.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[Dict[str, Any]]: A list of relevant messages and their metadata.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Querying the database for query: {query}")
|
|
|
|
|
results = swarm_collection.query(
|
|
|
|
|
query_texts=[query], n_results=n_results
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Found {len(results['documents'])} relevant messages."
|
|
|
|
|
)
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
async def run_async(self, task: str) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Main entry point to find the most relevant agent, submit the task, and allow agents to
|
|
|
|
|
query the database to understand the task's history. Logs every task and response.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
task (str): The task to be completed.
|
|
|
|
|
"""
|
|
|
|
|
# Query past messages to understand task history
|
|
|
|
|
past_messages = self.query_messages(task)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Past messages related to task '{task}': {past_messages}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Find the most relevant agent
|
|
|
|
|
agent = await self._find_most_relevant_agent(task)
|
|
|
|
|
|
|
|
|
|
if agent is None:
|
|
|
|
|
logger.error(
|
|
|
|
|
f"No relevant agent found for task: {task}. Task submission aborted."
|
|
|
|
|
)
|
|
|
|
|
return # Exit the function if no relevant agent is found
|
|
|
|
|
|
|
|
|
|
# Submit the task to the agent if found
|
|
|
|
|
await self._submit_task_to_agent(agent, task)
|
|
|
|
|
|
|
|
|
|
async def _submit_task_to_agent(
|
|
|
|
|
self, agent: Agent, task: str
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Submits a task to the specified agent and logs the result asynchronously.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
agent (Agent): The agent to which the task will be submitted.
|
|
|
|
|
task (str): The task to be solved.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict[str, Any]: The result of the task from the agent.
|
|
|
|
|
"""
|
|
|
|
|
if agent is None:
|
|
|
|
|
logger.error("No agent provided for task submission.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Submitting task '{task}' to agent {agent.agent_name}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
interaction_log = InteractionLog(
|
|
|
|
|
agent_name=agent.agent_name, task=task, status="started"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log the task as a message to the shared database
|
|
|
|
|
self._log_to_db(
|
|
|
|
|
{"task": task, "agent_name": agent.agent_name}, "task"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result = await agent.run(task)
|
|
|
|
|
|
|
|
|
|
interaction_log.response = result
|
|
|
|
|
interaction_log.status = "completed"
|
|
|
|
|
interaction_log.timestamp = datetime.utcnow()
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Task completed by agent {agent.agent_name}. Logged interaction: {interaction_log.dict()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log the result as a message to the shared database
|
|
|
|
|
self._log_to_db(
|
|
|
|
|
{"response": result, "agent_name": agent.agent_name},
|
|
|
|
|
"response",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def run(self, task: str, *args, **kwargs):
|
|
|
|
|
return asyncio.run(self.run_async(task))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize the OpenAI model and agents
|
|
|
|
@ -357,4 +172,10 @@ swarm = Swarm(
|
|
|
|
|
|
|
|
|
|
# Execute tasks asynchronously
|
|
|
|
|
task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"
|
|
|
|
|
print(swarm.run(task))
|
|
|
|
|
conversation_id = str(uuid.uuid4()) # Create a conversation ID
|
|
|
|
|
print(swarm.run(task, conversation_id))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Example of querying messages related to the conversation:
|
|
|
|
|
past_messages = swarm.query_messages(query="", message_type="response", n_results=10)
|
|
|
|
|
print(f"Past messages in conversation {conversation_id}: {past_messages}")
|
|
|
|
|