commit
32e9ec48b8
@ -0,0 +1,56 @@
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
aop = AOP(
|
||||
name="example_system",
|
||||
description="A simple example of tools, agents, and swarms",
|
||||
url="http://localhost:8000/sse",
|
||||
)
|
||||
|
||||
# print(
|
||||
# aop.call_tool_or_agent(
|
||||
# url="http://localhost:8000/sse",
|
||||
# name="calculator",
|
||||
# arguments={"operation": "add", "x": 1, "y": 2},
|
||||
# output_type="list",
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# print(
|
||||
# aop.call_tool_or_agent_batched(
|
||||
# url="http://localhost:8000/sse",
|
||||
# names=["calculator", "calculator"],
|
||||
# arguments=[{"operation": "add", "x": 1, "y": 2}, {"operation": "multiply", "x": 3, "y": 4}],
|
||||
# output_type="list",
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# print(
|
||||
# aop.call_tool_or_agent_concurrently(
|
||||
# url="http://localhost:8000/sse",
|
||||
# names=["calculator", "calculator"],
|
||||
# arguments=[{"operation": "add", "x": 1, "y": 2}, {"operation": "multiply", "x": 3, "y": 4}],
|
||||
# output_type="list",
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# print(aop.list_agents())
|
||||
|
||||
# print(aop.list_tools())
|
||||
|
||||
# print(aop.list_swarms())
|
||||
|
||||
# print(aop.list_all(url="http://localhost:8000/sse"))
|
||||
|
||||
# print(any_to_str(aop.list_all()))
|
||||
|
||||
# print(aop.search_if_tool_exists(name="calculator"))
|
||||
|
||||
# out = aop.list_tool_parameters(name="calculator")
|
||||
# print(type(out))
|
||||
# print(out)
|
||||
|
||||
print(aop.list_agents())
|
||||
print(aop.list_swarms())
|
@ -0,0 +1,66 @@
|
||||
from swarms.structs.aop import AOP
|
||||
|
||||
# Initialize the AOP instance
|
||||
aop = AOP(
|
||||
name="example_system",
|
||||
description="A simple example of tools, agents, and swarms",
|
||||
)
|
||||
|
||||
|
||||
# Define a simple tool
|
||||
@aop.tool(name="calculator", description="A simple calculator tool")
|
||||
async def calculator(operation: str, x: float, y: float):
|
||||
"""
|
||||
Performs basic arithmetic operations
|
||||
"""
|
||||
if operation == "add":
|
||||
return x + y
|
||||
elif operation == "multiply":
|
||||
return x * y
|
||||
else:
|
||||
raise ValueError("Unsupported operation")
|
||||
|
||||
|
||||
# Define an agent that uses the calculator tool
|
||||
@aop.agent(
|
||||
name="math_agent",
|
||||
description="Agent that performs mathematical operations",
|
||||
)
|
||||
async def math_agent(operation: str, numbers: list[float]):
|
||||
"""
|
||||
Agent that chains multiple calculations together
|
||||
"""
|
||||
result = numbers[0]
|
||||
for num in numbers[1:]:
|
||||
# Using the calculator tool within the agent
|
||||
result = await aop.call_tool_or_agent(
|
||||
"calculator",
|
||||
{"operation": operation, "x": result, "y": num},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Define a swarm that coordinates multiple agents
|
||||
@aop.swarm(
|
||||
name="math_swarm",
|
||||
description="Swarm that coordinates mathematical operations",
|
||||
)
|
||||
async def math_swarm(numbers: list[float]):
|
||||
"""
|
||||
Swarm that performs multiple operations on a set of numbers
|
||||
"""
|
||||
# Perform addition and multiplication in parallel
|
||||
results = await aop.call_tool_or_agent_concurrently(
|
||||
names=["math_agent", "math_agent"],
|
||||
arguments=[
|
||||
{"operation": "add", "numbers": numbers},
|
||||
{"operation": "multiply", "numbers": numbers},
|
||||
],
|
||||
)
|
||||
|
||||
return {"sum": results[0], "product": results[1]}
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
aop.run_sse()
|
@ -0,0 +1,291 @@
|
||||
import concurrent.futures
|
||||
from typing import Dict, Optional
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from swarms import Agent
|
||||
|
||||
import replicate
|
||||
|
||||
from swarms.utils.str_to_dict import str_to_dict
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def generate_key(prefix: str = "run") -> str:
|
||||
"""
|
||||
Generates an API key similar to OpenAI's format (sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX).
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix for the API key. Defaults to "sk".
|
||||
|
||||
Returns:
|
||||
str: An API key string in format: prefix-<48 random characters>
|
||||
"""
|
||||
# Create random string of letters and numbers
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = "".join(secrets.choice(alphabet) for _ in range(28))
|
||||
return f"{prefix}-{random_part}"
|
||||
|
||||
|
||||
def _generate_media(
|
||||
prompt: str = None, modalities: list = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate media content (images or videos) based on text prompts using AI models.
|
||||
|
||||
Args:
|
||||
prompt (str): Text description of the content to be generated
|
||||
modalities (list): List of media types to generate (e.g., ["image", "video"])
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Dictionary containing file paths of generated media
|
||||
"""
|
||||
if not prompt or not modalities:
|
||||
raise ValueError("Prompt and modalities must be provided")
|
||||
|
||||
input = {"prompt": prompt}
|
||||
results = {}
|
||||
|
||||
def _generate_image(input: Dict) -> str:
|
||||
"""Generate an image and return the file path."""
|
||||
output = replicate.run(
|
||||
"black-forest-labs/flux-dev", input=input
|
||||
)
|
||||
file_paths = []
|
||||
|
||||
for index, item in enumerate(output):
|
||||
unique_id = str(uuid.uuid4())
|
||||
artifact = item.read()
|
||||
file_path = f"output_{unique_id}_{index}.webp"
|
||||
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(artifact)
|
||||
|
||||
file_paths.append(file_path)
|
||||
|
||||
return file_paths
|
||||
|
||||
def _generate_video(input: Dict) -> str:
|
||||
"""Generate a video and return the file path."""
|
||||
output = replicate.run("luma/ray", input=input)
|
||||
unique_id = str(uuid.uuid4())
|
||||
artifact = output.read()
|
||||
file_path = f"output_{unique_id}.mp4"
|
||||
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(artifact)
|
||||
|
||||
return file_path
|
||||
|
||||
for modality in modalities:
|
||||
if modality == "image":
|
||||
results["images"] = _generate_image(input)
|
||||
elif modality == "video":
|
||||
results["video"] = _generate_video(input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
print(results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_media(
|
||||
modalities: list,
|
||||
prompt: Optional[str] = None,
|
||||
count: int = 1,
|
||||
) -> Dict:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=count
|
||||
) as executor:
|
||||
# Create list of identical tasks to run concurrently
|
||||
futures = [
|
||||
executor.submit(
|
||||
_generate_media,
|
||||
prompt=prompt, # Fix: Pass as keyword arguments
|
||||
modalities=modalities,
|
||||
)
|
||||
for _ in range(count)
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete and collect results
|
||||
results = [
|
||||
future.result()
|
||||
for future in concurrent.futures.as_completed(futures)
|
||||
]
|
||||
|
||||
return {"results": results}
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_media",
|
||||
"description": "Generate different types of media content (image, video, or music) based on text prompts using AI models.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"modality": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": ["image", "video", "music"],
|
||||
},
|
||||
"description": "The type of media content to generate",
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text description of the content to be generated. Specialize it for the modality at hand. For example, if you are generating an image, the prompt should be a description of the image you want to see. If you are generating a video, the prompt should be a description of the video you want to see. If you are generating music, the prompt should be a description of the music you want to hear.",
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of outputs to generate (1-4)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"modality",
|
||||
"prompt",
|
||||
"count",
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
MEDIA_GENERATION_SYSTEM_PROMPT = """
|
||||
You are an expert AI Media Generation Assistant, specialized in crafting precise and effective prompts for generating images, videos, and music. Your role is to help users create high-quality media content by understanding their requests and translating them into optimal prompts.
|
||||
|
||||
GENERAL GUIDELINES:
|
||||
- Always analyze the user's request carefully to determine the appropriate modality (image, video, or music)
|
||||
- Maintain a balanced level of detail in prompts - specific enough to capture the desired outcome but not overly verbose
|
||||
- Consider the technical limitations and capabilities of AI generation systems
|
||||
- When unclear, ask for clarification about specific details or preferences
|
||||
|
||||
MODALITY-SPECIFIC GUIDELINES:
|
||||
|
||||
1. IMAGE GENERATION:
|
||||
- Structure prompts with primary subject first, followed by style, mood, and technical specifications
|
||||
- Include relevant art styles when specified (e.g., "digital art", "oil painting", "watercolor", "photorealistic")
|
||||
- Consider composition elements (foreground, background, lighting, perspective)
|
||||
- Use specific adjectives for clarity (instead of "beautiful", specify "vibrant", "ethereal", "gritty", etc.)
|
||||
|
||||
Example image prompts:
|
||||
- "A serene Japanese garden at sunset, with cherry blossoms falling, painted in traditional ukiyo-e style, soft pastel colors"
|
||||
- "Cyberpunk cityscape at night, neon lights reflecting in rain puddles, hyper-realistic digital art style"
|
||||
|
||||
2. VIDEO GENERATION:
|
||||
- Describe the sequence of events clearly
|
||||
- Specify camera movements if relevant (pan, zoom, tracking shot)
|
||||
- Include timing and transitions when necessary
|
||||
- Focus on dynamic elements and motion
|
||||
|
||||
Example video prompts:
|
||||
- "Timelapse of a flower blooming in a garden, close-up shot, soft natural lighting, 10-second duration"
|
||||
- "Drone shot flying through autumn forest, camera slowly rising above the canopy, revealing mountains in the distance"
|
||||
|
||||
3. MUSIC GENERATION:
|
||||
- Specify genre, tempo, and mood
|
||||
- Mention key instruments or sounds
|
||||
- Include emotional qualities and intensity
|
||||
- Reference similar artists or styles if relevant
|
||||
|
||||
Example music prompts:
|
||||
- "Calm ambient electronic music with soft synthesizer pads, gentle piano melodies, 80 BPM, suitable for meditation"
|
||||
- "Upbeat jazz fusion track with prominent bass line, dynamic drums, and horn section, inspired by Weather Report"
|
||||
|
||||
COUNT HANDLING:
|
||||
- When multiple outputs are requested (1-4), maintain consistency while introducing subtle variations
|
||||
- For images: Vary composition or perspective while maintaining style
|
||||
- For videos: Adjust camera angles or timing while keeping the core concept
|
||||
- For music: Modify instrument arrangements or tempo while preserving the genre and mood
|
||||
|
||||
PROMPT OPTIMIZATION PROCESS:
|
||||
1. Identify core requirements from user input
|
||||
2. Determine appropriate modality
|
||||
3. Add necessary style and technical specifications
|
||||
4. Adjust detail level based on complexity
|
||||
5. Consider count and create variations if needed
|
||||
|
||||
EXAMPLES OF HANDLING USER REQUESTS:
|
||||
|
||||
User: "I want a fantasy landscape"
|
||||
Assistant response: {
|
||||
"modality": "image",
|
||||
"prompt": "Majestic fantasy landscape with floating islands, crystal waterfalls, and ancient magical ruins, ethereal lighting, digital art style with rich colors",
|
||||
"count": 1
|
||||
}
|
||||
|
||||
User: "Create 3 variations of a peaceful nature scene"
|
||||
Assistant response: {
|
||||
"modality": "image",
|
||||
"prompt": "Tranquil forest clearing with morning mist, sunbeams filtering through ancient trees, photorealistic style with soft natural lighting",
|
||||
"count": 1
|
||||
}
|
||||
|
||||
IMPORTANT CONSIDERATIONS:
|
||||
- Avoid harmful, unethical, or inappropriate content
|
||||
- Respect copyright and intellectual property guidelines
|
||||
- Maintain consistency with brand guidelines when specified
|
||||
- Consider technical limitations of current AI generation systems
|
||||
|
||||
"""
|
||||
|
||||
# Initialize the agent with the new system prompt
|
||||
agent = Agent(
|
||||
agent_name="Media-Generation-Agent",
|
||||
agent_description="AI Media Generation Assistant",
|
||||
system_prompt=MEDIA_GENERATION_SYSTEM_PROMPT,
|
||||
max_loops=1,
|
||||
tools_list_dictionary=tools,
|
||||
output_type="final",
|
||||
)
|
||||
|
||||
|
||||
def create_agent(task: str):
|
||||
output = str_to_dict(agent.run(task))
|
||||
|
||||
print(output)
|
||||
print(type(output))
|
||||
|
||||
prompt = output["prompt"]
|
||||
count = output["count"]
|
||||
modalities = output["modality"]
|
||||
|
||||
output = generate_media(
|
||||
modalities=modalities,
|
||||
prompt=prompt,
|
||||
count=count,
|
||||
)
|
||||
|
||||
run_id = generate_key()
|
||||
|
||||
total_cost = 0
|
||||
|
||||
for modality in modalities:
|
||||
if modality == "image":
|
||||
total_cost += 0.1
|
||||
elif modality == "video":
|
||||
total_cost += 1
|
||||
|
||||
result = {
|
||||
"id": run_id,
|
||||
"success": True,
|
||||
"prompt": prompt,
|
||||
"count": count,
|
||||
"modality": modalities,
|
||||
"total_cost": total_cost,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
task = "Create 3 super kawaii variations of a magical Chinese mountain garden scene in anime style! 🌸✨ Include adorable elements like: cute koi fish swimming in crystal ponds, fluffy clouds floating around misty peaks, tiny pagodas with twinkling lights, and playful pandas hiding in bamboo groves. Make it extra magical with sparkles and soft pastel colors! Create both a video and an image for each variation. Just 1."
|
||||
output = create_agent(task)
|
||||
print("✨ Yay! Here's your super cute creation! ✨")
|
||||
print(output)
|
@ -1,334 +0,0 @@
|
||||
# ClusterOps API Reference
|
||||
|
||||
ClusterOps is a Python library for managing and executing tasks across CPU and GPU resources in a distributed computing environment. It provides functions for resource discovery, task execution, and performance monitoring.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
|
||||
$ pip3 install clusterops
|
||||
|
||||
```
|
||||
|
||||
## Table of Contents
|
||||
1. [CPU Operations](#cpu-operations)
|
||||
2. [GPU Operations](#gpu-operations)
|
||||
3. [Utility Functions](#utility-functions)
|
||||
4. [Resource Monitoring](#resource-monitoring)
|
||||
|
||||
## CPU Operations
|
||||
|
||||
### `list_available_cpus()`
|
||||
|
||||
Lists all available CPU cores.
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `List[int]` | A list of available CPU core indices. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `RuntimeError` | If no CPUs are found. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import list_available_cpus
|
||||
|
||||
available_cpus = list_available_cpus()
|
||||
print(f"Available CPU cores: {available_cpus}")
|
||||
```
|
||||
|
||||
### `execute_on_cpu(cpu_id: int, func: Callable, *args: Any, **kwargs: Any) -> Any`
|
||||
|
||||
Executes a callable on a specific CPU.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `cpu_id` | `int` | The CPU core to run the function on. |
|
||||
| `func` | `Callable` | The function to be executed. |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Any` | The result of the function execution. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `ValueError` | If the CPU core specified is invalid. |
|
||||
| `RuntimeError` | If there is an error executing the function on the CPU. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import execute_on_cpu
|
||||
|
||||
def sample_task(n: int) -> int:
|
||||
return n * n
|
||||
|
||||
result = execute_on_cpu(0, sample_task, 10)
|
||||
print(f"Result of sample task on CPU 0: {result}")
|
||||
```
|
||||
|
||||
### `execute_with_cpu_cores(core_count: int, func: Callable, *args: Any, **kwargs: Any) -> Any`
|
||||
|
||||
Executes a callable using a specified number of CPU cores.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `core_count` | `int` | The number of CPU cores to run the function on. |
|
||||
| `func` | `Callable` | The function to be executed. |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Any` | The result of the function execution. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `ValueError` | If the number of CPU cores specified is invalid or exceeds available cores. |
|
||||
| `RuntimeError` | If there is an error executing the function on the specified CPU cores. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import execute_with_cpu_cores
|
||||
|
||||
def parallel_task(n: int) -> int:
|
||||
return sum(range(n))
|
||||
|
||||
result = execute_with_cpu_cores(4, parallel_task, 1000000)
|
||||
print(f"Result of parallel task using 4 CPU cores: {result}")
|
||||
```
|
||||
|
||||
## GPU Operations
|
||||
|
||||
### `list_available_gpus() -> List[str]`
|
||||
|
||||
Lists all available GPUs.
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `List[str]` | A list of available GPU names. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `RuntimeError` | If no GPUs are found. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import list_available_gpus
|
||||
|
||||
available_gpus = list_available_gpus()
|
||||
print(f"Available GPUs: {available_gpus}")
|
||||
```
|
||||
|
||||
### `select_best_gpu() -> Optional[int]`
|
||||
|
||||
Selects the GPU with the most free memory.
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Optional[int]` | The GPU ID of the best available GPU, or None if no GPUs are available. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import select_best_gpu
|
||||
|
||||
best_gpu = select_best_gpu()
|
||||
if best_gpu is not None:
|
||||
print(f"Best GPU for execution: GPU {best_gpu}")
|
||||
else:
|
||||
print("No GPUs available")
|
||||
```
|
||||
|
||||
### `execute_on_gpu(gpu_id: int, func: Callable, *args: Any, **kwargs: Any) -> Any`
|
||||
|
||||
Executes a callable on a specific GPU using Ray.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `gpu_id` | `int` | The GPU to run the function on. |
|
||||
| `func` | `Callable` | The function to be executed. |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Any` | The result of the function execution. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `ValueError` | If the GPU index is invalid. |
|
||||
| `RuntimeError` | If there is an error executing the function on the GPU. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import execute_on_gpu
|
||||
|
||||
def gpu_task(n: int) -> int:
|
||||
return n ** 2
|
||||
|
||||
result = execute_on_gpu(0, gpu_task, 10)
|
||||
print(f"Result of GPU task on GPU 0: {result}")
|
||||
```
|
||||
|
||||
### `execute_on_multiple_gpus(gpu_ids: List[int], func: Callable, all_gpus: bool = False, timeout: float = None, *args: Any, **kwargs: Any) -> List[Any]`
|
||||
|
||||
Executes a callable across multiple GPUs using Ray.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `gpu_ids` | `List[int]` | The list of GPU IDs to run the function on. |
|
||||
| `func` | `Callable` | The function to be executed. |
|
||||
| `all_gpus` | `bool` | Whether to use all available GPUs (default: False). |
|
||||
| `timeout` | `float` | Timeout for the execution in seconds (default: None). |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `List[Any]` | A list of results from the execution on each GPU. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `ValueError` | If any GPU index is invalid. |
|
||||
| `RuntimeError` | If there is an error executing the function on the GPUs. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import execute_on_multiple_gpus
|
||||
|
||||
def multi_gpu_task(n: int) -> int:
|
||||
return n ** 3
|
||||
|
||||
results = execute_on_multiple_gpus([0, 1], multi_gpu_task, 5)
|
||||
print(f"Results of multi-GPU task: {results}")
|
||||
```
|
||||
|
||||
### `distributed_execute_on_gpus(gpu_ids: List[int], func: Callable, *args: Any, **kwargs: Any) -> List[Any]`
|
||||
|
||||
Executes a callable across multiple GPUs and nodes using Ray's distributed task scheduling.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `gpu_ids` | `List[int]` | The list of GPU IDs across nodes to run the function on. |
|
||||
| `func` | `Callable` | The function to be executed. |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `List[Any]` | A list of results from the execution on each GPU. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import distributed_execute_on_gpus
|
||||
|
||||
def distributed_task(n: int) -> int:
|
||||
return n ** 4
|
||||
|
||||
results = distributed_execute_on_gpus([0, 1, 2, 3], distributed_task, 3)
|
||||
print(f"Results of distributed GPU task: {results}")
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
### `retry_with_backoff(func: Callable, retries: int = RETRY_COUNT, delay: float = RETRY_DELAY, *args: Any, **kwargs: Any) -> Any`
|
||||
|
||||
Retries a callable function with exponential backoff in case of failure.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `func` | `Callable` | The function to execute with retries. |
|
||||
| `retries` | `int` | Number of retries (default: RETRY_COUNT from env). |
|
||||
| `delay` | `float` | Delay between retries in seconds (default: RETRY_DELAY from env). |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Any` | The result of the function execution. |
|
||||
|
||||
#### Raises
|
||||
| Exception | Description |
|
||||
|-----------|-------------|
|
||||
| `Exception` | After all retries fail. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import retry_with_backoff
|
||||
|
||||
def unstable_task():
|
||||
# Simulating an unstable task that might fail
|
||||
import random
|
||||
if random.random() < 0.5:
|
||||
raise Exception("Task failed")
|
||||
return "Task succeeded"
|
||||
|
||||
result = retry_with_backoff(unstable_task, retries=5, delay=1)
|
||||
print(f"Result of unstable task: {result}")
|
||||
```
|
||||
|
||||
## Resource Monitoring
|
||||
|
||||
### `monitor_resources()`
|
||||
|
||||
Continuously monitors CPU and GPU resources and logs alerts when thresholds are crossed.
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import monitor_resources
|
||||
|
||||
# Start monitoring resources
|
||||
monitor_resources()
|
||||
```
|
||||
|
||||
### `profile_execution(func: Callable, *args: Any, **kwargs: Any) -> Any`
|
||||
|
||||
Profiles the execution of a task, collecting metrics like execution time and CPU/GPU usage.
|
||||
|
||||
#### Parameters
|
||||
| Name | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `func` | `Callable` | The function to profile. |
|
||||
| `*args` | `Any` | Arguments for the callable. |
|
||||
| `**kwargs` | `Any` | Keyword arguments for the callable. |
|
||||
|
||||
#### Returns
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `Any` | The result of the function execution along with the collected metrics. |
|
||||
|
||||
#### Example
|
||||
```python
|
||||
from clusterops import profile_execution
|
||||
|
||||
def cpu_intensive_task():
|
||||
return sum(i*i for i in range(10000000))
|
||||
|
||||
result = profile_execution(cpu_intensive_task)
|
||||
print(f"Result of profiled task: {result}")
|
||||
```
|
||||
|
||||
This API reference provides a comprehensive overview of the ClusterOps library's main functions, their parameters, return values, and usage examples. It should help users understand and utilize the library effectively for managing and executing tasks across CPU and GPU resources in a distributed computing environment.
|
@ -0,0 +1,77 @@
|
||||
# 🔗 Links & Resources
|
||||
|
||||
Welcome to the Swarms ecosystem. Click any tile below to explore our products, community, documentation, and social platforms.
|
||||
|
||||
---
|
||||
|
||||
<style>
|
||||
.resource-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
|
||||
gap: 1rem;
|
||||
margin-top: 1.5rem;
|
||||
}
|
||||
|
||||
.resource-card {
|
||||
display: block;
|
||||
padding: 1.2rem;
|
||||
border-radius: 12px;
|
||||
background: #1e1e2f;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
text-align: center;
|
||||
font-weight: 600;
|
||||
transition: transform 0.2s ease, background 0.3s ease;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.resource-card:hover {
|
||||
transform: translateY(-4px);
|
||||
background: #2a2a3d;
|
||||
}
|
||||
</style>
|
||||
|
||||
<div class="resource-grid">
|
||||
|
||||
<a class="resource-card" href="https://swarms.world/platform/chat" target="_blank">🗣️ Swarms Chat</a>
|
||||
|
||||
<a class="resource-card" href="https://swarms.world" target="_blank">🛍️ Swarms Marketplace</a>
|
||||
|
||||
<a class="resource-card" href="https://docs.swarms.world/en/latest/swarms_cloud/swarms_api/" target="_blank">📚 Swarms API Docs</a>
|
||||
|
||||
<a class="resource-card" href="https://www.swarms.xyz/programs/startups" target="_blank">🚀 Swarms Startup Program</a>
|
||||
|
||||
<a class="resource-card" href="https://github.com/kyegomez/swarms" target="_blank">💻 GitHub: Swarms (Python)</a>
|
||||
|
||||
<a class="resource-card" href="https://github.com/The-Swarm-Corporation/swarms-rs" target="_blank">🦀 GitHub: Swarms (Rust)</a>
|
||||
|
||||
<a class="resource-card" href="https://discord.gg/jM3Z6M9uMq" target="_blank">💬 Join Our Discord</a>
|
||||
|
||||
<a class="resource-card" href="https://t.me/swarmsgroupchat" target="_blank">📱 Telegram Group</a>
|
||||
|
||||
<a class="resource-card" href="https://x.com/swarms_corp" target="_blank">🐦 Twitter / X</a>
|
||||
|
||||
<a class="resource-card" href="https://medium.com/@kyeg" target="_blank">✍️ Swarms Blog on Medium</a>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 💡 Quick Summary
|
||||
|
||||
| Category | Link |
|
||||
|--------------|----------------------------------------------------------------------|
|
||||
| API Docs | [docs.swarms.world](https://docs.swarms.world/en/latest/swarms_cloud/swarms_api/) |
|
||||
| GitHub | [kyegomez/swarms](https://github.com/kyegomez/swarms) |
|
||||
| GitHub (Rust)| [The-Swarm-Corporation/swarms-rs](https://github.com/The-Swarm-Corporation/swarms-rs) |
|
||||
| Chat UI | [swarms.world/platform/chat](https://swarms.world/platform/chat) |
|
||||
| Marketplace | [swarms.world](https://swarms.world) |
|
||||
| Startup App | [Apply Here](https://www.swarms.xyz/programs/startups) |
|
||||
| Discord | [Join Now](https://discord.gg/jM3Z6M9uMq) |
|
||||
| Telegram | [Group Chat](https://t.me/swarmsgroupchat) |
|
||||
| Twitter/X | [@swarms_corp](https://x.com/swarms_corp) |
|
||||
| Blog | [medium.com/@kyeg](https://medium.com/@kyeg) |
|
||||
|
||||
---
|
||||
|
||||
> 🐝 Swarms is building the agentic internet. Join the movement and build the future with us.
|
@ -1,278 +0,0 @@
|
||||
# AsyncWorkflow Documentation
|
||||
|
||||
The `AsyncWorkflow` class represents an asynchronous workflow that executes tasks concurrently using multiple agents. It allows for efficient task management, leveraging Python's `asyncio` for concurrent execution.
|
||||
|
||||
## Key Features
|
||||
- **Concurrent Task Execution**: Distribute tasks across multiple agents asynchronously.
|
||||
- **Configurable Workers**: Limit the number of concurrent workers (agents) for better resource management.
|
||||
- **Autosave Results**: Optionally save the task execution results automatically.
|
||||
- **Verbose Logging**: Enable detailed logging to monitor task execution.
|
||||
- **Error Handling**: Gracefully handles exceptions raised by agents during task execution.
|
||||
|
||||
---
|
||||
|
||||
## Attributes
|
||||
| Attribute | Type | Description |
|
||||
|-------------------|---------------------|-----------------------------------------------------------------------------|
|
||||
| `name` | `str` | The name of the workflow. |
|
||||
| `agents` | `List[Agent]` | A list of agents participating in the workflow. |
|
||||
| `max_workers` | `int` | The maximum number of concurrent workers (default: 5). |
|
||||
| `dashboard` | `bool` | Whether to display a dashboard (currently not implemented). |
|
||||
| `autosave` | `bool` | Whether to autosave task results (default: `False`). |
|
||||
| `verbose` | `bool` | Whether to enable detailed logging (default: `False`). |
|
||||
| `task_pool` | `List` | A pool of tasks to be executed. |
|
||||
| `results` | `List` | A list to store results of executed tasks. |
|
||||
| `loop` | `asyncio.EventLoop` | The event loop for asynchronous execution. |
|
||||
|
||||
---
|
||||
|
||||
**Description**:
|
||||
Initializes the `AsyncWorkflow` with specified agents, configuration, and options.
|
||||
|
||||
**Parameters**:
|
||||
- `name` (`str`): Name of the workflow. Default: "AsyncWorkflow".
|
||||
- `agents` (`List[Agent]`): A list of agents. Default: `None`.
|
||||
- `max_workers` (`int`): The maximum number of workers. Default: `5`.
|
||||
- `dashboard` (`bool`): Enable dashboard visualization (placeholder for future implementation).
|
||||
- `autosave` (`bool`): Enable autosave of task results. Default: `False`.
|
||||
- `verbose` (`bool`): Enable detailed logging. Default: `False`.
|
||||
- `**kwargs`: Additional parameters for `BaseWorkflow`.
|
||||
|
||||
---
|
||||
|
||||
### `_execute_agent_task`
|
||||
```python
|
||||
async def _execute_agent_task(self, agent: Agent, task: str) -> Any:
|
||||
```
|
||||
**Description**:
|
||||
Executes a single task asynchronously using a given agent.
|
||||
|
||||
**Parameters**:
|
||||
- `agent` (`Agent`): The agent responsible for executing the task.
|
||||
- `task` (`str`): The task to be executed.
|
||||
|
||||
**Returns**:
|
||||
- `Any`: The result of the task execution or an error message in case of an exception.
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
result = await workflow._execute_agent_task(agent, "Sample Task")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `run`
|
||||
```python
|
||||
async def run(self, task: str) -> List[Any]:
|
||||
```
|
||||
**Description**:
|
||||
Executes the specified task concurrently across all agents.
|
||||
|
||||
**Parameters**:
|
||||
- `task` (`str`): The task to be executed by all agents.
|
||||
|
||||
**Returns**:
|
||||
- `List[Any]`: A list of results or error messages returned by the agents.
|
||||
|
||||
**Raises**:
|
||||
- `ValueError`: If no agents are provided in the workflow.
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
agents = [Agent("Agent1"), Agent("Agent2")]
|
||||
workflow = AsyncWorkflow(agents=agents, verbose=True)
|
||||
|
||||
results = asyncio.run(workflow.run("Process Data"))
|
||||
print(results)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Production-Grade Financial Example: Multiple Agents
|
||||
### Example: Stock Analysis and Investment Strategy
|
||||
```python
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
from swarms.structs.async_workflow import (
|
||||
SpeakerConfig,
|
||||
SpeakerRole,
|
||||
create_default_workflow,
|
||||
run_workflow_with_retry,
|
||||
)
|
||||
from swarms.prompts.finance_agent_sys_prompt import (
|
||||
FINANCIAL_AGENT_SYS_PROMPT,
|
||||
)
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
async def create_specialized_agents() -> List[Agent]:
|
||||
"""Create a set of specialized agents for financial analysis"""
|
||||
|
||||
# Base model configuration
|
||||
model = OpenAIChat(model_name="gpt-4o")
|
||||
|
||||
# Financial Analysis Agent
|
||||
financial_agent = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
agent_description="Personal finance advisor agent",
|
||||
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
|
||||
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
|
||||
max_loops=1,
|
||||
llm=model,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="Kye",
|
||||
retry_attempts=3,
|
||||
context_length=8192,
|
||||
return_step_meta=False,
|
||||
output_type="str",
|
||||
auto_generate_prompt=False,
|
||||
max_tokens=4000,
|
||||
stopping_token="<DONE>",
|
||||
saved_state_path="financial_agent.json",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Risk Assessment Agent
|
||||
risk_agent = Agent(
|
||||
agent_name="Risk-Assessment-Agent",
|
||||
agent_description="Investment risk analysis specialist",
|
||||
system_prompt="Analyze investment risks and provide risk scores. Output <DONE> when analysis is complete.",
|
||||
max_loops=1,
|
||||
llm=model,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="Kye",
|
||||
retry_attempts=3,
|
||||
context_length=8192,
|
||||
output_type="str",
|
||||
max_tokens=4000,
|
||||
stopping_token="<DONE>",
|
||||
saved_state_path="risk_agent.json",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Market Research Agent
|
||||
research_agent = Agent(
|
||||
agent_name="Market-Research-Agent",
|
||||
agent_description="AI and tech market research specialist",
|
||||
system_prompt="Research AI market trends and growth opportunities. Output <DONE> when research is complete.",
|
||||
max_loops=1,
|
||||
llm=model,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="Kye",
|
||||
retry_attempts=3,
|
||||
context_length=8192,
|
||||
output_type="str",
|
||||
max_tokens=4000,
|
||||
stopping_token="<DONE>",
|
||||
saved_state_path="research_agent.json",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
return [financial_agent, risk_agent, research_agent]
|
||||
|
||||
|
||||
async def main():
|
||||
# Create specialized agents
|
||||
agents = await create_specialized_agents()
|
||||
|
||||
# Create workflow with group chat enabled
|
||||
workflow = create_default_workflow(
|
||||
agents=agents,
|
||||
name="AI-Investment-Analysis-Workflow",
|
||||
enable_group_chat=True,
|
||||
)
|
||||
|
||||
# Configure speaker roles
|
||||
workflow.speaker_system.add_speaker(
|
||||
SpeakerConfig(
|
||||
role=SpeakerRole.COORDINATOR,
|
||||
agent=agents[0], # Financial agent as coordinator
|
||||
priority=1,
|
||||
concurrent=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
|
||||
workflow.speaker_system.add_speaker(
|
||||
SpeakerConfig(
|
||||
role=SpeakerRole.CRITIC,
|
||||
agent=agents[1], # Risk agent as critic
|
||||
priority=2,
|
||||
concurrent=True,
|
||||
)
|
||||
)
|
||||
|
||||
workflow.speaker_system.add_speaker(
|
||||
SpeakerConfig(
|
||||
role=SpeakerRole.EXECUTOR,
|
||||
agent=agents[2], # Research agent as executor
|
||||
priority=2,
|
||||
concurrent=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Investment analysis task
|
||||
investment_task = """
|
||||
Create a comprehensive investment analysis for a $40k portfolio focused on AI growth opportunities:
|
||||
1. Identify high-growth AI ETFs and index funds
|
||||
2. Analyze risks and potential returns
|
||||
3. Create a diversified portfolio allocation
|
||||
4. Provide market trend analysis
|
||||
Present the results in a structured markdown format.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Run workflow with retry
|
||||
result = await run_workflow_with_retry(
|
||||
workflow=workflow, task=investment_task, max_retries=3
|
||||
)
|
||||
|
||||
print("\nWorkflow Results:")
|
||||
print("================")
|
||||
|
||||
# Process and display agent outputs
|
||||
for output in result.agent_outputs:
|
||||
print(f"\nAgent: {output.agent_name}")
|
||||
print("-" * (len(output.agent_name) + 8))
|
||||
print(output.output)
|
||||
|
||||
# Display group chat history if enabled
|
||||
if workflow.enable_group_chat:
|
||||
print("\nGroup Chat Discussion:")
|
||||
print("=====================")
|
||||
for msg in workflow.speaker_system.message_history:
|
||||
print(f"\n{msg.role} ({msg.agent_name}):")
|
||||
print(msg.content)
|
||||
|
||||
# Save detailed results
|
||||
if result.metadata.get("shared_memory_keys"):
|
||||
print("\nShared Insights:")
|
||||
print("===============")
|
||||
for key in result.metadata["shared_memory_keys"]:
|
||||
value = workflow.shared_memory.get(key)
|
||||
if value:
|
||||
print(f"\n{key}:")
|
||||
print(value)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Workflow failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
await workflow.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the example
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
```
|
||||
|
||||
|
||||
---
|
@ -0,0 +1,186 @@
|
||||
# Deep Research Swarm
|
||||
|
||||
!!! abstract "Overview"
|
||||
The Deep Research Swarm is a powerful, production-grade research system that conducts comprehensive analysis across multiple domains using parallel processing and advanced AI agents.
|
||||
|
||||
Key Features:
|
||||
|
||||
- Parallel search processing
|
||||
|
||||
- Multi-agent research coordination
|
||||
|
||||
- Advanced information synthesis
|
||||
|
||||
- Automated query generation
|
||||
|
||||
- Concurrent task execution
|
||||
|
||||
## Getting Started
|
||||
|
||||
!!! tip "Quick Installation"
|
||||
```bash
|
||||
pip install swarms
|
||||
```
|
||||
|
||||
=== "Basic Usage"
|
||||
```python
|
||||
from swarms.structs import DeepResearchSwarm
|
||||
|
||||
# Initialize the swarm
|
||||
swarm = DeepResearchSwarm(
|
||||
name="MyResearchSwarm",
|
||||
output_type="json",
|
||||
max_loops=1
|
||||
)
|
||||
|
||||
# Run a single research task
|
||||
results = swarm.run("What are the latest developments in quantum computing?")
|
||||
```
|
||||
|
||||
=== "Batch Processing"
|
||||
```python
|
||||
# Run multiple research tasks in parallel
|
||||
tasks = [
|
||||
"What are the environmental impacts of electric vehicles?",
|
||||
"How is AI being used in drug discovery?",
|
||||
]
|
||||
batch_results = swarm.batched_run(tasks)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
!!! info "Constructor Arguments"
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `name` | str | "DeepResearchSwarm" | Name identifier for the swarm |
|
||||
| `description` | str | "A swarm that conducts..." | Description of the swarm's purpose |
|
||||
| `research_agent` | Agent | research_agent | Custom research agent instance |
|
||||
| `max_loops` | int | 1 | Maximum number of research iterations |
|
||||
| `nice_print` | bool | True | Enable formatted console output |
|
||||
| `output_type` | str | "json" | Output format ("json" or "string") |
|
||||
| `max_workers` | int | CPU_COUNT * 2 | Maximum concurrent threads |
|
||||
| `token_count` | bool | False | Enable token counting |
|
||||
| `research_model_name` | str | "gpt-4o-mini" | Model to use for research |
|
||||
|
||||
## Core Methods
|
||||
|
||||
### Run
|
||||
!!! example "Single Task Execution"
|
||||
```python
|
||||
results = swarm.run("What are the latest breakthroughs in fusion energy?")
|
||||
```
|
||||
|
||||
### Batched Run
|
||||
!!! example "Parallel Task Execution"
|
||||
```python
|
||||
tasks = [
|
||||
"What are current AI safety initiatives?",
|
||||
"How is CRISPR being used in agriculture?",
|
||||
]
|
||||
results = swarm.batched_run(tasks)
|
||||
```
|
||||
|
||||
### Step
|
||||
!!! example "Single Step Execution"
|
||||
```python
|
||||
results = swarm.step("Analyze recent developments in renewable energy storage")
|
||||
```
|
||||
|
||||
## Domain-Specific Examples
|
||||
|
||||
=== "Scientific Research"
|
||||
```python
|
||||
science_swarm = DeepResearchSwarm(
|
||||
name="ScienceSwarm",
|
||||
output_type="json",
|
||||
max_loops=2 # More iterations for thorough research
|
||||
)
|
||||
|
||||
results = science_swarm.run(
|
||||
"What are the latest experimental results in quantum entanglement?"
|
||||
)
|
||||
```
|
||||
|
||||
=== "Market Research"
|
||||
```python
|
||||
market_swarm = DeepResearchSwarm(
|
||||
name="MarketSwarm",
|
||||
output_type="json"
|
||||
)
|
||||
|
||||
results = market_swarm.run(
|
||||
"What are the emerging trends in electric vehicle battery technology market?"
|
||||
)
|
||||
```
|
||||
|
||||
=== "News Analysis"
|
||||
```python
|
||||
news_swarm = DeepResearchSwarm(
|
||||
name="NewsSwarm",
|
||||
output_type="string" # Human-readable output
|
||||
)
|
||||
|
||||
results = news_swarm.run(
|
||||
"What are the global economic impacts of recent geopolitical events?"
|
||||
)
|
||||
```
|
||||
|
||||
=== "Medical Research"
|
||||
```python
|
||||
medical_swarm = DeepResearchSwarm(
|
||||
name="MedicalSwarm",
|
||||
max_loops=2
|
||||
)
|
||||
|
||||
results = medical_swarm.run(
|
||||
"What are the latest clinical trials for Alzheimer's treatment?"
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
??? note "Custom Research Agent"
|
||||
```python
|
||||
from swarms import Agent
|
||||
|
||||
custom_agent = Agent(
|
||||
agent_name="SpecializedResearcher",
|
||||
system_prompt="Your specialized prompt here",
|
||||
model_name="gpt-4"
|
||||
)
|
||||
|
||||
swarm = DeepResearchSwarm(
|
||||
research_agent=custom_agent,
|
||||
max_loops=2
|
||||
)
|
||||
```
|
||||
|
||||
??? note "Parallel Processing Control"
|
||||
```python
|
||||
swarm = DeepResearchSwarm(
|
||||
max_workers=8, # Limit to 8 concurrent threads
|
||||
nice_print=False # Disable console output for production
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
!!! success "Recommended Practices"
|
||||
1. **Query Formulation**: Be specific and clear in your research queries
|
||||
2. **Resource Management**: Adjust `max_workers` based on your system's capabilities
|
||||
3. **Output Handling**: Use appropriate `output_type` for your use case
|
||||
4. **Error Handling**: Implement try-catch blocks around swarm operations
|
||||
5. **Model Selection**: Choose appropriate models based on research complexity
|
||||
|
||||
## Limitations
|
||||
|
||||
!!! warning "Known Limitations"
|
||||
|
||||
- Requires valid API keys for external services
|
||||
|
||||
- Performance depends on system resources
|
||||
|
||||
- Rate limits may apply to external API calls
|
||||
|
||||
- Token limits apply to model responses
|
||||
|
@ -0,0 +1,342 @@
|
||||
# Swarms API as MCP
|
||||
|
||||
- Launch MCP server as a tool
|
||||
- Put `SWARMS_API_KEY` in `.env`
|
||||
- Client side code below
|
||||
|
||||
|
||||
## Server Side
|
||||
|
||||
```python
|
||||
# server.py
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import BaseModel, Field
|
||||
from swarms import SwarmType
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class AgentSpec(BaseModel):
|
||||
agent_name: Optional[str] = Field(
|
||||
description="The unique name assigned to the agent, which identifies its role and functionality within the swarm.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
description="A detailed explanation of the agent's purpose, capabilities, and any specific tasks it is designed to perform.",
|
||||
)
|
||||
system_prompt: Optional[str] = Field(
|
||||
description="The initial instruction or context provided to the agent, guiding its behavior and responses during execution.",
|
||||
)
|
||||
model_name: Optional[str] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="The name of the AI model that the agent will utilize for processing tasks and generating outputs. For example: gpt-4o, gpt-4o-mini, openai/o3-mini",
|
||||
)
|
||||
auto_generate_prompt: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="A flag indicating whether the agent should automatically create prompts based on the task requirements.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=8192,
|
||||
description="The maximum number of tokens that the agent is allowed to generate in its responses, limiting output length.",
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
default=0.5,
|
||||
description="A parameter that controls the randomness of the agent's output; lower values result in more deterministic responses.",
|
||||
)
|
||||
role: Optional[str] = Field(
|
||||
default="worker",
|
||||
description="The designated role of the agent within the swarm, which influences its behavior and interaction with other agents.",
|
||||
)
|
||||
max_loops: Optional[int] = Field(
|
||||
default=1,
|
||||
description="The maximum number of times the agent is allowed to repeat its task, enabling iterative processing if necessary.",
|
||||
)
|
||||
# New fields for RAG functionality
|
||||
rag_collection: Optional[str] = Field(
|
||||
None,
|
||||
description="The Qdrant collection name for RAG functionality. If provided, this agent will perform RAG queries.",
|
||||
)
|
||||
rag_documents: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Documents to ingest into the Qdrant collection for RAG. (List of text strings)",
|
||||
)
|
||||
tools: Optional[List[Dict[str, Any]]] = Field(
|
||||
None,
|
||||
description="A dictionary of tools that the agent can use to complete its task.",
|
||||
)
|
||||
|
||||
|
||||
class AgentCompletion(BaseModel):
|
||||
"""
|
||||
Configuration for a single agent that works together as a swarm to accomplish tasks.
|
||||
"""
|
||||
|
||||
agent: AgentSpec = Field(
|
||||
...,
|
||||
description="The agent to run.",
|
||||
)
|
||||
task: Optional[str] = Field(
|
||||
...,
|
||||
description="The task to run.",
|
||||
)
|
||||
img: Optional[str] = Field(
|
||||
None,
|
||||
description="An optional image URL that may be associated with the swarm's task or representation.",
|
||||
)
|
||||
output_type: Optional[str] = Field(
|
||||
"list",
|
||||
description="The type of output to return.",
|
||||
)
|
||||
|
||||
|
||||
class AgentCompletionResponse(BaseModel):
|
||||
"""
|
||||
Response from an agent completion.
|
||||
"""
|
||||
|
||||
agent_id: str = Field(
|
||||
...,
|
||||
description="The unique identifier for the agent that completed the task.",
|
||||
)
|
||||
agent_name: str = Field(
|
||||
...,
|
||||
description="The name of the agent that completed the task.",
|
||||
)
|
||||
agent_description: str = Field(
|
||||
...,
|
||||
description="The description of the agent that completed the task.",
|
||||
)
|
||||
messages: Any = Field(
|
||||
...,
|
||||
description="The messages from the agent completion.",
|
||||
)
|
||||
|
||||
cost: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="The cost of the agent completion.",
|
||||
)
|
||||
|
||||
|
||||
class Agents(BaseModel):
|
||||
"""Configuration for a collection of agents that work together as a swarm to accomplish tasks."""
|
||||
|
||||
agents: List[AgentSpec] = Field(
|
||||
description="A list containing the specifications of each agent that will participate in the swarm, detailing their roles and functionalities."
|
||||
)
|
||||
|
||||
|
||||
class ScheduleSpec(BaseModel):
|
||||
scheduled_time: datetime = Field(
|
||||
...,
|
||||
description="The exact date and time (in UTC) when the swarm is scheduled to execute its tasks.",
|
||||
)
|
||||
timezone: Optional[str] = Field(
|
||||
"UTC",
|
||||
description="The timezone in which the scheduled time is defined, allowing for proper scheduling across different regions.",
|
||||
)
|
||||
|
||||
|
||||
class SwarmSpec(BaseModel):
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="The name of the swarm, which serves as an identifier for the group of agents and their collective task.",
|
||||
max_length=100,
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="A comprehensive description of the swarm's objectives, capabilities, and intended outcomes.",
|
||||
)
|
||||
agents: Optional[List[AgentSpec]] = Field(
|
||||
None,
|
||||
description="A list of agents or specifications that define the agents participating in the swarm.",
|
||||
)
|
||||
max_loops: Optional[int] = Field(
|
||||
default=1,
|
||||
description="The maximum number of execution loops allowed for the swarm, enabling repeated processing if needed.",
|
||||
)
|
||||
swarm_type: Optional[SwarmType] = Field(
|
||||
None,
|
||||
description="The classification of the swarm, indicating its operational style and methodology.",
|
||||
)
|
||||
rearrange_flow: Optional[str] = Field(
|
||||
None,
|
||||
description="Instructions on how to rearrange the flow of tasks among agents, if applicable.",
|
||||
)
|
||||
task: Optional[str] = Field(
|
||||
None,
|
||||
description="The specific task or objective that the swarm is designed to accomplish.",
|
||||
)
|
||||
img: Optional[str] = Field(
|
||||
None,
|
||||
description="An optional image URL that may be associated with the swarm's task or representation.",
|
||||
)
|
||||
return_history: Optional[bool] = Field(
|
||||
True,
|
||||
description="A flag indicating whether the swarm should return its execution history along with the final output.",
|
||||
)
|
||||
rules: Optional[str] = Field(
|
||||
None,
|
||||
description="Guidelines or constraints that govern the behavior and interactions of the agents within the swarm.",
|
||||
)
|
||||
schedule: Optional[ScheduleSpec] = Field(
|
||||
None,
|
||||
description="Details regarding the scheduling of the swarm's execution, including timing and timezone information.",
|
||||
)
|
||||
tasks: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="A list of tasks that the swarm should complete.",
|
||||
)
|
||||
messages: Optional[List[Dict[str, Any]]] = Field(
|
||||
None,
|
||||
description="A list of messages that the swarm should complete.",
|
||||
)
|
||||
# rag_on: Optional[bool] = Field(
|
||||
# None,
|
||||
# description="A flag indicating whether the swarm should use RAG.",
|
||||
# )
|
||||
# collection_name: Optional[str] = Field(
|
||||
# None,
|
||||
# description="The name of the collection to use for RAG.",
|
||||
# )
|
||||
stream: Optional[bool] = Field(
|
||||
False,
|
||||
description="A flag indicating whether the swarm should stream its output.",
|
||||
)
|
||||
|
||||
|
||||
class SwarmCompletionResponse(BaseModel):
|
||||
"""
|
||||
Response from a swarm completion.
|
||||
"""
|
||||
|
||||
status: str = Field(..., description="The status of the swarm completion.")
|
||||
swarm_name: str = Field(..., description="The name of the swarm.")
|
||||
description: str = Field(..., description="Description of the swarm.")
|
||||
swarm_type: str = Field(..., description="The type of the swarm.")
|
||||
task: str = Field(
|
||||
..., description="The task that the swarm is designed to accomplish."
|
||||
)
|
||||
output: List[Dict[str, Any]] = Field(
|
||||
..., description="The output generated by the swarm."
|
||||
)
|
||||
number_of_agents: int = Field(
|
||||
..., description="The number of agents involved in the swarm."
|
||||
)
|
||||
# "input_config": Optional[Dict[str, Any]] = Field(None, description="The input configuration for the swarm.")
|
||||
|
||||
|
||||
BASE_URL = "https://swarms-api-285321057562.us-east1.run.app"
|
||||
|
||||
|
||||
# Create an MCP server
|
||||
mcp = FastMCP("swarms-api")
|
||||
|
||||
|
||||
# Add an addition tool
|
||||
@mcp.tool(name="swarm_completion", description="Run a swarm completion.")
|
||||
def swarm_completion(swarm: SwarmSpec) -> Dict[str, Any]:
|
||||
api_key = os.getenv("SWARMS_API_KEY")
|
||||
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
|
||||
|
||||
payload = swarm.model_dump()
|
||||
|
||||
response = requests.post(f"{BASE_URL}/v1/swarm/completions", json=payload, headers=headers)
|
||||
|
||||
return response.json()
|
||||
|
||||
@mcp.tool(name="swarms_available", description="Get the list of available swarms.")
|
||||
async def swarms_available() -> Any:
|
||||
"""
|
||||
Get the list of available swarms.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BASE_URL}/v1/models/available", headers=headers)
|
||||
response.raise_for_status() # Raise an error for bad responses
|
||||
return response.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport="sse")
|
||||
```
|
||||
|
||||
## Client side
|
||||
|
||||
- Call the tool with it's name and the payload config
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from fastmcp import Client
|
||||
|
||||
swarm_config = {
|
||||
"name": "Simple Financial Analysis",
|
||||
"description": "A swarm to analyze financial data",
|
||||
"agents": [
|
||||
{
|
||||
"agent_name": "Data Analyzer",
|
||||
"description": "Looks at financial data",
|
||||
"system_prompt": "Analyze the data.",
|
||||
"model_name": "gpt-4o",
|
||||
"role": "worker",
|
||||
"max_loops": 1,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.5,
|
||||
"auto_generate_prompt": False,
|
||||
},
|
||||
{
|
||||
"agent_name": "Risk Analyst",
|
||||
"description": "Checks risk levels",
|
||||
"system_prompt": "Evaluate the risks.",
|
||||
"model_name": "gpt-4o",
|
||||
"role": "worker",
|
||||
"max_loops": 1,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.5,
|
||||
"auto_generate_prompt": False,
|
||||
},
|
||||
{
|
||||
"agent_name": "Strategy Checker",
|
||||
"description": "Validates strategies",
|
||||
"system_prompt": "Review the strategy.",
|
||||
"model_name": "gpt-4o",
|
||||
"role": "worker",
|
||||
"max_loops": 1,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.5,
|
||||
"auto_generate_prompt": False,
|
||||
},
|
||||
],
|
||||
"max_loops": 1,
|
||||
"swarm_type": "SequentialWorkflow",
|
||||
"task": "Analyze the financial data and provide insights.",
|
||||
"return_history": False, # Added required field
|
||||
"stream": False, # Added required field
|
||||
"rules": None, # Added optional field
|
||||
"img": None, # Added optional field
|
||||
}
|
||||
|
||||
|
||||
async def swarm_completion():
|
||||
"""Connect to a server over SSE and fetch available swarms."""
|
||||
|
||||
async with Client(
|
||||
transport="http://localhost:8000/sse"
|
||||
) as client:
|
||||
# Basic connectivity testing
|
||||
# print("Ping check:", await client.ping())
|
||||
# print("Available tools:", await client.list_tools())
|
||||
# print("Swarms available:", await client.call_tool("swarms_available", None))
|
||||
result = await client.call_tool("swarm_completion", {"swarm": swarm_config})
|
||||
print("Swarm completion:", result)
|
||||
|
||||
|
||||
# Execute the function
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(swarm_completion())
|
||||
```
|
@ -0,0 +1,68 @@
|
||||
# 🔐 Swarms x Phala Deployment Guide
|
||||
|
||||
This guide will walk you through deploying your project to Phala's Trusted Execution Environment (TEE).
|
||||
|
||||
## 📋 Prerequisites
|
||||
|
||||
- Docker installed on your system
|
||||
- A DockerHub account
|
||||
- Access to Phala Cloud dashboard
|
||||
|
||||
## 🛡️ TEE Overview
|
||||
|
||||
For detailed instructions about Trusted Execution Environment setup, please refer to our [TEE Documentation](./tee/README.md).
|
||||
|
||||
## 🚀 Deployment Steps
|
||||
|
||||
### 1. Build and Publish Docker Image
|
||||
|
||||
```bash
|
||||
# Build the Docker image
|
||||
docker compose build -t <your-dockerhub-username>/swarm-agent-node:latest
|
||||
|
||||
# Push to DockerHub
|
||||
docker push <your-dockerhub-username>/swarm-agent-node:latest
|
||||
```
|
||||
|
||||
### 2. Deploy to Phala Cloud
|
||||
|
||||
Choose one of these deployment methods:
|
||||
- Use [tee-cloud-cli](https://github.com/Phala-Network/tee-cloud-cli) (Recommended)
|
||||
- Deploy manually via the [Phala Cloud Dashboard](https://cloud.phala.network/)
|
||||
|
||||
### 3. Verify TEE Attestation
|
||||
|
||||
Visit the [TEE Attestation Explorer](https://proof.t16z.com/) to check and verify your agent's TEE proof.
|
||||
|
||||
## 📝 Docker Configuration
|
||||
|
||||
Below is a sample Docker Compose configuration for your Swarms agent:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
swarms-agent-server:
|
||||
image: swarms-agent-node:latest
|
||||
platform: linux/amd64
|
||||
volumes:
|
||||
- /var/run/tappd.sock:/var/run/tappd.sock
|
||||
- swarms:/app
|
||||
restart: always
|
||||
ports:
|
||||
- 8000:8000
|
||||
command: # Sample MCP Server
|
||||
- /bin/sh
|
||||
- -c
|
||||
- |
|
||||
cd /app/mcp_example
|
||||
python mcp_test.py
|
||||
volumes:
|
||||
swarms:
|
||||
```
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
For more comprehensive documentation and examples, visit our [Official Documentation](https://docs.swarms.world/en/latest/).
|
||||
|
||||
---
|
||||
|
||||
> **Note**: Make sure to replace `<your-dockerhub-username>` with your actual DockerHub username when building and pushing the image.
|
@ -0,0 +1,353 @@
|
||||
# swarms-rs
|
||||
|
||||
!!! note "Modern AI Agent Framework"
|
||||
swarms-rs is a powerful Rust framework for building autonomous AI agents powered by LLMs, equipped with robust tools and memory capabilities. Designed for various applications from trading analysis to healthcare diagnostics.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
cargo add swarms-rs
|
||||
```
|
||||
|
||||
!!! tip "Compatible with Rust 1.70+"
|
||||
This library requires Rust 1.70 or later. Make sure your Rust toolchain is up to date.
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
```bash
|
||||
# Required API keys
|
||||
OPENAI_API_KEY="your_openai_api_key_here"
|
||||
DEEPSEEK_API_KEY="your_deepseek_api_key_here"
|
||||
```
|
||||
|
||||
### Quick Start
|
||||
|
||||
Here's a simple example to get you started with swarms-rs:
|
||||
|
||||
```rust
|
||||
use std::env;
|
||||
use anyhow::Result;
|
||||
use swarms_rs::{llm::provider::openai::OpenAI, structs::agent::Agent};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Load environment variables from .env file
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
// Initialize tracing for better debugging
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_line_number(true)
|
||||
.with_file(true),
|
||||
)
|
||||
.init();
|
||||
|
||||
// Set up your LLM client
|
||||
let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
|
||||
let client = OpenAI::new(api_key).set_model("gpt-4-turbo");
|
||||
|
||||
// Create a basic agent
|
||||
let agent = client
|
||||
.agent_builder()
|
||||
.system_prompt("You are a helpful assistant.")
|
||||
.agent_name("BasicAgent")
|
||||
.user_name("User")
|
||||
.build();
|
||||
|
||||
// Run the agent with a user query
|
||||
let response = agent
|
||||
.run("Tell me about Rust programming.".to_owned())
|
||||
.await?;
|
||||
|
||||
println!("{}", response);
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Agents
|
||||
|
||||
Agents in swarms-rs are autonomous entities that can:
|
||||
|
||||
- Perform complex reasoning based on LLM capabilities
|
||||
- Use tools to interact with external systems
|
||||
- Maintain persistent memory
|
||||
- Execute multi-step plans
|
||||
|
||||
## Agent Configuration
|
||||
|
||||
### Core Parameters
|
||||
|
||||
| Parameter | Description | Default | Required |
|
||||
|-----------|-------------|---------|----------|
|
||||
| `system_prompt` | Initial instructions/role for the agent | - | Yes |
|
||||
| `agent_name` | Name identifier for the agent | - | Yes |
|
||||
| `user_name` | Name for the user interacting with agent | - | Yes |
|
||||
| `max_loops` | Maximum number of reasoning loops | 1 | No |
|
||||
| `retry_attempts` | Number of retry attempts on failure | 1 | No |
|
||||
| `enable_autosave` | Enable state persistence | false | No |
|
||||
| `save_state_dir` | Directory for saving agent state | None | No |
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
You can enhance your agent's capabilities with:
|
||||
|
||||
- **Planning**: Enable structured planning for complex tasks
|
||||
- **Memory**: Persistent storage for agent state
|
||||
- **Tools**: External capabilities through MCP protocol
|
||||
|
||||
!!! warning "Resource Usage"
|
||||
Setting high values for `max_loops` can increase API usage and costs. Start with lower values and adjust as needed.
|
||||
|
||||
## Examples
|
||||
|
||||
### Specialized Agent for Cryptocurrency Analysis
|
||||
|
||||
```rust
|
||||
use std::env;
|
||||
use anyhow::Result;
|
||||
use swarms_rs::{llm::provider::openai::OpenAI, structs::agent::Agent};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv::dotenv().ok();
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_line_number(true)
|
||||
.with_file(true),
|
||||
)
|
||||
.init();
|
||||
|
||||
let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
|
||||
let client = OpenAI::new(api_key).set_model("gpt-4-turbo");
|
||||
|
||||
let agent = client
|
||||
.agent_builder()
|
||||
.system_prompt(
|
||||
"You are a sophisticated cryptocurrency analysis assistant specialized in:
|
||||
1. Technical analysis of crypto markets
|
||||
2. Fundamental analysis of blockchain projects
|
||||
3. Market sentiment analysis
|
||||
4. Risk assessment
|
||||
5. Trading patterns recognition
|
||||
|
||||
When analyzing cryptocurrencies, always consider:
|
||||
- Market capitalization and volume
|
||||
- Historical price trends
|
||||
- Project fundamentals and technology
|
||||
- Recent news and developments
|
||||
- Market sentiment indicators
|
||||
- Potential risks and opportunities
|
||||
|
||||
Provide clear, data-driven insights and always include relevant disclaimers about market volatility."
|
||||
)
|
||||
.agent_name("CryptoAnalyst")
|
||||
.user_name("Trader")
|
||||
.enable_autosave()
|
||||
.max_loops(3) // Increased for more thorough analysis
|
||||
.save_state_dir("./crypto_analysis/")
|
||||
.enable_plan("Break down the crypto analysis into systematic steps:
|
||||
1. Gather market data
|
||||
2. Analyze technical indicators
|
||||
3. Review fundamental factors
|
||||
4. Assess market sentiment
|
||||
5. Provide comprehensive insights".to_owned())
|
||||
.build();
|
||||
|
||||
let response = agent
|
||||
.run("What are your thoughts on Bitcoin's current market position?".to_owned())
|
||||
.await?;
|
||||
|
||||
println!("{}", response);
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Using Tools with MCP
|
||||
|
||||
### Model Context Protocol (MCP)
|
||||
|
||||
swarms-rs supports the Model Context Protocol (MCP), enabling agents to interact with external tools through standardized interfaces.
|
||||
|
||||
!!! info "What is MCP?"
|
||||
MCP (Model Context Protocol) provides a standardized way for LLMs to interact with external tools, giving your agents access to real-world data and capabilities beyond language processing.
|
||||
|
||||
### Supported MCP Server Types
|
||||
|
||||
- **STDIO MCP Servers**: Connect to command-line tools implementing the MCP protocol
|
||||
- **SSE MCP Servers**: Connect to web-based MCP servers using Server-Sent Events
|
||||
|
||||
### Tool Integration
|
||||
|
||||
Add tools to your agent during configuration:
|
||||
|
||||
```rust
|
||||
let agent = client
|
||||
.agent_builder()
|
||||
.system_prompt("You are a helpful assistant with access to tools.")
|
||||
.agent_name("ToolAgent")
|
||||
.user_name("User")
|
||||
// Add STDIO MCP server
|
||||
.add_stdio_mcp_server("uvx", ["mcp-hn"])
|
||||
.await
|
||||
// Add SSE MCP server
|
||||
.add_sse_mcp_server("file-browser", "http://127.0.0.1:8000/sse")
|
||||
.await
|
||||
.build();
|
||||
```
|
||||
|
||||
### Full MCP Agent Example
|
||||
|
||||
```rust
|
||||
use std::env;
|
||||
use anyhow::Result;
|
||||
use swarms_rs::{llm::provider::openai::OpenAI, structs::agent::Agent};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv::dotenv().ok();
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_line_number(true)
|
||||
.with_file(true),
|
||||
)
|
||||
.init();
|
||||
|
||||
let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
|
||||
let client = OpenAI::new(api_key).set_model("gpt-4-turbo");
|
||||
|
||||
let agent = client
|
||||
.agent_builder()
|
||||
.system_prompt("You are a helpful assistant with access to news and file system tools.")
|
||||
.agent_name("SwarmsAgent")
|
||||
.user_name("User")
|
||||
// Add Hacker News tool
|
||||
.add_stdio_mcp_server("uvx", ["mcp-hn"])
|
||||
.await
|
||||
// Add filesystem tool
|
||||
// To set up: uvx mcp-proxy --sse-port=8000 -- npx -y @modelcontextprotocol/server-filesystem ~
|
||||
.add_sse_mcp_server("file-browser", "http://127.0.0.1:8000/sse")
|
||||
.await
|
||||
.retry_attempts(2)
|
||||
.max_loops(3)
|
||||
.build();
|
||||
|
||||
// Use the news tool
|
||||
let news_response = agent
|
||||
.run("Get the top 3 stories of today from Hacker News".to_owned())
|
||||
.await?;
|
||||
println!("NEWS RESPONSE:\n{}", news_response);
|
||||
|
||||
// Use the filesystem tool
|
||||
let fs_response = agent.run("List files in my home directory".to_owned()).await?;
|
||||
println!("FILESYSTEM RESPONSE:\n{}", fs_response);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Setting Up MCP Tools
|
||||
|
||||
### Installing MCP Servers
|
||||
|
||||
To use MCP servers with swarms-rs, you'll need to install the appropriate tools:
|
||||
|
||||
1. **uv Package Manager**:
|
||||
```bash
|
||||
curl -sSf https://raw.githubusercontent.com/astral-sh/uv/main/install.sh | sh
|
||||
```
|
||||
|
||||
2. **MCP-HN** (Hacker News MCP server):
|
||||
```bash
|
||||
uvx install mcp-hn
|
||||
```
|
||||
|
||||
3. **Setting up an SSE MCP server**:
|
||||
```bash
|
||||
# Start file system MCP server over SSE
|
||||
uvx mcp-proxy --sse-port=8000 -- npx -y @modelcontextprotocol/server-filesystem ~
|
||||
```
|
||||
|
||||
## FAQ
|
||||
|
||||
### General Questions
|
||||
|
||||
??? question "What LLM providers are supported?"
|
||||
swarms-rs currently supports:
|
||||
|
||||
- OpenAI (GPT models)
|
||||
|
||||
- DeepSeek AI
|
||||
|
||||
- More providers coming soon
|
||||
|
||||
??? question "How does state persistence work?"
|
||||
When `enable_autosave` is set to `true`, the agent will save its state to the directory specified in `save_state_dir`. This includes conversation history and tool states, allowing the agent to resume from where it left off.
|
||||
|
||||
??? question "What is the difference between `max_loops` and `retry_attempts`?"
|
||||
- `max_loops`: Controls how many reasoning steps the agent can take for a single query
|
||||
|
||||
- `retry_attempts`: Specifies how many times the agent will retry if an error occurs
|
||||
|
||||
### MCP Tools
|
||||
|
||||
??? question "How do I create my own MCP server?"
|
||||
You can create your own MCP server by implementing the MCP protocol. Check out the [MCP documentation](https://github.com/modelcontextprotocol/spec) for details on the protocol specification.
|
||||
|
||||
??? question "Can I use tools without MCP?"
|
||||
Currently, swarms-rs is designed to use the MCP protocol for tool integration. This provides a standardized way for agents to interact with external systems.
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
Optimize your agent's performance by:
|
||||
|
||||
1. **Crafting Effective System Prompts**:
|
||||
- Be specific about the agent's role and capabilities
|
||||
|
||||
- Include clear instructions on how to use available tools
|
||||
|
||||
- Define success criteria for the agent's responses
|
||||
|
||||
2. **Tuning Loop Parameters**:
|
||||
|
||||
- Start with lower values for `max_loops` and increase as needed
|
||||
|
||||
- Consider the complexity of tasks when setting loop limits
|
||||
|
||||
3. **Strategic Tool Integration**:
|
||||
|
||||
- Only integrate tools that are necessary for the agent's tasks
|
||||
|
||||
- Provide clear documentation in the system prompt about when to use each tool
|
||||
|
||||
### Security Considerations
|
||||
|
||||
!!! danger "Security Notice"
|
||||
When using file system tools or other system-level access, always be careful about permissions. Limit the scope of what your agent can access, especially in production environments.
|
||||
|
||||
## Coming Soon
|
||||
|
||||
- Memory plugins for different storage backends
|
||||
|
||||
- Additional LLM providers
|
||||
|
||||
- Group agent coordination
|
||||
|
||||
- Function calling
|
||||
|
||||
- Custom tool development framework
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to swarms-rs are welcome! Check out our [GitHub repository](https://github.com/swarms-rs) for more information.
|
@ -0,0 +1,55 @@
|
||||
# swarms-rs 🚀
|
||||
|
||||
<div class="badges" align="center">
|
||||
<img src="https://img.shields.io/github/workflow/status/The-Swarm-Corporation/swarms-rs/CI" alt="Build Status">
|
||||
<img src="https://img.shields.io/crates/v/swarm-rs" alt="Version">
|
||||
<img src="https://img.shields.io/crates/l/swarm-rs" alt="License">
|
||||
</div>
|
||||
|
||||
## 📖 Overview
|
||||
|
||||
**swarms-rs** is an enterprise-grade, production-ready multi-agent orchestration framework built in Rust, designed to handle the most demanding tasks with unparalleled speed and efficiency. By leveraging Rust's bleeding-edge performance and safety features, swarms-rs provides a powerful and scalable solution for orchestrating complex multi-agent systems across various industries.
|
||||
|
||||
## ✨ Key Benefits
|
||||
|
||||
### ⚡ Extreme Performance
|
||||
|
||||
<div class="grid cards" markdown>
|
||||
|
||||
- **Multi-Threaded Architecture**
|
||||
- Utilize the full potential of modern multi-core processors
|
||||
|
||||
- Zero-cost abstractions and fearless concurrency
|
||||
|
||||
- Minimal overhead with maximum throughput
|
||||
|
||||
- Optimal resource utilization
|
||||
|
||||
- **Bleeding-Edge Speed**
|
||||
|
||||
- Near-zero latency execution
|
||||
|
||||
- Lightning-fast performance
|
||||
|
||||
- Ideal for high-frequency applications
|
||||
|
||||
- Perfect for real-time systems
|
||||
</div>
|
||||
|
||||
## 🔗 Quick Links
|
||||
|
||||
<div class="grid cards" markdown>
|
||||
|
||||
- [:fontawesome-brands-github: GitHub](https://github.com/The-Swarm-Corporation/swarms-rs)
|
||||
- Browse the source code
|
||||
- Contribute to the project
|
||||
- Report issues
|
||||
|
||||
- [:package: Crates.io](https://crates.io/crates/swarm-rs)
|
||||
- Download the latest version
|
||||
- View package statistics
|
||||
|
||||
- [:book: Documentation](https://docs.rs/swarm-rs/0.1.4/swarm_rs/)
|
||||
- Read the API documentation
|
||||
- Learn how to use swarms-rs
|
||||
</div>
|
@ -0,0 +1,13 @@
|
||||
from swarms.structs.auto_swarm_builder import AutoSwarmBuilder
|
||||
|
||||
example = AutoSwarmBuilder(
|
||||
name="ContentCreation-Swarm",
|
||||
description="A swarm of specialized AI agents for research, writing, editing, and publishing that maintain brand consistency across channels while automating distribution.",
|
||||
max_loops=1,
|
||||
)
|
||||
|
||||
print(
|
||||
example.run(
|
||||
"Build agents for research, writing, editing, and publishing to enhance brand consistency and automate distribution across channels."
|
||||
)
|
||||
)
|
@ -0,0 +1,16 @@
|
||||
from swarms.utils.litellm_wrapper import LiteLLM
|
||||
|
||||
llm = LiteLLM(
|
||||
model_name="gpt-4o-mini",
|
||||
temperature=0.5,
|
||||
max_tokens=1000,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
out = llm.run("What is the capital of France?")
|
||||
|
||||
print(out)
|
||||
for chunk in out:
|
||||
out = chunk["choices"][0]["delta"]
|
||||
print(type(out))
|
||||
print(out)
|
@ -0,0 +1,58 @@
|
||||
from swarms import Agent
|
||||
from swarms.prompts.finance_agent_sys_prompt import (
|
||||
FINANCIAL_AGENT_SYS_PROMPT,
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_stock_price",
|
||||
"description": "Retrieve the current stock price and related information for a specified company.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"type": "string",
|
||||
"description": "The stock ticker symbol of the company, e.g. AAPL for Apple Inc.",
|
||||
},
|
||||
"include_history": {
|
||||
"type": "boolean",
|
||||
"description": "Indicates whether to include historical price data along with the current price.",
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Optional parameter to specify the time for which the stock data is requested, in ISO 8601 format.",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"ticker",
|
||||
"include_history",
|
||||
"time",
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Initialize the agent
|
||||
agent = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
agent_description="Personal finance advisor agent",
|
||||
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
max_loops=1,
|
||||
# tools_list_dictionary=tools,
|
||||
# mcp_servers=["http://localhost:8000/sse"],
|
||||
# output_type="dict-all-except-first",
|
||||
# "dict-final",
|
||||
# "dict-all-except-first",
|
||||
# "str-all-except-first",
|
||||
)
|
||||
|
||||
print(
|
||||
agent.run(
|
||||
"What is the current stock price for Apple Inc. (AAPL)? Include historical price data.",
|
||||
)
|
||||
)
|
@ -0,0 +1,72 @@
|
||||
from swarms.prompts.paper_idea_agent import (
|
||||
PAPER_IDEA_AGENT_SYSTEM_PROMPT,
|
||||
)
|
||||
from swarms import Agent
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_paper_idea",
|
||||
"description": "Generate a structured academic paper idea with all required components.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Concise identifier for the paper idea",
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Academic paper title",
|
||||
},
|
||||
"short_hypothesis": {
|
||||
"type": "string",
|
||||
"description": "Core hypothesis in 1-2 sentences",
|
||||
},
|
||||
"related_work": {
|
||||
"type": "string",
|
||||
"description": "Key papers and how this differs from existing work",
|
||||
},
|
||||
"abstract": {
|
||||
"type": "string",
|
||||
"description": "Complete paper abstract",
|
||||
},
|
||||
"experiments": {
|
||||
"type": "string",
|
||||
"description": "Detailed experimental plan",
|
||||
},
|
||||
"risk_factors": {
|
||||
"type": "string",
|
||||
"description": "Known challenges and constraints",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"title",
|
||||
"short_hypothesis",
|
||||
"related_work",
|
||||
"abstract",
|
||||
"experiments",
|
||||
"risk_factors",
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
agent_name="Paper Idea Agent",
|
||||
agent_role="You are an experienced AI researcher tasked with proposing high-impact research ideas.",
|
||||
system_prompt=PAPER_IDEA_AGENT_SYSTEM_PROMPT,
|
||||
tools_list_dictionary=tools,
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
output_type="final",
|
||||
)
|
||||
|
||||
out = agent.run(
|
||||
"Generate a paper idea for collaborative foundation transformer models"
|
||||
)
|
||||
print(any_to_str(out))
|
@ -0,0 +1,120 @@
|
||||
import cProfile
|
||||
import time
|
||||
|
||||
from swarms.prompts.paper_idea_agent import (
|
||||
PAPER_IDEA_AGENT_SYSTEM_PROMPT,
|
||||
)
|
||||
from swarms import Agent
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
print("All imports completed...")
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_paper_idea",
|
||||
"description": "Generate a structured academic paper idea with all required components.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Concise identifier for the paper idea",
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Academic paper title",
|
||||
},
|
||||
"short_hypothesis": {
|
||||
"type": "string",
|
||||
"description": "Core hypothesis in 1-2 sentences",
|
||||
},
|
||||
"related_work": {
|
||||
"type": "string",
|
||||
"description": "Key papers and how this differs from existing work",
|
||||
},
|
||||
"abstract": {
|
||||
"type": "string",
|
||||
"description": "Complete paper abstract",
|
||||
},
|
||||
"experiments": {
|
||||
"type": "string",
|
||||
"description": "Detailed experimental plan",
|
||||
},
|
||||
"risk_factors": {
|
||||
"type": "string",
|
||||
"description": "Known challenges and constraints",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"title",
|
||||
"short_hypothesis",
|
||||
"related_work",
|
||||
"abstract",
|
||||
"experiments",
|
||||
"risk_factors",
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# agent = Agent(
|
||||
# agent_name="Paper Idea Agent",
|
||||
# agent_role="You are an experienced AI researcher tasked with proposing high-impact research ideas.",
|
||||
# system_prompt=PAPER_IDEA_AGENT_SYSTEM_PROMPT,
|
||||
# tools_list_dictionary=tools,
|
||||
# max_loops=1,
|
||||
# model_name="gpt-4o-mini",
|
||||
# output_type="final",
|
||||
# )
|
||||
def generate_paper_idea():
|
||||
print("Starting generate_paper_idea function...")
|
||||
try:
|
||||
print("Creating agent...")
|
||||
agent = Agent(
|
||||
agent_name="Paper Idea Agent",
|
||||
agent_role="You are an experienced AI researcher tasked with proposing high-impact research ideas.",
|
||||
system_prompt=PAPER_IDEA_AGENT_SYSTEM_PROMPT,
|
||||
tools_list_dictionary=tools,
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
output_type="final",
|
||||
)
|
||||
|
||||
print("Agent created, starting run...")
|
||||
start_time = time.time()
|
||||
out = agent.run(
|
||||
"Generate a paper idea for collaborative foundation transformer models"
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
||||
print("Output:", any_to_str(out))
|
||||
return out
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
print("Defining main block...")
|
||||
if __name__ == "__main__":
|
||||
print("Entering main block...")
|
||||
|
||||
# Basic timing first
|
||||
print("\nRunning basic timing...")
|
||||
generate_paper_idea()
|
||||
|
||||
# Then with profiler
|
||||
print("\nRunning with profiler...")
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
generate_paper_idea()
|
||||
profiler.disable()
|
||||
profiler.print_stats(sort="cumulative")
|
||||
|
||||
print("Script completed.")
|
@ -0,0 +1,31 @@
|
||||
# System Role Definition
|
||||
PAPER_IDEA_AGENT_SYSTEM_PROMPT = """
|
||||
You are an experienced AI researcher tasked with proposing high-impact research ideas. Your ideas should:
|
||||
|
||||
- Be novel and creative
|
||||
- Think outside conventional boundaries
|
||||
- Start from simple, elegant questions or observations
|
||||
- Be distinguishable from existing literature
|
||||
- Be feasible within academic lab resources
|
||||
- Be publishable at top ML conferences
|
||||
- Be implementable using the provided codebase
|
||||
|
||||
Your responses must follow this strict format:
|
||||
|
||||
|
||||
IDEA JSON Structure:
|
||||
{
|
||||
"Name": "Concise identifier",
|
||||
"Title": "Academic paper title",
|
||||
"Short Hypothesis": "Core hypothesis in 1-2 sentences",
|
||||
"Related Work": "Key papers and how this differs",
|
||||
"Abstract": "Complete paper abstract",
|
||||
"Experiments": "Detailed experimental plan",
|
||||
"Risk Factors and Limitations": "Known challenges and constraints"
|
||||
}
|
||||
|
||||
Important Guidelines:
|
||||
- Perform at least one literature search before finalizing any idea
|
||||
- Ensure JSON formatting is valid for automated parsing
|
||||
- Keep proposals clear and implementable
|
||||
"""
|
@ -1,729 +0,0 @@
|
||||
"""
|
||||
TalkHier: A hierarchical multi-agent framework for content generation and refinement.
|
||||
Implements structured communication and evaluation protocols.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.structs.conversation import Conversation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRole(Enum):
|
||||
"""Defines the possible roles for agents in the system."""
|
||||
|
||||
SUPERVISOR = "supervisor"
|
||||
GENERATOR = "generator"
|
||||
EVALUATOR = "evaluator"
|
||||
REVISOR = "revisor"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunicationEvent:
|
||||
"""Represents a structured communication event between agents."""
|
||||
|
||||
message: str
|
||||
background: Optional[str] = None
|
||||
intermediate_output: Optional[Dict[str, Any]] = None
|
||||
sender: str = ""
|
||||
receiver: str = ""
|
||||
timestamp: str = str(datetime.now())
|
||||
|
||||
|
||||
class TalkHier:
|
||||
"""
|
||||
A hierarchical multi-agent system for content generation and refinement.
|
||||
|
||||
Implements the TalkHier framework with structured communication protocols
|
||||
and hierarchical refinement processes.
|
||||
|
||||
Attributes:
|
||||
max_iterations: Maximum number of refinement iterations
|
||||
quality_threshold: Minimum score required for content acceptance
|
||||
model_name: Name of the LLM model to use
|
||||
base_path: Path for saving agent states
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = 3,
|
||||
quality_threshold: float = 0.8,
|
||||
model_name: str = "gpt-4",
|
||||
base_path: Optional[str] = None,
|
||||
return_string: bool = False,
|
||||
):
|
||||
"""Initialize the TalkHier system."""
|
||||
self.max_iterations = max_iterations
|
||||
self.quality_threshold = quality_threshold
|
||||
self.model_name = model_name
|
||||
self.return_string = return_string
|
||||
self.base_path = (
|
||||
Path(base_path) if base_path else Path("./agent_states")
|
||||
)
|
||||
self.base_path.mkdir(exist_ok=True)
|
||||
|
||||
# Initialize agents
|
||||
self._init_agents()
|
||||
|
||||
# Create conversation
|
||||
self.conversation = Conversation()
|
||||
|
||||
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Safely parse JSON string, handling various formats and potential errors.
|
||||
|
||||
Args:
|
||||
json_str: String containing JSON data
|
||||
|
||||
Returns:
|
||||
Parsed dictionary
|
||||
"""
|
||||
try:
|
||||
# Try direct JSON parsing
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
# Try extracting JSON from potential text wrapper
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
# Try extracting from markdown code blocks
|
||||
code_block_match = re.search(
|
||||
r"```(?:json)?\s*(\{.*?\})\s*```",
|
||||
json_str,
|
||||
re.DOTALL,
|
||||
)
|
||||
if code_block_match:
|
||||
return json.loads(code_block_match.group(1))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract JSON: {str(e)}")
|
||||
|
||||
# Fallback: create structured dict from text
|
||||
return {
|
||||
"content": json_str,
|
||||
"metadata": {
|
||||
"parsed": False,
|
||||
"timestamp": str(datetime.now()),
|
||||
},
|
||||
}
|
||||
|
||||
def _get_criteria_generator_prompt(self) -> str:
|
||||
"""Get the prompt for the criteria generator agent."""
|
||||
return """You are a Criteria Generator agent responsible for creating task-specific evaluation criteria.
|
||||
Analyze the task and generate appropriate evaluation criteria based on:
|
||||
- Task type and complexity
|
||||
- Required domain knowledge
|
||||
- Target audience expectations
|
||||
- Quality requirements
|
||||
|
||||
Output all responses in strict JSON format:
|
||||
{
|
||||
"criteria": {
|
||||
"criterion_name": {
|
||||
"description": "Detailed description of what this criterion measures",
|
||||
"importance": "Weight from 0.0-1.0 indicating importance",
|
||||
"evaluation_guide": "Guidelines for how to evaluate this criterion"
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
"task_type": "Classification of the task type",
|
||||
"complexity_level": "Assessment of task complexity",
|
||||
"domain_focus": "Primary domain or field of the task"
|
||||
}
|
||||
}"""
|
||||
|
||||
def _init_agents(self) -> None:
|
||||
"""Initialize all agents with their specific roles and prompts."""
|
||||
# Main supervisor agent
|
||||
self.main_supervisor = Agent(
|
||||
agent_name="Main-Supervisor",
|
||||
system_prompt=self._get_supervisor_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(
|
||||
self.base_path / "main_supervisor.json"
|
||||
),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Generator agent
|
||||
self.generator = Agent(
|
||||
agent_name="Content-Generator",
|
||||
system_prompt=self._get_generator_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(self.base_path / "generator.json"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Criteria Generator agent
|
||||
self.criteria_generator = Agent(
|
||||
agent_name="Criteria-Generator",
|
||||
system_prompt=self._get_criteria_generator_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(
|
||||
self.base_path / "criteria_generator.json"
|
||||
),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Evaluators without criteria (will be set during run)
|
||||
self.evaluators = []
|
||||
for i in range(3):
|
||||
self.evaluators.append(
|
||||
Agent(
|
||||
agent_name=f"Evaluator-{i}",
|
||||
system_prompt=self._get_evaluator_prompt(i),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(
|
||||
self.base_path / f"evaluator_{i}.json"
|
||||
),
|
||||
verbose=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Revisor agent
|
||||
self.revisor = Agent(
|
||||
agent_name="Content-Revisor",
|
||||
system_prompt=self._get_revisor_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(self.base_path / "revisor.json"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def _generate_dynamic_criteria(self, task: str) -> Dict[str, str]:
|
||||
"""
|
||||
Generate dynamic evaluation criteria based on the task.
|
||||
|
||||
Args:
|
||||
task: Content generation task description
|
||||
|
||||
Returns:
|
||||
Dictionary containing dynamic evaluation criteria
|
||||
"""
|
||||
# Example dynamic criteria generation logic
|
||||
if "technical" in task.lower():
|
||||
return {
|
||||
"accuracy": "Technical correctness and source reliability",
|
||||
"clarity": "Readability and logical structure",
|
||||
"depth": "Comprehensive coverage of technical details",
|
||||
"engagement": "Interest level and relevance to the audience",
|
||||
"technical_quality": "Grammar, spelling, and formatting",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"accuracy": "Factual correctness and source reliability",
|
||||
"clarity": "Readability and logical structure",
|
||||
"coherence": "Logical consistency and argument structure",
|
||||
"engagement": "Interest level and relevance to the audience",
|
||||
"completeness": "Coverage of the topic and depth",
|
||||
"technical_quality": "Grammar, spelling, and formatting",
|
||||
}
|
||||
|
||||
def _get_supervisor_prompt(self) -> str:
|
||||
"""Get the prompt for the supervisor agent."""
|
||||
return """You are a Supervisor agent responsible for orchestrating the content generation process and selecting the best evaluation criteria.
|
||||
|
||||
You must:
|
||||
1. Analyze tasks and develop strategies
|
||||
2. Review multiple evaluator feedback
|
||||
3. Select the most appropriate evaluation based on:
|
||||
- Completeness of criteria
|
||||
- Relevance to task
|
||||
- Quality of feedback
|
||||
4. Provide clear instructions for content revision
|
||||
|
||||
Output all responses in strict JSON format:
|
||||
{
|
||||
"thoughts": {
|
||||
"task_analysis": "Analysis of requirements, audience, scope",
|
||||
"strategy": "Step-by-step plan and success metrics",
|
||||
"evaluation_selection": {
|
||||
"chosen_evaluator": "ID of selected evaluator",
|
||||
"reasoning": "Why this evaluation was chosen",
|
||||
"key_criteria": ["List of most important criteria"]
|
||||
}
|
||||
},
|
||||
"next_action": {
|
||||
"agent": "Next agent to engage",
|
||||
"instruction": "Detailed instructions with context"
|
||||
}
|
||||
}"""
|
||||
|
||||
def _get_generator_prompt(self) -> str:
|
||||
"""Get the prompt for the generator agent."""
|
||||
return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
|
||||
|
||||
When generating content:
|
||||
- Thoroughly research and fact-check all information
|
||||
- Structure content logically with clear flow
|
||||
- Use appropriate tone and language for the target audience
|
||||
- Include relevant examples and explanations
|
||||
- Ensure content is original and plagiarism-free
|
||||
- Consider SEO best practices where applicable
|
||||
|
||||
Output all responses in strict JSON format:
|
||||
{
|
||||
"content": {
|
||||
"main_body": "The complete generated content with proper formatting and structure",
|
||||
"metadata": {
|
||||
"word_count": "Accurate word count of main body",
|
||||
"target_audience": "Detailed audience description",
|
||||
"key_points": ["List of main points covered"],
|
||||
"sources": ["List of reference sources if applicable"],
|
||||
"readability_level": "Estimated reading level",
|
||||
"tone": "Description of content tone"
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
def _get_evaluator_prompt(self, evaluator_id: int) -> str:
|
||||
"""Get the base prompt for an evaluator agent."""
|
||||
return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
|
||||
|
||||
When receiving content to evaluate:
|
||||
1. First analyze the task description to determine appropriate evaluation criteria
|
||||
2. Generate specific criteria based on task requirements
|
||||
3. Evaluate content against these criteria
|
||||
4. Provide detailed feedback for each criterion
|
||||
|
||||
Output all responses in strict JSON format:
|
||||
{{
|
||||
"generated_criteria": {{
|
||||
"criteria_name": "description of what this criterion measures",
|
||||
// Add more criteria based on task analysis
|
||||
}},
|
||||
"scores": {{
|
||||
"overall": "0.0-1.0 composite score",
|
||||
"categories": {{
|
||||
// Scores for each generated criterion
|
||||
"criterion_name": "0.0-1.0 score with evidence"
|
||||
}}
|
||||
}},
|
||||
"feedback": [
|
||||
"Specific, actionable improvement suggestions per criterion"
|
||||
],
|
||||
"strengths": ["Notable positive aspects"],
|
||||
"weaknesses": ["Areas needing improvement"]
|
||||
}}"""
|
||||
|
||||
def _get_revisor_prompt(self) -> str:
|
||||
"""Get the prompt for the revisor agent."""
|
||||
return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
|
||||
|
||||
When revising content:
|
||||
- Address all evaluator feedback systematically
|
||||
- Maintain consistency in tone and style
|
||||
- Preserve accurate information
|
||||
- Enhance clarity and flow
|
||||
- Fix technical issues
|
||||
- Optimize for target audience
|
||||
- Track all changes made
|
||||
|
||||
Output all responses in strict JSON format:
|
||||
{
|
||||
"revised_content": {
|
||||
"main_body": "Complete revised content incorporating all improvements",
|
||||
"metadata": {
|
||||
"word_count": "Updated word count",
|
||||
"changes_made": [
|
||||
"Detailed list of specific changes and improvements",
|
||||
"Reasoning for each major revision",
|
||||
"Feedback points addressed"
|
||||
],
|
||||
"improvement_summary": "Overview of main enhancements",
|
||||
"preserved_elements": ["Key elements maintained from original"],
|
||||
"revision_approach": "Strategy used for revisions"
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
def _generate_criteria_for_task(
|
||||
self, task: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate evaluation criteria for the given task."""
|
||||
try:
|
||||
criteria_input = {
|
||||
"task": task,
|
||||
"instruction": "Generate specific evaluation criteria for this task.",
|
||||
}
|
||||
|
||||
criteria_response = self.criteria_generator.run(
|
||||
json.dumps(criteria_input)
|
||||
)
|
||||
self.conversation.add(
|
||||
role="Criteria-Generator", content=criteria_response
|
||||
)
|
||||
|
||||
return self._safely_parse_json(criteria_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating criteria: {str(e)}")
|
||||
return {"criteria": {}}
|
||||
|
||||
def _create_comm_event(
|
||||
self, sender: Agent, receiver: Agent, response: Dict
|
||||
) -> CommunicationEvent:
|
||||
"""Create a structured communication event between agents."""
|
||||
return CommunicationEvent(
|
||||
message=response.get("message", ""),
|
||||
background=response.get("background", ""),
|
||||
intermediate_output=response.get(
|
||||
"intermediate_output", {}
|
||||
),
|
||||
sender=sender.agent_name,
|
||||
receiver=receiver.agent_name,
|
||||
)
|
||||
|
||||
def _evaluate_content(
|
||||
self, content: Union[str, Dict], task: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Coordinate evaluation process with parallel evaluator execution."""
|
||||
try:
|
||||
content_dict = (
|
||||
self._safely_parse_json(content)
|
||||
if isinstance(content, str)
|
||||
else content
|
||||
)
|
||||
criteria_data = self._generate_criteria_for_task(task)
|
||||
|
||||
def run_evaluator(evaluator, eval_input):
|
||||
response = evaluator.run(json.dumps(eval_input))
|
||||
return {
|
||||
"evaluator_id": evaluator.agent_name,
|
||||
"evaluation": self._safely_parse_json(response),
|
||||
}
|
||||
|
||||
eval_inputs = [
|
||||
{
|
||||
"task": task,
|
||||
"content": content_dict,
|
||||
"criteria": criteria_data.get("criteria", {}),
|
||||
}
|
||||
for _ in self.evaluators
|
||||
]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
evaluations = list(
|
||||
executor.map(
|
||||
lambda x: run_evaluator(*x),
|
||||
zip(self.evaluators, eval_inputs),
|
||||
)
|
||||
)
|
||||
|
||||
supervisor_input = {
|
||||
"evaluations": evaluations,
|
||||
"task": task,
|
||||
"instruction": "Synthesize feedback",
|
||||
}
|
||||
supervisor_response = self.main_supervisor.run(
|
||||
json.dumps(supervisor_input)
|
||||
)
|
||||
aggregated_eval = self._safely_parse_json(
|
||||
supervisor_response
|
||||
)
|
||||
|
||||
# Track communication
|
||||
comm_event = self._create_comm_event(
|
||||
self.main_supervisor, self.revisor, aggregated_eval
|
||||
)
|
||||
self.conversation.add(
|
||||
role="Communication",
|
||||
content=json.dumps(asdict(comm_event)),
|
||||
)
|
||||
|
||||
return aggregated_eval
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Evaluation error: {str(e)}")
|
||||
return self._get_fallback_evaluation()
|
||||
|
||||
def _get_fallback_evaluation(self) -> Dict[str, Any]:
|
||||
"""Get a safe fallback evaluation result."""
|
||||
return {
|
||||
"scores": {
|
||||
"overall": 0.5,
|
||||
"categories": {
|
||||
"accuracy": 0.5,
|
||||
"clarity": 0.5,
|
||||
"coherence": 0.5,
|
||||
},
|
||||
},
|
||||
"feedback": ["Evaluation failed"],
|
||||
"metadata": {
|
||||
"timestamp": str(datetime.now()),
|
||||
"is_fallback": True,
|
||||
},
|
||||
}
|
||||
|
||||
def _aggregate_evaluations(
|
||||
self, evaluations: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate multiple evaluation results into a single evaluation."""
|
||||
try:
|
||||
# Collect all unique criteria from evaluators
|
||||
all_criteria = set()
|
||||
for eval_data in evaluations:
|
||||
categories = eval_data.get("scores", {}).get(
|
||||
"categories", {}
|
||||
)
|
||||
all_criteria.update(categories.keys())
|
||||
|
||||
# Initialize score aggregation
|
||||
aggregated_scores = {
|
||||
criterion: [] for criterion in all_criteria
|
||||
}
|
||||
overall_scores = []
|
||||
all_feedback = []
|
||||
|
||||
# Collect scores and feedback
|
||||
for eval_data in evaluations:
|
||||
scores = eval_data.get("scores", {})
|
||||
overall_scores.append(scores.get("overall", 0.5))
|
||||
|
||||
categories = scores.get("categories", {})
|
||||
for criterion in all_criteria:
|
||||
if criterion in categories:
|
||||
aggregated_scores[criterion].append(
|
||||
categories.get(criterion, 0.5)
|
||||
)
|
||||
|
||||
all_feedback.extend(eval_data.get("feedback", []))
|
||||
|
||||
# Calculate means
|
||||
def safe_mean(scores: List[float]) -> float:
|
||||
return sum(scores) / len(scores) if scores else 0.5
|
||||
|
||||
return {
|
||||
"scores": {
|
||||
"overall": safe_mean(overall_scores),
|
||||
"categories": {
|
||||
criterion: safe_mean(scores)
|
||||
for criterion, scores in aggregated_scores.items()
|
||||
},
|
||||
},
|
||||
"feedback": list(set(all_feedback)),
|
||||
"metadata": {
|
||||
"evaluator_count": len(evaluations),
|
||||
"criteria_used": list(all_criteria),
|
||||
"timestamp": str(datetime.now()),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation aggregation: {str(e)}")
|
||||
return self._get_fallback_evaluation()
|
||||
|
||||
def _evaluate_and_revise(
|
||||
self, content: Union[str, Dict], task: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Coordinate evaluation and revision process."""
|
||||
try:
|
||||
# Get evaluations and supervisor selection
|
||||
evaluation_result = self._evaluate_content(content, task)
|
||||
|
||||
# Extract selected evaluation and supervisor reasoning
|
||||
selected_evaluation = evaluation_result.get(
|
||||
"selected_evaluation", {}
|
||||
)
|
||||
supervisor_reasoning = evaluation_result.get(
|
||||
"supervisor_reasoning", {}
|
||||
)
|
||||
|
||||
# Prepare revision input with selected evaluation
|
||||
revision_input = {
|
||||
"content": content,
|
||||
"evaluation": selected_evaluation,
|
||||
"supervisor_feedback": supervisor_reasoning,
|
||||
"instruction": "Revise the content based on the selected evaluation feedback",
|
||||
}
|
||||
|
||||
# Get revision from content generator
|
||||
revision_response = self.generator.run(
|
||||
json.dumps(revision_input)
|
||||
)
|
||||
revised_content = self._safely_parse_json(
|
||||
revision_response
|
||||
)
|
||||
|
||||
return {
|
||||
"content": revised_content,
|
||||
"evaluation": evaluation_result,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Evaluation and revision error: {str(e)}")
|
||||
return {
|
||||
"content": content,
|
||||
"evaluation": self._get_fallback_evaluation(),
|
||||
}
|
||||
|
||||
def run(self, task: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate and iteratively refine content based on the given task.
|
||||
|
||||
Args:
|
||||
task: Content generation task description
|
||||
|
||||
Returns:
|
||||
Dictionary containing final content and metadata
|
||||
"""
|
||||
logger.info(f"Starting content generation for task: {task}")
|
||||
|
||||
try:
|
||||
# Get initial direction from supervisor
|
||||
supervisor_response = self.main_supervisor.run(task)
|
||||
|
||||
self.conversation.add(
|
||||
role=self.main_supervisor.agent_name,
|
||||
content=supervisor_response,
|
||||
)
|
||||
|
||||
supervisor_data = self._safely_parse_json(
|
||||
supervisor_response
|
||||
)
|
||||
|
||||
# Generate initial content
|
||||
generator_response = self.generator.run(
|
||||
json.dumps(supervisor_data.get("next_action", {}))
|
||||
)
|
||||
|
||||
self.conversation.add(
|
||||
role=self.generator.agent_name,
|
||||
content=generator_response,
|
||||
)
|
||||
|
||||
current_content = self._safely_parse_json(
|
||||
generator_response
|
||||
)
|
||||
|
||||
for iteration in range(self.max_iterations):
|
||||
logger.info(f"Starting iteration {iteration + 1}")
|
||||
|
||||
# Evaluate and revise content
|
||||
result = self._evaluate_and_revise(
|
||||
current_content, task
|
||||
)
|
||||
evaluation = result["evaluation"]
|
||||
current_content = result["content"]
|
||||
|
||||
# Check if quality threshold is met
|
||||
selected_eval = evaluation.get(
|
||||
"selected_evaluation", {}
|
||||
)
|
||||
overall_score = selected_eval.get("scores", {}).get(
|
||||
"overall", 0.0
|
||||
)
|
||||
|
||||
if overall_score >= self.quality_threshold:
|
||||
logger.info(
|
||||
"Quality threshold met, returning content"
|
||||
)
|
||||
return {
|
||||
"content": current_content.get(
|
||||
"content", {}
|
||||
).get("main_body", ""),
|
||||
"final_score": overall_score,
|
||||
"iterations": iteration + 1,
|
||||
"metadata": {
|
||||
"content_metadata": current_content.get(
|
||||
"content", {}
|
||||
).get("metadata", {}),
|
||||
"evaluation": evaluation,
|
||||
},
|
||||
}
|
||||
|
||||
# Add to conversation history
|
||||
self.conversation.add(
|
||||
role=self.generator.agent_name,
|
||||
content=json.dumps(current_content),
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Max iterations reached without meeting quality threshold"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_and_refine: {str(e)}")
|
||||
current_content = {
|
||||
"content": {"main_body": f"Error: {str(e)}"}
|
||||
}
|
||||
evaluation = self._get_fallback_evaluation()
|
||||
|
||||
if self.return_string:
|
||||
return self.conversation.return_history_as_string()
|
||||
else:
|
||||
return {
|
||||
"content": current_content.get("content", {}).get(
|
||||
"main_body", ""
|
||||
),
|
||||
"final_score": evaluation["scores"]["overall"],
|
||||
"iterations": self.max_iterations,
|
||||
"metadata": {
|
||||
"content_metadata": current_content.get(
|
||||
"content", {}
|
||||
).get("metadata", {}),
|
||||
"evaluation": evaluation,
|
||||
"error": "Max iterations reached",
|
||||
},
|
||||
}
|
||||
|
||||
def save_state(self) -> None:
|
||||
"""Save the current state of all agents."""
|
||||
for agent in [
|
||||
self.main_supervisor,
|
||||
self.generator,
|
||||
*self.evaluators,
|
||||
self.revisor,
|
||||
]:
|
||||
try:
|
||||
agent.save_state()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving state for {agent.agent_name}: {str(e)}"
|
||||
)
|
||||
|
||||
def load_state(self) -> None:
|
||||
"""Load the saved state of all agents."""
|
||||
for agent in [
|
||||
self.main_supervisor,
|
||||
self.generator,
|
||||
*self.evaluators,
|
||||
self.revisor,
|
||||
]:
|
||||
try:
|
||||
agent.load_state()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error loading state for {agent.agent_name}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# try:
|
||||
# talkhier = TalkHier(
|
||||
# max_iterations=1,
|
||||
# quality_threshold=0.8,
|
||||
# model_name="gpt-4o",
|
||||
# return_string=False,
|
||||
# )
|
||||
|
||||
# # Ask for user input
|
||||
# task = input("Enter the content generation task description: ")
|
||||
# result = talkhier.run(task)
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in main execution: {str(e)}")
|
@ -0,0 +1,557 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
from fastmcp import FastMCP, Client
|
||||
from loguru import logger
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
|
||||
|
||||
class AOP:
|
||||
"""
|
||||
Agent-Orchestration Protocol (AOP) class for managing tools, agents, and swarms.
|
||||
|
||||
This class provides decorators and methods for registering and running various components
|
||||
in a Swarms environment. It handles logging, metadata management, and execution control.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the AOP instance
|
||||
description (str): A description of the AOP instance
|
||||
mcp (FastMCP): The underlying FastMCP instance for managing components
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
url: Optional[str] = "http://localhost:8000/sse",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the AOP instance.
|
||||
|
||||
Args:
|
||||
name (str): The name of the AOP instance
|
||||
description (str): A description of the AOP instance
|
||||
url (str): The URL of the MCP instance
|
||||
*args: Additional positional arguments passed to FastMCP
|
||||
**kwargs: Additional keyword arguments passed to FastMCP
|
||||
"""
|
||||
logger.info(f"[AOP] Initializing AOP instance: {name}")
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.url = url
|
||||
|
||||
self.tools = {}
|
||||
self.swarms = {}
|
||||
|
||||
self.mcp = FastMCP(name=name, *args, **kwargs)
|
||||
|
||||
logger.success(
|
||||
f"[AOP] Successfully initialized AOP instance: {name}"
|
||||
)
|
||||
|
||||
def tool(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorator to register an MCP tool with optional metadata.
|
||||
|
||||
This decorator registers a function as a tool in the MCP system. It handles
|
||||
logging, metadata management, and execution tracking.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): Custom name for the tool. If None, uses function name
|
||||
description (Optional[str]): Custom description. If None, uses function docstring
|
||||
|
||||
Returns:
|
||||
Callable: A decorator function that registers the tool
|
||||
"""
|
||||
logger.debug(
|
||||
f"[AOP] Creating tool decorator with name={name}, description={description}"
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
tool_name = name or func.__name__
|
||||
tool_description = description or (
|
||||
inspect.getdoc(func) or ""
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[AOP] Registering tool: {tool_name} - {tool_description}"
|
||||
)
|
||||
|
||||
self.tools[tool_name] = {
|
||||
"name": tool_name,
|
||||
"description": tool_description,
|
||||
"function": func,
|
||||
}
|
||||
|
||||
@self.mcp.tool(
|
||||
name=f"tool_{tool_name}", description=tool_description
|
||||
)
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
logger.info(
|
||||
f"[TOOL:{tool_name}] ➤ called with args={args}, kwargs={kwargs}"
|
||||
)
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
logger.success(f"[TOOL:{tool_name}] ✅ completed")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[TOOL:{tool_name}] ❌ failed with error: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorator to define an agent entry point.
|
||||
|
||||
This decorator registers a function as an agent in the MCP system. It handles
|
||||
logging, metadata management, and execution tracking for agent operations.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): Custom name for the agent. If None, uses 'agent_' + function name
|
||||
description (Optional[str]): Custom description. If None, uses function docstring
|
||||
|
||||
Returns:
|
||||
Callable: A decorator function that registers the agent
|
||||
"""
|
||||
logger.debug(
|
||||
f"[AOP] Creating agent decorator with name={name}, description={description}"
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
agent_name = name or f"agent_{func.__name__}"
|
||||
agent_description = description or (
|
||||
inspect.getdoc(func) or ""
|
||||
)
|
||||
|
||||
@self.mcp.tool(
|
||||
name=agent_name, description=agent_description
|
||||
)
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
logger.info(f"[AGENT:{agent_name}] 👤 Starting")
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
logger.success(
|
||||
f"[AGENT:{agent_name}] ✅ Finished"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AGENT:{agent_name}] ❌ failed with error: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
wrapper._is_agent = True
|
||||
wrapper._agent_name = agent_name
|
||||
wrapper._agent_description = agent_description
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def swarm(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Decorator to define a swarm controller.
|
||||
|
||||
This decorator registers a function as a swarm controller in the MCP system.
|
||||
It handles logging, metadata management, and execution tracking for swarm operations.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): Custom name for the swarm. If None, uses 'swarm_' + function name
|
||||
description (Optional[str]): Custom description. If None, uses function docstring
|
||||
|
||||
Returns:
|
||||
Callable: A decorator function that registers the swarm
|
||||
"""
|
||||
logger.debug(
|
||||
f"[AOP] Creating swarm decorator with name={name}, description={description}"
|
||||
)
|
||||
|
||||
def decorator(func: Callable):
|
||||
swarm_name = name or f"swarm_{func.__name__}"
|
||||
swarm_description = description or (
|
||||
inspect.getdoc(func) or ""
|
||||
)
|
||||
|
||||
@self.mcp.tool(
|
||||
name=swarm_name, description=swarm_description
|
||||
)
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
logger.info(
|
||||
f"[SWARM:{swarm_name}] 🐝 Spawning swarm..."
|
||||
)
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
logger.success(
|
||||
f"[SWARM:{swarm_name}] 🐝 Completed"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SWARM:{swarm_name}] ❌ failed with error: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
wrapper._is_swarm = True
|
||||
wrapper._swarm_name = swarm_name
|
||||
wrapper._swarm_description = swarm_description
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def run(self, method: Literal["stdio", "sse"], *args, **kwargs):
|
||||
"""
|
||||
Run the MCP with the specified method.
|
||||
|
||||
Args:
|
||||
method (Literal['stdio', 'sse']): The execution method to use
|
||||
*args: Additional positional arguments for the run method
|
||||
**kwargs: Additional keyword arguments for the run method
|
||||
|
||||
Returns:
|
||||
Any: The result of the MCP run operation
|
||||
"""
|
||||
logger.info(f"[AOP] Running MCP with method: {method}")
|
||||
try:
|
||||
result = self.mcp.run(method, *args, **kwargs)
|
||||
logger.success(
|
||||
f"[AOP] Successfully ran MCP with method: {method}"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AOP] Failed to run MCP with method {method}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def run_stdio(self, *args, **kwargs):
|
||||
"""
|
||||
Run the MCP using standard I/O method.
|
||||
|
||||
Args:
|
||||
*args: Additional positional arguments for the run method
|
||||
**kwargs: Additional keyword arguments for the run method
|
||||
|
||||
Returns:
|
||||
Any: The result of the MCP run operation
|
||||
"""
|
||||
logger.info("[AOP] Running MCP with stdio method")
|
||||
return self.run("stdio", *args, **kwargs)
|
||||
|
||||
def run_sse(self, *args, **kwargs):
|
||||
"""
|
||||
Run the MCP using Server-Sent Events method.
|
||||
|
||||
Args:
|
||||
*args: Additional positional arguments for the run method
|
||||
**kwargs: Additional keyword arguments for the run method
|
||||
|
||||
Returns:
|
||||
Any: The result of the MCP run operation
|
||||
"""
|
||||
logger.info("[AOP] Running MCP with SSE method")
|
||||
return self.run("sse", *args, **kwargs)
|
||||
|
||||
def list_available(
|
||||
self, output_type: Literal["str", "list"] = "str"
|
||||
):
|
||||
"""
|
||||
List all available tools in the MCP.
|
||||
|
||||
Returns:
|
||||
list: A list of all registered tools
|
||||
"""
|
||||
if output_type == "str":
|
||||
return any_to_str(self.mcp.list_tools())
|
||||
elif output_type == "list":
|
||||
return self.mcp.list_tools()
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
async def check_utility_exists(
|
||||
self, url: str, name: str, *args, **kwargs
|
||||
):
|
||||
async with Client(url, *args, **kwargs) as client:
|
||||
if any(tool.name == name for tool in client.list_tools()):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def _call_tool(
|
||||
self, url: str, name: str, arguments: dict, *args, **kwargs
|
||||
):
|
||||
try:
|
||||
async with Client(url, *args, **kwargs) as client:
|
||||
result = await client.call_tool(name, arguments)
|
||||
logger.info(
|
||||
f"Client connected: {client.is_connected()}"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling tool: {e}")
|
||||
return None
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
url: str,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return asyncio.run(
|
||||
self._call_tool(url, name, arguments, *args, **kwargs)
|
||||
)
|
||||
|
||||
def call_tool_or_agent(
|
||||
self,
|
||||
url: str,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
output_type: Literal["str", "list"] = "str",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute a tool or agent by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool or agent to execute
|
||||
arguments (dict): The arguments to pass to the tool or agent
|
||||
"""
|
||||
if output_type == "str":
|
||||
return any_to_str(
|
||||
self.call_tool(
|
||||
url=url, name=name, arguments=arguments
|
||||
)
|
||||
)
|
||||
elif output_type == "list":
|
||||
return self.call_tool(
|
||||
url=url, name=name, arguments=arguments
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
def call_tool_or_agent_batched(
|
||||
self,
|
||||
url: str,
|
||||
names: list[str],
|
||||
arguments: list[dict],
|
||||
output_type: Literal["str", "list"] = "str",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute a list of tools or agents by name.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the tools or agents to execute
|
||||
"""
|
||||
if output_type == "str":
|
||||
return [
|
||||
any_to_str(
|
||||
self.call_tool_or_agent(
|
||||
url=url,
|
||||
name=name,
|
||||
arguments=argument,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
for name, argument in zip(names, arguments)
|
||||
]
|
||||
elif output_type == "list":
|
||||
return [
|
||||
self.call_tool_or_agent(
|
||||
url=url,
|
||||
name=name,
|
||||
arguments=argument,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
for name, argument in zip(names, arguments)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
def call_tool_or_agent_concurrently(
|
||||
self,
|
||||
url: str,
|
||||
names: list[str],
|
||||
arguments: list[dict],
|
||||
output_type: Literal["str", "list"] = "str",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute a list of tools or agents by name concurrently.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the tools or agents to execute
|
||||
arguments (list[dict]): The arguments to pass to the tools or agents
|
||||
"""
|
||||
outputs = []
|
||||
with ThreadPoolExecutor(max_workers=len(names)) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self.call_tool_or_agent,
|
||||
url=url,
|
||||
name=name,
|
||||
arguments=argument,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
for name, argument in zip(names, arguments)
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
outputs.append(future.result())
|
||||
|
||||
if output_type == "str":
|
||||
return any_to_str(outputs)
|
||||
elif output_type == "list":
|
||||
return outputs
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
def call_swarm(
|
||||
self,
|
||||
url: str,
|
||||
name: str,
|
||||
arguments: dict,
|
||||
output_type: Literal["str", "list"] = "str",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Execute a swarm by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the swarm to execute
|
||||
"""
|
||||
if output_type == "str":
|
||||
return any_to_str(
|
||||
asyncio.run(
|
||||
self._call_tool(
|
||||
url=url,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif output_type == "list":
|
||||
return asyncio.run(
|
||||
self._call_tool(
|
||||
url=url,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
def list_agents(
|
||||
self, output_type: Literal["str", "list"] = "str"
|
||||
):
|
||||
"""
|
||||
List all available agents in the MCP.
|
||||
|
||||
Returns:
|
||||
list: A list of all registered agents
|
||||
"""
|
||||
|
||||
out = self.list_all()
|
||||
agents = []
|
||||
for item in out:
|
||||
if "agent" in item["name"]:
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
def list_swarms(
|
||||
self, output_type: Literal["str", "list"] = "str"
|
||||
):
|
||||
"""
|
||||
List all available swarms in the MCP.
|
||||
|
||||
Returns:
|
||||
list: A list of all registered swarms
|
||||
"""
|
||||
out = self.list_all()
|
||||
agents = []
|
||||
for item in out:
|
||||
if "swarm" in item["name"]:
|
||||
agents.append(item)
|
||||
return agents
|
||||
|
||||
async def _list_all(self):
|
||||
async with Client(self.url) as client:
|
||||
return await client.list_tools()
|
||||
|
||||
def list_all(self):
|
||||
out = asyncio.run(self._list_all())
|
||||
|
||||
outputs = []
|
||||
for tool in out:
|
||||
outputs.append(tool.model_dump())
|
||||
|
||||
return outputs
|
||||
|
||||
def list_tool_parameters(self, name: str):
|
||||
out = self.list_all()
|
||||
|
||||
# Find the tool by name
|
||||
for tool in out:
|
||||
if tool["name"] == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def search_if_tool_exists(self, name: str):
|
||||
out = self.list_all()
|
||||
for tool in out:
|
||||
if tool["name"] == name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def search(
|
||||
self,
|
||||
type: Literal["tool", "agent", "swarm"],
|
||||
name: str,
|
||||
output_type: Literal["str", "list"] = "str",
|
||||
):
|
||||
"""
|
||||
Search for a tool, agent, or swarm by name.
|
||||
|
||||
Args:
|
||||
type (Literal["tool", "agent", "swarm"]): The type of the item to search for
|
||||
name (str): The name of the item to search for
|
||||
|
||||
Returns:
|
||||
dict: The item if found, otherwise None
|
||||
"""
|
||||
all_items = self.list_all()
|
||||
for item in all_items:
|
||||
if item["name"] == name:
|
||||
return item
|
||||
return None
|
@ -1,818 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.base_workflow import BaseWorkflow
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
# Base logger initialization
|
||||
logger = initialize_logger("async_workflow")
|
||||
|
||||
|
||||
# Pydantic models for structured data
|
||||
class AgentOutput(BaseModel):
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
task_id: str
|
||||
input: str
|
||||
output: Any
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowOutput(BaseModel):
|
||||
workflow_id: str
|
||||
workflow_name: str
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
total_agents: int
|
||||
successful_tasks: int
|
||||
failed_tasks: int
|
||||
agent_outputs: List[AgentOutput]
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SpeakerRole(str, Enum):
|
||||
COORDINATOR = "coordinator"
|
||||
CRITIC = "critic"
|
||||
EXECUTOR = "executor"
|
||||
VALIDATOR = "validator"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class SpeakerMessage(BaseModel):
|
||||
role: SpeakerRole
|
||||
content: Any
|
||||
timestamp: datetime
|
||||
agent_name: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GroupChatConfig(BaseModel):
|
||||
max_loops: int = 10
|
||||
timeout_per_turn: float = 30.0
|
||||
require_all_speakers: bool = False
|
||||
allow_concurrent: bool = True
|
||||
save_history: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedMemoryItem:
|
||||
key: str
|
||||
value: Any
|
||||
timestamp: datetime
|
||||
author: str
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerConfig:
|
||||
role: SpeakerRole
|
||||
agent: Any
|
||||
priority: int = 0
|
||||
concurrent: bool = True
|
||||
timeout: float = 30.0
|
||||
required: bool = False
|
||||
|
||||
|
||||
class SharedMemory:
|
||||
"""Thread-safe shared memory implementation with persistence"""
|
||||
|
||||
def __init__(self, persistence_path: Optional[str] = None):
|
||||
self._memory = {}
|
||||
self._lock = threading.Lock()
|
||||
self._persistence_path = persistence_path
|
||||
self._load_from_disk()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
author: str,
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> None:
|
||||
with self._lock:
|
||||
item = SharedMemoryItem(
|
||||
key=key,
|
||||
value=value,
|
||||
timestamp=datetime.utcnow(),
|
||||
author=author,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._memory[key] = item
|
||||
self._persist_to_disk()
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
with self._lock:
|
||||
item = self._memory.get(key)
|
||||
return item.value if item else None
|
||||
|
||||
def get_with_metadata(
|
||||
self, key: str
|
||||
) -> Optional[SharedMemoryItem]:
|
||||
with self._lock:
|
||||
return self._memory.get(key)
|
||||
|
||||
def _persist_to_disk(self) -> None:
|
||||
if self._persistence_path:
|
||||
with open(self._persistence_path, "w") as f:
|
||||
json.dump(
|
||||
{k: asdict(v) for k, v in self._memory.items()}, f
|
||||
)
|
||||
|
||||
def _load_from_disk(self) -> None:
|
||||
if self._persistence_path and os.path.exists(
|
||||
self._persistence_path
|
||||
):
|
||||
with open(self._persistence_path, "r") as f:
|
||||
data = json.load(f)
|
||||
self._memory = {
|
||||
k: SharedMemoryItem(**v) for k, v in data.items()
|
||||
}
|
||||
|
||||
|
||||
class SpeakerSystem:
|
||||
"""Manages speaker interactions and group chat functionality"""
|
||||
|
||||
def __init__(self, default_timeout: float = 30.0):
|
||||
self.speakers: Dict[SpeakerRole, SpeakerConfig] = {}
|
||||
self.message_history: List[SpeakerMessage] = []
|
||||
self.default_timeout = default_timeout
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_speaker(self, config: SpeakerConfig) -> None:
|
||||
with self._lock:
|
||||
self.speakers[config.role] = config
|
||||
|
||||
def remove_speaker(self, role: SpeakerRole) -> None:
|
||||
with self._lock:
|
||||
self.speakers.pop(role, None)
|
||||
|
||||
async def _execute_speaker(
|
||||
self,
|
||||
config: SpeakerConfig,
|
||||
input_data: Any,
|
||||
context: Dict[str, Any] = None,
|
||||
) -> SpeakerMessage:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
config.agent.arun(input_data), timeout=config.timeout
|
||||
)
|
||||
|
||||
return SpeakerMessage(
|
||||
role=config.role,
|
||||
content=result,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_name=config.agent.agent_name,
|
||||
metadata={"context": context or {}},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return SpeakerMessage(
|
||||
role=config.role,
|
||||
content=None,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_name=config.agent.agent_name,
|
||||
metadata={"error": "Timeout"},
|
||||
)
|
||||
except Exception as e:
|
||||
return SpeakerMessage(
|
||||
role=config.role,
|
||||
content=None,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_name=config.agent.agent_name,
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflow(BaseWorkflow):
|
||||
"""Enhanced asynchronous workflow with advanced speaker system"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "AsyncWorkflow",
|
||||
agents: List[Agent] = None,
|
||||
max_workers: int = 5,
|
||||
dashboard: bool = False,
|
||||
autosave: bool = False,
|
||||
verbose: bool = False,
|
||||
log_path: str = "workflow.log",
|
||||
shared_memory_path: Optional[str] = "shared_memory.json",
|
||||
enable_group_chat: bool = False,
|
||||
group_chat_config: Optional[GroupChatConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(agents=agents, **kwargs)
|
||||
self.workflow_id = str(uuid.uuid4())
|
||||
self.name = name
|
||||
self.agents = agents or []
|
||||
self.max_workers = max_workers
|
||||
self.dashboard = dashboard
|
||||
self.autosave = autosave
|
||||
self.verbose = verbose
|
||||
self.task_pool = []
|
||||
self.results = []
|
||||
self.shared_memory = SharedMemory(shared_memory_path)
|
||||
self.speaker_system = SpeakerSystem()
|
||||
self.enable_group_chat = enable_group_chat
|
||||
self.group_chat_config = (
|
||||
group_chat_config or GroupChatConfig()
|
||||
)
|
||||
self._setup_logging(log_path)
|
||||
self.metadata = {}
|
||||
|
||||
def _setup_logging(self, log_path: str) -> None:
|
||||
"""Configure rotating file logger"""
|
||||
self.logger = logging.getLogger(
|
||||
f"workflow_{self.workflow_id}"
|
||||
)
|
||||
self.logger.setLevel(
|
||||
logging.DEBUG if self.verbose else logging.INFO
|
||||
)
|
||||
|
||||
handler = RotatingFileHandler(
|
||||
log_path, maxBytes=10 * 1024 * 1024, backupCount=5
|
||||
)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
|
||||
def add_default_speakers(self) -> None:
|
||||
"""Add all agents as default concurrent speakers"""
|
||||
for agent in self.agents:
|
||||
config = SpeakerConfig(
|
||||
role=SpeakerRole.DEFAULT,
|
||||
agent=agent,
|
||||
concurrent=True,
|
||||
timeout=30.0,
|
||||
required=False,
|
||||
)
|
||||
self.speaker_system.add_speaker(config)
|
||||
|
||||
async def run_concurrent_speakers(
|
||||
self, task: str, context: Dict[str, Any] = None
|
||||
) -> List[SpeakerMessage]:
|
||||
"""Run all concurrent speakers in parallel"""
|
||||
concurrent_tasks = [
|
||||
self.speaker_system._execute_speaker(
|
||||
config, task, context
|
||||
)
|
||||
for config in self.speaker_system.speakers.values()
|
||||
if config.concurrent
|
||||
]
|
||||
|
||||
results = await asyncio.gather(
|
||||
*concurrent_tasks, return_exceptions=True
|
||||
)
|
||||
return [r for r in results if isinstance(r, SpeakerMessage)]
|
||||
|
||||
async def run_sequential_speakers(
|
||||
self, task: str, context: Dict[str, Any] = None
|
||||
) -> List[SpeakerMessage]:
|
||||
"""Run non-concurrent speakers in sequence"""
|
||||
results = []
|
||||
for config in sorted(
|
||||
self.speaker_system.speakers.values(),
|
||||
key=lambda x: x.priority,
|
||||
):
|
||||
if not config.concurrent:
|
||||
result = await self.speaker_system._execute_speaker(
|
||||
config, task, context
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def run_group_chat(
|
||||
self, initial_message: str, context: Dict[str, Any] = None
|
||||
) -> List[SpeakerMessage]:
|
||||
"""Run a group chat discussion among speakers"""
|
||||
if not self.enable_group_chat:
|
||||
raise ValueError(
|
||||
"Group chat is not enabled for this workflow"
|
||||
)
|
||||
|
||||
messages: List[SpeakerMessage] = []
|
||||
current_turn = 0
|
||||
|
||||
while current_turn < self.group_chat_config.max_loops:
|
||||
turn_context = {
|
||||
"turn": current_turn,
|
||||
"history": messages,
|
||||
**(context or {}),
|
||||
}
|
||||
|
||||
if self.group_chat_config.allow_concurrent:
|
||||
turn_messages = await self.run_concurrent_speakers(
|
||||
(
|
||||
initial_message
|
||||
if current_turn == 0
|
||||
else messages[-1].content
|
||||
),
|
||||
turn_context,
|
||||
)
|
||||
else:
|
||||
turn_messages = await self.run_sequential_speakers(
|
||||
(
|
||||
initial_message
|
||||
if current_turn == 0
|
||||
else messages[-1].content
|
||||
),
|
||||
turn_context,
|
||||
)
|
||||
|
||||
messages.extend(turn_messages)
|
||||
|
||||
# Check if we should continue the conversation
|
||||
if self._should_end_group_chat(messages):
|
||||
break
|
||||
|
||||
current_turn += 1
|
||||
|
||||
if self.group_chat_config.save_history:
|
||||
self.speaker_system.message_history.extend(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def _should_end_group_chat(
|
||||
self, messages: List[SpeakerMessage]
|
||||
) -> bool:
|
||||
"""Determine if group chat should end based on messages"""
|
||||
if not messages:
|
||||
return True
|
||||
|
||||
# Check if all required speakers have participated
|
||||
if self.group_chat_config.require_all_speakers:
|
||||
participating_roles = {msg.role for msg in messages}
|
||||
required_roles = {
|
||||
role
|
||||
for role, config in self.speaker_system.speakers.items()
|
||||
if config.required
|
||||
}
|
||||
if not required_roles.issubset(participating_roles):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def task_context(self):
|
||||
"""Context manager for task execution with proper cleanup"""
|
||||
start_time = datetime.utcnow()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
end_time = datetime.utcnow()
|
||||
if self.autosave:
|
||||
await self._save_results(start_time, end_time)
|
||||
|
||||
async def _execute_agent_task(
|
||||
self, agent: Agent, task: str
|
||||
) -> AgentOutput:
|
||||
"""Execute a single agent task with enhanced error handling and monitoring"""
|
||||
start_time = datetime.utcnow()
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Agent {agent.agent_name} starting task {task_id}: {task}"
|
||||
)
|
||||
|
||||
result = await agent.arun(task)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
self.logger.info(
|
||||
f"Agent {agent.agent_name} completed task {task_id}"
|
||||
)
|
||||
|
||||
return AgentOutput(
|
||||
agent_id=str(id(agent)),
|
||||
agent_name=agent.agent_name,
|
||||
task_id=task_id,
|
||||
input=task,
|
||||
output=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
status="success",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.utcnow()
|
||||
self.logger.error(
|
||||
f"Error in agent {agent.agent_name} task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return AgentOutput(
|
||||
agent_id=str(id(agent)),
|
||||
agent_name=agent.agent_name,
|
||||
task_id=task_id,
|
||||
input=task,
|
||||
output=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
status="error",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def run(self, task: str) -> WorkflowOutput:
|
||||
"""Enhanced workflow execution with speaker system integration"""
|
||||
if not self.agents:
|
||||
raise ValueError("No agents provided to the workflow")
|
||||
|
||||
async with self.task_context():
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Run speakers first if enabled
|
||||
speaker_outputs = []
|
||||
if self.enable_group_chat:
|
||||
speaker_outputs = await self.run_group_chat(task)
|
||||
else:
|
||||
concurrent_outputs = (
|
||||
await self.run_concurrent_speakers(task)
|
||||
)
|
||||
sequential_outputs = (
|
||||
await self.run_sequential_speakers(task)
|
||||
)
|
||||
speaker_outputs = (
|
||||
concurrent_outputs + sequential_outputs
|
||||
)
|
||||
|
||||
# Store speaker outputs in shared memory
|
||||
self.shared_memory.set(
|
||||
"speaker_outputs",
|
||||
[msg.dict() for msg in speaker_outputs],
|
||||
"workflow",
|
||||
)
|
||||
|
||||
# Create tasks for all agents
|
||||
tasks = [
|
||||
self._execute_agent_task(agent, task)
|
||||
for agent in self.agents
|
||||
]
|
||||
|
||||
# Execute all tasks concurrently
|
||||
agent_outputs = await asyncio.gather(
|
||||
*tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
# Calculate success/failure counts
|
||||
successful_tasks = sum(
|
||||
1
|
||||
for output in agent_outputs
|
||||
if isinstance(output, AgentOutput)
|
||||
and output.status == "success"
|
||||
)
|
||||
failed_tasks = len(agent_outputs) - successful_tasks
|
||||
|
||||
return WorkflowOutput(
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_name=self.name,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
total_agents=len(self.agents),
|
||||
successful_tasks=successful_tasks,
|
||||
failed_tasks=failed_tasks,
|
||||
agent_outputs=[
|
||||
output
|
||||
for output in agent_outputs
|
||||
if isinstance(output, AgentOutput)
|
||||
],
|
||||
metadata={
|
||||
"max_workers": self.max_workers,
|
||||
"shared_memory_keys": list(
|
||||
self.shared_memory._memory.keys()
|
||||
),
|
||||
"group_chat_enabled": self.enable_group_chat,
|
||||
"total_speaker_messages": len(
|
||||
speaker_outputs
|
||||
),
|
||||
"speaker_outputs": [
|
||||
msg.dict() for msg in speaker_outputs
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Critical workflow error: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def _save_results(
|
||||
self, start_time: datetime, end_time: datetime
|
||||
) -> None:
|
||||
"""Save workflow results to disk"""
|
||||
if not self.autosave:
|
||||
return
|
||||
|
||||
output_dir = "workflow_outputs"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"results": [
|
||||
(
|
||||
asdict(result)
|
||||
if hasattr(result, "__dict__")
|
||||
else (
|
||||
result.dict()
|
||||
if hasattr(result, "dict")
|
||||
else str(result)
|
||||
)
|
||||
)
|
||||
for result in self.results
|
||||
],
|
||||
"speaker_history": [
|
||||
msg.dict()
|
||||
for msg in self.speaker_system.message_history
|
||||
],
|
||||
"metadata": self.metadata,
|
||||
},
|
||||
f,
|
||||
default=str,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
self.logger.info(f"Workflow results saved to {filename}")
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error saving workflow results: {str(e)}"
|
||||
)
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate workflow configuration"""
|
||||
if self.max_workers < 1:
|
||||
raise ValueError("max_workers must be at least 1")
|
||||
|
||||
if (
|
||||
self.enable_group_chat
|
||||
and not self.speaker_system.speakers
|
||||
):
|
||||
raise ValueError(
|
||||
"Group chat enabled but no speakers configured"
|
||||
)
|
||||
|
||||
for config in self.speaker_system.speakers.values():
|
||||
if config.timeout <= 0:
|
||||
raise ValueError(
|
||||
f"Invalid timeout for speaker {config.role}"
|
||||
)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup workflow resources"""
|
||||
try:
|
||||
# Close any open file handlers
|
||||
for handler in self.logger.handlers[:]:
|
||||
handler.close()
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
# Persist final state
|
||||
if self.autosave:
|
||||
end_time = datetime.utcnow()
|
||||
await self._save_results(
|
||||
(
|
||||
self.results[0].start_time
|
||||
if self.results
|
||||
else end_time
|
||||
),
|
||||
end_time,
|
||||
)
|
||||
|
||||
# Clear shared memory if configured
|
||||
self.shared_memory._memory.clear()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during cleanup: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Utility functions for the workflow
|
||||
def create_default_workflow(
|
||||
agents: List[Agent],
|
||||
name: str = "DefaultWorkflow",
|
||||
enable_group_chat: bool = False,
|
||||
) -> AsyncWorkflow:
|
||||
"""Create a workflow with default configuration"""
|
||||
workflow = AsyncWorkflow(
|
||||
name=name,
|
||||
agents=agents,
|
||||
max_workers=len(agents),
|
||||
dashboard=True,
|
||||
autosave=True,
|
||||
verbose=True,
|
||||
enable_group_chat=enable_group_chat,
|
||||
group_chat_config=GroupChatConfig(
|
||||
max_loops=5,
|
||||
allow_concurrent=True,
|
||||
require_all_speakers=False,
|
||||
),
|
||||
)
|
||||
|
||||
workflow.add_default_speakers()
|
||||
return workflow
|
||||
|
||||
|
||||
async def run_workflow_with_retry(
|
||||
workflow: AsyncWorkflow,
|
||||
task: str,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
) -> WorkflowOutput:
|
||||
"""Run workflow with retry logic"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await workflow.run(task)
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
workflow.logger.warning(
|
||||
f"Attempt {attempt + 1} failed, retrying in {retry_delay} seconds: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
|
||||
|
||||
# async def create_specialized_agents() -> List[Agent]:
|
||||
# """Create a set of specialized agents for financial analysis"""
|
||||
|
||||
# # Base model configuration
|
||||
# model = OpenAIChat(model_name="gpt-4o")
|
||||
|
||||
# # Financial Analysis Agent
|
||||
# financial_agent = Agent(
|
||||
# agent_name="Financial-Analysis-Agent",
|
||||
# agent_description="Personal finance advisor agent",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT +
|
||||
# "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
|
||||
# max_loops=1,
|
||||
# llm=model,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# user_name="Kye",
|
||||
# retry_attempts=3,
|
||||
# context_length=8192,
|
||||
# return_step_meta=False,
|
||||
# output_type="str",
|
||||
# auto_generate_prompt=False,
|
||||
# max_tokens=4000,
|
||||
# stopping_token="<DONE>",
|
||||
# saved_state_path="financial_agent.json",
|
||||
# interactive=False,
|
||||
# )
|
||||
|
||||
# # Risk Assessment Agent
|
||||
# risk_agent = Agent(
|
||||
# agent_name="Risk-Assessment-Agent",
|
||||
# agent_description="Investment risk analysis specialist",
|
||||
# system_prompt="Analyze investment risks and provide risk scores. Output <DONE> when analysis is complete.",
|
||||
# max_loops=1,
|
||||
# llm=model,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# user_name="Kye",
|
||||
# retry_attempts=3,
|
||||
# context_length=8192,
|
||||
# output_type="str",
|
||||
# max_tokens=4000,
|
||||
# stopping_token="<DONE>",
|
||||
# saved_state_path="risk_agent.json",
|
||||
# interactive=False,
|
||||
# )
|
||||
|
||||
# # Market Research Agent
|
||||
# research_agent = Agent(
|
||||
# agent_name="Market-Research-Agent",
|
||||
# agent_description="AI and tech market research specialist",
|
||||
# system_prompt="Research AI market trends and growth opportunities. Output <DONE> when research is complete.",
|
||||
# max_loops=1,
|
||||
# llm=model,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# user_name="Kye",
|
||||
# retry_attempts=3,
|
||||
# context_length=8192,
|
||||
# output_type="str",
|
||||
# max_tokens=4000,
|
||||
# stopping_token="<DONE>",
|
||||
# saved_state_path="research_agent.json",
|
||||
# interactive=False,
|
||||
# )
|
||||
|
||||
# return [financial_agent, risk_agent, research_agent]
|
||||
|
||||
# async def main():
|
||||
# # Create specialized agents
|
||||
# agents = await create_specialized_agents()
|
||||
|
||||
# # Create workflow with group chat enabled
|
||||
# workflow = create_default_workflow(
|
||||
# agents=agents,
|
||||
# name="AI-Investment-Analysis-Workflow",
|
||||
# enable_group_chat=True
|
||||
# )
|
||||
|
||||
# # Configure speaker roles
|
||||
# workflow.speaker_system.add_speaker(
|
||||
# SpeakerConfig(
|
||||
# role=SpeakerRole.COORDINATOR,
|
||||
# agent=agents[0], # Financial agent as coordinator
|
||||
# priority=1,
|
||||
# concurrent=False,
|
||||
# required=True
|
||||
# )
|
||||
# )
|
||||
|
||||
# workflow.speaker_system.add_speaker(
|
||||
# SpeakerConfig(
|
||||
# role=SpeakerRole.CRITIC,
|
||||
# agent=agents[1], # Risk agent as critic
|
||||
# priority=2,
|
||||
# concurrent=True
|
||||
# )
|
||||
# )
|
||||
|
||||
# workflow.speaker_system.add_speaker(
|
||||
# SpeakerConfig(
|
||||
# role=SpeakerRole.EXECUTOR,
|
||||
# agent=agents[2], # Research agent as executor
|
||||
# priority=2,
|
||||
# concurrent=True
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Investment analysis task
|
||||
# investment_task = """
|
||||
# Create a comprehensive investment analysis for a $40k portfolio focused on AI growth opportunities:
|
||||
# 1. Identify high-growth AI ETFs and index funds
|
||||
# 2. Analyze risks and potential returns
|
||||
# 3. Create a diversified portfolio allocation
|
||||
# 4. Provide market trend analysis
|
||||
# Present the results in a structured markdown format.
|
||||
# """
|
||||
|
||||
# try:
|
||||
# # Run workflow with retry
|
||||
# result = await run_workflow_with_retry(
|
||||
# workflow=workflow,
|
||||
# task=investment_task,
|
||||
# max_retries=3
|
||||
# )
|
||||
|
||||
# print("\nWorkflow Results:")
|
||||
# print("================")
|
||||
|
||||
# # Process and display agent outputs
|
||||
# for output in result.agent_outputs:
|
||||
# print(f"\nAgent: {output.agent_name}")
|
||||
# print("-" * (len(output.agent_name) + 8))
|
||||
# print(output.output)
|
||||
|
||||
# # Display group chat history if enabled
|
||||
# if workflow.enable_group_chat:
|
||||
# print("\nGroup Chat Discussion:")
|
||||
# print("=====================")
|
||||
# for msg in workflow.speaker_system.message_history:
|
||||
# print(f"\n{msg.role} ({msg.agent_name}):")
|
||||
# print(msg.content)
|
||||
|
||||
# # Save detailed results
|
||||
# if result.metadata.get("shared_memory_keys"):
|
||||
# print("\nShared Insights:")
|
||||
# print("===============")
|
||||
# for key in result.metadata["shared_memory_keys"]:
|
||||
# value = workflow.shared_memory.get(key)
|
||||
# if value:
|
||||
# print(f"\n{key}:")
|
||||
# print(value)
|
||||
|
||||
# except Exception as e:
|
||||
# print(f"Workflow failed: {str(e)}")
|
||||
|
||||
# finally:
|
||||
# await workflow.cleanup()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Run the example
|
||||
# asyncio.run(main())
|
@ -1,844 +0,0 @@
|
||||
"""
|
||||
OctoToolsSwarm: A multi-agent system for complex reasoning.
|
||||
Implements the OctoTools framework using swarms.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
import math # Import the math module
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from swarms import Agent
|
||||
from swarms.structs.conversation import Conversation
|
||||
|
||||
# from exa_search import exa_search as web_search_execute
|
||||
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolType(Enum):
|
||||
"""Defines the types of tools available."""
|
||||
|
||||
IMAGE_CAPTIONER = "image_captioner"
|
||||
OBJECT_DETECTOR = "object_detector"
|
||||
WEB_SEARCH = "web_search"
|
||||
PYTHON_CALCULATOR = "python_calculator"
|
||||
# Add more tool types as needed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
"""
|
||||
Represents an external tool.
|
||||
|
||||
Attributes:
|
||||
name: Unique name of the tool.
|
||||
description: Description of the tool's function.
|
||||
metadata: Dictionary containing tool metadata.
|
||||
execute_func: Callable function that executes the tool's logic.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
metadata: Dict[str, Any]
|
||||
execute_func: Callable
|
||||
|
||||
def execute(self, **kwargs):
|
||||
"""Executes the tool's logic, handling potential errors."""
|
||||
try:
|
||||
return self.execute_func(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error executing tool {self.name}: {str(e)}"
|
||||
)
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
class AgentRole(Enum):
|
||||
"""Defines the roles for agents in the OctoTools system."""
|
||||
|
||||
PLANNER = "planner"
|
||||
VERIFIER = "verifier"
|
||||
SUMMARIZER = "summarizer"
|
||||
|
||||
|
||||
class OctoToolsSwarm:
|
||||
"""
|
||||
A multi-agent system implementing the OctoTools framework.
|
||||
|
||||
Attributes:
|
||||
model_name: Name of the LLM model to use.
|
||||
max_iterations: Maximum number of action-execution iterations.
|
||||
base_path: Path for saving agent states.
|
||||
tools: List of available Tool objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
model_name: str = "gemini/gemini-2.0-flash",
|
||||
max_iterations: int = 10,
|
||||
base_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the OctoToolsSwarm system."""
|
||||
self.model_name = model_name
|
||||
self.max_iterations = max_iterations
|
||||
self.base_path = (
|
||||
Path(base_path)
|
||||
if base_path
|
||||
else Path("./octotools_states")
|
||||
)
|
||||
self.base_path.mkdir(exist_ok=True)
|
||||
self.tools = {
|
||||
tool.name: tool for tool in tools
|
||||
} # Store tools in a dictionary
|
||||
|
||||
# Initialize agents
|
||||
self._init_agents()
|
||||
|
||||
# Create conversation tracker and memory
|
||||
self.conversation = Conversation()
|
||||
self.memory = [] # Store the trajectory
|
||||
|
||||
def _init_agents(self) -> None:
|
||||
"""Initialize all agents with their specific roles and prompts."""
|
||||
# Planner agent
|
||||
self.planner = Agent(
|
||||
agent_name="OctoTools-Planner",
|
||||
system_prompt=self._get_planner_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=3,
|
||||
saved_state_path=str(self.base_path / "planner.json"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Verifier agent
|
||||
self.verifier = Agent(
|
||||
agent_name="OctoTools-Verifier",
|
||||
system_prompt=self._get_verifier_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(self.base_path / "verifier.json"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Summarizer agent
|
||||
self.summarizer = Agent(
|
||||
agent_name="OctoTools-Summarizer",
|
||||
system_prompt=self._get_summarizer_prompt(),
|
||||
model_name=self.model_name,
|
||||
max_loops=1,
|
||||
saved_state_path=str(self.base_path / "summarizer.json"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def _get_planner_prompt(self) -> str:
|
||||
"""Get the prompt for the planner agent (Improved with few-shot examples)."""
|
||||
tool_descriptions = "\n".join(
|
||||
[
|
||||
f"- {tool_name}: {self.tools[tool_name].description}"
|
||||
for tool_name in self.tools
|
||||
]
|
||||
)
|
||||
return f"""You are the Planner in the OctoTools framework. Your role is to analyze the user's query,
|
||||
identify required skills, suggest relevant tools, and plan the steps to solve the problem.
|
||||
|
||||
1. **Analyze the user's query:** Understand the requirements and identify the necessary skills and potentially relevant tools.
|
||||
2. **Perform high-level planning:** Create a rough outline of how tools might be used to solve the problem.
|
||||
3. **Perform low-level planning (action prediction):** At each step, select the best tool to use and formulate a specific sub-goal for that tool, considering the current context.
|
||||
|
||||
Available Tools:
|
||||
{tool_descriptions}
|
||||
|
||||
Output your response in JSON format. Here are examples for different stages:
|
||||
|
||||
**Query Analysis (High-Level Planning):**
|
||||
Example Input:
|
||||
Query: "What is the capital of France?"
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{{
|
||||
"summary": "The user is asking for the capital of France.",
|
||||
"required_skills": ["knowledge retrieval"],
|
||||
"relevant_tools": ["Web_Search_Tool"]
|
||||
}}
|
||||
```
|
||||
|
||||
**Action Prediction (Low-Level Planning):**
|
||||
Example Input:
|
||||
Context: {{ "query": "What is the capital of France?", "available_tools": ["Web_Search_Tool"] }}
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{{
|
||||
"justification": "The Web_Search_Tool can be used to directly find the capital of France.",
|
||||
"context": {{}},
|
||||
"sub_goal": "Search the web for 'capital of France'.",
|
||||
"tool_name": "Web_Search_Tool"
|
||||
}}
|
||||
```
|
||||
Another Example:
|
||||
Context: {{"query": "How many objects are in the image?", "available_tools": ["Image_Captioner_Tool", "Object_Detector_Tool"], "image": "objects.png"}}
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{{
|
||||
"justification": "First, get a general description of the image to understand the context.",
|
||||
"context": {{ "image": "objects.png" }},
|
||||
"sub_goal": "Generate a description of the image.",
|
||||
"tool_name": "Image_Captioner_Tool"
|
||||
}}
|
||||
```
|
||||
|
||||
Example for Finding Square Root:
|
||||
Context: {{"query": "What is the square root of the number of objects in the image?", "available_tools": ["Object_Detector_Tool", "Python_Calculator_Tool"], "image": "objects.png", "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"]}}
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{{
|
||||
"justification": "We have detected 4 objects in the image. Now we need to find the square root of 4.",
|
||||
"context": {{}},
|
||||
"sub_goal": "Calculate the square root of 4",
|
||||
"tool_name": "Python_Calculator_Tool"
|
||||
}}
|
||||
```
|
||||
|
||||
Your output MUST be a single, valid JSON object with the following keys:
|
||||
- justification (string): Your reasoning.
|
||||
- context (dict): A dictionary containing relevant information.
|
||||
- sub_goal (string): The specific instruction for the tool.
|
||||
- tool_name (string): The EXACT name of the tool to use.
|
||||
|
||||
Do NOT include any text outside of the JSON object.
|
||||
"""
|
||||
|
||||
def _get_verifier_prompt(self) -> str:
|
||||
"""Get the prompt for the verifier agent (Improved with few-shot examples)."""
|
||||
return """You are the Context Verifier in the OctoTools framework. Your role is to analyze the current context
|
||||
and memory to determine if the problem is solved, if there are any inconsistencies, or if further steps are needed.
|
||||
|
||||
Output your response in JSON format:
|
||||
|
||||
Expected output structure:
|
||||
```json
|
||||
{
|
||||
"completeness": "Indicate whether the query is fully, partially, or not answered.",
|
||||
"inconsistencies": "List any inconsistencies found in the context or memory.",
|
||||
"verification_needs": "List any information that needs further verification.",
|
||||
"ambiguities": "List any ambiguities found in the context or memory.",
|
||||
"stop_signal": true/false
|
||||
}
|
||||
```
|
||||
|
||||
Example Input:
|
||||
Context: { "last_result": { "result": "Caption: The image shows a cat." } }
|
||||
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Image_Captioner_Tool" } } ]
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{
|
||||
"completeness": "partial",
|
||||
"inconsistencies": [],
|
||||
"verification_needs": ["Object detection to confirm the presence of a cat."],
|
||||
"ambiguities": [],
|
||||
"stop_signal": false
|
||||
}
|
||||
```
|
||||
|
||||
Another Example:
|
||||
Context: { "last_result": { "result": ["Detected object: cat"] } }
|
||||
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } } ]
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{
|
||||
"completeness": "yes",
|
||||
"inconsistencies": [],
|
||||
"verification_needs": [],
|
||||
"ambiguities": [],
|
||||
"stop_signal": true
|
||||
}
|
||||
```
|
||||
|
||||
Square Root Example:
|
||||
Context: {
|
||||
"query": "What is the square root of the number of objects in the image?",
|
||||
"image": "example.png",
|
||||
"Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"],
|
||||
"Python_Calculator_Tool_result": "Result of 4**0.5 is 2.0"
|
||||
}
|
||||
Memory: [
|
||||
{ "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } },
|
||||
{ "component": "Action Predictor", "result": { "tool_name": "Python_Calculator_Tool" } }
|
||||
]
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{
|
||||
"completeness": "yes",
|
||||
"inconsistencies": [],
|
||||
"verification_needs": [],
|
||||
"ambiguities": [],
|
||||
"stop_signal": true
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
def _get_summarizer_prompt(self) -> str:
|
||||
"""Get the prompt for the summarizer agent (Improved with few-shot examples)."""
|
||||
return """You are the Solution Summarizer in the OctoTools framework. Your role is to synthesize the final
|
||||
answer to the user's query based on the complete trajectory of actions and results.
|
||||
|
||||
Output your response in JSON format:
|
||||
|
||||
Expected output structure:
|
||||
```json
|
||||
{
|
||||
"final_answer": "Provide a clear and concise answer to the original query."
|
||||
}
|
||||
```
|
||||
Example Input:
|
||||
Memory: [
|
||||
{"component": "Query Analyzer", "result": {"summary": "Find the capital of France."}},
|
||||
{"component": "Action Predictor", "result": {"tool_name": "Web_Search_Tool"}},
|
||||
{"component": "Tool Execution", "result": {"result": "The capital of France is Paris."}}
|
||||
]
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{
|
||||
"final_answer": "The capital of France is Paris."
|
||||
}
|
||||
```
|
||||
|
||||
Square Root Example:
|
||||
Memory: [
|
||||
{"component": "Query Analyzer", "result": {"summary": "Find the square root of the number of objects in the image."}},
|
||||
{"component": "Action Predictor", "result": {"tool_name": "Object_Detector_Tool", "sub_goal": "Detect objects in the image"}},
|
||||
{"component": "Tool Execution", "result": {"result": ["object1", "object2", "object3", "object4"]}},
|
||||
{"component": "Action Predictor", "result": {"tool_name": "Python_Calculator_Tool", "sub_goal": "Calculate the square root of 4"}},
|
||||
{"component": "Tool Execution", "result": {"result": "Result of 4**0.5 is 2.0"}}
|
||||
]
|
||||
|
||||
Example Output:
|
||||
```json
|
||||
{
|
||||
"final_answer": "The square root of the number of objects in the image is 2.0. There are 4 objects in the image, and the square root of 4 is 2.0."
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
|
||||
"""Safely parse JSON, handling errors and using recursive descent."""
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"JSONDecodeError: Attempting to extract JSON from: {json_str}"
|
||||
)
|
||||
try:
|
||||
# More robust JSON extraction with recursive descent
|
||||
def extract_json(s):
|
||||
stack = []
|
||||
start = -1
|
||||
for i, c in enumerate(s):
|
||||
if c == "{":
|
||||
if not stack:
|
||||
start = i
|
||||
stack.append(c)
|
||||
elif c == "}":
|
||||
if stack:
|
||||
stack.pop()
|
||||
if not stack and start != -1:
|
||||
return s[start : i + 1]
|
||||
return None
|
||||
|
||||
extracted_json = extract_json(json_str)
|
||||
if extracted_json:
|
||||
logger.info(f"Extracted JSON: {extracted_json}")
|
||||
return json.loads(extracted_json)
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to extract JSON using recursive descent."
|
||||
)
|
||||
return {
|
||||
"error": "Failed to parse JSON",
|
||||
"content": json_str,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during JSON extraction: {e}")
|
||||
return {
|
||||
"error": "Failed to parse JSON",
|
||||
"content": json_str,
|
||||
}
|
||||
|
||||
def _execute_tool(
|
||||
self, tool_name: str, context: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Executes a tool based on its name and provided context."""
|
||||
if tool_name not in self.tools:
|
||||
return {"error": f"Tool '{tool_name}' not found."}
|
||||
|
||||
tool = self.tools[tool_name]
|
||||
try:
|
||||
# For Python Calculator tool, handle object counts from Object Detector
|
||||
if tool_name == "Python_Calculator_Tool":
|
||||
# Check for object detector results
|
||||
object_detector_result = context.get(
|
||||
"Object_Detector_Tool_result"
|
||||
)
|
||||
if object_detector_result and isinstance(
|
||||
object_detector_result, list
|
||||
):
|
||||
# Calculate the number of objects
|
||||
num_objects = len(object_detector_result)
|
||||
# If sub_goal doesn't already contain an expression, create one
|
||||
if (
|
||||
"sub_goal" in context
|
||||
and "Calculate the square root"
|
||||
in context["sub_goal"]
|
||||
):
|
||||
context["expression"] = f"{num_objects}**0.5"
|
||||
elif "expression" not in context:
|
||||
# Default to square root if no expression is specified
|
||||
context["expression"] = f"{num_objects}**0.5"
|
||||
|
||||
# Filter context: only pass expected inputs to the tool
|
||||
valid_inputs = {
|
||||
k: v
|
||||
for k, v in context.items()
|
||||
if k in tool.metadata.get("input_types", {})
|
||||
}
|
||||
result = tool.execute(**valid_inputs)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing tool {tool_name}: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _run_agent(
|
||||
self, agent: Agent, input_prompt: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Runs a swarms agent, handling output and JSON parsing."""
|
||||
try:
|
||||
# Construct the full input, including the system prompt
|
||||
full_input = f"{agent.system_prompt}\n\n{input_prompt}"
|
||||
|
||||
# Run the agent and capture the output
|
||||
agent_response = agent.run(full_input)
|
||||
|
||||
logger.info(
|
||||
f"DEBUG: Raw agent response: {agent_response}"
|
||||
)
|
||||
|
||||
# Extract the LLM's response (remove conversation history, etc.)
|
||||
response_text = agent_response # Assuming direct return
|
||||
|
||||
# Try to parse the response as JSON
|
||||
parsed_response = self._safely_parse_json(response_text)
|
||||
|
||||
return parsed_response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error running agent {agent.agent_name}: {e}"
|
||||
)
|
||||
return {
|
||||
"error": f"Agent {agent.agent_name} failed: {str(e)}"
|
||||
}
|
||||
|
||||
def run(
|
||||
self, query: str, image: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the task through the multi-agent workflow."""
|
||||
logger.info(f"Starting task: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: Query Analysis (High-Level Planning)
|
||||
planner_input = (
|
||||
f"Analyze the following query and determine the necessary skills and"
|
||||
f" relevant tools: {query}"
|
||||
)
|
||||
query_analysis = self._run_agent(
|
||||
self.planner, planner_input
|
||||
)
|
||||
|
||||
if "error" in query_analysis:
|
||||
return {
|
||||
"error": f"Planner query analysis failed: {query_analysis['error']}",
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
|
||||
self.memory.append(
|
||||
{
|
||||
"step": 0,
|
||||
"component": "Query Analyzer",
|
||||
"result": query_analysis,
|
||||
}
|
||||
)
|
||||
self.conversation.add(
|
||||
role=self.planner.agent_name,
|
||||
content=json.dumps(query_analysis),
|
||||
)
|
||||
|
||||
# Initialize context with the query and image (if provided)
|
||||
context = {"query": query}
|
||||
if image:
|
||||
context["image"] = image
|
||||
|
||||
# Add available tools to context
|
||||
if "relevant_tools" in query_analysis:
|
||||
context["available_tools"] = query_analysis[
|
||||
"relevant_tools"
|
||||
]
|
||||
else:
|
||||
# If no relevant tools specified, make all tools available
|
||||
context["available_tools"] = list(self.tools.keys())
|
||||
|
||||
step_count = 1
|
||||
|
||||
# Step 2: Iterative Action-Execution Loop
|
||||
while step_count <= self.max_iterations:
|
||||
logger.info(
|
||||
f"Starting iteration {step_count} of {self.max_iterations}"
|
||||
)
|
||||
|
||||
# Step 2a: Action Prediction (Low-Level Planning)
|
||||
action_planner_input = (
|
||||
f"Current Context: {json.dumps(context)}\nAvailable Tools:"
|
||||
f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the"
|
||||
" next step."
|
||||
)
|
||||
action = self._run_agent(
|
||||
self.planner, action_planner_input
|
||||
)
|
||||
if "error" in action:
|
||||
logger.error(
|
||||
f"Error in action prediction: {action['error']}"
|
||||
)
|
||||
return {
|
||||
"error": f"Planner action prediction failed: {action['error']}",
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
self.memory.append(
|
||||
{
|
||||
"step": step_count,
|
||||
"component": "Action Predictor",
|
||||
"result": action,
|
||||
}
|
||||
)
|
||||
self.conversation.add(
|
||||
role=self.planner.agent_name,
|
||||
content=json.dumps(action),
|
||||
)
|
||||
|
||||
# Input Validation for Action (Relaxed)
|
||||
if (
|
||||
not isinstance(action, dict)
|
||||
or "tool_name" not in action
|
||||
or "sub_goal" not in action
|
||||
):
|
||||
error_msg = (
|
||||
"Action prediction did not return required fields (tool_name,"
|
||||
" sub_goal) or was not a dictionary."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
self.memory.append(
|
||||
{
|
||||
"step": step_count,
|
||||
"component": "Error",
|
||||
"result": error_msg,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
# Step 2b: Execute Tool
|
||||
tool_execution_context = {
|
||||
**context,
|
||||
**action.get(
|
||||
"context", {}
|
||||
), # Add any additional context
|
||||
"sub_goal": action[
|
||||
"sub_goal"
|
||||
], # Pass sub_goal to tool
|
||||
}
|
||||
|
||||
tool_result = self._execute_tool(
|
||||
action["tool_name"], tool_execution_context
|
||||
)
|
||||
|
||||
self.memory.append(
|
||||
{
|
||||
"step": step_count,
|
||||
"component": "Tool Execution",
|
||||
"result": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 2c: Context Update - Store result with a descriptive key
|
||||
if "result" in tool_result:
|
||||
context[f"{action['tool_name']}_result"] = (
|
||||
tool_result["result"]
|
||||
)
|
||||
if "error" in tool_result:
|
||||
context[f"{action['tool_name']}_error"] = (
|
||||
tool_result["error"]
|
||||
)
|
||||
|
||||
# Step 2d: Context Verification
|
||||
verifier_input = (
|
||||
f"Current Context: {json.dumps(context)}\nMemory:"
|
||||
f" {json.dumps(self.memory)}\nQuery: {query}"
|
||||
)
|
||||
verification = self._run_agent(
|
||||
self.verifier, verifier_input
|
||||
)
|
||||
if "error" in verification:
|
||||
return {
|
||||
"error": f"Verifier failed: {verification['error']}",
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
|
||||
self.memory.append(
|
||||
{
|
||||
"step": step_count,
|
||||
"component": "Context Verifier",
|
||||
"result": verification,
|
||||
}
|
||||
)
|
||||
self.conversation.add(
|
||||
role=self.verifier.agent_name,
|
||||
content=json.dumps(verification),
|
||||
)
|
||||
|
||||
# Check for stop signal from Verifier
|
||||
if verification.get("stop_signal") is True:
|
||||
logger.info(
|
||||
"Received stop signal from verifier. Stopping iterations."
|
||||
)
|
||||
break
|
||||
|
||||
# Safety mechanism - if we've executed the same tool multiple times
|
||||
same_tool_count = sum(
|
||||
1
|
||||
for m in self.memory
|
||||
if m.get("component") == "Action Predictor"
|
||||
and m.get("result", {}).get("tool_name")
|
||||
== action.get("tool_name")
|
||||
)
|
||||
|
||||
if same_tool_count > 3:
|
||||
logger.warning(
|
||||
f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop."
|
||||
)
|
||||
break
|
||||
|
||||
step_count += 1
|
||||
|
||||
# Step 3: Solution Summarization
|
||||
summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}"
|
||||
|
||||
summarization = self._run_agent(
|
||||
self.summarizer, summarizer_input
|
||||
)
|
||||
if "error" in summarization:
|
||||
return {
|
||||
"error": f"Summarizer failed: {summarization['error']}",
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
self.conversation.add(
|
||||
role=self.summarizer.agent_name,
|
||||
content=json.dumps(summarization),
|
||||
)
|
||||
|
||||
return {
|
||||
"final_answer": summarization.get(
|
||||
"final_answer", "No answer found."
|
||||
),
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected error in run method: {e}"
|
||||
) # More detailed
|
||||
return {
|
||||
"error": str(e),
|
||||
"trajectory": self.memory,
|
||||
"conversation": self.conversation.return_history_as_string(),
|
||||
}
|
||||
|
||||
def save_state(self) -> None:
|
||||
"""Save the current state of all agents."""
|
||||
for agent in [self.planner, self.verifier, self.summarizer]:
|
||||
try:
|
||||
agent.save_state()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving state for {agent.agent_name}: {str(e)}"
|
||||
)
|
||||
|
||||
def load_state(self) -> None:
|
||||
"""Load the saved state of all agents."""
|
||||
for agent in [self.planner, self.verifier, self.summarizer]:
|
||||
try:
|
||||
agent.load_state()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error loading state for {agent.agent_name}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# --- Example Usage ---
|
||||
|
||||
|
||||
# Define dummy tool functions (replace with actual implementations)
|
||||
def image_captioner_execute(
|
||||
image: str, prompt: str = "Describe the image", **kwargs
|
||||
) -> str:
|
||||
"""Dummy image captioner."""
|
||||
print(
|
||||
f"image_captioner_execute called with image: {image}, prompt: {prompt}"
|
||||
)
|
||||
return f"Caption for {image}: A descriptive caption (dummy)." # Simplified
|
||||
|
||||
|
||||
def object_detector_execute(
|
||||
image: str, labels: List[str] = [], **kwargs
|
||||
) -> List[str]:
|
||||
"""Dummy object detector, handles missing labels gracefully."""
|
||||
print(
|
||||
f"object_detector_execute called with image: {image}, labels: {labels}"
|
||||
)
|
||||
if not labels:
|
||||
return [
|
||||
"object1",
|
||||
"object2",
|
||||
"object3",
|
||||
"object4",
|
||||
] # Return default objects if no labels
|
||||
return [f"Detected {label}" for label in labels] # Simplified
|
||||
|
||||
|
||||
def web_search_execute(query: str, **kwargs) -> str:
|
||||
"""Dummy web search."""
|
||||
print(f"web_search_execute called with query: {query}")
|
||||
return f"Search results for '{query}'..." # Simplified
|
||||
|
||||
|
||||
def python_calculator_execute(expression: str, **kwargs) -> str:
|
||||
"""Python calculator (using math module)."""
|
||||
print(f"python_calculator_execute called with: {expression}")
|
||||
try:
|
||||
# Safely evaluate only simple expressions involving numbers and basic operations
|
||||
if re.match(r"^[0-9+\-*/().\s]+$", expression):
|
||||
result = eval(
|
||||
expression, {"__builtins__": {}, "math": math}
|
||||
)
|
||||
return f"Result of {expression} is {result}"
|
||||
else:
|
||||
return "Error: Invalid expression for calculator."
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
# Create utility function to get default tools
|
||||
def get_default_tools() -> List[Tool]:
|
||||
"""Returns a list of default tools that can be used with OctoToolsSwarm."""
|
||||
image_captioner = Tool(
|
||||
name="Image_Captioner_Tool",
|
||||
description="Generates a caption for an image.",
|
||||
metadata={
|
||||
"input_types": {"image": "str", "prompt": "str"},
|
||||
"output_type": "str",
|
||||
"limitations": "May struggle with complex scenes or ambiguous objects.",
|
||||
"best_practices": "Use with clear, well-lit images. Provide specific prompts for better results.",
|
||||
},
|
||||
execute_func=image_captioner_execute,
|
||||
)
|
||||
|
||||
object_detector = Tool(
|
||||
name="Object_Detector_Tool",
|
||||
description="Detects objects in an image.",
|
||||
metadata={
|
||||
"input_types": {"image": "str", "labels": "list"},
|
||||
"output_type": "list",
|
||||
"limitations": "Accuracy depends on the quality of the image and the clarity of the objects.",
|
||||
"best_practices": "Provide a list of specific object labels to detect. Use high-resolution images.",
|
||||
},
|
||||
execute_func=object_detector_execute,
|
||||
)
|
||||
|
||||
web_search = Tool(
|
||||
name="Web_Search_Tool",
|
||||
description="Performs a web search.",
|
||||
metadata={
|
||||
"input_types": {"query": "str"},
|
||||
"output_type": "str",
|
||||
"limitations": "May not find specific or niche information.",
|
||||
"best_practices": "Use specific and descriptive keywords for better results.",
|
||||
},
|
||||
execute_func=web_search_execute,
|
||||
)
|
||||
|
||||
calculator = Tool(
|
||||
name="Python_Calculator_Tool",
|
||||
description="Evaluates a Python expression.",
|
||||
metadata={
|
||||
"input_types": {"expression": "str"},
|
||||
"output_type": "str",
|
||||
"limitations": "Cannot handle complex mathematical functions or libraries.",
|
||||
"best_practices": "Use for basic arithmetic and simple calculations.",
|
||||
},
|
||||
execute_func=python_calculator_execute,
|
||||
)
|
||||
|
||||
return [image_captioner, object_detector, web_search, calculator]
|
||||
|
||||
|
||||
# Only execute the example when this script is run directly
|
||||
# if __name__ == "__main__":
|
||||
# print("Running OctoToolsSwarm example...")
|
||||
|
||||
# # Create an OctoToolsSwarm agent with default tools
|
||||
# tools = get_default_tools()
|
||||
# agent = OctoToolsSwarm(tools=tools)
|
||||
|
||||
# # Example query
|
||||
# query = "What is the square root of the number of objects in this image?"
|
||||
|
||||
# # Create a dummy image file for testing if it doesn't exist
|
||||
# image_path = "example.png"
|
||||
# if not os.path.exists(image_path):
|
||||
# with open(image_path, "w") as f:
|
||||
# f.write("Dummy image content")
|
||||
# print(f"Created dummy image file: {image_path}")
|
||||
|
||||
# # Run the agent
|
||||
# result = agent.run(query, image=image_path)
|
||||
|
||||
# # Display results
|
||||
# print("\n=== FINAL ANSWER ===")
|
||||
# print(result["final_answer"])
|
||||
|
||||
# print("\n=== TRAJECTORY SUMMARY ===")
|
||||
# for step in result["trajectory"]:
|
||||
# print(f"Step {step.get('step', 'N/A')}: {step.get('component', 'Unknown')}")
|
||||
|
||||
# print("\nOctoToolsSwarm example completed.")
|
@ -1,469 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import pulsar
|
||||
from cryptography.fernet import Fernet
|
||||
from loguru import logger
|
||||
from prometheus_client import Counter, Histogram, start_http_server
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
# Enhanced metrics
|
||||
TASK_COUNTER = Counter(
|
||||
"swarm_tasks_total", "Total number of tasks processed"
|
||||
)
|
||||
TASK_LATENCY = Histogram(
|
||||
"swarm_task_duration_seconds", "Task processing duration"
|
||||
)
|
||||
TASK_FAILURES = Counter(
|
||||
"swarm_task_failures_total", "Total number of task failures"
|
||||
)
|
||||
AGENT_ERRORS = Counter(
|
||||
"swarm_agent_errors_total", "Total number of agent errors"
|
||||
)
|
||||
|
||||
# Define types using Literal
|
||||
TaskStatus = Literal["pending", "processing", "completed", "failed"]
|
||||
TaskPriority = Literal["low", "medium", "high", "critical"]
|
||||
|
||||
|
||||
class SecurityConfig(BaseModel):
|
||||
"""Security configuration for the swarm"""
|
||||
|
||||
encryption_key: str = Field(
|
||||
..., description="Encryption key for sensitive data"
|
||||
)
|
||||
tls_cert_path: Optional[str] = Field(
|
||||
None, description="Path to TLS certificate"
|
||||
)
|
||||
tls_key_path: Optional[str] = Field(
|
||||
None, description="Path to TLS private key"
|
||||
)
|
||||
auth_token: Optional[str] = Field(
|
||||
None, description="Authentication token"
|
||||
)
|
||||
max_message_size: int = Field(
|
||||
default=1048576, description="Maximum message size in bytes"
|
||||
)
|
||||
rate_limit: int = Field(
|
||||
default=100, description="Maximum tasks per minute"
|
||||
)
|
||||
|
||||
@validator("encryption_key")
|
||||
def validate_encryption_key(cls, v):
|
||||
if len(v) < 32:
|
||||
raise ValueError(
|
||||
"Encryption key must be at least 32 bytes long"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Enhanced task model with additional metadata and validation"""
|
||||
|
||||
task_id: str = Field(
|
||||
..., description="Unique identifier for the task"
|
||||
)
|
||||
description: str = Field(
|
||||
..., description="Task description or instructions"
|
||||
)
|
||||
output_type: Literal["string", "json", "file"] = Field("string")
|
||||
status: TaskStatus = Field(default="pending")
|
||||
priority: TaskPriority = Field(default="medium")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
retry_count: int = Field(default=0)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@validator("task_id")
|
||||
def validate_task_id(cls, v):
|
||||
if not v.strip():
|
||||
raise ValueError("task_id cannot be empty")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Model for task execution results"""
|
||||
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
result: Any
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float
|
||||
agent_id: str
|
||||
|
||||
|
||||
@contextmanager
|
||||
def task_timing():
|
||||
"""Context manager for timing task execution"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
TASK_LATENCY.observe(duration)
|
||||
|
||||
|
||||
class SecurePulsarSwarm:
|
||||
"""
|
||||
Enhanced secure, scalable swarm system with improved reliability and security features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
agents: List[Any],
|
||||
pulsar_url: str,
|
||||
subscription_name: str,
|
||||
topic_name: str,
|
||||
security_config: SecurityConfig,
|
||||
max_workers: int = 5,
|
||||
retry_attempts: int = 3,
|
||||
task_timeout: int = 300,
|
||||
metrics_port: int = 8000,
|
||||
):
|
||||
"""Initialize the enhanced Pulsar Swarm"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.agents = agents
|
||||
self.pulsar_url = pulsar_url
|
||||
self.subscription_name = subscription_name
|
||||
self.topic_name = topic_name
|
||||
self.security_config = security_config
|
||||
self.max_workers = max_workers
|
||||
self.retry_attempts = retry_attempts
|
||||
self.task_timeout = task_timeout
|
||||
|
||||
# Initialize encryption
|
||||
self.cipher_suite = Fernet(
|
||||
security_config.encryption_key.encode()
|
||||
)
|
||||
|
||||
# Setup metrics server
|
||||
start_http_server(metrics_port)
|
||||
|
||||
# Initialize Pulsar client with security settings
|
||||
client_config = {
|
||||
"authentication": (
|
||||
None
|
||||
if not security_config.auth_token
|
||||
else pulsar.AuthenticationToken(
|
||||
security_config.auth_token
|
||||
)
|
||||
),
|
||||
"operation_timeout_seconds": 30,
|
||||
"connection_timeout_seconds": 30,
|
||||
"use_tls": bool(security_config.tls_cert_path),
|
||||
"tls_trust_certs_file_path": security_config.tls_cert_path,
|
||||
"tls_allow_insecure_connection": False,
|
||||
}
|
||||
|
||||
self.client = pulsar.Client(self.pulsar_url, **client_config)
|
||||
self.producer = self._create_producer()
|
||||
self.consumer = self._create_consumer()
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Initialize rate limiting
|
||||
self.last_execution_time = time.time()
|
||||
self.execution_count = 0
|
||||
|
||||
logger.info(
|
||||
f"Secure Pulsar Swarm '{self.name}' initialized with enhanced security features"
|
||||
)
|
||||
|
||||
def _create_producer(self):
|
||||
"""Create a secure producer with retry logic"""
|
||||
return self.client.create_producer(
|
||||
self.topic_name,
|
||||
max_pending_messages=1000,
|
||||
compression_type=pulsar.CompressionType.LZ4,
|
||||
block_if_queue_full=True,
|
||||
batching_enabled=True,
|
||||
batching_max_publish_delay_ms=10,
|
||||
)
|
||||
|
||||
def _create_consumer(self):
|
||||
"""Create a secure consumer with retry logic"""
|
||||
return self.client.subscribe(
|
||||
self.topic_name,
|
||||
subscription_name=self.subscription_name,
|
||||
consumer_type=pulsar.ConsumerType.Shared,
|
||||
message_listener=None,
|
||||
receiver_queue_size=1000,
|
||||
max_total_receiver_queue_size_across_partitions=50000,
|
||||
)
|
||||
|
||||
def _encrypt_message(self, data: str) -> bytes:
|
||||
"""Encrypt message data"""
|
||||
return self.cipher_suite.encrypt(data.encode())
|
||||
|
||||
def _decrypt_message(self, data: bytes) -> str:
|
||||
"""Decrypt message data"""
|
||||
return self.cipher_suite.decrypt(data).decode()
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def publish_task(self, task: Task) -> None:
|
||||
"""Publish a task with enhanced security and reliability"""
|
||||
try:
|
||||
# Validate message size
|
||||
task_data = task.json()
|
||||
if len(task_data) > self.security_config.max_message_size:
|
||||
raise ValueError(
|
||||
"Task data exceeds maximum message size"
|
||||
)
|
||||
|
||||
# Rate limiting
|
||||
current_time = time.time()
|
||||
if current_time - self.last_execution_time >= 60:
|
||||
self.execution_count = 0
|
||||
self.last_execution_time = current_time
|
||||
|
||||
if (
|
||||
self.execution_count
|
||||
>= self.security_config.rate_limit
|
||||
):
|
||||
raise ValueError("Rate limit exceeded")
|
||||
|
||||
# Encrypt and publish
|
||||
encrypted_data = self._encrypt_message(task_data)
|
||||
message_id = self.producer.send(encrypted_data)
|
||||
|
||||
self.execution_count += 1
|
||||
logger.info(
|
||||
f"Task {task.task_id} published successfully with message ID {message_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
TASK_FAILURES.inc()
|
||||
logger.error(
|
||||
f"Error publishing task {task.task_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _process_task(self, task: Task) -> TaskResult:
|
||||
"""Process a task with comprehensive error handling and monitoring"""
|
||||
task.status = "processing"
|
||||
task.started_at = datetime.utcnow()
|
||||
|
||||
with task_timing():
|
||||
try:
|
||||
# Select agent using round-robin
|
||||
agent = self.agents.pop(0)
|
||||
self.agents.append(agent)
|
||||
|
||||
# Execute task with timeout
|
||||
future = self.executor.submit(
|
||||
agent.run, task.description
|
||||
)
|
||||
result = future.result(timeout=self.task_timeout)
|
||||
|
||||
# Handle different output types
|
||||
if task.output_type == "json":
|
||||
result = json.loads(result)
|
||||
elif task.output_type == "file":
|
||||
file_path = f"output_{task.task_id}_{int(time.time())}.txt"
|
||||
with open(file_path, "w") as f:
|
||||
f.write(result)
|
||||
result = {"file_path": file_path}
|
||||
|
||||
task.status = "completed"
|
||||
task.completed_at = datetime.utcnow()
|
||||
TASK_COUNTER.inc()
|
||||
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="completed",
|
||||
result=result,
|
||||
execution_time=time.time()
|
||||
- task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name,
|
||||
)
|
||||
|
||||
except TimeoutError:
|
||||
TASK_FAILURES.inc()
|
||||
error_msg = f"Task {task.task_id} timed out after {self.task_timeout} seconds"
|
||||
logger.error(error_msg)
|
||||
task.status = "failed"
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="failed",
|
||||
result=None,
|
||||
error_message=error_msg,
|
||||
execution_time=time.time()
|
||||
- task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
TASK_FAILURES.inc()
|
||||
AGENT_ERRORS.inc()
|
||||
error_msg = (
|
||||
f"Error processing task {task.task_id}: {str(e)}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
task.status = "failed"
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
status="failed",
|
||||
result=None,
|
||||
error_message=error_msg,
|
||||
execution_time=time.time()
|
||||
- task.started_at.timestamp(),
|
||||
agent_id=agent.agent_name,
|
||||
)
|
||||
|
||||
async def consume_tasks(self):
|
||||
"""Enhanced task consumption with circuit breaker and backoff"""
|
||||
consecutive_failures = 0
|
||||
backoff_time = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Circuit breaker pattern
|
||||
if consecutive_failures >= 5:
|
||||
logger.warning(
|
||||
f"Circuit breaker triggered. Waiting {backoff_time} seconds"
|
||||
)
|
||||
await asyncio.sleep(backoff_time)
|
||||
backoff_time = min(backoff_time * 2, 60)
|
||||
continue
|
||||
|
||||
# Receive message with timeout
|
||||
message = await self.consumer.receive_async()
|
||||
|
||||
try:
|
||||
# Decrypt and process message
|
||||
decrypted_data = self._decrypt_message(
|
||||
message.data()
|
||||
)
|
||||
task_data = json.loads(decrypted_data)
|
||||
task = Task(**task_data)
|
||||
|
||||
# Process task
|
||||
result = await self._process_task(task)
|
||||
|
||||
# Handle result
|
||||
if result.status == "completed":
|
||||
await self.consumer.acknowledge_async(message)
|
||||
consecutive_failures = 0
|
||||
backoff_time = 1
|
||||
else:
|
||||
if task.retry_count < self.retry_attempts:
|
||||
task.retry_count += 1
|
||||
await self.consumer.negative_acknowledge(
|
||||
message
|
||||
)
|
||||
else:
|
||||
await self.consumer.acknowledge_async(
|
||||
message
|
||||
)
|
||||
logger.error(
|
||||
f"Task {task.task_id} failed after {self.retry_attempts} attempts"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message: {str(e)}"
|
||||
)
|
||||
await self.consumer.negative_acknowledge(message)
|
||||
consecutive_failures += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in consume_tasks: {str(e)}")
|
||||
consecutive_failures += 1
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit with proper cleanup"""
|
||||
try:
|
||||
self.producer.flush()
|
||||
self.producer.close()
|
||||
self.consumer.close()
|
||||
self.client.close()
|
||||
self.executor.shutdown(wait=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {str(e)}")
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Example usage with security configuration
|
||||
# security_config = SecurityConfig(
|
||||
# encryption_key=secrets.token_urlsafe(32),
|
||||
# tls_cert_path="/path/to/cert.pem",
|
||||
# tls_key_path="/path/to/key.pem",
|
||||
# auth_token="your-auth-token",
|
||||
# max_message_size=1048576,
|
||||
# rate_limit=100,
|
||||
# )
|
||||
|
||||
# # Agent factory function
|
||||
# def create_financial_agent() -> Agent:
|
||||
# """Factory function to create a financial analysis agent."""
|
||||
# return Agent(
|
||||
# agent_name="Financial-Analysis-Agent",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
# model_name="gpt-4o-mini",
|
||||
# max_loops=1,
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# verbose=True,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# saved_state_path="finance_agent.json",
|
||||
# user_name="swarms_corp",
|
||||
# retry_attempts=1,
|
||||
# context_length=200000,
|
||||
# return_step_meta=False,
|
||||
# output_type="string",
|
||||
# streaming_on=False,
|
||||
# )
|
||||
|
||||
# # Initialize agents (implementation not shown)
|
||||
# agents = [create_financial_agent() for _ in range(3)]
|
||||
|
||||
# # Initialize the secure swarm
|
||||
# with SecurePulsarSwarm(
|
||||
# name="Secure Financial Swarm",
|
||||
# description="Production-grade financial analysis swarm",
|
||||
# agents=agents,
|
||||
# pulsar_url="pulsar+ssl://localhost:6651",
|
||||
# subscription_name="secure_financial_subscription",
|
||||
# topic_name="secure_financial_tasks",
|
||||
# security_config=security_config,
|
||||
# max_workers=5,
|
||||
# retry_attempts=3,
|
||||
# task_timeout=300,
|
||||
# metrics_port=8000,
|
||||
# ) as swarm:
|
||||
# # Example task
|
||||
# task = Task(
|
||||
# task_id=secrets.token_urlsafe(16),
|
||||
# description="Analyze Q4 financial reports",
|
||||
# output_type="json",
|
||||
# priority="high",
|
||||
# metadata={
|
||||
# "department": "finance",
|
||||
# "requester": "john.doe@company.com",
|
||||
# },
|
||||
# )
|
||||
|
||||
# # Run the swarm
|
||||
# swarm.publish_task(task)
|
||||
# asyncio.run(swarm.consume_tasks())
|
@ -1,344 +0,0 @@
|
||||
import random
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.base_swarm import BaseSwarm
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
logger = initialize_logger(log_folder="swarm_load_balancer")
|
||||
|
||||
|
||||
class AgentLoadBalancer(BaseSwarm):
|
||||
"""
|
||||
A load balancer class that distributes tasks among a group of agents.
|
||||
|
||||
Args:
|
||||
agents (List[Agent]): The list of agents available for task execution.
|
||||
max_retries (int, optional): The maximum number of retries for a task if it fails. Defaults to 3.
|
||||
max_loops (int, optional): The maximum number of loops to run a task. Defaults to 5.
|
||||
cooldown_time (float, optional): The cooldown time between retries. Defaults to 0.
|
||||
|
||||
Attributes:
|
||||
agents (List[Agent]): The list of agents available for task execution.
|
||||
agent_status (Dict[str, bool]): The status of each agent, indicating whether it is available or not.
|
||||
max_retries (int): The maximum number of retries for a task if it fails.
|
||||
max_loops (int): The maximum number of loops to run a task.
|
||||
agent_performance (Dict[str, Dict[str, int]]): The performance statistics of each agent.
|
||||
lock (Lock): A lock to ensure thread safety.
|
||||
cooldown_time (float): The cooldown time between retries.
|
||||
|
||||
Methods:
|
||||
get_available_agent: Get an available agent for task execution.
|
||||
set_agent_status: Set the status of an agent.
|
||||
update_performance: Update the performance statistics of an agent.
|
||||
log_performance: Log the performance statistics of all agents.
|
||||
run_task: Run a single task using an available agent.
|
||||
run_multiple_tasks: Run multiple tasks using available agents.
|
||||
run_task_with_loops: Run a task multiple times using an available agent.
|
||||
run_task_with_callback: Run a task with a callback function.
|
||||
run_task_with_timeout: Run a task with a timeout.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: List[Agent],
|
||||
max_retries: int = 3,
|
||||
max_loops: int = 5,
|
||||
cooldown_time: float = 0,
|
||||
):
|
||||
self.agents = agents
|
||||
self.agent_status = {
|
||||
agent.agent_name: True for agent in agents
|
||||
}
|
||||
self.max_retries = max_retries
|
||||
self.max_loops = max_loops
|
||||
self.agent_performance = {
|
||||
agent.agent_name: {"success_count": 0, "failure_count": 0}
|
||||
for agent in agents
|
||||
}
|
||||
self.lock = Lock()
|
||||
self.cooldown_time = cooldown_time
|
||||
self.swarm_initialization()
|
||||
|
||||
def swarm_initialization(self):
|
||||
logger.info(
|
||||
"Initializing AgentLoadBalancer with the following agents:"
|
||||
)
|
||||
|
||||
# Make sure all the agents exist
|
||||
assert self.agents, "No agents provided to the Load Balancer"
|
||||
|
||||
# Assert that all agents are of type Agent
|
||||
for agent in self.agents:
|
||||
assert isinstance(
|
||||
agent, Agent
|
||||
), "All agents should be of type Agent"
|
||||
|
||||
for agent in self.agents:
|
||||
logger.info(f"Agent Name: {agent.agent_name}")
|
||||
|
||||
logger.info("Load Balancer Initialized Successfully!")
|
||||
|
||||
def get_available_agent(self) -> Optional[Agent]:
|
||||
"""
|
||||
Get an available agent for task execution.
|
||||
|
||||
Returns:
|
||||
Optional[Agent]: An available agent, or None if no agents are available.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
available_agents = [
|
||||
agent
|
||||
for agent in self.agents
|
||||
if self.agent_status[agent.agent_name]
|
||||
]
|
||||
logger.info(
|
||||
f"Available agents: {[agent.agent_name for agent in available_agents]}"
|
||||
)
|
||||
if not available_agents:
|
||||
return None
|
||||
return random.choice(available_agents)
|
||||
|
||||
def set_agent_status(self, agent: Agent, status: bool) -> None:
|
||||
"""
|
||||
Set the status of an agent.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent whose status needs to be set.
|
||||
status (bool): The status to set for the agent.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
self.agent_status[agent.agent_name] = status
|
||||
|
||||
def update_performance(self, agent: Agent, success: bool) -> None:
|
||||
"""
|
||||
Update the performance statistics of an agent.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent whose performance statistics need to be updated.
|
||||
success (bool): Whether the task executed by the agent was successful or not.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
if success:
|
||||
self.agent_performance[agent.agent_name][
|
||||
"success_count"
|
||||
] += 1
|
||||
else:
|
||||
self.agent_performance[agent.agent_name][
|
||||
"failure_count"
|
||||
] += 1
|
||||
|
||||
def log_performance(self) -> None:
|
||||
"""
|
||||
Log the performance statistics of all agents.
|
||||
|
||||
"""
|
||||
logger.info("Agent Performance:")
|
||||
for agent_name, stats in self.agent_performance.items():
|
||||
logger.info(f"{agent_name}: {stats}")
|
||||
|
||||
def run(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Run a single task using an available agent.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
|
||||
Returns:
|
||||
str: The output of the task execution.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available agents are found to handle the request.
|
||||
|
||||
"""
|
||||
try:
|
||||
retries = 0
|
||||
while retries < self.max_retries:
|
||||
agent = self.get_available_agent()
|
||||
if not agent:
|
||||
raise RuntimeError(
|
||||
"No available agents to handle the request."
|
||||
)
|
||||
|
||||
try:
|
||||
self.set_agent_status(agent, False)
|
||||
output = agent.run(task, *args, **kwargs)
|
||||
self.update_performance(agent, True)
|
||||
return output
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error with agent {agent.agent_name}: {e}"
|
||||
)
|
||||
self.update_performance(agent, False)
|
||||
retries += 1
|
||||
sleep(self.cooldown_time)
|
||||
if retries >= self.max_retries:
|
||||
raise e
|
||||
finally:
|
||||
self.set_agent_status(agent, True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task failed: {e} try again by optimizing the code."
|
||||
)
|
||||
raise RuntimeError(f"Task failed: {e}")
|
||||
|
||||
def run_multiple_tasks(self, tasks: List[str]) -> List[str]:
|
||||
"""
|
||||
Run multiple tasks using available agents.
|
||||
|
||||
Args:
|
||||
tasks (List[str]): The list of tasks to be executed.
|
||||
|
||||
Returns:
|
||||
List[str]: The list of outputs corresponding to each task execution.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for task in tasks:
|
||||
result = self.run(task)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def run_task_with_loops(self, task: str) -> List[str]:
|
||||
"""
|
||||
Run a task multiple times using an available agent.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
|
||||
Returns:
|
||||
List[str]: The list of outputs corresponding to each task execution.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for _ in range(self.max_loops):
|
||||
result = self.run(task)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def run_task_with_callback(
|
||||
self, task: str, callback: Callable[[str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Run a task with a callback function.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
callback (Callable[[str], None]): The callback function to be called with the task result.
|
||||
|
||||
"""
|
||||
try:
|
||||
result = self.run(task)
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Task failed: {e}")
|
||||
callback(str(e))
|
||||
|
||||
def run_task_with_timeout(self, task: str, timeout: float) -> str:
|
||||
"""
|
||||
Run a task with a timeout.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
timeout (float): The maximum time (in seconds) to wait for the task to complete.
|
||||
|
||||
Returns:
|
||||
str: The output of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution exceeds the specified timeout.
|
||||
Exception: If the task execution raises an exception.
|
||||
|
||||
"""
|
||||
import threading
|
||||
|
||||
result = [None]
|
||||
exception = [None]
|
||||
|
||||
def target():
|
||||
try:
|
||||
result[0] = self.run(task)
|
||||
except Exception as e:
|
||||
exception[0] = e
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
thread.join(timeout)
|
||||
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(
|
||||
f"Task timed out after {timeout} seconds."
|
||||
)
|
||||
|
||||
if exception[0]:
|
||||
raise exception[0]
|
||||
|
||||
return result[0]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from swarms import llama3Hosted()
|
||||
# # User initializes the agents
|
||||
# agents = [
|
||||
# Agent(
|
||||
# agent_name="Transcript Generator 1",
|
||||
# agent_description="Generate a transcript for a youtube video on what swarms are!",
|
||||
# llm=llama3Hosted(),
|
||||
# max_loops="auto",
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# streaming_on=True,
|
||||
# verbose=True,
|
||||
# stopping_token="<DONE>",
|
||||
# interactive=True,
|
||||
# state_save_file_type="json",
|
||||
# saved_state_path="transcript_generator_1.json",
|
||||
# ),
|
||||
# Agent(
|
||||
# agent_name="Transcript Generator 2",
|
||||
# agent_description="Generate a transcript for a youtube video on what swarms are!",
|
||||
# llm=llama3Hosted(),
|
||||
# max_loops="auto",
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# streaming_on=True,
|
||||
# verbose=True,
|
||||
# stopping_token="<DONE>",
|
||||
# interactive=True,
|
||||
# state_save_file_type="json",
|
||||
# saved_state_path="transcript_generator_2.json",
|
||||
# )
|
||||
# # Add more agents as needed
|
||||
# ]
|
||||
|
||||
# load_balancer = LoadBalancer(agents)
|
||||
|
||||
# try:
|
||||
# result = load_balancer.run_task("Generate a transcript for a youtube video on what swarms are!")
|
||||
# print(result)
|
||||
|
||||
# # Running multiple tasks
|
||||
# tasks = [
|
||||
# "Generate a transcript for a youtube video on what swarms are!",
|
||||
# "Generate a transcript for a youtube video on AI advancements!"
|
||||
# ]
|
||||
# results = load_balancer.run_multiple_tasks(tasks)
|
||||
# for res in results:
|
||||
# print(res)
|
||||
|
||||
# # Running task with loops
|
||||
# loop_results = load_balancer.run_task_with_loops("Generate a transcript for a youtube video on what swarms are!")
|
||||
# for res in loop_results:
|
||||
# print(res)
|
||||
|
||||
# except RuntimeError as e:
|
||||
# print(f"Error: {e}")
|
||||
|
||||
# # Log performance
|
||||
# load_balancer.log_performance()
|
@ -1,178 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
|
||||
logger = initialize_logger("workspace-manager")
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
Manages the workspace directory and settings for the application.
|
||||
This class is responsible for setting up the workspace directory, logging configuration,
|
||||
and retrieving environment variables for telemetry and API key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Optional[str] = "agent_workspace",
|
||||
use_telemetry: Optional[bool] = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the WorkspaceManager with optional parameters for workspace directory,
|
||||
telemetry usage, and API key.
|
||||
|
||||
Args:
|
||||
workspace_dir (Optional[str]): The path to the workspace directory.
|
||||
use_telemetry (Optional[bool]): A flag indicating whether to use telemetry.
|
||||
api_key (Optional[str]): The API key for the application.
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.use_telemetry = use_telemetry
|
||||
self.api_key = api_key
|
||||
|
||||
def _create_env_file(self, env_file_path: Path) -> None:
|
||||
"""
|
||||
Create a new .env file with default WORKSPACE_DIR.
|
||||
|
||||
Args:
|
||||
env_file_path (Path): The path to the .env file.
|
||||
"""
|
||||
with env_file_path.open("w") as file:
|
||||
file.write(f"WORKSPACE_DIR={self.workspace_dir}\n")
|
||||
logger.info(
|
||||
"Created a new .env file with default WORKSPACE_DIR."
|
||||
)
|
||||
|
||||
def _append_to_env_file(self, env_file_path: Path) -> None:
|
||||
"""
|
||||
Append WORKSPACE_DIR to .env if it doesn't exist.
|
||||
|
||||
Args:
|
||||
env_file_path (Path): The path to the .env file.
|
||||
"""
|
||||
with env_file_path.open("r+") as file:
|
||||
content = file.read()
|
||||
if "WORKSPACE_DIR" not in content:
|
||||
file.seek(0, os.SEEK_END)
|
||||
file.write(f"WORKSPACE_DIR={self.workspace_dir}\n")
|
||||
logger.info("Appended WORKSPACE_DIR to .env file.")
|
||||
|
||||
def _get_workspace_dir(
|
||||
self, workspace_dir: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get the workspace directory from environment variable or default.
|
||||
|
||||
Args:
|
||||
workspace_dir (Optional[str]): The path to the workspace directory.
|
||||
|
||||
Returns:
|
||||
str: The path to the workspace directory.
|
||||
"""
|
||||
return workspace_dir or os.getenv(
|
||||
"WORKSPACE_DIR", "agent_workspace"
|
||||
)
|
||||
|
||||
def _get_telemetry_status(
|
||||
self, use_telemetry: Optional[bool] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Get telemetry status from environment variable or default.
|
||||
|
||||
Args:
|
||||
use_telemetry (Optional[bool]): A flag indicating whether to use telemetry.
|
||||
|
||||
Returns:
|
||||
bool: The status of telemetry usage.
|
||||
"""
|
||||
return (
|
||||
use_telemetry
|
||||
if use_telemetry is not None
|
||||
else os.getenv("USE_TELEMETRY", "true").lower() == "true"
|
||||
)
|
||||
|
||||
def _get_api_key(
|
||||
self, api_key: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get API key from environment variable or default.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): The API key for the application.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The API key or None if not set.
|
||||
"""
|
||||
return api_key or os.getenv("SWARMS_API_KEY")
|
||||
|
||||
def _init_workspace(self) -> None:
|
||||
"""
|
||||
Initialize the workspace directory if it doesn't exist.
|
||||
"""
|
||||
if not self.workspace_path.exists():
|
||||
self.workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Workspace directory initialized.")
|
||||
|
||||
@property
|
||||
def get_workspace_path(self) -> Path:
|
||||
"""
|
||||
Get the workspace path.
|
||||
|
||||
Returns:
|
||||
Path: The path to the workspace directory.
|
||||
"""
|
||||
return self.workspace_path
|
||||
|
||||
@property
|
||||
def get_telemetry_status(self) -> bool:
|
||||
"""
|
||||
Get telemetry status.
|
||||
|
||||
Returns:
|
||||
bool: The status of telemetry usage.
|
||||
"""
|
||||
return self.use_telemetry
|
||||
|
||||
@property
|
||||
def get_api_key(self) -> Optional[str]:
|
||||
"""
|
||||
Get API key.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The API key or None if not set.
|
||||
"""
|
||||
return self.api_key
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
# Check if .env file exists and create it if it doesn't
|
||||
env_file_path = Path(".env")
|
||||
|
||||
# If the .env file doesn't exist, create it
|
||||
if not env_file_path.exists():
|
||||
self._create_env_file(env_file_path)
|
||||
else:
|
||||
# Append WORKSPACE_DIR to .env if it doesn't exist
|
||||
self._append_to_env_file(env_file_path)
|
||||
|
||||
# Set workspace directory
|
||||
self.workspace_dir = self._get_workspace_dir(
|
||||
self.workspace_dir
|
||||
)
|
||||
self.workspace_path = Path(self.workspace_dir)
|
||||
|
||||
# Set telemetry preference
|
||||
self.use_telemetry = self._get_telemetry_status(
|
||||
self.use_telemetry
|
||||
)
|
||||
|
||||
# Set API key
|
||||
self.api_key = self._get_api_key(self.api_key)
|
||||
|
||||
# Initialize workspace
|
||||
self._init_workspace()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing WorkspaceManager: {e}")
|
@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
from typing import Literal, Dict, Any, Union
|
||||
from fastmcp import Client
|
||||
from swarms.utils.any_to_str import any_to_str
|
||||
from swarms.utils.str_to_dict import str_to_dict
|
||||
|
||||
|
||||
def parse_agent_output(
|
||||
dictionary: Union[str, Dict[Any, Any]]
|
||||
) -> tuple[str, Dict[Any, Any]]:
|
||||
if isinstance(dictionary, str):
|
||||
dictionary = str_to_dict(dictionary)
|
||||
|
||||
elif not isinstance(dictionary, dict):
|
||||
raise ValueError("Invalid dictionary")
|
||||
|
||||
# Handle OpenAI function call format
|
||||
if "function_call" in dictionary:
|
||||
name = dictionary["function_call"]["name"]
|
||||
# arguments is a JSON string, so we need to parse it
|
||||
params = str_to_dict(dictionary["function_call"]["arguments"])
|
||||
return name, params
|
||||
|
||||
# Handle OpenAI tool calls format
|
||||
if "tool_calls" in dictionary:
|
||||
# Get the first tool call (or you could handle multiple if needed)
|
||||
tool_call = dictionary["tool_calls"][0]
|
||||
name = tool_call["function"]["name"]
|
||||
params = str_to_dict(tool_call["function"]["arguments"])
|
||||
return name, params
|
||||
|
||||
# Handle regular dictionary format
|
||||
if "name" in dictionary:
|
||||
name = dictionary["name"]
|
||||
params = dictionary.get("arguments", {})
|
||||
return name, params
|
||||
|
||||
raise ValueError("Invalid function call format")
|
||||
|
||||
|
||||
async def _execute_mcp_tool(
|
||||
url: str,
|
||||
method: Literal["stdio", "sse"] = "sse",
|
||||
parameters: Dict[Any, Any] = None,
|
||||
output_type: Literal["str", "dict"] = "str",
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict[Any, Any]:
|
||||
|
||||
if "sse" or "stdio" not in url:
|
||||
raise ValueError("Invalid URL")
|
||||
|
||||
url = f"{url}/{method}"
|
||||
|
||||
name, params = parse_agent_output(parameters)
|
||||
|
||||
if output_type == "str":
|
||||
async with Client(url, *args, **kwargs) as client:
|
||||
out = await client.call_tool(
|
||||
name=name,
|
||||
arguments=params,
|
||||
)
|
||||
return any_to_str(out)
|
||||
elif output_type == "dict":
|
||||
async with Client(url, *args, **kwargs) as client:
|
||||
out = await client.call_tool(
|
||||
name=name,
|
||||
arguments=params,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
|
||||
def execute_mcp_tool(
|
||||
url: str,
|
||||
tool_name: str = None,
|
||||
method: Literal["stdio", "sse"] = "sse",
|
||||
parameters: Dict[Any, Any] = None,
|
||||
output_type: Literal["str", "dict"] = "str",
|
||||
) -> Dict[Any, Any]:
|
||||
return asyncio.run(
|
||||
_execute_mcp_tool(
|
||||
url=url,
|
||||
tool_name=tool_name,
|
||||
method=method,
|
||||
parameters=parameters,
|
||||
output_type=output_type,
|
||||
)
|
||||
)
|
@ -1,81 +0,0 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
logger = initialize_logger(log_folder="swarm_reliability_checks")
|
||||
|
||||
|
||||
def reliability_check(
|
||||
agents: List[Union[Agent, Callable]],
|
||||
max_loops: int,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
flow: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Performs reliability checks on swarm configuration parameters.
|
||||
|
||||
Args:
|
||||
agents: List of Agent objects or callables that will be executed
|
||||
max_loops: Maximum number of execution loops
|
||||
name: Name identifier for the swarm
|
||||
description: Description of the swarm's purpose
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameters fail validation checks
|
||||
TypeError: If parameters are of incorrect type
|
||||
"""
|
||||
logger.info("Initializing swarm reliability checks")
|
||||
|
||||
# Type checking
|
||||
if not isinstance(agents, list):
|
||||
raise TypeError("agents parameter must be a list")
|
||||
|
||||
if not isinstance(max_loops, int):
|
||||
raise TypeError("max_loops must be an integer")
|
||||
|
||||
# Validate agents
|
||||
if not agents:
|
||||
raise ValueError("Agents list cannot be empty")
|
||||
|
||||
for i, agent in enumerate(agents):
|
||||
if not isinstance(agent, (Agent, Callable)):
|
||||
raise TypeError(
|
||||
f"Agent at index {i} must be an Agent instance or Callable"
|
||||
)
|
||||
|
||||
# Validate max_loops
|
||||
if max_loops <= 0:
|
||||
raise ValueError("max_loops must be greater than 0")
|
||||
|
||||
if max_loops > 1000:
|
||||
logger.warning(
|
||||
"Large max_loops value detected. This may impact performance."
|
||||
)
|
||||
|
||||
# Validate name
|
||||
if name is None:
|
||||
raise ValueError("name parameter is required")
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be a string")
|
||||
if len(name.strip()) == 0:
|
||||
raise ValueError("name cannot be empty or just whitespace")
|
||||
|
||||
# Validate description
|
||||
if description is None:
|
||||
raise ValueError("description parameter is required")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("description must be a string")
|
||||
if len(description.strip()) == 0:
|
||||
raise ValueError(
|
||||
"description cannot be empty or just whitespace"
|
||||
)
|
||||
|
||||
# Validate flow
|
||||
if flow is None:
|
||||
raise ValueError("flow parameter is required")
|
||||
if not isinstance(flow, str):
|
||||
raise TypeError("flow must be a string")
|
||||
|
||||
logger.info("All reliability checks passed successfully")
|
@ -0,0 +1,284 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from swarms.structs.agent import Agent
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class AgentBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
num_iterations: int = 5,
|
||||
output_dir: str = "benchmark_results",
|
||||
):
|
||||
self.num_iterations = num_iterations
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Use process pool for CPU-bound tasks
|
||||
self.process_pool = concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=min(os.cpu_count(), 4)
|
||||
)
|
||||
|
||||
# Use thread pool for I/O-bound tasks
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=min(os.cpu_count() * 2, 8)
|
||||
)
|
||||
|
||||
self.default_queries = [
|
||||
"Conduct an analysis of the best real undervalued ETFs",
|
||||
"What are the top performing tech stocks this quarter?",
|
||||
"Analyze current market trends in renewable energy sector",
|
||||
"Compare Bitcoin and Ethereum investment potential",
|
||||
"Evaluate the risk factors in emerging markets",
|
||||
]
|
||||
|
||||
self.agent = self._initialize_agent()
|
||||
self.process = psutil.Process()
|
||||
|
||||
# Cache for storing repeated query results
|
||||
self._query_cache = {}
|
||||
|
||||
def _initialize_agent(self) -> Agent:
|
||||
return Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
agent_description="Personal finance advisor agent",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
max_loops=1,
|
||||
model_name="gpt-4o-mini",
|
||||
dynamic_temperature_enabled=True,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
def _get_system_metrics(self) -> Dict[str, float]:
|
||||
# Optimized system metrics collection
|
||||
return {
|
||||
"cpu_percent": self.process.cpu_percent(),
|
||||
"memory_mb": self.process.memory_info().rss / 1024 / 1024,
|
||||
}
|
||||
|
||||
def _calculate_statistics(
|
||||
self, values: List[float]
|
||||
) -> Dict[str, float]:
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
sorted_values = sorted(values)
|
||||
n = len(sorted_values)
|
||||
mean_val = sum(values) / n
|
||||
|
||||
stats = {
|
||||
"mean": mean_val,
|
||||
"median": sorted_values[n // 2],
|
||||
"min": sorted_values[0],
|
||||
"max": sorted_values[-1],
|
||||
}
|
||||
|
||||
# Only calculate stdev if we have enough values
|
||||
if n > 1:
|
||||
stats["std_dev"] = (
|
||||
sum((x - mean_val) ** 2 for x in values) / n
|
||||
) ** 0.5
|
||||
|
||||
return {k: round(v, 3) for k, v in stats.items()}
|
||||
|
||||
async def process_iteration(
|
||||
self, query: str, iteration: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single iteration of a query"""
|
||||
try:
|
||||
# Check cache for repeated queries
|
||||
cache_key = f"{query}_{iteration}"
|
||||
if cache_key in self._query_cache:
|
||||
return self._query_cache[cache_key]
|
||||
|
||||
iteration_start = datetime.datetime.now()
|
||||
pre_metrics = self._get_system_metrics()
|
||||
|
||||
# Run the agent
|
||||
try:
|
||||
self.agent.run(query)
|
||||
success = True
|
||||
except Exception as e:
|
||||
str(e)
|
||||
success = False
|
||||
|
||||
execution_time = (
|
||||
datetime.datetime.now() - iteration_start
|
||||
).total_seconds()
|
||||
post_metrics = self._get_system_metrics()
|
||||
|
||||
result = {
|
||||
"execution_time": execution_time,
|
||||
"success": success,
|
||||
"pre_metrics": pre_metrics,
|
||||
"post_metrics": post_metrics,
|
||||
"iteration_data": {
|
||||
"iteration": iteration + 1,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"success": success,
|
||||
"system_metrics": {
|
||||
"pre": pre_metrics,
|
||||
"post": post_metrics,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
self._query_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in iteration {iteration}: {e}")
|
||||
raise
|
||||
|
||||
async def run_benchmark(
|
||||
self, queries: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the benchmark asynchronously"""
|
||||
queries = queries or self.default_queries
|
||||
benchmark_data = {
|
||||
"metadata": {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"num_iterations": self.num_iterations,
|
||||
"agent_config": {
|
||||
"model_name": self.agent.model_name,
|
||||
"max_loops": self.agent.max_loops,
|
||||
},
|
||||
},
|
||||
"results": {},
|
||||
}
|
||||
|
||||
async def process_query(query: str):
|
||||
query_results = {
|
||||
"execution_times": [],
|
||||
"system_metrics": [],
|
||||
"iterations": [],
|
||||
}
|
||||
|
||||
# Process iterations concurrently
|
||||
tasks = [
|
||||
self.process_iteration(query, i)
|
||||
for i in range(self.num_iterations)
|
||||
]
|
||||
iteration_results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in iteration_results:
|
||||
query_results["execution_times"].append(
|
||||
result["execution_time"]
|
||||
)
|
||||
query_results["system_metrics"].append(
|
||||
result["post_metrics"]
|
||||
)
|
||||
query_results["iterations"].append(
|
||||
result["iteration_data"]
|
||||
)
|
||||
|
||||
# Calculate statistics
|
||||
query_results["statistics"] = {
|
||||
"execution_time": self._calculate_statistics(
|
||||
query_results["execution_times"]
|
||||
),
|
||||
"memory_usage": self._calculate_statistics(
|
||||
[
|
||||
m["memory_mb"]
|
||||
for m in query_results["system_metrics"]
|
||||
]
|
||||
),
|
||||
"cpu_usage": self._calculate_statistics(
|
||||
[
|
||||
m["cpu_percent"]
|
||||
for m in query_results["system_metrics"]
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
return query, query_results
|
||||
|
||||
# Execute all queries concurrently
|
||||
query_tasks = [process_query(query) for query in queries]
|
||||
query_results = await asyncio.gather(*query_tasks)
|
||||
|
||||
for query, results in query_results:
|
||||
benchmark_data["results"][query] = results
|
||||
|
||||
return benchmark_data
|
||||
|
||||
def save_results(self, benchmark_data: Dict[str, Any]) -> str:
|
||||
"""Save benchmark results efficiently"""
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = (
|
||||
self.output_dir / f"benchmark_results_{timestamp}.json"
|
||||
)
|
||||
|
||||
# Write results in a single operation
|
||||
with open(filename, "w") as f:
|
||||
json.dump(benchmark_data, f, indent=2)
|
||||
|
||||
logger.info(f"Benchmark results saved to: {filename}")
|
||||
return str(filename)
|
||||
|
||||
def print_summary(self, results: Dict[str, Any]):
|
||||
"""Print a summary of the benchmark results"""
|
||||
print("\n=== Benchmark Summary ===")
|
||||
for query, data in results["results"].items():
|
||||
print(f"\nQuery: {query[:50]}...")
|
||||
stats = data["statistics"]["execution_time"]
|
||||
print(f"Average time: {stats['mean']:.2f}s")
|
||||
print(
|
||||
f"Memory usage (avg): {data['statistics']['memory_usage']['mean']:.1f}MB"
|
||||
)
|
||||
print(
|
||||
f"CPU usage (avg): {data['statistics']['cpu_usage']['mean']:.1f}%"
|
||||
)
|
||||
|
||||
async def run_with_timeout(
|
||||
self, timeout: int = 300
|
||||
) -> Dict[str, Any]:
|
||||
"""Run benchmark with timeout"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.run_benchmark(), timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"Benchmark timed out after {timeout} seconds"
|
||||
)
|
||||
raise
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
self.process_pool.shutdown()
|
||||
self.thread_pool.shutdown()
|
||||
self._query_cache.clear()
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Create and run benchmark
|
||||
benchmark = AgentBenchmark(num_iterations=1)
|
||||
|
||||
# Run benchmark with timeout
|
||||
results = await benchmark.run_with_timeout(timeout=300)
|
||||
|
||||
# Save results
|
||||
benchmark.save_results(results)
|
||||
|
||||
# Print summary
|
||||
benchmark.print_summary(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Benchmark failed: {e}")
|
||||
finally:
|
||||
# Cleanup resources
|
||||
benchmark.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the async main function
|
||||
asyncio.run(main())
|
Loading…
Reference in new issue