pull/590/head
Your Name 3 months ago
parent 8cf8761a8f
commit c51ef4b72d

@ -6,15 +6,19 @@ from swarms import create_agents_from_yaml
load_dotenv() load_dotenv()
# Path to your YAML file # Path to your YAML file
yaml_file = 'agents_config.yaml' yaml_file = "agents_config.yaml"
try: try:
# Create agents and run tasks (using 'both' to return agents and task results) # Create agents and run tasks (using 'both' to return agents and task results)
agents, task_results = create_agents_from_yaml(yaml_file, return_type="both") agents, task_results = create_agents_from_yaml(
yaml_file, return_type="both"
)
# Print the results of the tasks # Print the results of the tasks
for result in task_results: for result in task_results:
print(f"Agent: {result['agent_name']} | Task: {result['task']} | Output: {result.get('output', 'Error encountered')}") print(
f"Agent: {result['agent_name']} | Task: {result['task']} | Output: {result.get('output', 'Error encountered')}"
)
except Exception as e: except Exception as e:
logger.error(f"An error occurred: {e}") logger.error(f"An error occurred: {e}")

@ -166,6 +166,7 @@ nav:
- GroupChat: "swarms/structs/group_chat.md" - GroupChat: "swarms/structs/group_chat.md"
- AgentRegistry: "swarms/structs/agent_registry.md" - AgentRegistry: "swarms/structs/agent_registry.md"
- SpreadSheetSwarm: "swarms/structs/spreadsheet_swarm.md" - SpreadSheetSwarm: "swarms/structs/spreadsheet_swarm.md"
- ForestSwarm: "swarms/structs/forest_swarm.md"
- Workflows: - Workflows:
- ConcurrentWorkflow: "swarms/structs/concurrentworkflow.md" - ConcurrentWorkflow: "swarms/structs/concurrentworkflow.md"
- SequentialWorkflow: "swarms/structs/sequential_workflow.md" - SequentialWorkflow: "swarms/structs/sequential_workflow.md"

@ -0,0 +1,141 @@
# Forest Swarm
This documentation describes the **ForestSwarm** that organizes agents into trees. Each agent specializes in processing specific tasks. Trees are collections of agents, each assigned based on their relevance to a task through keyword extraction and embedding-based similarity.
The architecture allows for efficient task assignment by selecting the most relevant agent from a set of trees. Tasks are processed asynchronously, with agents selected based on task relevance, calculated by the similarity of system prompts and task keywords.
---
### Class: `TreeAgent`
`TreeAgent` represents an individual agent responsible for handling a specific task. Agents are initialized with a **system prompt** and are responsible for dynamically determining their relevance to a given task.
#### Attributes
| **Attribute** | **Type** | **Description** |
|--------------------------|------------------|---------------------------------------------------------------------------------|
| `system_prompt` | `str` | A string that defines the agent's area of expertise and task-handling capability.|
| `llm` | `callable` | The language model (LLM) used to process tasks (e.g., GPT-4). |
| `agent_name` | `str` | The name of the agent. |
| `system_prompt_embedding`| `tensor` | Embedding of the system prompt for similarity-based task matching. |
| `relevant_keywords` | `List[str]` | Keywords dynamically extracted from the system prompt to assist in task matching.|
| `distance` | `Optional[float]`| The computed distance between agents based on embedding similarity. |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `calculate_distance(other_agent: TreeAgent)` | `other_agent: TreeAgent` | `float` | Calculates the cosine similarity between this agent and another agent. |
| `run_task(task: str)` | `task: str` | `Any` | Executes the task, logs the input/output, and returns the result. |
| `is_relevant_for_task(task: str, threshold: float = 0.7)` | `task: str, threshold: float` | `bool` | Checks if the agent is relevant for the task using keyword matching or embedding similarity.|
---
### Class: `Tree`
`Tree` organizes multiple agents into a hierarchical structure, where agents are sorted based on their relevance to tasks.
#### Attributes
| **Attribute** | **Type** | **Description** |
|--------------------------|------------------|---------------------------------------------------------------------------------|
| `tree_name` | `str` | The name of the tree (represents a domain of agents, e.g., "Financial Tree"). |
| `agents` | `List[TreeAgent]`| List of agents belonging to this tree. |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `calculate_agent_distances()` | `None` | `None` | Calculates and assigns distances between agents based on similarity of prompts. |
| `find_relevant_agent(task: str)` | `task: str` | `Optional[TreeAgent]` | Finds the most relevant agent for a task based on keyword and embedding similarity. |
| `log_tree_execution(task: str, selected_agent: TreeAgent, result: Any)` | `task: str, selected_agent: TreeAgent, result: Any` | `None` | Logs details of the task execution by the selected agent. |
---
### Class: `ForestSwarm`
`ForestSwarm` is the main class responsible for managing multiple trees. It oversees task delegation by finding the most relevant tree and agent for a given task.
#### Attributes
| **Attribute** | **Type** | **Description** |
|--------------------------|------------------|---------------------------------------------------------------------------------|
| `trees` | `List[Tree]` | List of trees containing agents organized by domain. |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `find_relevant_tree(task: str)` | `task: str` | `Optional[Tree]` | Searches across all trees to find the most relevant tree based on task requirements.|
| `run(task: str)` | `task: str` | `Any` | Executes the task by finding the most relevant agent from the relevant tree. |
## Full Code Example
---
## Example Workflow
1. **Create Agents**: Agents are initialized with varying system prompts, representing different areas of expertise (e.g., stock analysis, tax filing).
2. **Create Trees**: Agents are grouped into trees, with each tree representing a domain (e.g., "Financial Tree", "Investment Tree").
3. **Run Task**: When a task is submitted, the system traverses through all trees and finds the most relevant agent to handle the task.
4. **Task Execution**: The selected agent processes the task, and the result is returned.
```plaintext
Task: "Our company is incorporated in Delaware, how do we do our taxes for free?"
```
**Process**:
- The system searches through the `Financial Tree` and `Investment Tree`.
- The most relevant agent (likely the "Tax Filing Agent") is selected based on keyword matching and prompt similarity.
- The task is processed, and the result is logged and returned.
---
## Analysis of the Swarm Architecture
The **Swarm Architecture** leverages a hierarchical structure (forest) composed of individual trees, each containing agents specialized in specific domains. This design allows for:
- **Modular and Scalable Organization**: By separating agents into trees, it is easy to expand or contract the system by adding or removing trees or agents.
- **Task Specialization**: Each agent is specialized, which ensures that tasks are matched with the most appropriate agent based on relevance and expertise.
- **Dynamic Matching**: The architecture uses both keyword-based and embedding-based matching to assign tasks, ensuring a high level of accuracy in agent selection.
- **Logging and Accountability**: Each task execution is logged in detail, providing transparency and an audit trail of which agent handled which task and the results produced.
- **Asynchronous Task Execution**: The architecture can be adapted for asynchronous task processing, making it scalable and suitable for large-scale task handling in real-time systems.
---
## Mermaid Diagram of the Swarm Architecture
```mermaid
graph TD
A[ForestSwarm] --> B[Financial Tree]
A --> C[Investment Tree]
B --> D[Stock Analysis Agent]
B --> E[Financial Planning Agent]
B --> F[Retirement Strategy Agent]
C --> G[Tax Filing Agent]
C --> H[Investment Strategy Agent]
C --> I[ROTH IRA Agent]
subgraph Tree Agents
D[Stock Analysis Agent]
E[Financial Planning Agent]
F[Retirement Strategy Agent]
G[Tax Filing Agent]
H[Investment Strategy Agent]
I[ROTH IRA Agent]
end
```
### Explanation of the Diagram
- **ForestSwarm**: Represents the top-level structure managing multiple trees.
- **Trees**: In the example, two trees exist—**Financial Tree** and **Investment Tree**—each containing agents related to specific domains.
- **Agents**: Each agent within the tree is responsible for handling tasks in its area of expertise. Agents within a tree are organized based on their prompt similarity (distance).
---
### Summary
This **Multi-Agent Tree Structure** provides an efficient, scalable, and accurate architecture for delegating and executing tasks based on domain-specific expertise. The combination of hierarchical organization, dynamic task matching, and logging ensures reliability, performance, and transparency in task execution.

