[IMPROVEMENT][Remove sentence transformers from forest swarm and instead use litellm] [UPDATE][DOCs]

pull/1054/head
Kye Gomez 3 days ago
parent 438029dbe5
commit 5c79016afa

@ -1,17 +1,40 @@
# Forest Swarm
This documentation describes the **ForestSwarm** that organizes agents into trees. Each agent specializes in processing specific tasks. Trees are collections of agents, each assigned based on their relevance to a task through keyword extraction and embedding-based similarity.
The architecture allows for efficient task assignment by selecting the most relevant agent from a set of trees. Tasks are processed asynchronously, with agents selected based on task relevance, calculated by the similarity of system prompts and task keywords.
This documentation describes the **ForestSwarm** that organizes agents into trees. Each agent specializes in processing specific tasks. Trees are collections of agents, each assigned based on their relevance to a task through keyword extraction and **litellm-based embedding similarity**.
The architecture allows for efficient task assignment by selecting the most relevant agent from a set of trees. Tasks are processed asynchronously, with agents selected based on task relevance, calculated by the similarity of system prompts and task keywords using **litellm embeddings** and cosine similarity calculations.
## Module Path: `swarms.structs.tree_swarm`
---
### Utility Functions
#### `extract_keywords(prompt: str, top_n: int = 5) -> List[str]`
Extracts relevant keywords from a text prompt using basic word splitting and frequency counting.
**Parameters:**
- `prompt` (str): The text to extract keywords from
- `top_n` (int): Maximum number of keywords to return
**Returns:**
- `List[str]`: List of extracted keywords sorted by frequency
#### `cosine_similarity(vec1: List[float], vec2: List[float]) -> float`
Calculates the cosine similarity between two embedding vectors.
**Parameters:**
- `vec1` (List[float]): First embedding vector
- `vec2` (List[float]): Second embedding vector
**Returns:**
- `float`: Cosine similarity score between -1 and 1
---
### Class: `TreeAgent`
`TreeAgent` represents an individual agent responsible for handling a specific task. Agents are initialized with a **system prompt** and are responsible for dynamically determining their relevance to a given task.
`TreeAgent` represents an individual agent responsible for handling a specific task. Agents are initialized with a **system prompt** and use **litellm embeddings** to dynamically determine their relevance to a given task.
#### Attributes
@ -20,136 +43,205 @@ The architecture allows for efficient task assignment by selecting the most rele
| `system_prompt` | `str` | A string that defines the agent's area of expertise and task-handling capability.|
| `llm` | `callable` | The language model (LLM) used to process tasks (e.g., GPT-4). |
| `agent_name` | `str` | The name of the agent. |
| `system_prompt_embedding`| `tensor` | Embedding of the system prompt for similarity-based task matching. |
| `system_prompt_embedding`| `List[float]` | **litellm-generated embedding** of the system prompt for similarity-based task matching.|
| `relevant_keywords` | `List[str]` | Keywords dynamically extracted from the system prompt to assist in task matching.|
| `distance` | `Optional[float]`| The computed distance between agents based on embedding similarity. |
| `embedding_model_name` | `str` | **Name of the litellm embedding model** (default: "text-embedding-ada-002"). |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `calculate_distance(other_agent: TreeAgent)` | `other_agent: TreeAgent` | `float` | Calculates the cosine similarity between this agent and another agent. |
| `_get_embedding(text: str)` | `text: str` | `List[float]` | **Internal method to generate embeddings using litellm.** |
| `calculate_distance(other_agent: TreeAgent)` | `other_agent: TreeAgent` | `float` | Calculates the **cosine similarity distance** between this agent and another agent.|
| `run_task(task: str)` | `task: str` | `Any` | Executes the task, logs the input/output, and returns the result. |
| `is_relevant_for_task(task: str, threshold: float = 0.7)` | `task: str, threshold: float` | `bool` | Checks if the agent is relevant for the task using keyword matching or embedding similarity.|
| `is_relevant_for_task(task: str, threshold: float = 0.7)` | `task: str, threshold: float` | `bool` | Checks if the agent is relevant for the task using **keyword matching and litellm embedding similarity**.|
---
### Class: `Tree`
`Tree` organizes multiple agents into a hierarchical structure, where agents are sorted based on their relevance to tasks.
`Tree` organizes multiple agents into a hierarchical structure, where agents are sorted based on their relevance to tasks using **litellm embeddings**.
#### Attributes
| **Attribute** | **Type** | **Description** |
|--------------------------|------------------|---------------------------------------------------------------------------------|
| `tree_name` | `str` | The name of the tree (represents a domain of agents, e.g., "Financial Tree"). |
| `agents` | `List[TreeAgent]`| List of agents belonging to this tree. |
| `agents` | `List[TreeAgent]`| List of agents belonging to this tree, **sorted by embedding-based distance**. |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `calculate_agent_distances()` | `None` | `None` | Calculates and assigns distances between agents based on similarity of prompts. |
| `find_relevant_agent(task: str)` | `task: str` | `Optional[TreeAgent]` | Finds the most relevant agent for a task based on keyword and embedding similarity. |
| `calculate_agent_distances()` | `None` | `None` | **Calculates and assigns distances between agents based on litellm embedding similarity of prompts.** |
| `find_relevant_agent(task: str)` | `task: str` | `Optional[TreeAgent]` | **Finds the most relevant agent for a task based on keyword and litellm embedding similarity.** |
| `log_tree_execution(task: str, selected_agent: TreeAgent, result: Any)` | `task: str, selected_agent: TreeAgent, result: Any` | `None` | Logs details of the task execution by the selected agent. |
---
### Class: `ForestSwarm`
`ForestSwarm` is the main class responsible for managing multiple trees. It oversees task delegation by finding the most relevant tree and agent for a given task.
`ForestSwarm` is the main class responsible for managing multiple trees. It oversees task delegation by finding the most relevant tree and agent for a given task using **litellm embeddings**.
#### Attributes
| **Attribute** | **Type** | **Description** |
|--------------------------|------------------|---------------------------------------------------------------------------------|
| `name` | `str` | Name of the forest swarm. |
| `description` | `str` | Description of the forest swarm. |
| `trees` | `List[Tree]` | List of trees containing agents organized by domain. |
| `shared_memory` | `Any` | Shared memory object for inter-tree communication. |
| `rules` | `str` | Rules governing the forest swarm behavior. |
| `conversation` | `Conversation` | Conversation object for tracking interactions. |
#### Methods
| **Method** | **Input** | **Output** | **Description** |
|--------------------|---------------------------------|--------------------|---------------------------------------------------------------------------------|
| `find_relevant_tree(task: str)` | `task: str` | `Optional[Tree]` | Searches across all trees to find the most relevant tree based on task requirements.|
| `run(task: str)` | `task: str` | `Any` | Executes the task by finding the most relevant agent from the relevant tree. |
| `find_relevant_tree(task: str)` | `task: str` | `Optional[Tree]` | **Searches across all trees to find the most relevant tree based on litellm embedding similarity.**|
| `run(task: str, img: str = None, *args, **kwargs)` | `task: str, img: str, *args, **kwargs` | `Any` | **Executes the task by finding the most relevant agent from the relevant tree using litellm embeddings.**|
---
### Pydantic Models for Logging
#### `AgentLogInput`
Input log model for tracking agent task execution.
**Fields:**
- `log_id` (str): Unique identifier for the log entry
- `agent_name` (str): Name of the agent executing the task
- `task` (str): Description of the task being executed
- `timestamp` (datetime): When the task was started
#### `AgentLogOutput`
Output log model for tracking agent task completion.
**Fields:**
- `log_id` (str): Unique identifier for the log entry
- `agent_name` (str): Name of the agent that completed the task
- `result` (Any): Result/output from the task execution
- `timestamp` (datetime): When the task was completed
#### `TreeLog`
Tree execution log model for tracking tree-level operations.
**Fields:**
- `log_id` (str): Unique identifier for the log entry
- `tree_name` (str): Name of the tree that executed the task
- `task` (str): Description of the task that was executed
- `selected_agent` (str): Name of the agent selected for the task
- `timestamp` (datetime): When the task was executed
- `result` (Any): Result/output from the task execution
---
## Full Code Example
```python
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
# Example Usage:
# Create agents with varying system prompts and dynamically generated distances/keywords
agents_tree1 = [
TreeAgent(
system_prompt="Stock Analysis Agent",
agent_name="Stock Analysis Agent",
system_prompt="I am a financial advisor specializing in investment planning, retirement strategies, and tax optimization for individuals and businesses.",
agent_name="Financial Advisor",
),
TreeAgent(
system_prompt="Financial Planning Agent",
agent_name="Financial Planning Agent",
system_prompt="I am a tax expert with deep knowledge of corporate taxation, Delaware incorporation benefits, and free tax filing options for businesses.",
agent_name="Tax Expert",
),
TreeAgent(
agent_name="Retirement Strategy Agent",
system_prompt="Retirement Strategy Agent",
system_prompt="I am a retirement planning specialist who helps individuals and businesses create comprehensive retirement strategies and investment plans.",
agent_name="Retirement Planner",
),
]
agents_tree2 = [
TreeAgent(
system_prompt="Tax Filing Agent",
agent_name="Tax Filing Agent",
system_prompt="I am a stock market analyst who provides insights on market trends, stock recommendations, and portfolio optimization strategies.",
agent_name="Stock Analyst",
),
TreeAgent(
system_prompt="Investment Strategy Agent",
agent_name="Investment Strategy Agent",
system_prompt="I am an investment strategist specializing in portfolio diversification, risk management, and market analysis.",
agent_name="Investment Strategist",
),
TreeAgent(
system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
system_prompt="I am a ROTH IRA specialist who helps individuals optimize their retirement accounts and tax advantages.",
agent_name="ROTH IRA Specialist",
),
]
# Create trees
tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
tree1 = Tree(tree_name="Financial Services Tree", agents=agents_tree1)
tree2 = Tree(tree_name="Investment & Trading Tree", agents=agents_tree2)
# Create the ForestSwarm
multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
forest_swarm = ForestSwarm(
name="Financial Services Forest",
description="A comprehensive financial services multi-agent system",
trees=[tree1, tree2]
)
# Run a task
task = "Our company is incorporated in delaware, how do we do our taxes for free?"
output = multi_agent_structure.run(task)
task = "Our company is incorporated in Delaware, how do we do our taxes for free?"
output = forest_swarm.run(task)
print(output)
```
---
## Example Workflow
1. **Create Agents**: Agents are initialized with varying system prompts, representing different areas of expertise (e.g., stock analysis, tax filing).
2. **Create Trees**: Agents are grouped into trees, with each tree representing a domain (e.g., "Financial Tree", "Investment Tree").
3. **Run Task**: When a task is submitted, the system traverses through all trees and finds the most relevant agent to handle the task.
4. **Task Execution**: The selected agent processes the task, and the result is returned.
1. **Create Agents**: Agents are initialized with varying system prompts, representing different areas of expertise (e.g., financial planning, tax filing).
2. **Generate Embeddings**: Each agent's system prompt is converted to **litellm embeddings** for semantic similarity calculations.
3. **Create Trees**: Agents are grouped into trees, with each tree representing a domain (e.g., "Financial Services Tree", "Investment & Trading Tree").
4. **Calculate Distances**: **litellm embeddings** are used to calculate semantic distances between agents within each tree.
5. **Run Task**: When a task is submitted, the system:
- Generates **litellm embeddings** for the task
- Searches through all trees using **cosine similarity**
- Finds the most relevant agent based on **embedding similarity and keyword matching**
6. **Task Execution**: The selected agent processes the task, and the result is returned and logged.
```plaintext
Task: "Our company is incorporated in Delaware, how do we do our taxes for free?"
```
**Process**:
- The system searches through the `Financial Tree` and `Investment Tree`.
- The most relevant agent (likely the "Tax Filing Agent") is selected based on keyword matching and prompt similarity.
- The task is processed, and the result is logged and returned.
- The system generates **litellm embeddings** for the task
- Searches through the `Financial Services Tree` and `Investment & Trading Tree`
- Uses **cosine similarity** to find the most relevant agent (likely the "Tax Expert")
- The task is processed, and the result is logged and returned
---
## Key Features
### **litellm Integration**
- **Embedding Generation**: Uses litellm's `embedding()` function for generating high-quality embeddings
- **Model Flexibility**: Supports various embedding models (default: "text-embedding-ada-002")
- **Error Handling**: Robust fallback mechanisms for embedding failures
### **Semantic Similarity**
- **Cosine Similarity**: Implements efficient cosine similarity calculations for vector comparisons
- **Threshold-based Selection**: Configurable similarity thresholds for agent selection
- **Hybrid Matching**: Combines keyword matching with semantic similarity for optimal results
### **Dynamic Agent Organization**
- **Automatic Distance Calculation**: Agents are automatically organized by semantic similarity
- **Real-time Relevance**: Task relevance is calculated dynamically using current embeddings
- **Scalable Architecture**: Easy to add/remove agents and trees without manual configuration
---
## Analysis of the Swarm Architecture
The **Swarm Architecture** leverages a hierarchical structure (forest) composed of individual trees, each containing agents specialized in specific domains. This design allows for:
The **ForestSwarm Architecture** leverages a hierarchical structure (forest) composed of individual trees, each containing agents specialized in specific domains. This design allows for:
- **Modular and Scalable Organization**: By separating agents into trees, it is easy to expand or contract the system by adding or removing trees or agents.
- **Task Specialization**: Each agent is specialized, which ensures that tasks are matched with the most appropriate agent based on relevance and expertise.
- **Dynamic Matching**: The architecture uses both keyword-based and embedding-based matching to assign tasks, ensuring a high level of accuracy in agent selection.
- **Task Specialization**: Each agent is specialized, which ensures that tasks are matched with the most appropriate agent based on **litellm embedding similarity** and expertise.
- **Dynamic Matching**: The architecture uses both keyword-based and **litellm embedding-based matching** to assign tasks, ensuring a high level of accuracy in agent selection.
- **Logging and Accountability**: Each task execution is logged in detail, providing transparency and an audit trail of which agent handled which task and the results produced.
- **Asynchronous Task Execution**: The architecture can be adapted for asynchronous task processing, making it scalable and suitable for large-scale task handling in real-time systems.
@ -159,35 +251,65 @@ The **Swarm Architecture** leverages a hierarchical structure (forest) composed
```mermaid
graph TD
A[ForestSwarm] --> B[Financial Tree]
A --> C[Investment Tree]
A[ForestSwarm] --> B[Financial Services Tree]
A --> C[Investment & Trading Tree]
B --> D[Stock Analysis Agent]
B --> E[Financial Planning Agent]
B --> F[Retirement Strategy Agent]
B --> D[Financial Advisor]
B --> E[Tax Expert]
B --> F[Retirement Planner]
C --> G[Tax Filing Agent]
C --> H[Investment Strategy Agent]
C --> I[ROTH IRA Agent]
subgraph Tree Agents
D[Stock Analysis Agent]
E[Financial Planning Agent]
F[Retirement Strategy Agent]
G[Tax Filing Agent]
H[Investment Strategy Agent]
I[ROTH IRA Agent]
C --> G[Stock Analyst]
C --> H[Investment Strategist]
C --> I[ROTH IRA Specialist]
subgraph Embedding Process
J[litellm Embeddings] --> K[Cosine Similarity]
K --> L[Agent Selection]
end
subgraph Task Processing
M[Task Input] --> N[Generate Task Embedding]
N --> O[Find Relevant Tree]
O --> P[Find Relevant Agent]
P --> Q[Execute Task]
Q --> R[Log Results]
end
```
### Explanation of the Diagram
- **ForestSwarm**: Represents the top-level structure managing multiple trees.
- **Trees**: In the example, two trees exist—**Financial Tree** and **Investment Tree**—each containing agents related to specific domains.
- **Agents**: Each agent within the tree is responsible for handling tasks in its area of expertise. Agents within a tree are organized based on their prompt similarity (distance).
- **Trees**: In the example, two trees exist—**Financial Services Tree** and **Investment & Trading Tree**—each containing agents related to specific domains.
- **Agents**: Each agent within the tree is responsible for handling tasks in its area of expertise. Agents within a tree are organized based on their **litellm embedding similarity** (distance).
- **Embedding Process**: Shows how **litellm embeddings** are used for similarity calculations and agent selection.
- **Task Processing**: Illustrates the complete workflow from task input to result logging.
---
## Testing
The ForestSwarm implementation includes comprehensive unit tests that can be run independently:
```bash
python test_forest_swarm.py
```
The test suite covers:
- **Utility Functions**: `extract_keywords`, `cosine_similarity`
- **Pydantic Models**: `AgentLogInput`, `AgentLogOutput`, `TreeLog`
- **Core Classes**: `TreeAgent`, `Tree`, `ForestSwarm`
- **Edge Cases**: Error handling, empty inputs, null values
- **Integration**: End-to-end task execution workflows
---
### Summary
This **Multi-Agent Tree Structure** provides an efficient, scalable, and accurate architecture for delegating and executing tasks based on domain-specific expertise. The combination of hierarchical organization, dynamic task matching, and logging ensures reliability, performance, and transparency in task execution.
This **ForestSwarm Architecture** provides an efficient, scalable, and accurate architecture for delegating and executing tasks based on domain-specific expertise. The combination of hierarchical organization, **litellm-based semantic similarity**, dynamic task matching, and comprehensive logging ensures reliability, performance, and transparency in task execution.
**Key Advantages:**
- **High Accuracy**: litellm embeddings provide superior semantic understanding
- **Scalability**: Easy to add new agents, trees, and domains
- **Flexibility**: Configurable similarity thresholds and embedding models
- **Robustness**: Comprehensive error handling and fallback mechanisms
- **Transparency**: Detailed logging and audit trails for all operations

