Merge pull request #787 from harshalmore31/octotools

Implemented OctoTools Framework with Swarms
pull/790/head
Kye Gomez 2 months ago committed by GitHub
commit ddb4119c04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -72,7 +72,6 @@ from swarms.structs.swarming_architectures import (
staircase_swarm,
star_swarm,
)
from swarms.structs.swarms_api import (
SwarmsAPIClient,
SwarmRequest,
@ -81,6 +80,8 @@ from swarms.structs.swarms_api import (
SwarmValidationError,
AgentInput,
)
from swarms.structs.talk_hier import TalkHier, AgentRole, CommunicationEvent
from swarms.structs.octotools import OctoToolsSwarm, Tool, ToolType, get_default_tools
__all__ = [
"Agent",
@ -147,6 +148,13 @@ __all__ = [
"MultiAgentRouter",
"MemeAgentGenerator",
"ModelRouter",
"OctoToolsSwarm",
"Tool",
"ToolType",
"get_default_tools",
"TalkHier",
"AgentRole",
"CommunicationEvent",
"SwarmsAPIClient",
"SwarmRequest",
"SwarmAuthenticationError",

@ -0,0 +1,725 @@
"""
OctoToolsSwarm: A multi-agent system for complex reasoning.
Implements the OctoTools framework using swarms.
"""
import json
import logging
import os
import re
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import math # Import the math module
from dotenv import load_dotenv
from swarms import Agent
from swarms.structs.conversation import Conversation
# from exa_search import exa_search as web_search_execute
# Load environment variables
load_dotenv()
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ToolType(Enum):
"""Defines the types of tools available."""
IMAGE_CAPTIONER = "image_captioner"
OBJECT_DETECTOR = "object_detector"
WEB_SEARCH = "web_search"
PYTHON_CALCULATOR = "python_calculator"
# Add more tool types as needed
@dataclass
class Tool:
"""
Represents an external tool.
Attributes:
name: Unique name of the tool.
description: Description of the tool's function.
metadata: Dictionary containing tool metadata.
execute_func: Callable function that executes the tool's logic.
"""
name: str
description: str
metadata: Dict[str, Any]
execute_func: Callable
def execute(self, **kwargs):
"""Executes the tool's logic, handling potential errors."""
try:
return self.execute_func(**kwargs)
except Exception as e:
logger.error(f"Error executing tool {self.name}: {str(e)}")
return {"error": str(e)}
class AgentRole(Enum):
"""Defines the roles for agents in the OctoTools system."""
PLANNER = "planner"
VERIFIER = "verifier"
SUMMARIZER = "summarizer"
class OctoToolsSwarm:
"""
A multi-agent system implementing the OctoTools framework.
Attributes:
model_name: Name of the LLM model to use.
max_iterations: Maximum number of action-execution iterations.
base_path: Path for saving agent states.
tools: List of available Tool objects.
"""
def __init__(
self,
tools: List[Tool],
model_name: str = "gemini/gemini-2.0-flash",
max_iterations: int = 10,
base_path: Optional[str] = None,
):
"""Initialize the OctoToolsSwarm system."""
self.model_name = model_name
self.max_iterations = max_iterations
self.base_path = Path(base_path) if base_path else Path("./octotools_states")
self.base_path.mkdir(exist_ok=True)
self.tools = {tool.name: tool for tool in tools} # Store tools in a dictionary
# Initialize agents
self._init_agents()
# Create conversation tracker and memory
self.conversation = Conversation()
self.memory = [] # Store the trajectory
def _init_agents(self) -> None:
"""Initialize all agents with their specific roles and prompts."""
# Planner agent
self.planner = Agent(
agent_name="OctoTools-Planner",
system_prompt=self._get_planner_prompt(),
model_name=self.model_name,
max_loops=3,
saved_state_path=str(self.base_path / "planner.json"),
verbose=True,
)
# Verifier agent
self.verifier = Agent(
agent_name="OctoTools-Verifier",
system_prompt=self._get_verifier_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "verifier.json"),
verbose=True,
)
# Summarizer agent
self.summarizer = Agent(
agent_name="OctoTools-Summarizer",
system_prompt=self._get_summarizer_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "summarizer.json"),
verbose=True,
)
def _get_planner_prompt(self) -> str:
"""Get the prompt for the planner agent (Improved with few-shot examples)."""
tool_descriptions = "\n".join(
[
f"- {tool_name}: {self.tools[tool_name].description}"
for tool_name in self.tools
]
)
return f"""You are the Planner in the OctoTools framework. Your role is to analyze the user's query,
identify required skills, suggest relevant tools, and plan the steps to solve the problem.
1. **Analyze the user's query:** Understand the requirements and identify the necessary skills and potentially relevant tools.
2. **Perform high-level planning:** Create a rough outline of how tools might be used to solve the problem.
3. **Perform low-level planning (action prediction):** At each step, select the best tool to use and formulate a specific sub-goal for that tool, considering the current context.
Available Tools:
{tool_descriptions}
Output your response in JSON format. Here are examples for different stages:
**Query Analysis (High-Level Planning):**
Example Input:
Query: "What is the capital of France?"
Example Output:
```json
{{
"summary": "The user is asking for the capital of France.",
"required_skills": ["knowledge retrieval"],
"relevant_tools": ["Web_Search_Tool"]
}}
```
**Action Prediction (Low-Level Planning):**
Example Input:
Context: {{ "query": "What is the capital of France?", "available_tools": ["Web_Search_Tool"] }}
Example Output:
```json
{{
"justification": "The Web_Search_Tool can be used to directly find the capital of France.",
"context": {{}},
"sub_goal": "Search the web for 'capital of France'.",
"tool_name": "Web_Search_Tool"
}}
```
Another Example:
Context: {{"query": "How many objects are in the image?", "available_tools": ["Image_Captioner_Tool", "Object_Detector_Tool"], "image": "objects.png"}}
Example Output:
```json
{{
"justification": "First, get a general description of the image to understand the context.",
"context": {{ "image": "objects.png" }},
"sub_goal": "Generate a description of the image.",
"tool_name": "Image_Captioner_Tool"
}}
```
Example for Finding Square Root:
Context: {{"query": "What is the square root of the number of objects in the image?", "available_tools": ["Object_Detector_Tool", "Python_Calculator_Tool"], "image": "objects.png", "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"]}}
Example Output:
```json
{{
"justification": "We have detected 4 objects in the image. Now we need to find the square root of 4.",
"context": {{}},
"sub_goal": "Calculate the square root of 4",
"tool_name": "Python_Calculator_Tool"
}}
```
Your output MUST be a single, valid JSON object with the following keys:
- justification (string): Your reasoning.
- context (dict): A dictionary containing relevant information.
- sub_goal (string): The specific instruction for the tool.
- tool_name (string): The EXACT name of the tool to use.
Do NOT include any text outside of the JSON object.
"""
def _get_verifier_prompt(self) -> str:
"""Get the prompt for the verifier agent (Improved with few-shot examples)."""
return """You are the Context Verifier in the OctoTools framework. Your role is to analyze the current context
and memory to determine if the problem is solved, if there are any inconsistencies, or if further steps are needed.
Output your response in JSON format:
Expected output structure:
```json
{
"completeness": "Indicate whether the query is fully, partially, or not answered.",
"inconsistencies": "List any inconsistencies found in the context or memory.",
"verification_needs": "List any information that needs further verification.",
"ambiguities": "List any ambiguities found in the context or memory.",
"stop_signal": true/false
}
```
Example Input:
Context: { "last_result": { "result": "Caption: The image shows a cat." } }
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Image_Captioner_Tool" } } ]
Example Output:
```json
{
"completeness": "partial",
"inconsistencies": [],
"verification_needs": ["Object detection to confirm the presence of a cat."],
"ambiguities": [],
"stop_signal": false
}
```
Another Example:
Context: { "last_result": { "result": ["Detected object: cat"] } }
Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } } ]
Example Output:
```json
{
"completeness": "yes",
"inconsistencies": [],
"verification_needs": [],
"ambiguities": [],
"stop_signal": true
}
```
Square Root Example:
Context: {
"query": "What is the square root of the number of objects in the image?",
"image": "example.png",
"Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"],
"Python_Calculator_Tool_result": "Result of 4**0.5 is 2.0"
}
Memory: [
{ "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } },
{ "component": "Action Predictor", "result": { "tool_name": "Python_Calculator_Tool" } }
]
Example Output:
```json
{
"completeness": "yes",
"inconsistencies": [],
"verification_needs": [],
"ambiguities": [],
"stop_signal": true
}
```
"""
def _get_summarizer_prompt(self) -> str:
"""Get the prompt for the summarizer agent (Improved with few-shot examples)."""
return """You are the Solution Summarizer in the OctoTools framework. Your role is to synthesize the final
answer to the user's query based on the complete trajectory of actions and results.
Output your response in JSON format:
Expected output structure:
```json
{
"final_answer": "Provide a clear and concise answer to the original query."
}
```
Example Input:
Memory: [
{"component": "Query Analyzer", "result": {"summary": "Find the capital of France."}},
{"component": "Action Predictor", "result": {"tool_name": "Web_Search_Tool"}},
{"component": "Tool Execution", "result": {"result": "The capital of France is Paris."}}
]
Example Output:
```json
{
"final_answer": "The capital of France is Paris."
}
```
Square Root Example:
Memory: [
{"component": "Query Analyzer", "result": {"summary": "Find the square root of the number of objects in the image."}},
{"component": "Action Predictor", "result": {"tool_name": "Object_Detector_Tool", "sub_goal": "Detect objects in the image"}},
{"component": "Tool Execution", "result": {"result": ["object1", "object2", "object3", "object4"]}},
{"component": "Action Predictor", "result": {"tool_name": "Python_Calculator_Tool", "sub_goal": "Calculate the square root of 4"}},
{"component": "Tool Execution", "result": {"result": "Result of 4**0.5 is 2.0"}}
]
Example Output:
```json
{
"final_answer": "The square root of the number of objects in the image is 2.0. There are 4 objects in the image, and the square root of 4 is 2.0."
}
```
"""
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
"""Safely parse JSON, handling errors and using recursive descent."""
try:
return json.loads(json_str)
except json.JSONDecodeError:
logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}")
try:
# More robust JSON extraction with recursive descent
def extract_json(s):
stack = []
start = -1
for i, c in enumerate(s):
if c == '{':
if not stack:
start = i
stack.append(c)
elif c == '}':
if stack:
stack.pop()
if not stack and start != -1:
return s[start:i+1]
return None
extracted_json = extract_json(json_str)
if extracted_json:
logger.info(f"Extracted JSON: {extracted_json}")
return json.loads(extracted_json)
else:
logger.error("Failed to extract JSON using recursive descent.")
return {"error": "Failed to parse JSON", "content": json_str}
except Exception as e:
logger.exception(f"Error during JSON extraction: {e}")
return {"error": "Failed to parse JSON", "content": json_str}
def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
"""Executes a tool based on its name and provided context."""
if tool_name not in self.tools:
return {"error": f"Tool '{tool_name}' not found."}
tool = self.tools[tool_name]
try:
# For Python Calculator tool, handle object counts from Object Detector
if tool_name == "Python_Calculator_Tool":
# Check for object detector results
object_detector_result = context.get("Object_Detector_Tool_result")
if object_detector_result and isinstance(object_detector_result, list):
# Calculate the number of objects
num_objects = len(object_detector_result)
# If sub_goal doesn't already contain an expression, create one
if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]:
context["expression"] = f"{num_objects}**0.5"
elif "expression" not in context:
# Default to square root if no expression is specified
context["expression"] = f"{num_objects}**0.5"
# Filter context: only pass expected inputs to the tool
valid_inputs = {
k: v for k, v in context.items() if k in tool.metadata.get("input_types", {})
}
result = tool.execute(**valid_inputs)
return {"result": result}
except Exception as e:
logger.exception(f"Error executing tool {tool_name}: {e}")
return {"error": str(e)}
def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]:
"""Runs a swarms agent, handling output and JSON parsing."""
try:
# Construct the full input, including the system prompt
full_input = f"{agent.system_prompt}\n\n{input_prompt}"
# Run the agent and capture the output
agent_response = agent.run(full_input)
logger.info(f"DEBUG: Raw agent response: {agent_response}")
# Extract the LLM's response (remove conversation history, etc.)
response_text = agent_response # Assuming direct return
# Try to parse the response as JSON
parsed_response = self._safely_parse_json(response_text)
return parsed_response
except Exception as e:
logger.exception(f"Error running agent {agent.agent_name}: {e}")
return {"error": f"Agent {agent.agent_name} failed: {str(e)}"}
def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]:
"""Execute the task through the multi-agent workflow."""
logger.info(f"Starting task: {query}")
try:
# Step 1: Query Analysis (High-Level Planning)
planner_input = (
f"Analyze the following query and determine the necessary skills and"
f" relevant tools: {query}"
)
query_analysis = self._run_agent(self.planner, planner_input)
if "error" in query_analysis:
return {
"error": f"Planner query analysis failed: {query_analysis['error']}",
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string(),
}
self.memory.append(
{"step": 0, "component": "Query Analyzer", "result": query_analysis}
)
self.conversation.add(
role=self.planner.agent_name, content=json.dumps(query_analysis)
)
# Initialize context with the query and image (if provided)
context = {"query": query}
if image:
context["image"] = image
# Add available tools to context
if "relevant_tools" in query_analysis:
context["available_tools"] = query_analysis["relevant_tools"]
else:
# If no relevant tools specified, make all tools available
context["available_tools"] = list(self.tools.keys())
step_count = 1
# Step 2: Iterative Action-Execution Loop
while step_count <= self.max_iterations:
logger.info(f"Starting iteration {step_count} of {self.max_iterations}")
# Step 2a: Action Prediction (Low-Level Planning)
action_planner_input = (
f"Current Context: {json.dumps(context)}\nAvailable Tools:"
f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the"
" next step."
)
action = self._run_agent(self.planner, action_planner_input)
if "error" in action:
logger.error(f"Error in action prediction: {action['error']}")
return {
"error": f"Planner action prediction failed: {action['error']}",
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string()
}
self.memory.append(
{"step": step_count, "component": "Action Predictor", "result": action}
)
self.conversation.add(role=self.planner.agent_name, content=json.dumps(action))
# Input Validation for Action (Relaxed)
if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action:
error_msg = (
"Action prediction did not return required fields (tool_name,"
" sub_goal) or was not a dictionary."
)
logger.error(error_msg)
self.memory.append(
{"step": step_count, "component": "Error", "result": error_msg}
)
break
# Step 2b: Execute Tool
tool_execution_context = {
**context,
**action.get("context", {}), # Add any additional context
"sub_goal": action["sub_goal"], # Pass sub_goal to tool
}
tool_result = self._execute_tool(action["tool_name"], tool_execution_context)
self.memory.append(
{
"step": step_count,
"component": "Tool Execution",
"result": tool_result,
}
)
# Step 2c: Context Update - Store result with a descriptive key
if "result" in tool_result:
context[f"{action['tool_name']}_result"] = tool_result["result"]
if "error" in tool_result:
context[f"{action['tool_name']}_error"] = tool_result["error"]
# Step 2d: Context Verification
verifier_input = (
f"Current Context: {json.dumps(context)}\nMemory:"
f" {json.dumps(self.memory)}\nQuery: {query}"
)
verification = self._run_agent(self.verifier, verifier_input)
if "error" in verification:
return {
"error": f"Verifier failed: {verification['error']}",
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string(),
}
self.memory.append(
{
"step": step_count,
"component": "Context Verifier",
"result": verification,
}
)
self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification))
# Check for stop signal from Verifier
if verification.get("stop_signal") is True:
logger.info("Received stop signal from verifier. Stopping iterations.")
break
# Safety mechanism - if we've executed the same tool multiple times
same_tool_count = sum(
1 for m in self.memory
if m.get("component") == "Action Predictor"
and m.get("result", {}).get("tool_name") == action.get("tool_name")
)
if same_tool_count > 3:
logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.")
break
step_count += 1
# Step 3: Solution Summarization
summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}"
summarization = self._run_agent(self.summarizer, summarizer_input)
if "error" in summarization:
return {
"error": f"Summarizer failed: {summarization['error']}",
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string()
}
self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization))
return {
"final_answer": summarization.get("final_answer", "No answer found."),
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string(),
}
except Exception as e:
logger.exception(f"Unexpected error in run method: {e}") # More detailed
return {
"error": str(e),
"trajectory": self.memory,
"conversation": self.conversation.return_history_as_string(),
}
def save_state(self) -> None:
"""Save the current state of all agents."""
for agent in [self.planner, self.verifier, self.summarizer]:
try:
agent.save_state()
except Exception as e:
logger.error(f"Error saving state for {agent.agent_name}: {str(e)}")
def load_state(self) -> None:
"""Load the saved state of all agents."""
for agent in [self.planner, self.verifier, self.summarizer]:
try:
agent.load_state()
except Exception as e:
logger.error(f"Error loading state for {agent.agent_name}: {str(e)}")
# --- Example Usage ---
# Define dummy tool functions (replace with actual implementations)
def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str:
"""Dummy image captioner."""
print(f"image_captioner_execute called with image: {image}, prompt: {prompt}")
return f"Caption for {image}: A descriptive caption (dummy)." # Simplified
def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]:
"""Dummy object detector, handles missing labels gracefully."""
print(f"object_detector_execute called with image: {image}, labels: {labels}")
if not labels:
return ["object1", "object2", "object3", "object4"] # Return default objects if no labels
return [f"Detected {label}" for label in labels] # Simplified
def web_search_execute(query: str, **kwargs) -> str:
"""Dummy web search."""
print(f"web_search_execute called with query: {query}")
return f"Search results for '{query}'..." # Simplified
def python_calculator_execute(expression: str, **kwargs) -> str:
"""Python calculator (using math module)."""
print(f"python_calculator_execute called with: {expression}")
try:
# Safely evaluate only simple expressions involving numbers and basic operations
if re.match(r"^[0-9+\-*/().\s]+$", expression):
result = eval(expression, {"__builtins__": {}, "math": math})
return f"Result of {expression} is {result}"
else:
return "Error: Invalid expression for calculator."
except Exception as e:
return f"Error: {e}"
# Create utility function to get default tools
def get_default_tools() -> List[Tool]:
"""Returns a list of default tools that can be used with OctoToolsSwarm."""
image_captioner = Tool(
name="Image_Captioner_Tool",
description="Generates a caption for an image.",
metadata={
"input_types": {"image": "str", "prompt": "str"},
"output_type": "str",
"limitations": "May struggle with complex scenes or ambiguous objects.",
"best_practices": "Use with clear, well-lit images. Provide specific prompts for better results.",
},
execute_func=image_captioner_execute,
)
object_detector = Tool(
name="Object_Detector_Tool",
description="Detects objects in an image.",
metadata={
"input_types": {"image": "str", "labels": "list"},
"output_type": "list",
"limitations": "Accuracy depends on the quality of the image and the clarity of the objects.",
"best_practices": "Provide a list of specific object labels to detect. Use high-resolution images.",
},
execute_func=object_detector_execute,
)
web_search = Tool(
name="Web_Search_Tool",
description="Performs a web search.",
metadata={
"input_types": {"query": "str"},
"output_type": "str",
"limitations": "May not find specific or niche information.",
"best_practices": "Use specific and descriptive keywords for better results.",
},
execute_func=web_search_execute,
)
calculator = Tool(
name="Python_Calculator_Tool",
description="Evaluates a Python expression.",
metadata={
"input_types": {"expression": "str"},
"output_type": "str",
"limitations": "Cannot handle complex mathematical functions or libraries.",
"best_practices": "Use for basic arithmetic and simple calculations.",
},
execute_func=python_calculator_execute,
)
return [image_captioner, object_detector, web_search, calculator]
# Only execute the example when this script is run directly
# if __name__ == "__main__":
# print("Running OctoToolsSwarm example...")
# # Create an OctoToolsSwarm agent with default tools
# tools = get_default_tools()
# agent = OctoToolsSwarm(tools=tools)
# # Example query
# query = "What is the square root of the number of objects in this image?"
# # Create a dummy image file for testing if it doesn't exist
# image_path = "example.png"
# if not os.path.exists(image_path):
# with open(image_path, "w") as f:
# f.write("Dummy image content")
# print(f"Created dummy image file: {image_path}")
# # Run the agent
# result = agent.run(query, image=image_path)
# # Display results
# print("\n=== FINAL ANSWER ===")
# print(result["final_answer"])
# print("\n=== TRAJECTORY SUMMARY ===")
# for step in result["trajectory"]:
# print(f"Step {step.get('step', 'N/A')}: {step.get('component', 'Unknown')}")
# print("\nOctoToolsSwarm example completed.")