@ -0,0 +1,44 @@
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
# Example Usage:
# Create agents with varying system prompts and dynamically generated distances/keywords
agents_tree1 = [
TreeAgent(
system_prompt="Stock Analysis Agent",
agent_name="Stock Analysis Agent",
),
TreeAgent(
system_prompt="Financial Planning Agent",
agent_name="Financial Planning Agent",
),
TreeAgent(
agent_name="Retirement Strategy Agent",
system_prompt="Retirement Strategy Agent",
),
]
agents_tree2 = [
TreeAgent(
system_prompt="Tax Filing Agent",
agent_name="Tax Filing Agent",
),
TreeAgent(
system_prompt="Investment Strategy Agent",
agent_name="Investment Strategy Agent",
),
TreeAgent(
system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
),
]
# Create trees
tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
# Create the ForestSwarm
multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
# Run a task
task = "Our company is incorporated in delaware, how do we do our taxes for free?"
output = multi_agent_structure.run(task)
print(output)

360
sap.py

@ -0,0 +1,360 @@
import asyncio
import os
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import chromadb
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field
from swarm_models import OpenAIChat
from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
load_dotenv()
# Initialize ChromaDB client
chroma_client = chromadb.Client()
# Create a ChromaDB collection to store tasks, responses, and all swarm activity
swarm_collection = chroma_client.create_collection(
name="swarm_activity"
)
class InteractionLog(BaseModel):
"""
Pydantic model to log all interactions between agents, tasks, and responses.
"""
interaction_id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID for the interaction.",
)
agent_name: str
task: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
response: Optional[Dict[str, Any]] = None
status: str = Field(
description="The status of the interaction, e.g., 'completed', 'failed'."
)
neighbors: Optional[List[str]] = (
None # Names of neighboring agents involved
)
conversation_id: Optional[str] = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID for the conversation history.",
)
class AgentHealthStatus(BaseModel):
"""
Pydantic model to log and monitor agent health.
"""
agent_name: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
status: str = Field(
default="available",
description="Agent health status, e.g., 'available', 'busy', 'failed'.",
)
active_tasks: int = Field(
0,
description="Number of active tasks assigned to this agent.",
)
load: float = Field(
0.0,
description="Current load on the agent (CPU or memory usage).",
)
class Swarm:
"""
A scalable swarm architecture where agents can communicate by posting and querying all activities to ChromaDB.
Every input task, response, and action by the agents is logged to the vector database for persistent tracking.
Attributes:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): An instance of the ChromaDB client for agent-to-agent communication.
api_key (str): The OpenAI API key.
health_statuses (Dict[str, AgentHealthStatus]): A dictionary to monitor agent health statuses.
"""
def __init__(
self,
agents: List[Agent],
chroma_client: chromadb.Client,
api_key: str,
) -> None:
"""
Initializes the swarm with agents and a ChromaDB client for vector storage and communication.
Args:
agents (List[Agent]): A list of initialized agents.
chroma_client (chroma.Client): The ChromaDB client for handling vector embeddings.
api_key (str): The OpenAI API key.
"""
self.agents = agents
self.chroma_client = chroma_client
self.api_key = api_key
self.health_statuses: Dict[str, AgentHealthStatus] = {
agent.agent_name: AgentHealthStatus(
agent_name=agent.agent_name
)
for agent in agents
}
logger.info(f"Swarm initialized with {len(agents)} agents.")
def _log_to_db(
self, data: Dict[str, Any], description: str
) -> None:
"""
Logs a dictionary of data into the ChromaDB collection as a new entry.
Args:
data (Dict[str, Any]): The data to log in the database (task, response, etc.).
description (str): Description of the action (e.g., 'task', 'response').
"""
logger.info(f"Logging {description} to the database: {data}")
swarm_collection.add(
documents=[str(data)],
ids=[str(uuid.uuid4())], # Unique ID for each entry
metadatas=[
{
"description": description,
"timestamp": datetime.utcnow().isoformat(),
}
],
)
logger.info(
f"{description.capitalize()} logged successfully."
)
async def _find_most_relevant_agent(
self, task: str
) -> Optional[Agent]:
"""
Finds the agent whose system prompt is most relevant to the given task by querying ChromaDB.
If no relevant agents are found, return None and log a message.
Args:
task (str): The task for which to find the most relevant agent.
Returns:
Optional[Agent]: The most relevant agent for the task, or None if no relevant agent is found.
"""
logger.info(
f"Searching for the most relevant agent for the task: {task}"
)
# Query ChromaDB collection for nearest neighbor to the task
result = swarm_collection.query(
query_texts=[task], n_results=4
)
# Check if the query result contains any data
if not result["ids"] or not result["ids"][0]:
logger.error(
"No relevant agents found for the given task."
)
return None # No agent found, return None
# Extract the agent ID from the result and find the corresponding agent
agent_id = result["ids"][0][0]
most_relevant_agent = next(
(
agent
for agent in self.agents
if agent.agent_name == agent_id
),
None,
)
if most_relevant_agent:
logger.info(
f"Most relevant agent for task '{task}' is {most_relevant_agent.agent_name}."
)
else:
logger.error("No matching agent found in the agent list.")
return most_relevant_agent
def _monitor_health(self, agent: Agent) -> None:
"""
Monitors the health status of agents and logs it to the database.
Args:
agent (Agent): The agent whose health is being monitored.
"""
current_status = self.health_statuses[agent.agent_name]
current_status.active_tasks += (
1 # Example increment for active tasks
)
current_status.status = (
"busy" if current_status.active_tasks > 0 else "available"
)
current_status.load = 0.5 # Placeholder for real load data
logger.info(
f"Agent {agent.agent_name} is currently {current_status.status} with load {current_status.load}."
)
# Log health status to the database
self._log_to_db(current_status.dict(), "health status")
def post_message(self, agent: Agent, message: str) -> None:
"""
Posts a message from an agent to the shared database.
Args:
agent (Agent): The agent posting the message.
message (str): The message to be posted.
"""
logger.info(
f"Agent {agent.agent_name} posting message: {message}"
)
message_data = {
"agent_name": agent.agent_name,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
}
self._log_to_db(message_data, "message")
def query_messages(
self, query: str, n_results: int = 5
) -> List[Dict[str, Any]]:
"""
Queries the database for relevant messages.
Args:
query (str): The query message or task for which to retrieve related messages.
n_results (int, optional): The number of relevant messages to retrieve. Defaults to 5.
Returns:
List[Dict[str, Any]]: A list of relevant messages and their metadata.
"""
logger.info(f"Querying the database for query: {query}")
results = swarm_collection.query(
query_texts=[query], n_results=n_results
)
logger.info(
f"Found {len(results['documents'])} relevant messages."
)
return results
async def run_async(self, task: str) -> None:
"""
Main entry point to find the most relevant agent, submit the task, and allow agents to
query the database to understand the task's history. Logs every task and response.
Args:
task (str): The task to be completed.
"""
# Query past messages to understand task history
past_messages = self.query_messages(task)
logger.info(
f"Past messages related to task '{task}': {past_messages}"
)
# Find the most relevant agent
agent = await self._find_most_relevant_agent(task)
if agent is None:
logger.error(
f"No relevant agent found for task: {task}. Task submission aborted."
)
return # Exit the function if no relevant agent is found
# Submit the task to the agent if found
await self._submit_task_to_agent(agent, task)
async def _submit_task_to_agent(
self, agent: Agent, task: str
) -> Dict[str, Any]:
"""
Submits a task to the specified agent and logs the result asynchronously.
Args:
agent (Agent): The agent to which the task will be submitted.
task (str): The task to be solved.
Returns:
Dict[str, Any]: The result of the task from the agent.
"""
if agent is None:
logger.error("No agent provided for task submission.")
return
logger.info(
f"Submitting task '{task}' to agent {agent.agent_name}."
)
interaction_log = InteractionLog(
agent_name=agent.agent_name, task=task, status="started"
)
# Log the task as a message to the shared database
self._log_to_db(
{"task": task, "agent_name": agent.agent_name}, "task"
)
result = await agent.run(task)
interaction_log.response = result
interaction_log.status = "completed"
interaction_log.timestamp = datetime.utcnow()
logger.info(
f"Task completed by agent {agent.agent_name}. Logged interaction: {interaction_log.dict()}"
)
# Log the result as a message to the shared database
self._log_to_db(
{"response": result, "agent_name": agent.agent_name},
"response",
)
return result
def run(self, task: str, *args, **kwargs):
return asyncio.run(self.run_async(task))
# Initialize the OpenAI model and agents
api_key = os.getenv("OPENAI_API_KEY")
model = OpenAIChat(
openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
)
# Example agent creation
agent = Agent(
agent_name="Financial-Analysis-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
llm=model,
max_loops=1,
autosave=True,
dashboard=False,
verbose=True,
dynamic_temperature_enabled=True,
saved_state_path="finance_agent.json",
user_name="swarms_corp",
retry_attempts=1,
context_length=200000,
return_step_meta=False,
)
# Example agents list
agents_list = [agent]
# Create the swarm
swarm = Swarm(
agents=agents_list, chroma_client=chroma_client, api_key=api_key
)
# Execute tasks asynchronously
task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"
print(swarm.run(task))

