From 23666c345ab86b51b591dc1ac7029feab51f2d53 Mon Sep 17 00:00:00 2001 From: harshalmore31 Date: Sat, 1 Mar 2025 09:27:28 +0530 Subject: [PATCH] Add OctoToolsSwarm and init talk_hier in swarms module --- swarms/structs/__init__.py | 9 +- swarms/structs/octotools.py | 705 ++++++++++++++++++++++++++++++++++++ swarms/structs/talktier.py | 580 ----------------------------- 3 files changed, 713 insertions(+), 581 deletions(-) create mode 100644 swarms/structs/octotools.py delete mode 100644 swarms/structs/talktier.py diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index dd0dce3b..9fa7a35e 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -72,7 +72,6 @@ from swarms.structs.swarming_architectures import ( staircase_swarm, star_swarm, ) - from swarms.structs.swarms_api import ( SwarmsAPIClient, SwarmRequest, @@ -81,6 +80,8 @@ from swarms.structs.swarms_api import ( SwarmValidationError, AgentInput, ) +from swarms.structs.talk_hier import TalkHier, AgentRole, CommunicationEvent +from swarms.structs.octotools import OctoToolsSwarm, Tool, ToolType __all__ = [ "Agent", @@ -147,6 +148,12 @@ __all__ = [ "MultiAgentRouter", "MemeAgentGenerator", "ModelRouter", + "OctoToolsSwarm", + "Tool", + "ToolType", + "TalkHier", + "AgentRole", + "CommunicationEvent", "SwarmsAPIClient", "SwarmRequest", "SwarmAuthenticationError", diff --git a/swarms/structs/octotools.py b/swarms/structs/octotools.py new file mode 100644 index 00000000..f7c027e3 --- /dev/null +++ b/swarms/structs/octotools.py @@ -0,0 +1,705 @@ +""" +OctoToolsSwarm: A multi-agent system for complex reasoning. +Implements the OctoTools framework using swarms. +""" + +import json +import logging +import os +import re +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional +import math # Import the math module + +from dotenv import load_dotenv +from swarms import Agent +from swarms.structs.conversation import Conversation +# from exa_search import exa_search as web_search_execute + + +# Load environment variables +load_dotenv() + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ToolType(Enum): + """Defines the types of tools available.""" + + IMAGE_CAPTIONER = "image_captioner" + OBJECT_DETECTOR = "object_detector" + WEB_SEARCH = "web_search" + PYTHON_CALCULATOR = "python_calculator" + # Add more tool types as needed + + +@dataclass +class Tool: + """ + Represents an external tool. + + Attributes: + name: Unique name of the tool. + description: Description of the tool's function. + metadata: Dictionary containing tool metadata. + execute_func: Callable function that executes the tool's logic. + """ + + name: str + description: str + metadata: Dict[str, Any] + execute_func: Callable + + def execute(self, **kwargs): + """Executes the tool's logic, handling potential errors.""" + try: + return self.execute_func(**kwargs) + except Exception as e: + logger.error(f"Error executing tool {self.name}: {str(e)}") + return {"error": str(e)} + + +class AgentRole(Enum): + """Defines the roles for agents in the OctoTools system.""" + + PLANNER = "planner" + VERIFIER = "verifier" + SUMMARIZER = "summarizer" + + +class OctoToolsSwarm: + """ + A multi-agent system implementing the OctoTools framework. + + Attributes: + model_name: Name of the LLM model to use. + max_iterations: Maximum number of action-execution iterations. + base_path: Path for saving agent states. + tools: List of available Tool objects. + """ + + def __init__( + self, + tools: List[Tool], + model_name: str = "gemini/gemini-2.0-flash", + max_iterations: int = 10, + base_path: Optional[str] = None, + ): + """Initialize the OctoToolsSwarm system.""" + self.model_name = model_name + self.max_iterations = max_iterations + self.base_path = Path(base_path) if base_path else Path("./octotools_states") + self.base_path.mkdir(exist_ok=True) + self.tools = {tool.name: tool for tool in tools} # Store tools in a dictionary + + # Initialize agents + self._init_agents() + + # Create conversation tracker and memory + self.conversation = Conversation() + self.memory = [] # Store the trajectory + + def _init_agents(self) -> None: + """Initialize all agents with their specific roles and prompts.""" + # Planner agent + self.planner = Agent( + agent_name="OctoTools-Planner", + system_prompt=self._get_planner_prompt(), + model_name=self.model_name, + max_loops=3, + saved_state_path=str(self.base_path / "planner.json"), + verbose=True, + ) + + # Verifier agent + self.verifier = Agent( + agent_name="OctoTools-Verifier", + system_prompt=self._get_verifier_prompt(), + model_name=self.model_name, + max_loops=1, + saved_state_path=str(self.base_path / "verifier.json"), + verbose=True, + ) + + # Summarizer agent + self.summarizer = Agent( + agent_name="OctoTools-Summarizer", + system_prompt=self._get_summarizer_prompt(), + model_name=self.model_name, + max_loops=1, + saved_state_path=str(self.base_path / "summarizer.json"), + verbose=True, + ) + + def _get_planner_prompt(self) -> str: + """Get the prompt for the planner agent (Improved with few-shot examples).""" + tool_descriptions = "\n".join( + [ + f"- {tool_name}: {self.tools[tool_name].description}" + for tool_name in self.tools + ] + ) + return f"""You are the Planner in the OctoTools framework. Your role is to analyze the user's query, + identify required skills, suggest relevant tools, and plan the steps to solve the problem. + + 1. **Analyze the user's query:** Understand the requirements and identify the necessary skills and potentially relevant tools. + 2. **Perform high-level planning:** Create a rough outline of how tools might be used to solve the problem. + 3. **Perform low-level planning (action prediction):** At each step, select the best tool to use and formulate a specific sub-goal for that tool, considering the current context. + + Available Tools: + {tool_descriptions} + + Output your response in JSON format. Here are examples for different stages: + + **Query Analysis (High-Level Planning):** + Example Input: + Query: "What is the capital of France?" + + Example Output: + ```json + {{ + "summary": "The user is asking for the capital of France.", + "required_skills": ["knowledge retrieval"], + "relevant_tools": ["Web_Search_Tool"] + }} + ``` + + **Action Prediction (Low-Level Planning):** + Example Input: + Context: {{ "query": "What is the capital of France?", "available_tools": ["Web_Search_Tool"] }} + + Example Output: + ```json + {{ + "justification": "The Web_Search_Tool can be used to directly find the capital of France.", + "context": {{}}, + "sub_goal": "Search the web for 'capital of France'.", + "tool_name": "Web_Search_Tool" + }} + ``` + Another Example: + Context: {{"query": "How many objects are in the image?", "available_tools": ["Image_Captioner_Tool", "Object_Detector_Tool"], "image": "objects.png"}} + + Example Output: + ```json + {{ + "justification": "First, get a general description of the image to understand the context.", + "context": {{ "image": "objects.png" }}, + "sub_goal": "Generate a description of the image.", + "tool_name": "Image_Captioner_Tool" + }} + ``` + + Example for Finding Square Root: + Context: {{"query": "What is the square root of the number of objects in the image?", "available_tools": ["Object_Detector_Tool", "Python_Calculator_Tool"], "image": "objects.png", "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"]}} + + Example Output: + ```json + {{ + "justification": "We have detected 4 objects in the image. Now we need to find the square root of 4.", + "context": {{}}, + "sub_goal": "Calculate the square root of 4", + "tool_name": "Python_Calculator_Tool" + }} + ``` + + Your output MUST be a single, valid JSON object with the following keys: + - justification (string): Your reasoning. + - context (dict): A dictionary containing relevant information. + - sub_goal (string): The specific instruction for the tool. + - tool_name (string): The EXACT name of the tool to use. + + Do NOT include any text outside of the JSON object. + """ + + def _get_verifier_prompt(self) -> str: + """Get the prompt for the verifier agent (Improved with few-shot examples).""" + return """You are the Context Verifier in the OctoTools framework. Your role is to analyze the current context + and memory to determine if the problem is solved, if there are any inconsistencies, or if further steps are needed. + + Output your response in JSON format: + + Expected output structure: + ```json + { + "completeness": "Indicate whether the query is fully, partially, or not answered.", + "inconsistencies": "List any inconsistencies found in the context or memory.", + "verification_needs": "List any information that needs further verification.", + "ambiguities": "List any ambiguities found in the context or memory.", + "stop_signal": true/false + } + ``` + + Example Input: + Context: { "last_result": { "result": "Caption: The image shows a cat." } } + Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Image_Captioner_Tool" } } ] + + Example Output: + ```json + { + "completeness": "partial", + "inconsistencies": [], + "verification_needs": ["Object detection to confirm the presence of a cat."], + "ambiguities": [], + "stop_signal": false + } + ``` + + Another Example: + Context: { "last_result": { "result": ["Detected object: cat"] } } + Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } } ] + + Example Output: + ```json + { + "completeness": "yes", + "inconsistencies": [], + "verification_needs": [], + "ambiguities": [], + "stop_signal": true + } + ``` + + Square Root Example: + Context: { + "query": "What is the square root of the number of objects in the image?", + "image": "example.png", + "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"], + "Python_Calculator_Tool_result": "Result of 4**0.5 is 2.0" + } + Memory: [ + { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } }, + { "component": "Action Predictor", "result": { "tool_name": "Python_Calculator_Tool" } } + ] + + Example Output: + ```json + { + "completeness": "yes", + "inconsistencies": [], + "verification_needs": [], + "ambiguities": [], + "stop_signal": true + } + ``` + """ + + def _get_summarizer_prompt(self) -> str: + """Get the prompt for the summarizer agent (Improved with few-shot examples).""" + return """You are the Solution Summarizer in the OctoTools framework. Your role is to synthesize the final + answer to the user's query based on the complete trajectory of actions and results. + + Output your response in JSON format: + + Expected output structure: + ```json + { + "final_answer": "Provide a clear and concise answer to the original query." + } + ``` + Example Input: + Memory: [ + {"component": "Query Analyzer", "result": {"summary": "Find the capital of France."}}, + {"component": "Action Predictor", "result": {"tool_name": "Web_Search_Tool"}}, + {"component": "Tool Execution", "result": {"result": "The capital of France is Paris."}} + ] + + Example Output: + ```json + { + "final_answer": "The capital of France is Paris." + } + ``` + + Square Root Example: + Memory: [ + {"component": "Query Analyzer", "result": {"summary": "Find the square root of the number of objects in the image."}}, + {"component": "Action Predictor", "result": {"tool_name": "Object_Detector_Tool", "sub_goal": "Detect objects in the image"}}, + {"component": "Tool Execution", "result": {"result": ["object1", "object2", "object3", "object4"]}}, + {"component": "Action Predictor", "result": {"tool_name": "Python_Calculator_Tool", "sub_goal": "Calculate the square root of 4"}}, + {"component": "Tool Execution", "result": {"result": "Result of 4**0.5 is 2.0"}} + ] + + Example Output: + ```json + { + "final_answer": "The square root of the number of objects in the image is 2.0. There are 4 objects in the image, and the square root of 4 is 2.0." + } + ``` + """ + + def _safely_parse_json(self, json_str: str) -> Dict[str, Any]: + """Safely parse JSON, handling errors and using recursive descent.""" + try: + return json.loads(json_str) + except json.JSONDecodeError: + logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}") + try: + # More robust JSON extraction with recursive descent + def extract_json(s): + stack = [] + start = -1 + for i, c in enumerate(s): + if c == '{': + if not stack: + start = i + stack.append(c) + elif c == '}': + if stack: + stack.pop() + if not stack and start != -1: + return s[start:i+1] + return None + + extracted_json = extract_json(json_str) + if extracted_json: + logger.info(f"Extracted JSON: {extracted_json}") + return json.loads(extracted_json) + else: + logger.error("Failed to extract JSON using recursive descent.") + return {"error": "Failed to parse JSON", "content": json_str} + except Exception as e: + logger.exception(f"Error during JSON extraction: {e}") + return {"error": "Failed to parse JSON", "content": json_str} + + def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]: + """Executes a tool based on its name and provided context.""" + if tool_name not in self.tools: + return {"error": f"Tool '{tool_name}' not found."} + + tool = self.tools[tool_name] + try: + # For Python Calculator tool, handle object counts from Object Detector + if tool_name == "Python_Calculator_Tool": + # Check for object detector results + object_detector_result = context.get("Object_Detector_Tool_result") + if object_detector_result and isinstance(object_detector_result, list): + # Calculate the number of objects + num_objects = len(object_detector_result) + # If sub_goal doesn't already contain an expression, create one + if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]: + context["expression"] = f"{num_objects}**0.5" + elif "expression" not in context: + # Default to square root if no expression is specified + context["expression"] = f"{num_objects}**0.5" + + # Filter context: only pass expected inputs to the tool + valid_inputs = { + k: v for k, v in context.items() if k in tool.metadata.get("input_types", {}) + } + result = tool.execute(**valid_inputs) + return {"result": result} + except Exception as e: + logger.exception(f"Error executing tool {tool_name}: {e}") + return {"error": str(e)} + + def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]: + """Runs a swarms agent, handling output and JSON parsing.""" + try: + # Construct the full input, including the system prompt + full_input = f"{agent.system_prompt}\n\n{input_prompt}" + + # Run the agent and capture the output + agent_response = agent.run(full_input) + + logger.info(f"DEBUG: Raw agent response: {agent_response}") + + # Extract the LLM's response (remove conversation history, etc.) + response_text = agent_response # Assuming direct return + + # Try to parse the response as JSON + parsed_response = self._safely_parse_json(response_text) + + return parsed_response + + except Exception as e: + logger.exception(f"Error running agent {agent.agent_name}: {e}") + return {"error": f"Agent {agent.agent_name} failed: {str(e)}"} + + def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]: + """Execute the task through the multi-agent workflow.""" + logger.info(f"Starting task: {query}") + + try: + # Step 1: Query Analysis (High-Level Planning) + planner_input = ( + f"Analyze the following query and determine the necessary skills and" + f" relevant tools: {query}" + ) + query_analysis = self._run_agent(self.planner, planner_input) + + if "error" in query_analysis: + return { + "error": f"Planner query analysis failed: {query_analysis['error']}", + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string(), + } + + self.memory.append( + {"step": 0, "component": "Query Analyzer", "result": query_analysis} + ) + self.conversation.add( + role=self.planner.agent_name, content=json.dumps(query_analysis) + ) + + # Initialize context with the query and image (if provided) + context = {"query": query} + if image: + context["image"] = image + + # Add available tools to context + if "relevant_tools" in query_analysis: + context["available_tools"] = query_analysis["relevant_tools"] + else: + # If no relevant tools specified, make all tools available + context["available_tools"] = list(self.tools.keys()) + + step_count = 1 + + # Step 2: Iterative Action-Execution Loop + while step_count <= self.max_iterations: + logger.info(f"Starting iteration {step_count} of {self.max_iterations}") + + # Step 2a: Action Prediction (Low-Level Planning) + action_planner_input = ( + f"Current Context: {json.dumps(context)}\nAvailable Tools:" + f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the" + " next step." + ) + action = self._run_agent(self.planner, action_planner_input) + if "error" in action: + logger.error(f"Error in action prediction: {action['error']}") + return { + "error": f"Planner action prediction failed: {action['error']}", + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string() + } + self.memory.append( + {"step": step_count, "component": "Action Predictor", "result": action} + ) + self.conversation.add(role=self.planner.agent_name, content=json.dumps(action)) + + # Input Validation for Action (Relaxed) + if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action: + error_msg = ( + "Action prediction did not return required fields (tool_name," + " sub_goal) or was not a dictionary." + ) + logger.error(error_msg) + self.memory.append( + {"step": step_count, "component": "Error", "result": error_msg} + ) + break + + # Step 2b: Execute Tool + tool_execution_context = { + **context, + **action.get("context", {}), # Add any additional context + "sub_goal": action["sub_goal"], # Pass sub_goal to tool + } + + tool_result = self._execute_tool(action["tool_name"], tool_execution_context) + + self.memory.append( + { + "step": step_count, + "component": "Tool Execution", + "result": tool_result, + } + ) + + # Step 2c: Context Update - Store result with a descriptive key + if "result" in tool_result: + context[f"{action['tool_name']}_result"] = tool_result["result"] + if "error" in tool_result: + context[f"{action['tool_name']}_error"] = tool_result["error"] + + # Step 2d: Context Verification + verifier_input = ( + f"Current Context: {json.dumps(context)}\nMemory:" + f" {json.dumps(self.memory)}\nQuery: {query}" + ) + verification = self._run_agent(self.verifier, verifier_input) + if "error" in verification: + return { + "error": f"Verifier failed: {verification['error']}", + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string(), + } + + self.memory.append( + { + "step": step_count, + "component": "Context Verifier", + "result": verification, + } + ) + self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification)) + + # Check for stop signal from Verifier + if verification.get("stop_signal") is True: + logger.info("Received stop signal from verifier. Stopping iterations.") + break + + # Safety mechanism - if we've executed the same tool multiple times + same_tool_count = sum( + 1 for m in self.memory + if m.get("component") == "Action Predictor" + and m.get("result", {}).get("tool_name") == action.get("tool_name") + ) + + if same_tool_count > 3: + logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.") + break + + step_count += 1 + + # Step 3: Solution Summarization + summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}" + + summarization = self._run_agent(self.summarizer, summarizer_input) + if "error" in summarization: + return { + "error": f"Summarizer failed: {summarization['error']}", + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string() + } + self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization)) + + return { + "final_answer": summarization.get("final_answer", "No answer found."), + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string(), + } + + except Exception as e: + logger.exception(f"Unexpected error in run method: {e}") # More detailed + return { + "error": str(e), + "trajectory": self.memory, + "conversation": self.conversation.return_history_as_string(), + } + + def save_state(self) -> None: + """Save the current state of all agents.""" + for agent in [self.planner, self.verifier, self.summarizer]: + try: + agent.save_state() + except Exception as e: + logger.error(f"Error saving state for {agent.agent_name}: {str(e)}") + + def load_state(self) -> None: + """Load the saved state of all agents.""" + for agent in [self.planner, self.verifier, self.summarizer]: + try: + agent.load_state() + except Exception as e: + logger.error(f"Error loading state for {agent.agent_name}: {str(e)}") + + +# --- Example Usage --- + + +# Define dummy tool functions (replace with actual implementations) +def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str: + """Dummy image captioner.""" + print(f"image_captioner_execute called with image: {image}, prompt: {prompt}") + return f"Caption for {image}: A descriptive caption (dummy)." # Simplified + + +def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]: + """Dummy object detector, handles missing labels gracefully.""" + print(f"object_detector_execute called with image: {image}, labels: {labels}") + if not labels: + return ["object1", "object2", "object3", "object4"] # Return default objects if no labels + return [f"Detected {label}" for label in labels] # Simplified + + +def web_search_execute(query: str, **kwargs) -> str: + """Dummy web search.""" + print(f"web_search_execute called with query: {query}") + return f"Search results for '{query}'..." # Simplified + + +def python_calculator_execute(expression: str, **kwargs) -> str: + """Python calculator (using math module).""" + print(f"python_calculator_execute called with: {expression}") + try: + # Safely evaluate only simple expressions involving numbers and basic operations + if re.match(r"^[0-9+\-*/().\s]+$", expression): + result = eval(expression, {"__builtins__": {}, "math": math}) + return f"Result of {expression} is {result}" + else: + return "Error: Invalid expression for calculator." + except Exception as e: + return f"Error: {e}" + + +# Create Tool instances +image_captioner = Tool( + name="Image_Captioner_Tool", + description="Generates a caption for an image.", + metadata={ + "input_types": {"image": "str", "prompt": "str"}, + "output_type": "str", + "limitations": "May struggle with complex scenes or ambiguous objects.", + "best_practices": "Use with clear, well-lit images. Provide specific prompts for better results.", + }, + execute_func=image_captioner_execute, +) + +object_detector = Tool( + name="Object_Detector_Tool", + description="Detects objects in an image.", + metadata={ + "input_types": {"image": "str", "labels": "list"}, + "output_type": "list", + "limitations": "Accuracy depends on the quality of the image and the clarity of the objects.", + "best_practices": "Provide a list of specific object labels to detect. Use high-resolution images.", + }, + execute_func=object_detector_execute, +) + +web_search = Tool( + name="Web_Search_Tool", + description="Performs a web search.", + metadata={ + "input_types": {"query": "str"}, + "output_type": "str", + "limitations": "May not find specific or niche information.", + "best_practices": "Use specific and descriptive keywords for better results.", + }, + execute_func=web_search_execute, +) + +calculator = Tool( + name="Python_Calculator_Tool", + description="Evaluates a Python expression.", + metadata={ + "input_types": {"expression": "str"}, + "output_type": "str", + "limitations": "Cannot handle complex mathematical functions or libraries.", + "best_practices": "Use for basic arithmetic and simple calculations.", + }, + execute_func=python_calculator_execute, +) + +# Create an OctoToolsSwarm agent +agent = OctoToolsSwarm(tools=[image_captioner, object_detector, web_search, calculator]) + +# Run the agent with a query +# query = "Who is the president of US, final all the PM of humans?" +# # Create a dummy image file for testing +# with open("example.png", "w") as f: +# f.write("Dummy image content") + +# image_path = "example.png" +# result = agent.run(query, image=image_path) + +# print(result["final_answer"]) +# print(result["trajectory"]) # Uncomment to see the full trajectory +# print("\n".join(result["conversation"])) # Uncomment to see agent conversation \ No newline at end of file diff --git a/swarms/structs/talktier.py b/swarms/structs/talktier.py deleted file mode 100644 index f1dc4df2..00000000 --- a/swarms/structs/talktier.py +++ /dev/null @@ -1,580 +0,0 @@ -""" -TalkHier: A hierarchical multi-agent framework for content generation and refinement. -Implements structured communication and evaluation protocols. -""" - -import json -import logging -from dataclasses import dataclass -from datetime import datetime -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from swarms import Agent -from swarms.structs.conversation import Conversation - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class AgentRole(Enum): - """Defines the possible roles for agents in the system.""" - - SUPERVISOR = "supervisor" - GENERATOR = "generator" - EVALUATOR = "evaluator" - REVISOR = "revisor" - - -@dataclass -class CommunicationEvent: - """Represents a structured communication event between agents.""" - - message: str - background: Optional[str] = None - intermediate_output: Optional[Dict[str, Any]] = None - - -class TalkHier: - """ - A hierarchical multi-agent system for content generation and refinement. - - Implements the TalkHier framework with structured communication protocols - and hierarchical refinement processes. - - Attributes: - max_iterations: Maximum number of refinement iterations - quality_threshold: Minimum score required for content acceptance - model_name: Name of the LLM model to use - base_path: Path for saving agent states - """ - - def __init__( - self, - max_iterations: int = 3, - quality_threshold: float = 0.8, - model_name: str = "gpt-4", - base_path: Optional[str] = None, - return_string: bool = False, - ): - """Initialize the TalkHier system.""" - self.max_iterations = max_iterations - self.quality_threshold = quality_threshold - self.model_name = model_name - self.return_string = return_string - self.base_path = ( - Path(base_path) if base_path else Path("./agent_states") - ) - self.base_path.mkdir(exist_ok=True) - - # Initialize agents - self._init_agents() - - # Create conversation - self.conversation = Conversation() - - def _safely_parse_json(self, json_str: str) -> Dict[str, Any]: - """ - Safely parse JSON string, handling various formats and potential errors. - - Args: - json_str: String containing JSON data - - Returns: - Parsed dictionary - """ - try: - # Try direct JSON parsing - return json.loads(json_str) - except json.JSONDecodeError: - try: - # Try extracting JSON from potential text wrapper - import re - - json_match = re.search(r"\{.*\}", json_str, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - # Try extracting from markdown code blocks - code_block_match = re.search( - r"```(?:json)?\s*(\{.*?\})\s*```", - json_str, - re.DOTALL, - ) - if code_block_match: - return json.loads(code_block_match.group(1)) - except Exception as e: - logger.warning(f"Failed to extract JSON: {str(e)}") - - # Fallback: create structured dict from text - return { - "content": json_str, - "metadata": { - "parsed": False, - "timestamp": str(datetime.now()), - }, - } - - def _init_agents(self) -> None: - """Initialize all agents with their specific roles and prompts.""" - # Main supervisor agent - self.main_supervisor = Agent( - agent_name="Main-Supervisor", - system_prompt=self._get_supervisor_prompt(), - model_name=self.model_name, - max_loops=1, - saved_state_path=str( - self.base_path / "main_supervisor.json" - ), - verbose=True, - ) - - # Generator agent - self.generator = Agent( - agent_name="Content-Generator", - system_prompt=self._get_generator_prompt(), - model_name=self.model_name, - max_loops=1, - saved_state_path=str(self.base_path / "generator.json"), - verbose=True, - ) - - # Evaluators - self.evaluators = [ - Agent( - agent_name=f"Evaluator-{i}", - system_prompt=self._get_evaluator_prompt(i), - model_name=self.model_name, - max_loops=1, - saved_state_path=str( - self.base_path / f"evaluator_{i}.json" - ), - verbose=True, - ) - for i in range(3) - ] - - # Revisor agent - self.revisor = Agent( - agent_name="Content-Revisor", - system_prompt=self._get_revisor_prompt(), - model_name=self.model_name, - max_loops=1, - saved_state_path=str(self.base_path / "revisor.json"), - verbose=True, - ) - - def _get_supervisor_prompt(self) -> str: - """Get the prompt for the supervisor agent.""" - return """You are a Supervisor agent responsible for orchestrating the content generation process. Your role is to analyze tasks, develop strategies, and coordinate other agents effectively. - -You must carefully analyze each task to understand: -- The core objectives and requirements -- Target audience and their needs -- Complexity level and scope -- Any constraints or special considerations - -Based on your analysis, develop a clear strategy that: -- Breaks down the task into manageable steps -- Identifies which agents are best suited for each step -- Anticipates potential challenges -- Sets clear success criteria - -Output all responses in strict JSON format: -{ - "thoughts": { - "task_analysis": "Detailed analysis of requirements, audience, scope, and constraints", - "strategy": "Step-by-step plan including agent allocation and success metrics", - "concerns": "Potential challenges, edge cases, and mitigation strategies" - }, - "next_action": { - "agent": "Specific agent to engage (Generator, Evaluator, or Revisor)", - "instruction": "Detailed instructions including context, requirements, and expected output" - } -}""" - - def _get_generator_prompt(self) -> str: - """Get the prompt for the generator agent.""" - return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience. - -When generating content: -- Thoroughly research and fact-check all information -- Structure content logically with clear flow -- Use appropriate tone and language for the target audience -- Include relevant examples and explanations -- Ensure content is original and plagiarism-free -- Consider SEO best practices where applicable - -Output all responses in strict JSON format: -{ - "content": { - "main_body": "The complete generated content with proper formatting and structure", - "metadata": { - "word_count": "Accurate word count of main body", - "target_audience": "Detailed audience description", - "key_points": ["List of main points covered"], - "sources": ["List of reference sources if applicable"], - "readability_level": "Estimated reading level", - "tone": "Description of content tone" - } - } -}""" - - def _get_evaluator_prompt(self, evaluator_id: int) -> str: - """Get the prompt for an evaluator agent.""" - return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive. - -Evaluate content across multiple dimensions: -- Accuracy: factual correctness, source reliability -- Clarity: readability, organization, flow -- Coherence: logical consistency, argument structure -- Engagement: interest level, relevance -- Completeness: topic coverage, depth -- Technical quality: grammar, spelling, formatting -- Audience alignment: appropriate level and tone - -Output all responses in strict JSON format: -{{ - "scores": {{ - "overall": "0.0-1.0 composite score", - "categories": {{ - "accuracy": "0.0-1.0 score with evidence", - "clarity": "0.0-1.0 score with examples", - "coherence": "0.0-1.0 score with analysis", - "engagement": "0.0-1.0 score with justification", - "completeness": "0.0-1.0 score with gaps identified", - "technical_quality": "0.0-1.0 score with issues noted", - "audience_alignment": "0.0-1.0 score with reasoning" - }} - }}, - "feedback": [ - "Specific, actionable improvement suggestions", - "Examples of issues found", - "Recommendations for enhancement" - ], - "strengths": ["Notable positive aspects"], - "weaknesses": ["Areas needing improvement"] -}}""" - - def _get_revisor_prompt(self) -> str: - """Get the prompt for the revisor agent.""" - return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose. - -When revising content: -- Address all evaluator feedback systematically -- Maintain consistency in tone and style -- Preserve accurate information -- Enhance clarity and flow -- Fix technical issues -- Optimize for target audience -- Track all changes made - -Output all responses in strict JSON format: -{ - "revised_content": { - "main_body": "Complete revised content incorporating all improvements", - "metadata": { - "word_count": "Updated word count", - "changes_made": [ - "Detailed list of specific changes and improvements", - "Reasoning for each major revision", - "Feedback points addressed" - ], - "improvement_summary": "Overview of main enhancements", - "preserved_elements": ["Key elements maintained from original"], - "revision_approach": "Strategy used for revisions" - } - } -}""" - - def _evaluate_content( - self, content: Union[str, Dict] - ) -> Dict[str, Any]: - """ - Coordinate the evaluation of content across multiple evaluators. - - Args: - content: Content to evaluate (string or dict) - - Returns: - Combined evaluation results - """ - try: - # Ensure content is in correct format - content_dict = ( - self._safely_parse_json(content) - if isinstance(content, str) - else content - ) - - # Collect evaluations - evaluations = [] - for evaluator in self.evaluators: - try: - eval_response = evaluator.run( - json.dumps(content_dict) - ) - - self.conversation.add( - role=evaluator.agent_name, - content=eval_response, - ) - - eval_data = self._safely_parse_json(eval_response) - evaluations.append(eval_data) - except Exception as e: - logger.warning(f"Evaluator error: {str(e)}") - evaluations.append( - self._get_fallback_evaluation() - ) - - # Aggregate results - return self._aggregate_evaluations(evaluations) - - except Exception as e: - logger.error(f"Evaluation error: {str(e)}") - return self._get_fallback_evaluation() - - def _get_fallback_evaluation(self) -> Dict[str, Any]: - """Get a safe fallback evaluation result.""" - return { - "scores": { - "overall": 0.5, - "categories": { - "accuracy": 0.5, - "clarity": 0.5, - "coherence": 0.5, - }, - }, - "feedback": ["Evaluation failed"], - "metadata": { - "timestamp": str(datetime.now()), - "is_fallback": True, - }, - } - - def _aggregate_evaluations( - self, evaluations: List[Dict[str, Any]] - ) -> Dict[str, Any]: - """ - Aggregate multiple evaluation results into a single evaluation. - - Args: - evaluations: List of evaluation results - - Returns: - Combined evaluation - """ - # Calculate average scores - overall_scores = [] - accuracy_scores = [] - clarity_scores = [] - coherence_scores = [] - all_feedback = [] - - for eval_data in evaluations: - try: - scores = eval_data.get("scores", {}) - overall_scores.append(scores.get("overall", 0.5)) - - categories = scores.get("categories", {}) - accuracy_scores.append( - categories.get("accuracy", 0.5) - ) - clarity_scores.append(categories.get("clarity", 0.5)) - coherence_scores.append( - categories.get("coherence", 0.5) - ) - - all_feedback.extend(eval_data.get("feedback", [])) - except Exception as e: - logger.warning( - f"Error aggregating evaluation: {str(e)}" - ) - - def safe_mean(scores: List[float]) -> float: - return sum(scores) / len(scores) if scores else 0.5 - - return { - "scores": { - "overall": safe_mean(overall_scores), - "categories": { - "accuracy": safe_mean(accuracy_scores), - "clarity": safe_mean(clarity_scores), - "coherence": safe_mean(coherence_scores), - }, - }, - "feedback": list(set(all_feedback)), # Remove duplicates - "metadata": { - "evaluator_count": len(evaluations), - "timestamp": str(datetime.now()), - }, - } - - def run(self, task: str) -> Dict[str, Any]: - """ - Generate and iteratively refine content based on the given task. - - Args: - task: Content generation task description - - Returns: - Dictionary containing final content and metadata - """ - logger.info(f"Starting content generation for task: {task}") - - try: - # Get initial direction from supervisor - supervisor_response = self.main_supervisor.run(task) - - self.conversation.add( - role=self.main_supervisor.agent_name, - content=supervisor_response, - ) - - supervisor_data = self._safely_parse_json( - supervisor_response - ) - - # Generate initial content - generator_response = self.generator.run( - json.dumps(supervisor_data.get("next_action", {})) - ) - - self.conversation.add( - role=self.generator.agent_name, - content=generator_response, - ) - - current_content = self._safely_parse_json( - generator_response - ) - - for iteration in range(self.max_iterations): - logger.info(f"Starting iteration {iteration + 1}") - - # Evaluate current content - evaluation = self._evaluate_content(current_content) - - # Check if quality threshold is met - if ( - evaluation["scores"]["overall"] - >= self.quality_threshold - ): - logger.info( - "Quality threshold met, returning content" - ) - return { - "content": current_content.get( - "content", {} - ).get("main_body", ""), - "final_score": evaluation["scores"][ - "overall" - ], - "iterations": iteration + 1, - "metadata": { - "content_metadata": current_content.get( - "content", {} - ).get("metadata", {}), - "evaluation": evaluation, - }, - } - - # Revise content if needed - revision_input = { - "content": current_content, - "evaluation": evaluation, - } - - revision_response = self.revisor.run( - json.dumps(revision_input) - ) - current_content = self._safely_parse_json( - revision_response - ) - - self.conversation.add( - role=self.revisor.agent_name, - content=revision_response, - ) - - logger.warning( - "Max iterations reached without meeting quality threshold" - ) - - except Exception as e: - logger.error(f"Error in generate_and_refine: {str(e)}") - current_content = { - "content": {"main_body": f"Error: {str(e)}"} - } - evaluation = self._get_fallback_evaluation() - - if self.return_string: - return self.conversation.return_history_as_string() - else: - return { - "content": current_content.get("content", {}).get( - "main_body", "" - ), - "final_score": evaluation["scores"]["overall"], - "iterations": self.max_iterations, - "metadata": { - "content_metadata": current_content.get( - "content", {} - ).get("metadata", {}), - "evaluation": evaluation, - "error": "Max iterations reached", - }, - } - - def save_state(self) -> None: - """Save the current state of all agents.""" - for agent in [ - self.main_supervisor, - self.generator, - *self.evaluators, - self.revisor, - ]: - try: - agent.save_state() - except Exception as e: - logger.error( - f"Error saving state for {agent.agent_name}: {str(e)}" - ) - - def load_state(self) -> None: - """Load the saved state of all agents.""" - for agent in [ - self.main_supervisor, - self.generator, - *self.evaluators, - self.revisor, - ]: - try: - agent.load_state() - except Exception as e: - logger.error( - f"Error loading state for {agent.agent_name}: {str(e)}" - ) - - -if __name__ == "__main__": - # Example usage - try: - talkhier = TalkHier( - max_iterations=1, - quality_threshold=0.8, - model_name="gpt-4o", - return_string=True, - ) - - task = "Write a comprehensive explanation of quantum computing for beginners" - result = talkhier.run(task) - print(result) - - # print(f"Final content: {result['content']}") - # print(f"Quality score: {result['final_score']}") - # print(f"Iterations: {result['iterations']}") - - except Exception as e: - logger.error(f"Error in main execution: {str(e)}")