@ -1,580 +0,0 @@
"""
TalkHier: A hierarchical multi-agent framework for content generation and refinement.
Implements structured communication and evaluation protocols.
"""
import json
import logging
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from swarms import Agent
from swarms.structs.conversation import Conversation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgentRole(Enum):
"""Defines the possible roles for agents in the system."""
SUPERVISOR = "supervisor"
GENERATOR = "generator"
EVALUATOR = "evaluator"
REVISOR = "revisor"
@dataclass
class CommunicationEvent:
"""Represents a structured communication event between agents."""
message: str
background: Optional[str] = None
intermediate_output: Optional[Dict[str, Any]] = None
class TalkHier:
"""
A hierarchical multi-agent system for content generation and refinement.
Implements the TalkHier framework with structured communication protocols
and hierarchical refinement processes.
Attributes:
max_iterations: Maximum number of refinement iterations
quality_threshold: Minimum score required for content acceptance
model_name: Name of the LLM model to use
base_path: Path for saving agent states
"""
def __init__(
self,
max_iterations: int = 3,
quality_threshold: float = 0.8,
model_name: str = "gpt-4",
base_path: Optional[str] = None,
return_string: bool = False,
):
"""Initialize the TalkHier system."""
self.max_iterations = max_iterations
self.quality_threshold = quality_threshold
self.model_name = model_name
self.return_string = return_string
self.base_path = (
Path(base_path) if base_path else Path("./agent_states")
)
self.base_path.mkdir(exist_ok=True)
# Initialize agents
self._init_agents()
# Create conversation
self.conversation = Conversation()
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
"""
Safely parse JSON string, handling various formats and potential errors.
Args:
json_str: String containing JSON data
Returns:
Parsed dictionary
"""
try:
# Try direct JSON parsing
return json.loads(json_str)
except json.JSONDecodeError:
try:
# Try extracting JSON from potential text wrapper
import re
json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
if json_match:
return json.loads(json_match.group())
# Try extracting from markdown code blocks
code_block_match = re.search(
r"```(?:json)?\s*(\{.*?\})\s*```",
json_str,
re.DOTALL,
)
if code_block_match:
return json.loads(code_block_match.group(1))
except Exception as e:
logger.warning(f"Failed to extract JSON: {str(e)}")
# Fallback: create structured dict from text
return {
"content": json_str,
"metadata": {
"parsed": False,
"timestamp": str(datetime.now()),
},
}
def _init_agents(self) -> None:
"""Initialize all agents with their specific roles and prompts."""
# Main supervisor agent
self.main_supervisor = Agent(
agent_name="Main-Supervisor",
system_prompt=self._get_supervisor_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(
self.base_path / "main_supervisor.json"
),
verbose=True,
)
# Generator agent
self.generator = Agent(
agent_name="Content-Generator",
system_prompt=self._get_generator_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "generator.json"),
verbose=True,
)
# Evaluators
self.evaluators = [
Agent(
agent_name=f"Evaluator-{i}",
system_prompt=self._get_evaluator_prompt(i),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(
self.base_path / f"evaluator_{i}.json"
),
verbose=True,
)
for i in range(3)
]
# Revisor agent
self.revisor = Agent(
agent_name="Content-Revisor",
system_prompt=self._get_revisor_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "revisor.json"),
verbose=True,
)
def _get_supervisor_prompt(self) -> str:
"""Get the prompt for the supervisor agent."""
return """You are a Supervisor agent responsible for orchestrating the content generation process. Your role is to analyze tasks, develop strategies, and coordinate other agents effectively.
You must carefully analyze each task to understand:
- The core objectives and requirements
- Target audience and their needs
- Complexity level and scope
- Any constraints or special considerations
Based on your analysis, develop a clear strategy that:
- Breaks down the task into manageable steps
- Identifies which agents are best suited for each step
- Anticipates potential challenges
- Sets clear success criteria
Output all responses in strict JSON format:
{
"thoughts": {
"task_analysis": "Detailed analysis of requirements, audience, scope, and constraints",
"strategy": "Step-by-step plan including agent allocation and success metrics",
"concerns": "Potential challenges, edge cases, and mitigation strategies"
},
"next_action": {
"agent": "Specific agent to engage (Generator, Evaluator, or Revisor)",
"instruction": "Detailed instructions including context, requirements, and expected output"
}
}"""
def _get_generator_prompt(self) -> str:
"""Get the prompt for the generator agent."""
return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
When generating content:
- Thoroughly research and fact-check all information
- Structure content logically with clear flow
- Use appropriate tone and language for the target audience
- Include relevant examples and explanations
- Ensure content is original and plagiarism-free
- Consider SEO best practices where applicable
Output all responses in strict JSON format:
{
"content": {
"main_body": "The complete generated content with proper formatting and structure",
"metadata": {
"word_count": "Accurate word count of main body",
"target_audience": "Detailed audience description",
"key_points": ["List of main points covered"],
"sources": ["List of reference sources if applicable"],
"readability_level": "Estimated reading level",
"tone": "Description of content tone"
}
}
}"""
def _get_evaluator_prompt(self, evaluator_id: int) -> str:
"""Get the prompt for an evaluator agent."""
return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
Evaluate content across multiple dimensions:
- Accuracy: factual correctness, source reliability
- Clarity: readability, organization, flow
- Coherence: logical consistency, argument structure
- Engagement: interest level, relevance
- Completeness: topic coverage, depth
- Technical quality: grammar, spelling, formatting
- Audience alignment: appropriate level and tone
Output all responses in strict JSON format:
{{
"scores": {{
"overall": "0.0-1.0 composite score",
"categories": {{
"accuracy": "0.0-1.0 score with evidence",
"clarity": "0.0-1.0 score with examples",
"coherence": "0.0-1.0 score with analysis",
"engagement": "0.0-1.0 score with justification",
"completeness": "0.0-1.0 score with gaps identified",
"technical_quality": "0.0-1.0 score with issues noted",
"audience_alignment": "0.0-1.0 score with reasoning"
}}
}},
"feedback": [
"Specific, actionable improvement suggestions",
"Examples of issues found",
"Recommendations for enhancement"
],
"strengths": ["Notable positive aspects"],
"weaknesses": ["Areas needing improvement"]
}}"""
def _get_revisor_prompt(self) -> str:
"""Get the prompt for the revisor agent."""
return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
When revising content:
- Address all evaluator feedback systematically
- Maintain consistency in tone and style
- Preserve accurate information
- Enhance clarity and flow
- Fix technical issues
- Optimize for target audience
- Track all changes made
Output all responses in strict JSON format:
{
"revised_content": {
"main_body": "Complete revised content incorporating all improvements",
"metadata": {
"word_count": "Updated word count",
"changes_made": [
"Detailed list of specific changes and improvements",
"Reasoning for each major revision",
"Feedback points addressed"
],
"improvement_summary": "Overview of main enhancements",
"preserved_elements": ["Key elements maintained from original"],
"revision_approach": "Strategy used for revisions"
}
}
}"""
def _evaluate_content(
self, content: Union[str, Dict]
) -> Dict[str, Any]:
"""
Coordinate the evaluation of content across multiple evaluators.
Args:
content: Content to evaluate (string or dict)
Returns:
Combined evaluation results
"""
try:
# Ensure content is in correct format
content_dict = (
self._safely_parse_json(content)
if isinstance(content, str)
else content
)
# Collect evaluations
evaluations = []
for evaluator in self.evaluators:
try:
eval_response = evaluator.run(
json.dumps(content_dict)
)
self.conversation.add(
role=evaluator.agent_name,
content=eval_response,
)
eval_data = self._safely_parse_json(eval_response)
evaluations.append(eval_data)
except Exception as e:
logger.warning(f"Evaluator error: {str(e)}")
evaluations.append(
self._get_fallback_evaluation()
)
# Aggregate results
return self._aggregate_evaluations(evaluations)
except Exception as e:
logger.error(f"Evaluation error: {str(e)}")
return self._get_fallback_evaluation()
def _get_fallback_evaluation(self) -> Dict[str, Any]:
"""Get a safe fallback evaluation result."""
return {
"scores": {
"overall": 0.5,
"categories": {
"accuracy": 0.5,
"clarity": 0.5,
"coherence": 0.5,
},
},
"feedback": ["Evaluation failed"],
"metadata": {
"timestamp": str(datetime.now()),
"is_fallback": True,
},
}
def _aggregate_evaluations(
self, evaluations: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Aggregate multiple evaluation results into a single evaluation.
Args:
evaluations: List of evaluation results
Returns:
Combined evaluation
"""
# Calculate average scores
overall_scores = []
accuracy_scores = []
clarity_scores = []
coherence_scores = []
all_feedback = []
for eval_data in evaluations:
try:
scores = eval_data.get("scores", {})
overall_scores.append(scores.get("overall", 0.5))
categories = scores.get("categories", {})
accuracy_scores.append(
categories.get("accuracy", 0.5)
)
clarity_scores.append(categories.get("clarity", 0.5))
coherence_scores.append(
categories.get("coherence", 0.5)
)
all_feedback.extend(eval_data.get("feedback", []))
except Exception as e:
logger.warning(
f"Error aggregating evaluation: {str(e)}"
)
def safe_mean(scores: List[float]) -> float:
return sum(scores) / len(scores) if scores else 0.5
return {
"scores": {
"overall": safe_mean(overall_scores),
"categories": {
"accuracy": safe_mean(accuracy_scores),
"clarity": safe_mean(clarity_scores),
"coherence": safe_mean(coherence_scores),
},
},
"feedback": list(set(all_feedback)), # Remove duplicates
"metadata": {
"evaluator_count": len(evaluations),
"timestamp": str(datetime.now()),
},
}
def run(self, task: str) -> Dict[str, Any]:
"""
Generate and iteratively refine content based on the given task.
Args:
task: Content generation task description
Returns:
Dictionary containing final content and metadata
"""
logger.info(f"Starting content generation for task: {task}")
try:
# Get initial direction from supervisor
supervisor_response = self.main_supervisor.run(task)
self.conversation.add(
role=self.main_supervisor.agent_name,
content=supervisor_response,
)
supervisor_data = self._safely_parse_json(
supervisor_response
)
# Generate initial content
generator_response = self.generator.run(
json.dumps(supervisor_data.get("next_action", {}))
)
self.conversation.add(
role=self.generator.agent_name,
content=generator_response,
)
current_content = self._safely_parse_json(
generator_response
)
for iteration in range(self.max_iterations):
logger.info(f"Starting iteration {iteration + 1}")
# Evaluate current content
evaluation = self._evaluate_content(current_content)
# Check if quality threshold is met
if (
evaluation["scores"]["overall"]
>= self.quality_threshold
):
logger.info(
"Quality threshold met, returning content"
)
return {
"content": current_content.get(
"content", {}
).get("main_body", ""),
"final_score": evaluation["scores"][
"overall"
],
"iterations": iteration + 1,
"metadata": {
"content_metadata": current_content.get(
"content", {}
).get("metadata", {}),
"evaluation": evaluation,
},
}
# Revise content if needed
revision_input = {
"content": current_content,
"evaluation": evaluation,
}
revision_response = self.revisor.run(
json.dumps(revision_input)
)
current_content = self._safely_parse_json(
revision_response
)
self.conversation.add(
role=self.revisor.agent_name,
content=revision_response,
)
logger.warning(
"Max iterations reached without meeting quality threshold"
)
except Exception as e:
logger.error(f"Error in generate_and_refine: {str(e)}")
current_content = {
"content": {"main_body": f"Error: {str(e)}"}
}
evaluation = self._get_fallback_evaluation()
if self.return_string:
return self.conversation.return_history_as_string()
else:
return {
"content": current_content.get("content", {}).get(
"main_body", ""
),
"final_score": evaluation["scores"]["overall"],
"iterations": self.max_iterations,
"metadata": {
"content_metadata": current_content.get(
"content", {}
).get("metadata", {}),
"evaluation": evaluation,
"error": "Max iterations reached",
},
}
def save_state(self) -> None:
"""Save the current state of all agents."""
for agent in [
self.main_supervisor,
self.generator,
*self.evaluators,
self.revisor,
]:
try:
agent.save_state()
except Exception as e:
logger.error(
f"Error saving state for {agent.agent_name}: {str(e)}"
)
def load_state(self) -> None:
"""Load the saved state of all agents."""
for agent in [
self.main_supervisor,
self.generator,
*self.evaluators,
self.revisor,
]:
try:
agent.load_state()
except Exception as e:
logger.error(
f"Error loading state for {agent.agent_name}: {str(e)}"
)
if __name__ == "__main__":
# Example usage
try:
talkhier = TalkHier(
max_iterations=1,
quality_threshold=0.8,
model_name="gpt-4o",
return_string=True,
)
task = "Write a comprehensive explanation of quantum computing for beginners"
result = talkhier.run(task)
print(result)
# print(f"Final content: {result['content']}")
# print(f"Quality score: {result['final_score']}")
# print(f"Iterations: {result['iterations']}")
except Exception as e:
logger.error(f"Error in main execution: {str(e)}")
Loading…
Cancel
Save