@ -11,7 +11,9 @@ from swarms.agents.stopping_conditions import (
check_success, check_success,
) )
from swarms.agents.tool_agent import ToolAgent from swarms.agents.tool_agent import ToolAgent
from swarms.agents.create_agents_from_yaml import create_agents_from_yaml from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
)
__all__ = [ __all__ = [

@ -1,14 +1,19 @@
import os import os
import yaml import yaml
from dotenv import load_dotenv
from loguru import logger from loguru import logger
from swarms.structs.agent import Agent
from swarm_models import OpenAIChat from swarm_models import OpenAIChat
from dotenv import load_dotenv
from swarms.structs.agent import Agent
load_dotenv() load_dotenv()
# Function to create and optionally run agents from a YAML file # Function to create and optionally run agents from a YAML file
def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args, **kwargs): def create_agents_from_yaml(
yaml_file: str, return_type: str = "agents", *args, **kwargs
):
""" """
Create agents based on configurations defined in a YAML file. Create agents based on configurations defined in a YAML file.
If a 'task' is provided in the YAML, the agent will execute the task after creation. If a 'task' is provided in the YAML, the agent will execute the task after creation.
@ -32,13 +37,17 @@ def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args,
# Load the YAML configuration # Load the YAML configuration
logger.info(f"Loading YAML file {yaml_file}") logger.info(f"Loading YAML file {yaml_file}")
with open(yaml_file, 'r') as file: with open(yaml_file, "r") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
# Ensure agents key exists # Ensure agents key exists
if "agents" not in config: if "agents" not in config:
logger.error("The YAML configuration does not contain 'agents'.") logger.error(
raise ValueError("The YAML configuration does not contain 'agents'.") "The YAML configuration does not contain 'agents'."
)
raise ValueError(
"The YAML configuration does not contain 'agents'."
)
# List to store created agents and task results # List to store created agents and task results
agents = [] agents = []
@ -49,10 +58,16 @@ def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args,
logger.info(f"Creating agent: {agent_config['agent_name']}") logger.info(f"Creating agent: {agent_config['agent_name']}")
# Get the OpenAI API key from environment or YAML config # Get the OpenAI API key from environment or YAML config
api_key = os.getenv("OPENAI_API_KEY") or agent_config["model"].get("openai_api_key") api_key = os.getenv("OPENAI_API_KEY") or agent_config[
"model"
].get("openai_api_key")
if not api_key: if not api_key:
logger.error(f"API key is missing for agent: {agent_config['agent_name']}") logger.error(
raise ValueError(f"API key is missing for agent: {agent_config['agent_name']}") f"API key is missing for agent: {agent_config['agent_name']}"
)
raise ValueError(
f"API key is missing for agent: {agent_config['agent_name']}"
)
# Create an instance of OpenAIChat model # Create an instance of OpenAIChat model
model = OpenAIChat( model = OpenAIChat(
@ -60,23 +75,38 @@ def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args,
model_name=agent_config["model"]["model_name"], model_name=agent_config["model"]["model_name"],
temperature=agent_config["model"]["temperature"], temperature=agent_config["model"]["temperature"],
max_tokens=agent_config["model"]["max_tokens"], max_tokens=agent_config["model"]["max_tokens"],
*args, **kwargs # Pass any additional arguments to the model *args,
**kwargs, # Pass any additional arguments to the model
) )
# Ensure the system prompt is provided # Ensure the system prompt is provided
if "system_prompt" not in agent_config: if "system_prompt" not in agent_config:
logger.error(f"System prompt is missing for agent: {agent_config['agent_name']}") logger.error(
raise ValueError(f"System prompt is missing for agent: {agent_config['agent_name']}") f"System prompt is missing for agent: {agent_config['agent_name']}"
)
raise ValueError(
f"System prompt is missing for agent: {agent_config['agent_name']}"
)
# Dynamically choose the system prompt based on the agent config # Dynamically choose the system prompt based on the agent config
try: try:
system_prompt = globals().get(agent_config["system_prompt"]) system_prompt = globals().get(
agent_config["system_prompt"]
)
if not system_prompt: if not system_prompt:
logger.error(f"System prompt {agent_config['system_prompt']} not found.") logger.error(
raise ValueError(f"System prompt {agent_config['system_prompt']} not found.") f"System prompt {agent_config['system_prompt']} not found."
)
raise ValueError(
f"System prompt {agent_config['system_prompt']} not found."
)
except KeyError: except KeyError:
logger.error(f"System prompt {agent_config['system_prompt']} is not valid.") logger.error(
raise ValueError(f"System prompt {agent_config['system_prompt']} is not valid.") f"System prompt {agent_config['system_prompt']} is not valid."
)
raise ValueError(
f"System prompt {agent_config['system_prompt']} is not valid."
)
# Initialize the agent using the configuration # Initialize the agent using the configuration
agent = Agent( agent = Agent(
@ -87,38 +117,55 @@ def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args,
autosave=agent_config.get("autosave", True), autosave=agent_config.get("autosave", True),
dashboard=agent_config.get("dashboard", False), dashboard=agent_config.get("dashboard", False),
verbose=agent_config.get("verbose", False), verbose=agent_config.get("verbose", False),
dynamic_temperature_enabled=agent_config.get("dynamic_temperature_enabled", False), dynamic_temperature_enabled=agent_config.get(
"dynamic_temperature_enabled", False
),
saved_state_path=agent_config.get("saved_state_path"), saved_state_path=agent_config.get("saved_state_path"),
user_name=agent_config.get("user_name", "default_user"), user_name=agent_config.get("user_name", "default_user"),
retry_attempts=agent_config.get("retry_attempts", 1), retry_attempts=agent_config.get("retry_attempts", 1),
context_length=agent_config.get("context_length", 100000), context_length=agent_config.get("context_length", 100000),
return_step_meta=agent_config.get("return_step_meta", False), return_step_meta=agent_config.get(
"return_step_meta", False
),
output_type=agent_config.get("output_type", "str"), output_type=agent_config.get("output_type", "str"),
*args, **kwargs # Pass any additional arguments to the agent *args,
**kwargs, # Pass any additional arguments to the agent
) )
logger.info(f"Agent {agent_config['agent_name']} created successfully.") logger.info(
f"Agent {agent_config['agent_name']} created successfully."
)
agents.append(agent) agents.append(agent)
# Check if a task is provided, and if so, run the agent # Check if a task is provided, and if so, run the agent
task = agent_config.get("task") task = agent_config.get("task")
if task: if task:
logger.info(f"Running task '{task}' with agent {agent_config['agent_name']}") logger.info(
f"Running task '{task}' with agent {agent_config['agent_name']}"
)
try: try:
output = agent.run(task) output = agent.run(task)
logger.info(f"Output for agent {agent_config['agent_name']}: {output}") logger.info(
task_results.append({ f"Output for agent {agent_config['agent_name']}: {output}"
"agent_name": agent_config["agent_name"], )
"task": task, task_results.append(
"output": output {
}) "agent_name": agent_config["agent_name"],
"task": task,
"output": output,
}
)
except Exception as e: except Exception as e:
logger.error(f"Error running task for agent {agent_config['agent_name']}: {e}") logger.error(
task_results.append({ f"Error running task for agent {agent_config['agent_name']}: {e}"
"agent_name": agent_config["agent_name"], )
"task": task, task_results.append(
"error": str(e) {
}) "agent_name": agent_config["agent_name"],
"task": task,
"error": str(e),
}
)
# Return results based on the `return_type` # Return results based on the `return_type`
if return_type == "agents": if return_type == "agents":
@ -131,6 +178,7 @@ def create_agents_from_yaml(yaml_file: str, return_type: str = "agents", *args,
logger.error(f"Invalid return_type: {return_type}") logger.error(f"Invalid return_type: {return_type}")
raise ValueError(f"Invalid return_type: {return_type}") raise ValueError(f"Invalid return_type: {return_type}")
# # Usage example # # Usage example
# yaml_file = 'agents_config.yaml' # yaml_file = 'agents_config.yaml'

@ -523,11 +523,14 @@ class Agent:
# Telemetry Processor to log agent data # Telemetry Processor to log agent data
threading.Thread(target=self.log_agent_data).start() threading.Thread(target=self.log_agent_data).start()
if load_yaml_path is not None: if load_yaml_path is not None:
from swarms.agents.create_agents_from_yaml import create_agents_from_yaml from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
)
create_agents_from_yaml(load_yaml_path, return_type="tasks") create_agents_from_yaml(
load_yaml_path, return_type="tasks"
)
def set_system_prompt(self, system_prompt: str): def set_system_prompt(self, system_prompt: str):
"""Set the system prompt""" """Set the system prompt"""

@ -0,0 +1,362 @@
import os
import uuid
from collections import Counter
from datetime import datetime
from typing import Any, List, Optional
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer, util
from swarm_models import OpenAIChat
from swarms import Agent
load_dotenv()
# Get the OpenAI API key from the environment variable
api_key = os.getenv("OPENAI_API_KEY")
# # Set up loguru to log into a file and console
# logger.add(
# "multi_agent_log_{time}.log",
# format="{time} {level} {message}",
# level="DEBUG",
# rotation="10 MB",
# )
# Pretrained model for embeddings
embedding_model = SentenceTransformer(
"all-MiniLM-L6-v2"
) # A small, fast model for embedding
# Create an instance of the OpenAIChat class
model = OpenAIChat(
openai_api_key=api_key,
model_name="gpt-4o-mini",
temperature=0.1,
max_tokens=2000,
)
# Pydantic Models for Logging
class AgentLogInput(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
task: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
class AgentLogOutput(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
result: Any
timestamp: datetime = Field(default_factory=datetime.utcnow)
class TreeLog(BaseModel):
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
tree_name: str
task: str
selected_agent: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
result: Any
def extract_keywords(prompt: str, top_n: int = 5) -> List[str]:
"""
A simplified keyword extraction function using basic word splitting instead of NLTK tokenization.
"""
words = prompt.lower().split()
filtered_words = [word for word in words if word.isalnum()]
word_counts = Counter(filtered_words)
return [word for word, _ in word_counts.most_common(top_n)]
class TreeAgent(Agent):
"""
A specialized Agent class that contains information about the system prompt's
locality and allows for dynamic chaining of agents in trees.
"""
def __init__(
self,
system_prompt: str = None,
llm: callable = model,
agent_name: Optional[str] = None,
*args,
**kwargs,
):
agent_name = agent_name
super().__init__(
system_prompt=system_prompt,
llm=llm,
agent_name=agent_name,
*args,
**kwargs,
)
self.system_prompt_embedding = embedding_model.encode(
system_prompt, convert_to_tensor=True
)
# Automatically extract keywords from system prompt
self.relevant_keywords = extract_keywords(system_prompt)
# Distance is now calculated based on similarity between agents' prompts
self.distance = None # Will be dynamically calculated later
def calculate_distance(self, other_agent: "TreeAgent") -> float:
"""
Calculate the distance between this agent and another agent using embedding similarity.
Args:
other_agent (TreeAgent): Another agent in the tree.
Returns:
float: Distance score between 0 and 1, with 0 being close and 1 being far.
"""
similarity = util.pytorch_cos_sim(
self.system_prompt_embedding,
other_agent.system_prompt_embedding,
).item()
distance = (
1 - similarity
) # Closer agents have a smaller distance
return distance
def run_task(self, task: str) -> Any:
input_log = AgentLogInput(
agent_name=self.agent_name,
task=task,
timestamp=datetime.now(),
)
logger.info(f"Running task on {self.agent_name}: {task}")
logger.debug(f"Input Log: {input_log.json()}")
result = self.run(task)
output_log = AgentLogOutput(
agent_name=self.agent_name,
result=result,
timestamp=datetime.now(),
)
logger.info(f"Task result from {self.agent_name}: {result}")
logger.debug(f"Output Log: {output_log.json()}")
return result
def is_relevant_for_task(
self, task: str, threshold: float = 0.7
) -> bool:
"""
Checks if the agent is relevant for the given task using both keyword matching and embedding similarity.
Args:
task (str): The task to be executed.
threshold (float): The cosine similarity threshold for embedding-based matching.
Returns:
bool: True if the agent is relevant, False otherwise.
"""
# Check if any of the relevant keywords are present in the task (case-insensitive)
keyword_match = any(
keyword.lower() in task.lower()
for keyword in self.relevant_keywords
)
# Perform embedding similarity match if keyword match is not found
if not keyword_match:
task_embedding = embedding_model.encode(
task, convert_to_tensor=True
)
similarity = util.pytorch_cos_sim(
self.system_prompt_embedding, task_embedding
).item()
logger.info(
f"Semantic similarity between task and {self.agent_name}: {similarity:.2f}"
)
return similarity >= threshold
return True # Return True if keyword match is found
class Tree:
def __init__(self, tree_name: str, agents: List[TreeAgent]):
"""
Initializes a tree of agents.
Args:
tree_name (str): The name of the tree.
agents (List[TreeAgent]): A list of agents in the tree.
"""
self.tree_name = tree_name
self.agents = agents
self.calculate_agent_distances()
def calculate_agent_distances(self):
"""
Automatically calculate and assign distances between agents in the tree based on prompt similarity.
"""
logger.info(
f"Calculating distances between agents in tree '{self.tree_name}'"
)
for i, agent in enumerate(self.agents):
if i > 0:
agent.distance = agent.calculate_distance(
self.agents[i - 1]
)
else:
agent.distance = 0 # First agent is closest
# Sort agents by distance after calculation
self.agents.sort(key=lambda agent: agent.distance)
def find_relevant_agent(self, task: str) -> Optional[TreeAgent]:
"""
Finds the most relevant agent in the tree for the given task based on its system prompt.
Uses both keyword and semantic similarity matching.
Args:
task (str): The task or query for which we need to find a relevant agent.
Returns:
Optional[TreeAgent]: The most relevant agent, or None if no match found.
"""
logger.info(
f"Searching relevant agent in tree '{self.tree_name}' for task: {task}"
)
for agent in self.agents:
if agent.is_relevant_for_task(task):
return agent
logger.warning(
f"No relevant agent found in tree '{self.tree_name}' for task: {task}"
)
return None
def log_tree_execution(
self, task: str, selected_agent: TreeAgent, result: Any
) -> None:
"""
Logs the execution details of a tree, including selected agent and result.
"""
tree_log = TreeLog(
tree_name=self.tree_name,
task=task,
selected_agent=selected_agent.agent_name,
timestamp=datetime.now(),
result=result,
)
logger.info(
f"Tree '{self.tree_name}' executed task with agent '{selected_agent.agent_name}'"
)
logger.debug(f"Tree Log: {tree_log.json()}")
class ForestSwarm:
def __init__(self, trees: List[Tree], *args, **kwargs):
"""
Initializes the structure with multiple trees of agents.
Args:
trees (List[Tree]): A list of trees in the structure.
"""
self.trees = trees
# Add auto grouping based on trees.
# Add auto group agents
def find_relevant_tree(self, task: str) -> Optional[Tree]:
"""
Finds the most relevant tree based on the given task.
Args:
task (str): The task or query for which we need to find a relevant tree.
Returns:
Optional[Tree]: The most relevant tree, or None if no match found.
"""
logger.info(
f"Searching for the most relevant tree for task: {task}"
)
for tree in self.trees:
if tree.find_relevant_agent(task):
return tree
logger.warning(f"No relevant tree found for task: {task}")
return None
def run(self, task: str) -> Any:
"""
Executes the given task by finding the most relevant tree and agent within that tree.
Args:
task (str): The task or query to be executed.
Returns:
Any: The result of the task after it has been processed by the agents.
"""
logger.info(
f"Running task across MultiAgentTreeStructure: {task}"
)
relevant_tree = self.find_relevant_tree(task)
if relevant_tree:
agent = relevant_tree.find_relevant_agent(task)
if agent:
result = agent.run_task(task)
relevant_tree.log_tree_execution(task, agent, result)
return result
else:
logger.error(
"Task could not be completed: No relevant agent or tree found."
)
return "No relevant agent found to handle this task."
# # Example Usage:
# # Create agents with varying system prompts and dynamically generated distances/keywords
# agents_tree1 = [
# TreeAgent(
# system_prompt="Stock Analysis Agent",
# agent_name="Stock Analysis Agent",
# ),
# TreeAgent(
# system_prompt="Financial Planning Agent",
# agent_name="Financial Planning Agent",
# ),
# TreeAgent(
# agent_name="Retirement Strategy Agent",
# system_prompt="Retirement Strategy Agent",
# ),
# ]
# agents_tree2 = [
# TreeAgent(
# system_prompt="Tax Filing Agent",
# agent_name="Tax Filing Agent",
# ),
# TreeAgent(
# system_prompt="Investment Strategy Agent",
# agent_name="Investment Strategy Agent",
# ),
# TreeAgent(
# system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
# ),
# ]
# # Create trees
# tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
# tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
# # Create the ForestSwarm
# multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
# # Run a task
# task = "Our company is incorporated in delaware, how do we do our taxes for free?"
# output = multi_agent_structure.run(task)
# print(output)

@ -3,11 +3,12 @@ from unittest.mock import patch
from swarms import create_agents_from_yaml from swarms import create_agents_from_yaml
import os import os
class TestCreateAgentsFromYaml(unittest.TestCase): class TestCreateAgentsFromYaml(unittest.TestCase):
def setUp(self): def setUp(self):
# Mock the environment variable for API key # Mock the environment variable for API key
os.environ['OPENAI_API_KEY'] = 'fake-api-key' os.environ["OPENAI_API_KEY"] = "fake-api-key"
# Mock agent configuration YAML content # Mock agent configuration YAML content
self.valid_yaml_content = """ self.valid_yaml_content = """
@ -53,9 +54,15 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
task: "What is the best strategy for long-term stock investment?" task: "What is the best strategy for long-term stock investment?"
""" """
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="") @patch(
@patch('yaml.safe_load') "builtins.open",
def test_create_agents_return_agents(self, mock_safe_load, mock_open): new_callable=unittest.mock.mock_open,
read_data="",
)
@patch("yaml.safe_load")
def test_create_agents_return_agents(
self, mock_safe_load, mock_open
):
# Mock YAML content parsing # Mock YAML content parsing
mock_safe_load.return_value = { mock_safe_load.return_value = {
"agents": [ "agents": [
@ -65,7 +72,7 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"openai_api_key": "fake-api-key", "openai_api_key": "fake-api-key",
"model_name": "gpt-4o-mini", "model_name": "gpt-4o-mini",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000 "max_tokens": 2000,
}, },
"system_prompt": "financial_agent_sys_prompt", "system_prompt": "financial_agent_sys_prompt",
"max_loops": 1, "max_loops": 1,
@ -79,20 +86,32 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "str", "output_type": "str",
"task": "How can I establish a ROTH IRA to buy stocks and get a tax break?" "task": "How can I establish a ROTH IRA to buy stocks and get a tax break?",
} }
] ]
} }
# Test if agents are returned correctly # Test if agents are returned correctly
agents = create_agents_from_yaml('fake_yaml_path.yaml', return_type="agents") agents = create_agents_from_yaml(
"fake_yaml_path.yaml", return_type="agents"
)
self.assertEqual(len(agents), 1) self.assertEqual(len(agents), 1)
self.assertEqual(agents[0].agent_name, "Financial-Analysis-Agent") self.assertEqual(
agents[0].agent_name, "Financial-Analysis-Agent"
)
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="") @patch(
@patch('yaml.safe_load') "builtins.open",
@patch('swarms.Agent.run', return_value="Task completed successfully") new_callable=unittest.mock.mock_open,
def test_create_agents_return_tasks(self, mock_agent_run, mock_safe_load, mock_open): read_data="",
)
@patch("yaml.safe_load")
@patch(
"swarms.Agent.run", return_value="Task completed successfully"
)
def test_create_agents_return_tasks(
self, mock_agent_run, mock_safe_load, mock_open
):
# Mock YAML content parsing # Mock YAML content parsing
mock_safe_load.return_value = { mock_safe_load.return_value = {
"agents": [ "agents": [
@ -102,7 +121,7 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"openai_api_key": "fake-api-key", "openai_api_key": "fake-api-key",
"model_name": "gpt-4o-mini", "model_name": "gpt-4o-mini",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000 "max_tokens": 2000,
}, },
"system_prompt": "financial_agent_sys_prompt", "system_prompt": "financial_agent_sys_prompt",
"max_loops": 1, "max_loops": 1,
@ -116,20 +135,30 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "str", "output_type": "str",
"task": "How can I establish a ROTH IRA to buy stocks and get a tax break?" "task": "How can I establish a ROTH IRA to buy stocks and get a tax break?",
} }
] ]
} }
# Test if tasks are executed and results are returned # Test if tasks are executed and results are returned
task_results = create_agents_from_yaml('fake_yaml_path.yaml', return_type="tasks") task_results = create_agents_from_yaml(
"fake_yaml_path.yaml", return_type="tasks"
)
self.assertEqual(len(task_results), 1) self.assertEqual(len(task_results), 1)
self.assertEqual(task_results[0]['agent_name'], "Financial-Analysis-Agent") self.assertEqual(
self.assertIsNotNone(task_results[0]['output']) task_results[0]["agent_name"], "Financial-Analysis-Agent"
)
self.assertIsNotNone(task_results[0]["output"])
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="") @patch(
@patch('yaml.safe_load') "builtins.open",
def test_create_agents_return_both(self, mock_safe_load, mock_open): new_callable=unittest.mock.mock_open,
read_data="",
)
@patch("yaml.safe_load")
def test_create_agents_return_both(
self, mock_safe_load, mock_open
):
# Mock YAML content parsing # Mock YAML content parsing
mock_safe_load.return_value = { mock_safe_load.return_value = {
"agents": [ "agents": [
@ -139,7 +168,7 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"openai_api_key": "fake-api-key", "openai_api_key": "fake-api-key",
"model_name": "gpt-4o-mini", "model_name": "gpt-4o-mini",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000 "max_tokens": 2000,
}, },
"system_prompt": "financial_agent_sys_prompt", "system_prompt": "financial_agent_sys_prompt",
"max_loops": 1, "max_loops": 1,
@ -153,31 +182,48 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "str", "output_type": "str",
"task": "How can I establish a ROTH IRA to buy stocks and get a tax break?" "task": "How can I establish a ROTH IRA to buy stocks and get a tax break?",
} }
] ]
} }
# Test if both agents and tasks are returned # Test if both agents and tasks are returned
agents, task_results = create_agents_from_yaml('fake_yaml_path.yaml', return_type="both") agents, task_results = create_agents_from_yaml(
"fake_yaml_path.yaml", return_type="both"
)
self.assertEqual(len(agents), 1) self.assertEqual(len(agents), 1)
self.assertEqual(len(task_results), 1) self.assertEqual(len(task_results), 1)
self.assertEqual(agents[0].agent_name, "Financial-Analysis-Agent") self.assertEqual(
self.assertIsNotNone(task_results[0]['output']) agents[0].agent_name, "Financial-Analysis-Agent"
)
self.assertIsNotNone(task_results[0]["output"])
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="") @patch(
@patch('yaml.safe_load') "builtins.open",
new_callable=unittest.mock.mock_open,
read_data="",
)
@patch("yaml.safe_load")
def test_missing_agents_in_yaml(self, mock_safe_load, mock_open): def test_missing_agents_in_yaml(self, mock_safe_load, mock_open):
# Mock YAML content with missing "agents" key # Mock YAML content with missing "agents" key
mock_safe_load.return_value = {} mock_safe_load.return_value = {}
# Test if the function raises an error for missing "agents" key # Test if the function raises an error for missing "agents" key
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
create_agents_from_yaml('fake_yaml_path.yaml', return_type="agents") create_agents_from_yaml(
self.assertTrue("The YAML configuration does not contain 'agents'." in str(context.exception)) "fake_yaml_path.yaml", return_type="agents"
)
self.assertTrue(
"The YAML configuration does not contain 'agents'."
in str(context.exception)
)
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="") @patch(
@patch('yaml.safe_load') "builtins.open",
new_callable=unittest.mock.mock_open,
read_data="",
)
@patch("yaml.safe_load")
def test_invalid_return_type(self, mock_safe_load, mock_open): def test_invalid_return_type(self, mock_safe_load, mock_open):
# Mock YAML content parsing # Mock YAML content parsing
mock_safe_load.return_value = { mock_safe_load.return_value = {
@ -188,7 +234,7 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"openai_api_key": "fake-api-key", "openai_api_key": "fake-api-key",
"model_name": "gpt-4o-mini", "model_name": "gpt-4o-mini",
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 2000 "max_tokens": 2000,
}, },
"system_prompt": "financial_agent_sys_prompt", "system_prompt": "financial_agent_sys_prompt",
"max_loops": 1, "max_loops": 1,
@ -202,15 +248,20 @@ class TestCreateAgentsFromYaml(unittest.TestCase):
"context_length": 200000, "context_length": 200000,
"return_step_meta": False, "return_step_meta": False,
"output_type": "str", "output_type": "str",
"task": "How can I establish a ROTH IRA to buy stocks and get a tax break?" "task": "How can I establish a ROTH IRA to buy stocks and get a tax break?",
} }
] ]
} }
# Test if an error is raised for invalid return_type # Test if an error is raised for invalid return_type
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
create_agents_from_yaml('fake_yaml_path.yaml', return_type="invalid_type") create_agents_from_yaml(
self.assertTrue("Invalid return_type" in str(context.exception)) "fake_yaml_path.yaml", return_type="invalid_type"
)
self.assertTrue(
"Invalid return_type" in str(context.exception)
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save