@ -0,0 +1,279 @@
#!/usr/bin/env python3
"""
ForestSwarm Example Script
This script demonstrates the ForestSwarm functionality with realistic examples
of financial services and investment management agents.
"""
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
def create_financial_services_forest():
"""Create a comprehensive financial services forest with multiple specialized agents."""
print("🌳 Creating Financial Services Forest...")
# Financial Services Tree - Personal Finance & Planning
financial_agents = [
TreeAgent(
system_prompt="""I am a certified financial planner specializing in personal finance,
budgeting, debt management, and financial goal setting. I help individuals create
comprehensive financial plans and make informed decisions about their money.""",
agent_name="Personal Financial Planner",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a tax preparation specialist with expertise in individual and
small business tax returns. I help clients maximize deductions, understand tax laws,
and file taxes accurately and on time.""",
agent_name="Tax Preparation Specialist",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a retirement planning expert who helps individuals and families
plan for retirement. I specialize in 401(k)s, IRAs, Social Security optimization,
and creating sustainable retirement income strategies.""",
agent_name="Retirement Planning Expert",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a debt management counselor who helps individuals and families
get out of debt and build financial stability. I provide strategies for debt
consolidation, negotiation, and creating sustainable repayment plans.""",
agent_name="Debt Management Counselor",
model_name="gpt-4o",
),
]
# Investment & Trading Tree - Market Analysis & Portfolio Management
investment_agents = [
TreeAgent(
system_prompt="""I am a stock market analyst who provides insights on market trends,
stock recommendations, and portfolio optimization strategies. I analyze company
fundamentals, market conditions, and economic indicators to help investors make
informed decisions.""",
agent_name="Stock Market Analyst",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am an investment strategist specializing in portfolio diversification,
risk management, and asset allocation. I help investors create balanced portfolios
that align with their risk tolerance and financial goals.""",
agent_name="Investment Strategist",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a cryptocurrency and blockchain expert who provides insights on
digital assets, DeFi protocols, and emerging blockchain technologies. I help
investors understand the risks and opportunities in the crypto market.""",
agent_name="Cryptocurrency Expert",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a real estate investment advisor who helps investors evaluate
real estate opportunities, understand market trends, and build real estate
portfolios for long-term wealth building.""",
agent_name="Real Estate Investment Advisor",
model_name="gpt-4o",
),
]
# Business & Corporate Tree - Business Finance & Strategy
business_agents = [
TreeAgent(
system_prompt="""I am a business financial advisor specializing in corporate finance,
business valuation, mergers and acquisitions, and strategic financial planning
for small to medium-sized businesses.""",
agent_name="Business Financial Advisor",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a Delaware incorporation specialist with deep knowledge of
corporate formation, tax benefits, legal requirements, and ongoing compliance
for businesses incorporating in Delaware.""",
agent_name="Delaware Incorporation Specialist",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a startup funding advisor who helps entrepreneurs secure
funding through venture capital, angel investors, crowdfunding, and other
financing options. I provide guidance on business plans, pitch decks, and
investor relations.""",
agent_name="Startup Funding Advisor",
model_name="gpt-4o",
),
TreeAgent(
system_prompt="""I am a business tax strategist who helps businesses optimize their
tax position through strategic planning, entity structure optimization, and
compliance with federal, state, and local tax laws.""",
agent_name="Business Tax Strategist",
model_name="gpt-4o",
),
]
# Create trees
financial_tree = Tree(
"Personal Finance & Planning", financial_agents
)
investment_tree = Tree("Investment & Trading", investment_agents)
business_tree = Tree(
"Business & Corporate Finance", business_agents
)
# Create the forest
forest = ForestSwarm(
name="Comprehensive Financial Services Forest",
description="A multi-agent system providing expert financial advice across personal, investment, and business domains",
trees=[financial_tree, investment_tree, business_tree],
)
print(
f"✅ Created forest with {len(forest.trees)} trees and {sum(len(tree.agents) for tree in forest.trees)} agents"
)
return forest
def demonstrate_agent_selection(forest):
"""Demonstrate how the forest selects the most relevant agent for different types of questions."""
print("\n🎯 Demonstrating Agent Selection...")
# Test questions covering different domains
test_questions = [
{
"question": "How much should I save monthly for retirement if I want to retire at 65?",
"expected_agent": "Retirement Planning Expert",
"category": "Personal Finance",
},
{
"question": "What are the best investment strategies for a 401k retirement plan?",
"expected_agent": "Investment Strategist",
"category": "Investment",
},
{
"question": "Our company is incorporated in Delaware, how do we do our taxes for free?",
"expected_agent": "Delaware Incorporation Specialist",
"category": "Business",
},
{
"question": "Which tech stocks should I consider for my investment portfolio?",
"expected_agent": "Stock Market Analyst",
"category": "Investment",
},
{
"question": "How can I consolidate my credit card debt and create a repayment plan?",
"expected_agent": "Debt Management Counselor",
"category": "Personal Finance",
},
{
"question": "What are the benefits of incorporating in Delaware vs. other states?",
"expected_agent": "Delaware Incorporation Specialist",
"category": "Business",
},
]
for i, test_case in enumerate(test_questions, 1):
print(f"\n--- Test Case {i}: {test_case['category']} ---")
print(f"Question: {test_case['question']}")
print(f"Expected Agent: {test_case['expected_agent']}")
try:
# Find the relevant tree
relevant_tree = forest.find_relevant_tree(
test_case["question"]
)
if relevant_tree:
print(f"Selected Tree: {relevant_tree.tree_name}")
# Find the relevant agent
relevant_agent = relevant_tree.find_relevant_agent(
test_case["question"]
)
if relevant_agent:
print(
f"Selected Agent: {relevant_agent.agent_name}"
)
# Check if the selection matches expectation
if (
test_case["expected_agent"]
in relevant_agent.agent_name
):
print(
"✅ Agent selection matches expectation!"
)
else:
print(
"⚠️ Agent selection differs from expectation"
)
print(
f" Expected: {test_case['expected_agent']}"
)
print(
f" Selected: {relevant_agent.agent_name}"
)
else:
print("❌ No relevant agent found")
else:
print("❌ No relevant tree found")
except Exception as e:
print(f"❌ Error during agent selection: {e}")
def run_sample_tasks(forest):
"""Run sample tasks to demonstrate the forest's capabilities."""
print("\n🚀 Running Sample Tasks...")
sample_tasks = [
"What are the key benefits of incorporating a business in Delaware?",
"How should I allocate my investment portfolio if I'm 30 years old?",
"What's the best way to start saving for retirement in my 20s?",
]
for i, task in enumerate(sample_tasks, 1):
print(f"\n--- Task {i} ---")
print(f"Task: {task}")
try:
result = forest.run(task)
print(
f"Result: {result[:200]}..."
if len(str(result)) > 200
else f"Result: {result}"
)
except Exception as e:
print(f"❌ Task execution failed: {e}")
def main():
"""Main function to demonstrate ForestSwarm functionality."""
print("🌲 ForestSwarm Demonstration")
print("=" * 60)
try:
# Create the forest
forest = create_financial_services_forest()
# Demonstrate agent selection
demonstrate_agent_selection(forest)
# Run sample tasks
run_sample_tasks(forest)
print(
"\n🎉 ForestSwarm demonstration completed successfully!"
)
except Exception as e:
print(f"\n❌ Error during demonstration: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

@ -0,0 +1,43 @@
from swarms.structs.tree_swarm import TreeAgent, Tree, ForestSwarm
# Create agents with varying system prompts and dynamically generated distances/keywords
agents_tree1 = [
TreeAgent(
system_prompt="Stock Analysis Agent",
agent_name="Stock Analysis Agent",
),
TreeAgent(
system_prompt="Financial Planning Agent",
agent_name="Financial Planning Agent",
),
TreeAgent(
agent_name="Retirement Strategy Agent",
system_prompt="Retirement Strategy Agent",
),
]
agents_tree2 = [
TreeAgent(
system_prompt="Tax Filing Agent",
agent_name="Tax Filing Agent",
),
TreeAgent(
system_prompt="Investment Strategy Agent",
agent_name="Investment Strategy Agent",
),
TreeAgent(
system_prompt="ROTH IRA Agent", agent_name="ROTH IRA Agent"
),
]
# Create trees
tree1 = Tree(tree_name="Financial Tree", agents=agents_tree1)
tree2 = Tree(tree_name="Investment Tree", agents=agents_tree2)
# Create the ForestSwarm
multi_agent_structure = ForestSwarm(trees=[tree1, tree2])
# Run a task
task = "Our company is incorporated in delaware, how do we do our taxes for free?"
output = multi_agent_structure.run(task)
print(output)

@ -5,9 +5,6 @@ import numpy as np
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from swarms.utils.auto_download_check_packages import (
auto_check_and_download_package,
)
from swarms.utils.loguru_logger import initialize_logger
logger = initialize_logger(log_folder="swarm_matcher")
@ -48,18 +45,16 @@ class SwarmMatcher:
try:
import torch
except ImportError:
auto_check_and_download_package(
"torch", package_manager="pip", upgrade=True
raise ImportError(
"torch package not found. Pip install torch."
)
import torch
try:
import transformers
except ImportError:
auto_check_and_download_package(
"transformers", package_manager="pip", upgrade=True
raise ImportError(
"transformers package not found. Pip install transformers."
)
import transformers
self.torch = torch
try:

@ -3,51 +3,89 @@ from collections import Counter
from datetime import datetime
from typing import Any, List, Optional
import numpy as np
from litellm import embedding
from pydantic import BaseModel, Field
from swarms.structs.agent import Agent
from swarms.utils.loguru_logger import initialize_logger
from swarms.utils.auto_download_check_packages import (
auto_check_and_download_package,
)
from swarms.structs.conversation import Conversation
from swarms.utils.loguru_logger import initialize_logger
logger = initialize_logger(log_folder="tree_swarm")
# Pydantic Models for Logging
class AgentLogInput(BaseModel):
"""
Input log model for tracking agent task execution.
Attributes:
log_id (str): Unique identifier for the log entry
agent_name (str): Name of the agent executing the task
task (str): Description of the task being executed
timestamp (datetime): When the task was started
"""
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
task: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
timestamp: datetime = Field(default_factory=lambda: datetime.now(datetime.UTC))
class AgentLogOutput(BaseModel):
"""
Output log model for tracking agent task completion.
Attributes:
log_id (str): Unique identifier for the log entry
agent_name (str): Name of the agent that completed the task
result (Any): Result/output from the task execution
timestamp (datetime): When the task was completed
"""
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
agent_name: str
result: Any
timestamp: datetime = Field(default_factory=datetime.utcnow)
timestamp: datetime = Field(default_factory=lambda: datetime.now(datetime.UTC))
class TreeLog(BaseModel):
"""
Tree execution log model for tracking tree-level operations.
Attributes:
log_id (str): Unique identifier for the log entry
tree_name (str): Name of the tree that executed the task
task (str): Description of the task that was executed
selected_agent (str): Name of the agent selected for the task
timestamp (datetime): When the task was executed
result (Any): Result/output from the task execution
"""
log_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), alias="id"
)
tree_name: str
task: str
selected_agent: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
timestamp: datetime = Field(default_factory=lambda: datetime.now(datetime.UTC))
result: Any
def extract_keywords(prompt: str, top_n: int = 5) -> List[str]:
"""
A simplified keyword extraction function using basic word splitting instead of NLTK tokenization.
Args:
prompt (str): The text prompt to extract keywords from
top_n (int): Maximum number of keywords to return
Returns:
List[str]: List of extracted keywords
"""
words = prompt.lower().split()
filtered_words = [word for word in words if word.isalnum()]
@ -55,6 +93,30 @@ def extract_keywords(prompt: str, top_n: int = 5) -> List[str]:
return [word for word, _ in word_counts.most_common(top_n)]
def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""
Calculate cosine similarity between two vectors.
Args:
vec1 (List[float]): First vector
vec2 (List[float]): Second vector
Returns:
float: Cosine similarity score between 0 and 1
"""
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
class TreeAgent(Agent):
"""
A specialized Agent class that contains information about the system prompt's
@ -68,9 +130,23 @@ class TreeAgent(Agent):
system_prompt: str = None,
model_name: str = "gpt-4o",
agent_name: Optional[str] = None,
embedding_model_name: str = "text-embedding-ada-002",
*args,
**kwargs,
):
"""
Initialize a TreeAgent with litellm embedding capabilities.
Args:
name (str): Name of the agent
description (str): Description of the agent
system_prompt (str): System prompt for the agent
model_name (str): Name of the language model to use
agent_name (Optional[str]): Alternative name for the agent
embedding_model_name (str): Name of the embedding model to use
*args: Additional positional arguments
**kwargs: Additional keyword arguments
"""
agent_name = agent_name
super().__init__(
name=name,
@ -81,33 +157,64 @@ class TreeAgent(Agent):
*args,
**kwargs,
)
self.embedding_model_name = embedding_model_name
try:
import sentence_transformers
except ImportError:
auto_check_and_download_package(
"sentence-transformers", package_manager="pip"
# Generate system prompt embedding using litellm
if system_prompt:
self.system_prompt_embedding = self._get_embedding(
system_prompt
)
import sentence_transformers
self.sentence_transformers = sentence_transformers
# Pretrained model for embeddings
self.embedding_model = (
sentence_transformers.SentenceTransformer(
"all-MiniLM-L6-v2"
)
)
self.system_prompt_embedding = self.embedding_model.encode(
system_prompt, convert_to_tensor=True
)
else:
self.system_prompt_embedding = None
# Automatically extract keywords from system prompt
self.relevant_keywords = extract_keywords(system_prompt)
self.relevant_keywords = (
extract_keywords(system_prompt) if system_prompt else []
)
# Distance is now calculated based on similarity between agents' prompts
self.distance = None # Will be dynamically calculated later
def _get_embedding(self, text: str) -> List[float]:
"""
Get embedding for a given text using litellm.
Args:
text (str): Text to embed
Returns:
List[float]: Embedding vector
"""
try:
response = embedding(
model=self.embedding_model_name, input=[text]
)
logger.info(f"Embedding type: {type(response)}")
# print(response)
# Handle different response structures from litellm
if hasattr(response, "data") and response.data:
if hasattr(response.data[0], "embedding"):
return response.data[0].embedding
elif (
isinstance(response.data[0], dict)
and "embedding" in response.data[0]
):
return response.data[0]["embedding"]
else:
logger.error(
f"Unexpected response structure: {response.data[0]}"
)
return [0.0] * 1536
else:
logger.error(
f"Unexpected response structure: {response}"
)
return [0.0] * 1536
except Exception as e:
logger.error(f"Error getting embedding: {e}")
# Return a zero vector as fallback
return [0.0] * 1536 # Default OpenAI embedding dimension
def calculate_distance(self, other_agent: "TreeAgent") -> float:
"""
Calculate the distance between this agent and another agent using embedding similarity.
@ -118,10 +225,16 @@ class TreeAgent(Agent):
Returns:
float: Distance score between 0 and 1, with 0 being close and 1 being far.
"""
similarity = self.sentence_transformers.util.pytorch_cos_sim(
if (
not self.system_prompt_embedding
or not other_agent.system_prompt_embedding
):
return 1.0 # Maximum distance if embeddings are not available
similarity = cosine_similarity(
self.system_prompt_embedding,
other_agent.system_prompt_embedding,
).item()
)
distance = (
1 - similarity
) # Closer agents have a smaller distance
@ -130,6 +243,18 @@ class TreeAgent(Agent):
def run_task(
self, task: str, img: str = None, *args, **kwargs
) -> Any:
"""
Execute a task and log the input and output.
Args:
task (str): The task to execute
img (str): Optional image input
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Any: Result of the task execution
"""
input_log = AgentLogInput(
agent_name=self.agent_name,
task=task,
@ -157,7 +282,7 @@ class TreeAgent(Agent):
Checks if the agent is relevant for the given task using both keyword matching and embedding similarity.
Args:
task (str): The task to be executed.
task (str): The task or query for which we need to find a relevant agent.
threshold (float): The cosine similarity threshold for embedding-based matching.
Returns:
@ -170,14 +295,10 @@ class TreeAgent(Agent):
)
# Perform embedding similarity match if keyword match is not found
if not keyword_match:
task_embedding = self.embedding_model.encode(
task, convert_to_tensor=True
)
similarity = (
self.sentence_transformers.util.pytorch_cos_sim(
self.system_prompt_embedding, task_embedding
).item()
if not keyword_match and self.system_prompt_embedding:
task_embedding = self._get_embedding(task)
similarity = cosine_similarity(
self.system_prompt_embedding, task_embedding
)
logger.info(
f"Semantic similarity between task and {self.agent_name}: {similarity:.2f}"
@ -203,6 +324,9 @@ class Tree:
def calculate_agent_distances(self):
"""
Automatically calculate and assign distances between agents in the tree based on prompt similarity.
This method computes the semantic distance between consecutive agents using their system prompt
embeddings and sorts the agents by distance for optimal task routing.
"""
logger.info(
f"Calculating distances between agents in tree '{self.tree_name}'"
@ -271,10 +395,16 @@ class ForestSwarm:
**kwargs,
):
"""
Initializes the structure with multiple trees of agents.
Initialize a ForestSwarm with multiple trees of agents.
Args:
trees (List[Tree]): A list of trees in the structure.
name (str): Name of the forest swarm
description (str): Description of the forest swarm
trees (List[Tree]): A list of trees in the structure
shared_memory (Any): Shared memory object for inter-tree communication
rules (str): Rules governing the forest swarm behavior
*args: Additional positional arguments
**kwargs: Additional keyword arguments
"""
self.name = name
self.description = description
@ -290,13 +420,13 @@ class ForestSwarm:
def find_relevant_tree(self, task: str) -> Optional[Tree]:
"""
Finds the most relevant tree based on the given task.
Find the most relevant tree based on the given task.
Args:
task (str): The task or query for which we need to find a relevant tree.
task (str): The task or query for which we need to find a relevant tree
Returns:
Optional[Tree]: The most relevant tree, or None if no match found.
Optional[Tree]: The most relevant tree, or None if no match found
"""
logger.info(
f"Searching for the most relevant tree for task: {task}"
@ -309,13 +439,16 @@ class ForestSwarm:
def run(self, task: str, img: str = None, *args, **kwargs) -> Any:
"""
Executes the given task by finding the most relevant tree and agent within that tree.
Execute the given task by finding the most relevant tree and agent within that tree.
Args:
task (str): The task or query to be executed.
task (str): The task or query to be executed
img (str): Optional image input for vision-enabled tasks
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Any: The result of the task after it has been processed by the agents.
Any: The result of the task after it has been processed by the agents
"""
try:
logger.info(

@ -1,6 +1,7 @@
from swarms.utils.auto_download_check_packages import (
auto_check_and_download_package,
)
from typing import Any
try:
@ -22,7 +23,7 @@ except ImportError:
class StringStoppingCriteria(transformers.StoppingCriteria):
def __init__(
self, tokenizer: transformers.PreTrainedTokenizer, prompt_length: int # type: ignore
self, tokenizer: Any, prompt_length: int # type: ignore
):
self.tokenizer = tokenizer
self.prompt_length = prompt_length
@ -48,7 +49,7 @@ class StringStoppingCriteria(transformers.StoppingCriteria):
class NumberStoppingCriteria(transformers.StoppingCriteria):
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer, # type: ignore
tokenizer: Any, # type: ignore
prompt_length: int,
precision: int = 3,
):

@ -0,0 +1,653 @@
import sys
from swarms.structs.tree_swarm import (
TreeAgent,
Tree,
ForestSwarm,
AgentLogInput,
AgentLogOutput,
TreeLog,
extract_keywords,
cosine_similarity,
)
# Test Results Tracking
test_results = {"passed": 0, "failed": 0, "total": 0}
def assert_equal(actual, expected, test_name):
"""Assert that actual equals expected, track test results."""
test_results["total"] += 1
if actual == expected:
test_results["passed"] += 1
print(f"✅ PASS: {test_name}")
return True
else:
test_results["failed"] += 1
print(f"❌ FAIL: {test_name}")
print(f" Expected: {expected}")
print(f" Actual: {actual}")
return False
def assert_true(condition, test_name):
"""Assert that condition is True, track test results."""
test_results["total"] += 1
if condition:
test_results["passed"] += 1
print(f"✅ PASS: {test_name}")
return True
else:
test_results["failed"] += 1
print(f"❌ FAIL: {test_name}")
print(" Condition was False")
return False
def assert_false(condition, test_name):
"""Assert that condition is False, track test results."""
test_results["total"] += 1
if not condition:
test_results["passed"] += 1
print(f"✅ PASS: {test_name}")
return True
else:
test_results["failed"] += 1
print(f"❌ FAIL: {test_name}")
print(" Condition was True")
return False
def assert_is_instance(obj, expected_type, test_name):
"""Assert that obj is an instance of expected_type, track test results."""
test_results["total"] += 1
if isinstance(obj, expected_type):
test_results["passed"] += 1
print(f"✅ PASS: {test_name}")
return True
else:
test_results["failed"] += 1
print(f"❌ FAIL: {test_name}")
print(f" Expected type: {expected_type}")
print(f" Actual type: {type(obj)}")
return False
def assert_not_none(obj, test_name):
"""Assert that obj is not None, track test results."""
test_results["total"] += 1
if obj is not None:
test_results["passed"] += 1
print(f"✅ PASS: {test_name}")
return True
else:
test_results["failed"] += 1
print(f"❌ FAIL: {test_name}")
print(" Object was None")
return False
# Test Data
SAMPLE_SYSTEM_PROMPTS = {
"financial_advisor": "I am a financial advisor specializing in investment planning, retirement strategies, and tax optimization for individuals and businesses.",
"tax_expert": "I am a tax expert with deep knowledge of corporate taxation, Delaware incorporation benefits, and free tax filing options for businesses.",
"stock_analyst": "I am a stock market analyst who provides insights on market trends, stock recommendations, and portfolio optimization strategies.",
"retirement_planner": "I am a retirement planning specialist who helps individuals and businesses create comprehensive retirement strategies and investment plans.",
}
SAMPLE_TASKS = {
"tax_question": "Our company is incorporated in Delaware, how do we do our taxes for free?",
"investment_question": "What are the best investment strategies for a 401k retirement plan?",
"stock_question": "Which tech stocks should I consider for my investment portfolio?",
"retirement_question": "How much should I save monthly for retirement if I want to retire at 65?",
}
# Test Functions
def test_extract_keywords():
"""Test the extract_keywords function."""
print("\n🧪 Testing extract_keywords function...")
# Test basic keyword extraction
text = (
"financial advisor investment planning retirement strategies"
)
keywords = extract_keywords(text, top_n=3)
assert_equal(
len(keywords),
3,
"extract_keywords returns correct number of keywords",
)
assert_true(
"financial" in keywords,
"extract_keywords includes 'financial'",
)
assert_true(
"investment" in keywords,
"extract_keywords includes 'investment'",
)
# Test with punctuation and case
text = "Tax Expert! Corporate Taxation, Delaware Incorporation."
keywords = extract_keywords(text, top_n=5)
assert_true(
"tax" in keywords,
"extract_keywords handles punctuation and case",
)
assert_true(
"corporate" in keywords,
"extract_keywords handles punctuation and case",
)
# Test empty string
keywords = extract_keywords("", top_n=3)
assert_equal(
len(keywords), 0, "extract_keywords handles empty string"
)
def test_cosine_similarity():
"""Test the cosine_similarity function."""
print("\n🧪 Testing cosine_similarity function...")
# Test identical vectors
vec1 = [1.0, 0.0, 0.0]
vec2 = [1.0, 0.0, 0.0]
similarity = cosine_similarity(vec1, vec2)
assert_equal(
similarity,
1.0,
"cosine_similarity returns 1.0 for identical vectors",
)
# Test orthogonal vectors
vec1 = [1.0, 0.0, 0.0]
vec2 = [0.0, 1.0, 0.0]
similarity = cosine_similarity(vec1, vec2)
assert_equal(
similarity,
0.0,
"cosine_similarity returns 0.0 for orthogonal vectors",
)
# Test opposite vectors
vec1 = [1.0, 0.0, 0.0]
vec2 = [-1.0, 0.0, 0.0]
similarity = cosine_similarity(vec1, vec2)
assert_equal(
similarity,
-1.0,
"cosine_similarity returns -1.0 for opposite vectors",
)
# Test zero vectors
vec1 = [0.0, 0.0, 0.0]
vec2 = [1.0, 0.0, 0.0]
similarity = cosine_similarity(vec1, vec2)
assert_equal(
similarity, 0.0, "cosine_similarity handles zero vectors"
)
def test_agent_log_models():
"""Test the Pydantic log models."""
print("\n🧪 Testing Pydantic log models...")
# Test AgentLogInput
log_input = AgentLogInput(
agent_name="test_agent", task="test_task"
)
assert_is_instance(
log_input,
AgentLogInput,
"AgentLogInput creates correct instance",
)
assert_not_none(
log_input.log_id, "AgentLogInput generates log_id"
)
assert_equal(
log_input.agent_name,
"test_agent",
"AgentLogInput stores agent_name",
)
assert_equal(
log_input.task, "test_task", "AgentLogInput stores task"
)
# Test AgentLogOutput
log_output = AgentLogOutput(
agent_name="test_agent", result="test_result"
)
assert_is_instance(
log_output,
AgentLogOutput,
"AgentLogOutput creates correct instance",
)
assert_not_none(
log_output.log_id, "AgentLogOutput generates log_id"
)
assert_equal(
log_output.result,
"test_result",
"AgentLogOutput stores result",
)
# Test TreeLog
tree_log = TreeLog(
tree_name="test_tree",
task="test_task",
selected_agent="test_agent",
result="test_result",
)
assert_is_instance(
tree_log, TreeLog, "TreeLog creates correct instance"
)
assert_not_none(tree_log.log_id, "TreeLog generates log_id")
assert_equal(
tree_log.tree_name, "test_tree", "TreeLog stores tree_name"
)
def test_tree_agent_initialization():
"""Test TreeAgent initialization and basic properties."""
print("\n🧪 Testing TreeAgent initialization...")
# Test basic initialization
agent = TreeAgent(
name="Test Agent",
description="A test agent",
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
)
assert_is_instance(
agent, TreeAgent, "TreeAgent creates correct instance"
)
assert_equal(
agent.agent_name,
"financial_advisor",
"TreeAgent stores agent_name",
)
assert_equal(
agent.embedding_model_name,
"text-embedding-ada-002",
"TreeAgent has default embedding model",
)
assert_true(
len(agent.relevant_keywords) > 0,
"TreeAgent extracts keywords from system prompt",
)
assert_not_none(
agent.system_prompt_embedding,
"TreeAgent generates system prompt embedding",
)
# Test with custom embedding model
agent_custom = TreeAgent(
system_prompt="Test prompt",
embedding_model_name="custom-model",
)
assert_equal(
agent_custom.embedding_model_name,
"custom-model",
"TreeAgent accepts custom embedding model",
)
def test_tree_agent_distance_calculation():
"""Test TreeAgent distance calculation between agents."""
print("\n🧪 Testing TreeAgent distance calculation...")
agent1 = TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
)
agent2 = TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
)
agent3 = TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["stock_analyst"],
agent_name="stock_analyst",
)
# Test distance calculation
distance1 = agent1.calculate_distance(agent2)
distance2 = agent1.calculate_distance(agent3)
assert_true(
0.0 <= distance1 <= 1.0, "Distance is between 0 and 1"
)
assert_true(
0.0 <= distance2 <= 1.0, "Distance is between 0 and 1"
)
assert_true(isinstance(distance1, float), "Distance is a float")
# Test that identical agents have distance 0
identical_agent = TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="identical_advisor",
)
distance_identical = agent1.calculate_distance(identical_agent)
assert_true(
distance_identical < 0.1,
"Identical agents have very small distance",
)
def test_tree_agent_task_relevance():
"""Test TreeAgent task relevance checking."""
print("\n🧪 Testing TreeAgent task relevance...")
tax_agent = TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
)
# Test keyword matching
tax_task = SAMPLE_TASKS["tax_question"]
is_relevant = tax_agent.is_relevant_for_task(
tax_task, threshold=0.7
)
assert_true(is_relevant, "Tax agent is relevant for tax question")
# Test non-relevant task
stock_task = SAMPLE_TASKS["stock_question"]
is_relevant = tax_agent.is_relevant_for_task(
stock_task, threshold=0.7
)
# This might be True due to semantic similarity, so we just check it's a boolean
assert_true(
isinstance(is_relevant, bool),
"Task relevance returns boolean",
)
def test_tree_initialization():
"""Test Tree initialization and agent organization."""
print("\n🧪 Testing Tree initialization...")
agents = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
),
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
),
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["stock_analyst"],
agent_name="stock_analyst",
),
]
tree = Tree("Financial Services Tree", agents)
assert_equal(
tree.tree_name,
"Financial Services Tree",
"Tree stores tree_name",
)
assert_equal(len(tree.agents), 3, "Tree contains all agents")
assert_true(
all(hasattr(agent, "distance") for agent in tree.agents),
"All agents have distance calculated",
)
# Test that agents are sorted by distance
distances = [agent.distance for agent in tree.agents]
assert_true(
distances == sorted(distances),
"Agents are sorted by distance",
)
def test_tree_agent_finding():
"""Test Tree agent finding functionality."""
print("\n🧪 Testing Tree agent finding...")
agents = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
),
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
),
]
tree = Tree("Test Tree", agents)
# Test finding relevant agent
tax_task = SAMPLE_TASKS["tax_question"]
relevant_agent = tree.find_relevant_agent(tax_task)
assert_not_none(
relevant_agent, "Tree finds relevant agent for tax task"
)
# Test finding agent for unrelated task
unrelated_task = "How do I cook pasta?"
relevant_agent = tree.find_relevant_agent(unrelated_task)
# This might return None or an agent depending on similarity threshold
assert_true(
relevant_agent is None
or isinstance(relevant_agent, TreeAgent),
"Tree handles unrelated tasks",
)
def test_forest_swarm_initialization():
"""Test ForestSwarm initialization."""
print("\n🧪 Testing ForestSwarm initialization...")
agents_tree1 = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
)
]
agents_tree2 = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
)
]
tree1 = Tree("Financial Tree", agents_tree1)
tree2 = Tree("Tax Tree", agents_tree2)
forest = ForestSwarm(
name="Test Forest",
description="A test forest",
trees=[tree1, tree2],
)
assert_equal(
forest.name, "Test Forest", "ForestSwarm stores name"
)
assert_equal(
forest.description,
"A test forest",
"ForestSwarm stores description",
)
assert_equal(
len(forest.trees), 2, "ForestSwarm contains all trees"
)
assert_not_none(
forest.conversation, "ForestSwarm creates conversation object"
)
def test_forest_swarm_tree_finding():
"""Test ForestSwarm tree finding functionality."""
print("\n🧪 Testing ForestSwarm tree finding...")
agents_tree1 = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["financial_advisor"],
agent_name="financial_advisor",
)
]
agents_tree2 = [
TreeAgent(
system_prompt=SAMPLE_SYSTEM_PROMPTS["tax_expert"],
agent_name="tax_expert",
)
]
tree1 = Tree("Financial Tree", agents_tree1)
tree2 = Tree("Tax Tree", agents_tree2)
forest = ForestSwarm(trees=[tree1, tree2])
# Test finding relevant tree for tax question
tax_task = SAMPLE_TASKS["tax_question"]
relevant_tree = forest.find_relevant_tree(tax_task)
assert_not_none(
relevant_tree, "ForestSwarm finds relevant tree for tax task"
)
# Test finding relevant tree for financial question
financial_task = SAMPLE_TASKS["investment_question"]
relevant_tree = forest.find_relevant_tree(financial_task)
assert_not_none(
relevant_tree,
"ForestSwarm finds relevant tree for financial task",
)
def test_forest_swarm_execution():
"""Test ForestSwarm task execution."""
print("\n🧪 Testing ForestSwarm task execution...")
# Create a simple forest with one tree and one agent
agent = TreeAgent(
system_prompt="I am a helpful assistant that can answer questions about Delaware incorporation and taxes.",
agent_name="delaware_expert",
)
tree = Tree("Delaware Tree", [agent])
forest = ForestSwarm(trees=[tree])
# Test task execution
task = "What are the benefits of incorporating in Delaware?"
try:
result = forest.run(task)
assert_not_none(
result, "ForestSwarm returns result from task execution"
)
assert_true(isinstance(result, str), "Result is a string")
except Exception as e:
# If execution fails due to external dependencies, that's okay for unit tests
print(
f"⚠️ Task execution failed (expected in unit test environment): {e}"
)
def test_edge_cases():
"""Test edge cases and error handling."""
print("\n🧪 Testing edge cases and error handling...")
# Test TreeAgent with None system prompt
agent_no_prompt = TreeAgent(
system_prompt=None, agent_name="no_prompt_agent"
)
assert_equal(
len(agent_no_prompt.relevant_keywords),
0,
"Agent with None prompt has empty keywords",
)
assert_true(
agent_no_prompt.system_prompt_embedding is None,
"Agent with None prompt has None embedding",
)
# Test Tree with empty agents list
empty_tree = Tree("Empty Tree", [])
assert_equal(
len(empty_tree.agents), 0, "Empty tree has no agents"
)
# Test ForestSwarm with empty trees list
empty_forest = ForestSwarm(trees=[])
assert_equal(
len(empty_forest.trees), 0, "Empty forest has no trees"
)
# Test cosine_similarity with empty vectors
empty_vec = []
vec = [1.0, 0.0, 0.0]
similarity = cosine_similarity(empty_vec, vec)
assert_equal(
similarity, 0.0, "cosine_similarity handles empty vectors"
)
def run_all_tests():
"""Run all unit tests and display results."""
print("🚀 Starting ForestSwarm Unit Tests...")
print("=" * 60)
# Run all test functions
test_functions = [
test_extract_keywords,
test_cosine_similarity,
test_agent_log_models,
test_tree_agent_initialization,
test_tree_agent_distance_calculation,
test_tree_agent_task_relevance,
test_tree_initialization,
test_tree_agent_finding,
test_forest_swarm_initialization,
test_forest_swarm_tree_finding,
test_forest_swarm_execution,
test_edge_cases,
]
for test_func in test_functions:
try:
test_func()
except Exception as e:
test_results["total"] += 1
test_results["failed"] += 1
print(f"❌ ERROR: {test_func.__name__} - {e}")
# Display results
print("\n" + "=" * 60)
print("📊 TEST RESULTS SUMMARY")
print("=" * 60)
print(f"Total Tests: {test_results['total']}")
print(f"Passed: {test_results['passed']}")
print(f"Failed: {test_results['failed']}")
success_rate = (
(test_results["passed"] / test_results["total"]) * 100
if test_results["total"] > 0
else 0
)
print(f"Success Rate: {success_rate:.1f}%")
if test_results["failed"] == 0:
print(
"\n🎉 All tests passed! ForestSwarm is working correctly."
)
else:
print(
f"\n⚠️ {test_results['failed']} test(s) failed. Please review the failures above."
)
return test_results["failed"] == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
Loading…
Cancel
Save