Merge pull request #787 from harshalmore31/octotools
Implemented OctoTools Framework with Swarmspull/790/head
commit
ddb4119c04
@ -0,0 +1,725 @@
|
|||||||
|
"""
|
||||||
|
OctoToolsSwarm: A multi-agent system for complex reasoning.
|
||||||
|
Implements the OctoTools framework using swarms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
import math # Import the math module
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from swarms import Agent
|
||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
# from exa_search import exa_search as web_search_execute
|
||||||
|
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolType(Enum):
|
||||||
|
"""Defines the types of tools available."""
|
||||||
|
|
||||||
|
IMAGE_CAPTIONER = "image_captioner"
|
||||||
|
OBJECT_DETECTOR = "object_detector"
|
||||||
|
WEB_SEARCH = "web_search"
|
||||||
|
PYTHON_CALCULATOR = "python_calculator"
|
||||||
|
# Add more tool types as needed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tool:
|
||||||
|
"""
|
||||||
|
Represents an external tool.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Unique name of the tool.
|
||||||
|
description: Description of the tool's function.
|
||||||
|
metadata: Dictionary containing tool metadata.
|
||||||
|
execute_func: Callable function that executes the tool's logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
execute_func: Callable
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
"""Executes the tool's logic, handling potential errors."""
|
||||||
|
try:
|
||||||
|
return self.execute_func(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool {self.name}: {str(e)}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRole(Enum):
|
||||||
|
"""Defines the roles for agents in the OctoTools system."""
|
||||||
|
|
||||||
|
PLANNER = "planner"
|
||||||
|
VERIFIER = "verifier"
|
||||||
|
SUMMARIZER = "summarizer"
|
||||||
|
|
||||||
|
|
||||||
|
class OctoToolsSwarm:
|
||||||
|
"""
|
||||||
|
A multi-agent system implementing the OctoTools framework.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name: Name of the LLM model to use.
|
||||||
|
max_iterations: Maximum number of action-execution iterations.
|
||||||
|
base_path: Path for saving agent states.
|
||||||
|
tools: List of available Tool objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: List[Tool],
|
||||||
|
model_name: str = "gemini/gemini-2.0-flash",
|
||||||
|
max_iterations: int = 10,
|
||||||
|
base_path: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the OctoToolsSwarm system."""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
self.base_path = Path(base_path) if base_path else Path("./octotools_states")
|
||||||
|
self.base_path.mkdir(exist_ok=True)
|
||||||
|
self.tools = {tool.name: tool for tool in tools} # Store tools in a dictionary
|
||||||
|
|
||||||
|
# Initialize agents
|
||||||
|
self._init_agents()
|
||||||
|
|
||||||
|
# Create conversation tracker and memory
|
||||||
|
self.conversation = Conversation()
|
||||||
|
self.memory = [] # Store the trajectory
|
||||||
|
|
||||||
|
def _init_agents(self) -> None:
|
||||||
|
"""Initialize all agents with their specific roles and prompts."""
|
||||||
|
# Planner agent
|
||||||
|
self.planner = Agent(
|
||||||
|
agent_name="OctoTools-Planner",
|
||||||
|
system_prompt=self._get_planner_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=3,
|
||||||
|
saved_state_path=str(self.base_path / "planner.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verifier agent
|
||||||
|
self.verifier = Agent(
|
||||||
|
agent_name="OctoTools-Verifier",
|
||||||
|
system_prompt=self._get_verifier_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(self.base_path / "verifier.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summarizer agent
|
||||||
|
self.summarizer = Agent(
|
||||||
|
agent_name="OctoTools-Summarizer",
|
||||||
|
system_prompt=self._get_summarizer_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(self.base_path / "summarizer.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_planner_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the planner agent (Improved with few-shot examples)."""
|
||||||
|
tool_descriptions = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {tool_name}: {self.tools[tool_name].description}"
|
||||||
|
for tool_name in self.tools
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"""You are the Planner in the OctoTools framework. Your role is to analyze the user's query,
|
||||||
|
identify required skills, suggest relevant tools, and plan the steps to solve the problem.
|
||||||
|
|
||||||
|
1. **Analyze the user's query:** Understand the requirements and identify the necessary skills and potentially relevant tools.
|
||||||
|
2. **Perform high-level planning:** Create a rough outline of how tools might be used to solve the problem.
|
||||||
|
3. **Perform low-level planning (action prediction):** At each step, select the best tool to use and formulate a specific sub-goal for that tool, considering the current context.
|
||||||
|
|
||||||
|
Available Tools:
|
||||||
|
{tool_descriptions}
|
||||||
|
|
||||||
|
Output your response in JSON format. Here are examples for different stages:
|
||||||
|
|
||||||
|
**Query Analysis (High-Level Planning):**
|
||||||
|
Example Input:
|
||||||
|
Query: "What is the capital of France?"
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"summary": "The user is asking for the capital of France.",
|
||||||
|
"required_skills": ["knowledge retrieval"],
|
||||||
|
"relevant_tools": ["Web_Search_Tool"]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Action Prediction (Low-Level Planning):**
|
||||||
|
Example Input:
|
||||||
|
Context: {{ "query": "What is the capital of France?", "available_tools": ["Web_Search_Tool"] }}
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"justification": "The Web_Search_Tool can be used to directly find the capital of France.",
|
||||||
|
"context": {{}},
|
||||||
|
"sub_goal": "Search the web for 'capital of France'.",
|
||||||
|
"tool_name": "Web_Search_Tool"
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
Another Example:
|
||||||
|
Context: {{"query": "How many objects are in the image?", "available_tools": ["Image_Captioner_Tool", "Object_Detector_Tool"], "image": "objects.png"}}
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"justification": "First, get a general description of the image to understand the context.",
|
||||||
|
"context": {{ "image": "objects.png" }},
|
||||||
|
"sub_goal": "Generate a description of the image.",
|
||||||
|
"tool_name": "Image_Captioner_Tool"
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example for Finding Square Root:
|
||||||
|
Context: {{"query": "What is the square root of the number of objects in the image?", "available_tools": ["Object_Detector_Tool", "Python_Calculator_Tool"], "image": "objects.png", "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"]}}
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"justification": "We have detected 4 objects in the image. Now we need to find the square root of 4.",
|
||||||
|
"context": {{}},
|
||||||
|
"sub_goal": "Calculate the square root of 4",
|
||||||
|
"tool_name": "Python_Calculator_Tool"
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
Your output MUST be a single, valid JSON object with the following keys:
|
||||||
|
- justification (string): Your reasoning.
|
||||||
|
- context (dict): A dictionary containing relevant information.
|
||||||
|
- sub_goal (string): The specific instruction for the tool.
|
||||||
|
- tool_name (string): The EXACT name of the tool to use.
|
||||||
|
|
||||||
|
Do NOT include any text outside of the JSON object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_verifier_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the verifier agent (Improved with few-shot examples)."""
|
||||||
|
return """You are the Context Verifier in the OctoTools framework. Your role is to analyze the current context
|
||||||
|
and memory to determine if the problem is solved, if there are any inconsistencies, or if further steps are needed.
|
||||||
|
|
||||||
|
Output your response in JSON format:
|
||||||
|
|
||||||
|
Expected output structure:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"completeness": "Indicate whether the query is fully, partially, or not answered.",
|
||||||
|
"inconsistencies": "List any inconsistencies found in the context or memory.",
|
||||||
|
"verification_needs": "List any information that needs further verification.",
|
||||||
|
"ambiguities": "List any ambiguities found in the context or memory.",
|
||||||
|
"stop_signal": true/false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Input:
|
||||||
|
Context: { "last_result": { "result": "Caption: The image shows a cat." } }
|
||||||
|
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Image_Captioner_Tool" } } ]
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"completeness": "partial",
|
||||||
|
"inconsistencies": [],
|
||||||
|
"verification_needs": ["Object detection to confirm the presence of a cat."],
|
||||||
|
"ambiguities": [],
|
||||||
|
"stop_signal": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Another Example:
|
||||||
|
Context: { "last_result": { "result": ["Detected object: cat"] } }
|
||||||
|
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } } ]
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"completeness": "yes",
|
||||||
|
"inconsistencies": [],
|
||||||
|
"verification_needs": [],
|
||||||
|
"ambiguities": [],
|
||||||
|
"stop_signal": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Square Root Example:
|
||||||
|
Context: {
|
||||||
|
"query": "What is the square root of the number of objects in the image?",
|
||||||
|
"image": "example.png",
|
||||||
|
"Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"],
|
||||||
|
"Python_Calculator_Tool_result": "Result of 4**0.5 is 2.0"
|
||||||
|
}
|
||||||
|
Memory: [
|
||||||
|
{ "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } },
|
||||||
|
{ "component": "Action Predictor", "result": { "tool_name": "Python_Calculator_Tool" } }
|
||||||
|
]
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"completeness": "yes",
|
||||||
|
"inconsistencies": [],
|
||||||
|
"verification_needs": [],
|
||||||
|
"ambiguities": [],
|
||||||
|
"stop_signal": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_summarizer_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the summarizer agent (Improved with few-shot examples)."""
|
||||||
|
return """You are the Solution Summarizer in the OctoTools framework. Your role is to synthesize the final
|
||||||
|
answer to the user's query based on the complete trajectory of actions and results.
|
||||||
|
|
||||||
|
Output your response in JSON format:
|
||||||
|
|
||||||
|
Expected output structure:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"final_answer": "Provide a clear and concise answer to the original query."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
Example Input:
|
||||||
|
Memory: [
|
||||||
|
{"component": "Query Analyzer", "result": {"summary": "Find the capital of France."}},
|
||||||
|
{"component": "Action Predictor", "result": {"tool_name": "Web_Search_Tool"}},
|
||||||
|
{"component": "Tool Execution", "result": {"result": "The capital of France is Paris."}}
|
||||||
|
]
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"final_answer": "The capital of France is Paris."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Square Root Example:
|
||||||
|
Memory: [
|
||||||
|
{"component": "Query Analyzer", "result": {"summary": "Find the square root of the number of objects in the image."}},
|
||||||
|
{"component": "Action Predictor", "result": {"tool_name": "Object_Detector_Tool", "sub_goal": "Detect objects in the image"}},
|
||||||
|
{"component": "Tool Execution", "result": {"result": ["object1", "object2", "object3", "object4"]}},
|
||||||
|
{"component": "Action Predictor", "result": {"tool_name": "Python_Calculator_Tool", "sub_goal": "Calculate the square root of 4"}},
|
||||||
|
{"component": "Tool Execution", "result": {"result": "Result of 4**0.5 is 2.0"}}
|
||||||
|
]
|
||||||
|
|
||||||
|
Example Output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"final_answer": "The square root of the number of objects in the image is 2.0. There are 4 objects in the image, and the square root of 4 is 2.0."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
|
||||||
|
"""Safely parse JSON, handling errors and using recursive descent."""
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}")
|
||||||
|
try:
|
||||||
|
# More robust JSON extraction with recursive descent
|
||||||
|
def extract_json(s):
|
||||||
|
stack = []
|
||||||
|
start = -1
|
||||||
|
for i, c in enumerate(s):
|
||||||
|
if c == '{':
|
||||||
|
if not stack:
|
||||||
|
start = i
|
||||||
|
stack.append(c)
|
||||||
|
elif c == '}':
|
||||||
|
if stack:
|
||||||
|
stack.pop()
|
||||||
|
if not stack and start != -1:
|
||||||
|
return s[start:i+1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
extracted_json = extract_json(json_str)
|
||||||
|
if extracted_json:
|
||||||
|
logger.info(f"Extracted JSON: {extracted_json}")
|
||||||
|
return json.loads(extracted_json)
|
||||||
|
else:
|
||||||
|
logger.error("Failed to extract JSON using recursive descent.")
|
||||||
|
return {"error": "Failed to parse JSON", "content": json_str}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error during JSON extraction: {e}")
|
||||||
|
return {"error": "Failed to parse JSON", "content": json_str}
|
||||||
|
|
||||||
|
def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Executes a tool based on its name and provided context."""
|
||||||
|
if tool_name not in self.tools:
|
||||||
|
return {"error": f"Tool '{tool_name}' not found."}
|
||||||
|
|
||||||
|
tool = self.tools[tool_name]
|
||||||
|
try:
|
||||||
|
# For Python Calculator tool, handle object counts from Object Detector
|
||||||
|
if tool_name == "Python_Calculator_Tool":
|
||||||
|
# Check for object detector results
|
||||||
|
object_detector_result = context.get("Object_Detector_Tool_result")
|
||||||
|
if object_detector_result and isinstance(object_detector_result, list):
|
||||||
|
# Calculate the number of objects
|
||||||
|
num_objects = len(object_detector_result)
|
||||||
|
# If sub_goal doesn't already contain an expression, create one
|
||||||
|
if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]:
|
||||||
|
context["expression"] = f"{num_objects}**0.5"
|
||||||
|
elif "expression" not in context:
|
||||||
|
# Default to square root if no expression is specified
|
||||||
|
context["expression"] = f"{num_objects}**0.5"
|
||||||
|
|
||||||
|
# Filter context: only pass expected inputs to the tool
|
||||||
|
valid_inputs = {
|
||||||
|
k: v for k, v in context.items() if k in tool.metadata.get("input_types", {})
|
||||||
|
}
|
||||||
|
result = tool.execute(**valid_inputs)
|
||||||
|
return {"result": result}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error executing tool {tool_name}: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]:
|
||||||
|
"""Runs a swarms agent, handling output and JSON parsing."""
|
||||||
|
try:
|
||||||
|
# Construct the full input, including the system prompt
|
||||||
|
full_input = f"{agent.system_prompt}\n\n{input_prompt}"
|
||||||
|
|
||||||
|
# Run the agent and capture the output
|
||||||
|
agent_response = agent.run(full_input)
|
||||||
|
|
||||||
|
logger.info(f"DEBUG: Raw agent response: {agent_response}")
|
||||||
|
|
||||||
|
# Extract the LLM's response (remove conversation history, etc.)
|
||||||
|
response_text = agent_response # Assuming direct return
|
||||||
|
|
||||||
|
# Try to parse the response as JSON
|
||||||
|
parsed_response = self._safely_parse_json(response_text)
|
||||||
|
|
||||||
|
return parsed_response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error running agent {agent.agent_name}: {e}")
|
||||||
|
return {"error": f"Agent {agent.agent_name} failed: {str(e)}"}
|
||||||
|
|
||||||
|
def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""Execute the task through the multi-agent workflow."""
|
||||||
|
logger.info(f"Starting task: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Query Analysis (High-Level Planning)
|
||||||
|
planner_input = (
|
||||||
|
f"Analyze the following query and determine the necessary skills and"
|
||||||
|
f" relevant tools: {query}"
|
||||||
|
)
|
||||||
|
query_analysis = self._run_agent(self.planner, planner_input)
|
||||||
|
|
||||||
|
if "error" in query_analysis:
|
||||||
|
return {
|
||||||
|
"error": f"Planner query analysis failed: {query_analysis['error']}",
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.memory.append(
|
||||||
|
{"step": 0, "component": "Query Analyzer", "result": query_analysis}
|
||||||
|
)
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.planner.agent_name, content=json.dumps(query_analysis)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize context with the query and image (if provided)
|
||||||
|
context = {"query": query}
|
||||||
|
if image:
|
||||||
|
context["image"] = image
|
||||||
|
|
||||||
|
# Add available tools to context
|
||||||
|
if "relevant_tools" in query_analysis:
|
||||||
|
context["available_tools"] = query_analysis["relevant_tools"]
|
||||||
|
else:
|
||||||
|
# If no relevant tools specified, make all tools available
|
||||||
|
context["available_tools"] = list(self.tools.keys())
|
||||||
|
|
||||||
|
step_count = 1
|
||||||
|
|
||||||
|
# Step 2: Iterative Action-Execution Loop
|
||||||
|
while step_count <= self.max_iterations:
|
||||||
|
logger.info(f"Starting iteration {step_count} of {self.max_iterations}")
|
||||||
|
|
||||||
|
# Step 2a: Action Prediction (Low-Level Planning)
|
||||||
|
action_planner_input = (
|
||||||
|
f"Current Context: {json.dumps(context)}\nAvailable Tools:"
|
||||||
|
f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the"
|
||||||
|
" next step."
|
||||||
|
)
|
||||||
|
action = self._run_agent(self.planner, action_planner_input)
|
||||||
|
if "error" in action:
|
||||||
|
logger.error(f"Error in action prediction: {action['error']}")
|
||||||
|
return {
|
||||||
|
"error": f"Planner action prediction failed: {action['error']}",
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string()
|
||||||
|
}
|
||||||
|
self.memory.append(
|
||||||
|
{"step": step_count, "component": "Action Predictor", "result": action}
|
||||||
|
)
|
||||||
|
self.conversation.add(role=self.planner.agent_name, content=json.dumps(action))
|
||||||
|
|
||||||
|
# Input Validation for Action (Relaxed)
|
||||||
|
if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action:
|
||||||
|
error_msg = (
|
||||||
|
"Action prediction did not return required fields (tool_name,"
|
||||||
|
" sub_goal) or was not a dictionary."
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
self.memory.append(
|
||||||
|
{"step": step_count, "component": "Error", "result": error_msg}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Step 2b: Execute Tool
|
||||||
|
tool_execution_context = {
|
||||||
|
**context,
|
||||||
|
**action.get("context", {}), # Add any additional context
|
||||||
|
"sub_goal": action["sub_goal"], # Pass sub_goal to tool
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_result = self._execute_tool(action["tool_name"], tool_execution_context)
|
||||||
|
|
||||||
|
self.memory.append(
|
||||||
|
{
|
||||||
|
"step": step_count,
|
||||||
|
"component": "Tool Execution",
|
||||||
|
"result": tool_result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2c: Context Update - Store result with a descriptive key
|
||||||
|
if "result" in tool_result:
|
||||||
|
context[f"{action['tool_name']}_result"] = tool_result["result"]
|
||||||
|
if "error" in tool_result:
|
||||||
|
context[f"{action['tool_name']}_error"] = tool_result["error"]
|
||||||
|
|
||||||
|
# Step 2d: Context Verification
|
||||||
|
verifier_input = (
|
||||||
|
f"Current Context: {json.dumps(context)}\nMemory:"
|
||||||
|
f" {json.dumps(self.memory)}\nQuery: {query}"
|
||||||
|
)
|
||||||
|
verification = self._run_agent(self.verifier, verifier_input)
|
||||||
|
if "error" in verification:
|
||||||
|
return {
|
||||||
|
"error": f"Verifier failed: {verification['error']}",
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.memory.append(
|
||||||
|
{
|
||||||
|
"step": step_count,
|
||||||
|
"component": "Context Verifier",
|
||||||
|
"result": verification,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification))
|
||||||
|
|
||||||
|
# Check for stop signal from Verifier
|
||||||
|
if verification.get("stop_signal") is True:
|
||||||
|
logger.info("Received stop signal from verifier. Stopping iterations.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Safety mechanism - if we've executed the same tool multiple times
|
||||||
|
same_tool_count = sum(
|
||||||
|
1 for m in self.memory
|
||||||
|
if m.get("component") == "Action Predictor"
|
||||||
|
and m.get("result", {}).get("tool_name") == action.get("tool_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
if same_tool_count > 3:
|
||||||
|
logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.")
|
||||||
|
break
|
||||||
|
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
# Step 3: Solution Summarization
|
||||||
|
summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}"
|
||||||
|
|
||||||
|
summarization = self._run_agent(self.summarizer, summarizer_input)
|
||||||
|
if "error" in summarization:
|
||||||
|
return {
|
||||||
|
"error": f"Summarizer failed: {summarization['error']}",
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string()
|
||||||
|
}
|
||||||
|
self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"final_answer": summarization.get("final_answer", "No answer found."),
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error in run method: {e}") # More detailed
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"trajectory": self.memory,
|
||||||
|
"conversation": self.conversation.return_history_as_string(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_state(self) -> None:
|
||||||
|
"""Save the current state of all agents."""
|
||||||
|
for agent in [self.planner, self.verifier, self.summarizer]:
|
||||||
|
try:
|
||||||
|
agent.save_state()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving state for {agent.agent_name}: {str(e)}")
|
||||||
|
|
||||||
|
def load_state(self) -> None:
|
||||||
|
"""Load the saved state of all agents."""
|
||||||
|
for agent in [self.planner, self.verifier, self.summarizer]:
|
||||||
|
try:
|
||||||
|
agent.load_state()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading state for {agent.agent_name}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Example Usage ---
|
||||||
|
|
||||||
|
|
||||||
|
# Define dummy tool functions (replace with actual implementations)
|
||||||
|
def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str:
|
||||||
|
"""Dummy image captioner."""
|
||||||
|
print(f"image_captioner_execute called with image: {image}, prompt: {prompt}")
|
||||||
|
return f"Caption for {image}: A descriptive caption (dummy)." # Simplified
|
||||||
|
|
||||||
|
|
||||||
|
def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]:
|
||||||
|
"""Dummy object detector, handles missing labels gracefully."""
|
||||||
|
print(f"object_detector_execute called with image: {image}, labels: {labels}")
|
||||||
|
if not labels:
|
||||||
|
return ["object1", "object2", "object3", "object4"] # Return default objects if no labels
|
||||||
|
return [f"Detected {label}" for label in labels] # Simplified
|
||||||
|
|
||||||
|
|
||||||
|
def web_search_execute(query: str, **kwargs) -> str:
|
||||||
|
"""Dummy web search."""
|
||||||
|
print(f"web_search_execute called with query: {query}")
|
||||||
|
return f"Search results for '{query}'..." # Simplified
|
||||||
|
|
||||||
|
|
||||||
|
def python_calculator_execute(expression: str, **kwargs) -> str:
|
||||||
|
"""Python calculator (using math module)."""
|
||||||
|
print(f"python_calculator_execute called with: {expression}")
|
||||||
|
try:
|
||||||
|
# Safely evaluate only simple expressions involving numbers and basic operations
|
||||||
|
if re.match(r"^[0-9+\-*/().\s]+$", expression):
|
||||||
|
result = eval(expression, {"__builtins__": {}, "math": math})
|
||||||
|
return f"Result of {expression} is {result}"
|
||||||
|
else:
|
||||||
|
return "Error: Invalid expression for calculator."
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
# Create utility function to get default tools
|
||||||
|
def get_default_tools() -> List[Tool]:
|
||||||
|
"""Returns a list of default tools that can be used with OctoToolsSwarm."""
|
||||||
|
image_captioner = Tool(
|
||||||
|
name="Image_Captioner_Tool",
|
||||||
|
description="Generates a caption for an image.",
|
||||||
|
metadata={
|
||||||
|
"input_types": {"image": "str", "prompt": "str"},
|
||||||
|
"output_type": "str",
|
||||||
|
"limitations": "May struggle with complex scenes or ambiguous objects.",
|
||||||
|
"best_practices": "Use with clear, well-lit images. Provide specific prompts for better results.",
|
||||||
|
},
|
||||||
|
execute_func=image_captioner_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
object_detector = Tool(
|
||||||
|
name="Object_Detector_Tool",
|
||||||
|
description="Detects objects in an image.",
|
||||||
|
metadata={
|
||||||
|
"input_types": {"image": "str", "labels": "list"},
|
||||||
|
"output_type": "list",
|
||||||
|
"limitations": "Accuracy depends on the quality of the image and the clarity of the objects.",
|
||||||
|
"best_practices": "Provide a list of specific object labels to detect. Use high-resolution images.",
|
||||||
|
},
|
||||||
|
execute_func=object_detector_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
web_search = Tool(
|
||||||
|
name="Web_Search_Tool",
|
||||||
|
description="Performs a web search.",
|
||||||
|
metadata={
|
||||||
|
"input_types": {"query": "str"},
|
||||||
|
"output_type": "str",
|
||||||
|
"limitations": "May not find specific or niche information.",
|
||||||
|
"best_practices": "Use specific and descriptive keywords for better results.",
|
||||||
|
},
|
||||||
|
execute_func=web_search_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
calculator = Tool(
|
||||||
|
name="Python_Calculator_Tool",
|
||||||
|
description="Evaluates a Python expression.",
|
||||||
|
metadata={
|
||||||
|
"input_types": {"expression": "str"},
|
||||||
|
"output_type": "str",
|
||||||
|
"limitations": "Cannot handle complex mathematical functions or libraries.",
|
||||||
|
"best_practices": "Use for basic arithmetic and simple calculations.",
|
||||||
|
},
|
||||||
|
execute_func=python_calculator_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [image_captioner, object_detector, web_search, calculator]
|
||||||
|
|
||||||
|
|
||||||
|
# Only execute the example when this script is run directly
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print("Running OctoToolsSwarm example...")
|
||||||
|
|
||||||
|
# # Create an OctoToolsSwarm agent with default tools
|
||||||
|
# tools = get_default_tools()
|
||||||
|
# agent = OctoToolsSwarm(tools=tools)
|
||||||
|
|
||||||
|
# # Example query
|
||||||
|
# query = "What is the square root of the number of objects in this image?"
|
||||||
|
|
||||||
|
# # Create a dummy image file for testing if it doesn't exist
|
||||||
|
# image_path = "example.png"
|
||||||
|
# if not os.path.exists(image_path):
|
||||||
|
# with open(image_path, "w") as f:
|
||||||
|
# f.write("Dummy image content")
|
||||||
|
# print(f"Created dummy image file: {image_path}")
|
||||||
|
|
||||||
|
# # Run the agent
|
||||||
|
# result = agent.run(query, image=image_path)
|
||||||
|
|
||||||
|
# # Display results
|
||||||
|
# print("\n=== FINAL ANSWER ===")
|
||||||
|
# print(result["final_answer"])
|
||||||
|
|
||||||
|
# print("\n=== TRAJECTORY SUMMARY ===")
|
||||||
|
# for step in result["trajectory"]:
|
||||||
|
# print(f"Step {step.get('step', 'N/A')}: {step.get('component', 'Unknown')}")
|
||||||
|
|
||||||
|
# print("\nOctoToolsSwarm example completed.")
|
@ -1,580 +0,0 @@
|
|||||||
"""
|
|
||||||
TalkHier: A hierarchical multi-agent framework for content generation and refinement.
|
|
||||||
Implements structured communication and evaluation protocols.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from swarms import Agent
|
|
||||||
from swarms.structs.conversation import Conversation
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRole(Enum):
|
|
||||||
"""Defines the possible roles for agents in the system."""
|
|
||||||
|
|
||||||
SUPERVISOR = "supervisor"
|
|
||||||
GENERATOR = "generator"
|
|
||||||
EVALUATOR = "evaluator"
|
|
||||||
REVISOR = "revisor"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommunicationEvent:
|
|
||||||
"""Represents a structured communication event between agents."""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
background: Optional[str] = None
|
|
||||||
intermediate_output: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TalkHier:
|
|
||||||
"""
|
|
||||||
A hierarchical multi-agent system for content generation and refinement.
|
|
||||||
|
|
||||||
Implements the TalkHier framework with structured communication protocols
|
|
||||||
and hierarchical refinement processes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
max_iterations: Maximum number of refinement iterations
|
|
||||||
quality_threshold: Minimum score required for content acceptance
|
|
||||||
model_name: Name of the LLM model to use
|
|
||||||
base_path: Path for saving agent states
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_iterations: int = 3,
|
|
||||||
quality_threshold: float = 0.8,
|
|
||||||
model_name: str = "gpt-4",
|
|
||||||
base_path: Optional[str] = None,
|
|
||||||
return_string: bool = False,
|
|
||||||
):
|
|
||||||
"""Initialize the TalkHier system."""
|
|
||||||
self.max_iterations = max_iterations
|
|
||||||
self.quality_threshold = quality_threshold
|
|
||||||
self.model_name = model_name
|
|
||||||
self.return_string = return_string
|
|
||||||
self.base_path = (
|
|
||||||
Path(base_path) if base_path else Path("./agent_states")
|
|
||||||
)
|
|
||||||
self.base_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# Initialize agents
|
|
||||||
self._init_agents()
|
|
||||||
|
|
||||||
# Create conversation
|
|
||||||
self.conversation = Conversation()
|
|
||||||
|
|
||||||
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Safely parse JSON string, handling various formats and potential errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_str: String containing JSON data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed dictionary
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try direct JSON parsing
|
|
||||||
return json.loads(json_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
try:
|
|
||||||
# Try extracting JSON from potential text wrapper
|
|
||||||
import re
|
|
||||||
|
|
||||||
json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
return json.loads(json_match.group())
|
|
||||||
# Try extracting from markdown code blocks
|
|
||||||
code_block_match = re.search(
|
|
||||||
r"```(?:json)?\s*(\{.*?\})\s*```",
|
|
||||||
json_str,
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
if code_block_match:
|
|
||||||
return json.loads(code_block_match.group(1))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to extract JSON: {str(e)}")
|
|
||||||
|
|
||||||
# Fallback: create structured dict from text
|
|
||||||
return {
|
|
||||||
"content": json_str,
|
|
||||||
"metadata": {
|
|
||||||
"parsed": False,
|
|
||||||
"timestamp": str(datetime.now()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def _init_agents(self) -> None:
|
|
||||||
"""Initialize all agents with their specific roles and prompts."""
|
|
||||||
# Main supervisor agent
|
|
||||||
self.main_supervisor = Agent(
|
|
||||||
agent_name="Main-Supervisor",
|
|
||||||
system_prompt=self._get_supervisor_prompt(),
|
|
||||||
model_name=self.model_name,
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path=str(
|
|
||||||
self.base_path / "main_supervisor.json"
|
|
||||||
),
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generator agent
|
|
||||||
self.generator = Agent(
|
|
||||||
agent_name="Content-Generator",
|
|
||||||
system_prompt=self._get_generator_prompt(),
|
|
||||||
model_name=self.model_name,
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path=str(self.base_path / "generator.json"),
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluators
|
|
||||||
self.evaluators = [
|
|
||||||
Agent(
|
|
||||||
agent_name=f"Evaluator-{i}",
|
|
||||||
system_prompt=self._get_evaluator_prompt(i),
|
|
||||||
model_name=self.model_name,
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path=str(
|
|
||||||
self.base_path / f"evaluator_{i}.json"
|
|
||||||
),
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
for i in range(3)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Revisor agent
|
|
||||||
self.revisor = Agent(
|
|
||||||
agent_name="Content-Revisor",
|
|
||||||
system_prompt=self._get_revisor_prompt(),
|
|
||||||
model_name=self.model_name,
|
|
||||||
max_loops=1,
|
|
||||||
saved_state_path=str(self.base_path / "revisor.json"),
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_supervisor_prompt(self) -> str:
|
|
||||||
"""Get the prompt for the supervisor agent."""
|
|
||||||
return """You are a Supervisor agent responsible for orchestrating the content generation process. Your role is to analyze tasks, develop strategies, and coordinate other agents effectively.
|
|
||||||
|
|
||||||
You must carefully analyze each task to understand:
|
|
||||||
- The core objectives and requirements
|
|
||||||
- Target audience and their needs
|
|
||||||
- Complexity level and scope
|
|
||||||
- Any constraints or special considerations
|
|
||||||
|
|
||||||
Based on your analysis, develop a clear strategy that:
|
|
||||||
- Breaks down the task into manageable steps
|
|
||||||
- Identifies which agents are best suited for each step
|
|
||||||
- Anticipates potential challenges
|
|
||||||
- Sets clear success criteria
|
|
||||||
|
|
||||||
Output all responses in strict JSON format:
|
|
||||||
{
|
|
||||||
"thoughts": {
|
|
||||||
"task_analysis": "Detailed analysis of requirements, audience, scope, and constraints",
|
|
||||||
"strategy": "Step-by-step plan including agent allocation and success metrics",
|
|
||||||
"concerns": "Potential challenges, edge cases, and mitigation strategies"
|
|
||||||
},
|
|
||||||
"next_action": {
|
|
||||||
"agent": "Specific agent to engage (Generator, Evaluator, or Revisor)",
|
|
||||||
"instruction": "Detailed instructions including context, requirements, and expected output"
|
|
||||||
}
|
|
||||||
}"""
|
|
||||||
|
|
||||||
def _get_generator_prompt(self) -> str:
|
|
||||||
"""Get the prompt for the generator agent."""
|
|
||||||
return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
|
|
||||||
|
|
||||||
When generating content:
|
|
||||||
- Thoroughly research and fact-check all information
|
|
||||||
- Structure content logically with clear flow
|
|
||||||
- Use appropriate tone and language for the target audience
|
|
||||||
- Include relevant examples and explanations
|
|
||||||
- Ensure content is original and plagiarism-free
|
|
||||||
- Consider SEO best practices where applicable
|
|
||||||
|
|
||||||
Output all responses in strict JSON format:
|
|
||||||
{
|
|
||||||
"content": {
|
|
||||||
"main_body": "The complete generated content with proper formatting and structure",
|
|
||||||
"metadata": {
|
|
||||||
"word_count": "Accurate word count of main body",
|
|
||||||
"target_audience": "Detailed audience description",
|
|
||||||
"key_points": ["List of main points covered"],
|
|
||||||
"sources": ["List of reference sources if applicable"],
|
|
||||||
"readability_level": "Estimated reading level",
|
|
||||||
"tone": "Description of content tone"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}"""
|
|
||||||
|
|
||||||
def _get_evaluator_prompt(self, evaluator_id: int) -> str:
|
|
||||||
"""Get the prompt for an evaluator agent."""
|
|
||||||
return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
|
|
||||||
|
|
||||||
Evaluate content across multiple dimensions:
|
|
||||||
- Accuracy: factual correctness, source reliability
|
|
||||||
- Clarity: readability, organization, flow
|
|
||||||
- Coherence: logical consistency, argument structure
|
|
||||||
- Engagement: interest level, relevance
|
|
||||||
- Completeness: topic coverage, depth
|
|
||||||
- Technical quality: grammar, spelling, formatting
|
|
||||||
- Audience alignment: appropriate level and tone
|
|
||||||
|
|
||||||
Output all responses in strict JSON format:
|
|
||||||
{{
|
|
||||||
"scores": {{
|
|
||||||
"overall": "0.0-1.0 composite score",
|
|
||||||
"categories": {{
|
|
||||||
"accuracy": "0.0-1.0 score with evidence",
|
|
||||||
"clarity": "0.0-1.0 score with examples",
|
|
||||||
"coherence": "0.0-1.0 score with analysis",
|
|
||||||
"engagement": "0.0-1.0 score with justification",
|
|
||||||
"completeness": "0.0-1.0 score with gaps identified",
|
|
||||||
"technical_quality": "0.0-1.0 score with issues noted",
|
|
||||||
"audience_alignment": "0.0-1.0 score with reasoning"
|
|
||||||
}}
|
|
||||||
}},
|
|
||||||
"feedback": [
|
|
||||||
"Specific, actionable improvement suggestions",
|
|
||||||
"Examples of issues found",
|
|
||||||
"Recommendations for enhancement"
|
|
||||||
],
|
|
||||||
"strengths": ["Notable positive aspects"],
|
|
||||||
"weaknesses": ["Areas needing improvement"]
|
|
||||||
}}"""
|
|
||||||
|
|
||||||
def _get_revisor_prompt(self) -> str:
|
|
||||||
"""Get the prompt for the revisor agent."""
|
|
||||||
return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
|
|
||||||
|
|
||||||
When revising content:
|
|
||||||
- Address all evaluator feedback systematically
|
|
||||||
- Maintain consistency in tone and style
|
|
||||||
- Preserve accurate information
|
|
||||||
- Enhance clarity and flow
|
|
||||||
- Fix technical issues
|
|
||||||
- Optimize for target audience
|
|
||||||
- Track all changes made
|
|
||||||
|
|
||||||
Output all responses in strict JSON format:
|
|
||||||
{
|
|
||||||
"revised_content": {
|
|
||||||
"main_body": "Complete revised content incorporating all improvements",
|
|
||||||
"metadata": {
|
|
||||||
"word_count": "Updated word count",
|
|
||||||
"changes_made": [
|
|
||||||
"Detailed list of specific changes and improvements",
|
|
||||||
"Reasoning for each major revision",
|
|
||||||
"Feedback points addressed"
|
|
||||||
],
|
|
||||||
"improvement_summary": "Overview of main enhancements",
|
|
||||||
"preserved_elements": ["Key elements maintained from original"],
|
|
||||||
"revision_approach": "Strategy used for revisions"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}"""
|
|
||||||
|
|
||||||
def _evaluate_content(
|
|
||||||
self, content: Union[str, Dict]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Coordinate the evaluation of content across multiple evaluators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Content to evaluate (string or dict)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Combined evaluation results
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Ensure content is in correct format
|
|
||||||
content_dict = (
|
|
||||||
self._safely_parse_json(content)
|
|
||||||
if isinstance(content, str)
|
|
||||||
else content
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect evaluations
|
|
||||||
evaluations = []
|
|
||||||
for evaluator in self.evaluators:
|
|
||||||
try:
|
|
||||||
eval_response = evaluator.run(
|
|
||||||
json.dumps(content_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conversation.add(
|
|
||||||
role=evaluator.agent_name,
|
|
||||||
content=eval_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_data = self._safely_parse_json(eval_response)
|
|
||||||
evaluations.append(eval_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Evaluator error: {str(e)}")
|
|
||||||
evaluations.append(
|
|
||||||
self._get_fallback_evaluation()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aggregate results
|
|
||||||
return self._aggregate_evaluations(evaluations)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Evaluation error: {str(e)}")
|
|
||||||
return self._get_fallback_evaluation()
|
|
||||||
|
|
||||||
def _get_fallback_evaluation(self) -> Dict[str, Any]:
|
|
||||||
"""Get a safe fallback evaluation result."""
|
|
||||||
return {
|
|
||||||
"scores": {
|
|
||||||
"overall": 0.5,
|
|
||||||
"categories": {
|
|
||||||
"accuracy": 0.5,
|
|
||||||
"clarity": 0.5,
|
|
||||||
"coherence": 0.5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"feedback": ["Evaluation failed"],
|
|
||||||
"metadata": {
|
|
||||||
"timestamp": str(datetime.now()),
|
|
||||||
"is_fallback": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def _aggregate_evaluations(
|
|
||||||
self, evaluations: List[Dict[str, Any]]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Aggregate multiple evaluation results into a single evaluation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
evaluations: List of evaluation results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Combined evaluation
|
|
||||||
"""
|
|
||||||
# Calculate average scores
|
|
||||||
overall_scores = []
|
|
||||||
accuracy_scores = []
|
|
||||||
clarity_scores = []
|
|
||||||
coherence_scores = []
|
|
||||||
all_feedback = []
|
|
||||||
|
|
||||||
for eval_data in evaluations:
|
|
||||||
try:
|
|
||||||
scores = eval_data.get("scores", {})
|
|
||||||
overall_scores.append(scores.get("overall", 0.5))
|
|
||||||
|
|
||||||
categories = scores.get("categories", {})
|
|
||||||
accuracy_scores.append(
|
|
||||||
categories.get("accuracy", 0.5)
|
|
||||||
)
|
|
||||||
clarity_scores.append(categories.get("clarity", 0.5))
|
|
||||||
coherence_scores.append(
|
|
||||||
categories.get("coherence", 0.5)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_feedback.extend(eval_data.get("feedback", []))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Error aggregating evaluation: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def safe_mean(scores: List[float]) -> float:
|
|
||||||
return sum(scores) / len(scores) if scores else 0.5
|
|
||||||
|
|
||||||
return {
|
|
||||||
"scores": {
|
|
||||||
"overall": safe_mean(overall_scores),
|
|
||||||
"categories": {
|
|
||||||
"accuracy": safe_mean(accuracy_scores),
|
|
||||||
"clarity": safe_mean(clarity_scores),
|
|
||||||
"coherence": safe_mean(coherence_scores),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"feedback": list(set(all_feedback)), # Remove duplicates
|
|
||||||
"metadata": {
|
|
||||||
"evaluator_count": len(evaluations),
|
|
||||||
"timestamp": str(datetime.now()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, task: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate and iteratively refine content based on the given task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: Content generation task description
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing final content and metadata
|
|
||||||
"""
|
|
||||||
logger.info(f"Starting content generation for task: {task}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get initial direction from supervisor
|
|
||||||
supervisor_response = self.main_supervisor.run(task)
|
|
||||||
|
|
||||||
self.conversation.add(
|
|
||||||
role=self.main_supervisor.agent_name,
|
|
||||||
content=supervisor_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
supervisor_data = self._safely_parse_json(
|
|
||||||
supervisor_response
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate initial content
|
|
||||||
generator_response = self.generator.run(
|
|
||||||
json.dumps(supervisor_data.get("next_action", {}))
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conversation.add(
|
|
||||||
role=self.generator.agent_name,
|
|
||||||
content=generator_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
current_content = self._safely_parse_json(
|
|
||||||
generator_response
|
|
||||||
)
|
|
||||||
|
|
||||||
for iteration in range(self.max_iterations):
|
|
||||||
logger.info(f"Starting iteration {iteration + 1}")
|
|
||||||
|
|
||||||
# Evaluate current content
|
|
||||||
evaluation = self._evaluate_content(current_content)
|
|
||||||
|
|
||||||
# Check if quality threshold is met
|
|
||||||
if (
|
|
||||||
evaluation["scores"]["overall"]
|
|
||||||
>= self.quality_threshold
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Quality threshold met, returning content"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"content": current_content.get(
|
|
||||||
"content", {}
|
|
||||||
).get("main_body", ""),
|
|
||||||
"final_score": evaluation["scores"][
|
|
||||||
"overall"
|
|
||||||
],
|
|
||||||
"iterations": iteration + 1,
|
|
||||||
"metadata": {
|
|
||||||
"content_metadata": current_content.get(
|
|
||||||
"content", {}
|
|
||||||
).get("metadata", {}),
|
|
||||||
"evaluation": evaluation,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Revise content if needed
|
|
||||||
revision_input = {
|
|
||||||
"content": current_content,
|
|
||||||
"evaluation": evaluation,
|
|
||||||
}
|
|
||||||
|
|
||||||
revision_response = self.revisor.run(
|
|
||||||
json.dumps(revision_input)
|
|
||||||
)
|
|
||||||
current_content = self._safely_parse_json(
|
|
||||||
revision_response
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conversation.add(
|
|
||||||
role=self.revisor.agent_name,
|
|
||||||
content=revision_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"Max iterations reached without meeting quality threshold"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in generate_and_refine: {str(e)}")
|
|
||||||
current_content = {
|
|
||||||
"content": {"main_body": f"Error: {str(e)}"}
|
|
||||||
}
|
|
||||||
evaluation = self._get_fallback_evaluation()
|
|
||||||
|
|
||||||
if self.return_string:
|
|
||||||
return self.conversation.return_history_as_string()
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"content": current_content.get("content", {}).get(
|
|
||||||
"main_body", ""
|
|
||||||
),
|
|
||||||
"final_score": evaluation["scores"]["overall"],
|
|
||||||
"iterations": self.max_iterations,
|
|
||||||
"metadata": {
|
|
||||||
"content_metadata": current_content.get(
|
|
||||||
"content", {}
|
|
||||||
).get("metadata", {}),
|
|
||||||
"evaluation": evaluation,
|
|
||||||
"error": "Max iterations reached",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def save_state(self) -> None:
|
|
||||||
"""Save the current state of all agents."""
|
|
||||||
for agent in [
|
|
||||||
self.main_supervisor,
|
|
||||||
self.generator,
|
|
||||||
*self.evaluators,
|
|
||||||
self.revisor,
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
agent.save_state()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error saving state for {agent.agent_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_state(self) -> None:
|
|
||||||
"""Load the saved state of all agents."""
|
|
||||||
for agent in [
|
|
||||||
self.main_supervisor,
|
|
||||||
self.generator,
|
|
||||||
*self.evaluators,
|
|
||||||
self.revisor,
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
agent.load_state()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error loading state for {agent.agent_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Example usage
|
|
||||||
try:
|
|
||||||
talkhier = TalkHier(
|
|
||||||
max_iterations=1,
|
|
||||||
quality_threshold=0.8,
|
|
||||||
model_name="gpt-4o",
|
|
||||||
return_string=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
task = "Write a comprehensive explanation of quantum computing for beginners"
|
|
||||||
result = talkhier.run(task)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
# print(f"Final content: {result['content']}")
|
|
||||||
# print(f"Quality score: {result['final_score']}")
|
|
||||||
# print(f"Iterations: {result['iterations']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in main execution: {str(e)}")
|
|
Loading…
Reference in new issue