Merge pull request #787 from harshalmore31/octotools
	
		
	
				
					
				
			Implemented OctoTools Framework with Swarmspull/790/head
						commit
						ddb4119c04
					
				| @ -0,0 +1,725 @@ | |||||||
|  | """ | ||||||
|  | OctoToolsSwarm: A multi-agent system for complex reasoning. | ||||||
|  | Implements the OctoTools framework using swarms. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import json | ||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import re | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from enum import Enum | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Any, Callable, Dict, List, Optional | ||||||
|  | import math  # Import the math module | ||||||
|  | 
 | ||||||
|  | from dotenv import load_dotenv | ||||||
|  | from swarms import Agent | ||||||
|  | from swarms.structs.conversation import Conversation | ||||||
|  | # from exa_search import exa_search as web_search_execute | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Load environment variables | ||||||
|  | load_dotenv() | ||||||
|  | 
 | ||||||
|  | # Setup logging | ||||||
|  | logging.basicConfig(level=logging.INFO) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ToolType(Enum): | ||||||
|  |     """Defines the types of tools available.""" | ||||||
|  | 
 | ||||||
|  |     IMAGE_CAPTIONER = "image_captioner" | ||||||
|  |     OBJECT_DETECTOR = "object_detector" | ||||||
|  |     WEB_SEARCH = "web_search" | ||||||
|  |     PYTHON_CALCULATOR = "python_calculator" | ||||||
|  |     # Add more tool types as needed | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class Tool: | ||||||
|  |     """ | ||||||
|  |     Represents an external tool. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         name: Unique name of the tool. | ||||||
|  |         description: Description of the tool's function. | ||||||
|  |         metadata: Dictionary containing tool metadata. | ||||||
|  |         execute_func: Callable function that executes the tool's logic. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     name: str | ||||||
|  |     description: str | ||||||
|  |     metadata: Dict[str, Any] | ||||||
|  |     execute_func: Callable | ||||||
|  | 
 | ||||||
|  |     def execute(self, **kwargs): | ||||||
|  |         """Executes the tool's logic, handling potential errors.""" | ||||||
|  |         try: | ||||||
|  |             return self.execute_func(**kwargs) | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Error executing tool {self.name}: {str(e)}") | ||||||
|  |             return {"error": str(e)} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AgentRole(Enum): | ||||||
|  |     """Defines the roles for agents in the OctoTools system.""" | ||||||
|  | 
 | ||||||
|  |     PLANNER = "planner" | ||||||
|  |     VERIFIER = "verifier" | ||||||
|  |     SUMMARIZER = "summarizer" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class OctoToolsSwarm: | ||||||
|  |     """ | ||||||
|  |     A multi-agent system implementing the OctoTools framework. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         model_name: Name of the LLM model to use. | ||||||
|  |         max_iterations: Maximum number of action-execution iterations. | ||||||
|  |         base_path: Path for saving agent states. | ||||||
|  |         tools: List of available Tool objects. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         tools: List[Tool], | ||||||
|  |         model_name: str = "gemini/gemini-2.0-flash", | ||||||
|  |         max_iterations: int = 10, | ||||||
|  |         base_path: Optional[str] = None, | ||||||
|  |     ): | ||||||
|  |         """Initialize the OctoToolsSwarm system.""" | ||||||
|  |         self.model_name = model_name | ||||||
|  |         self.max_iterations = max_iterations | ||||||
|  |         self.base_path = Path(base_path) if base_path else Path("./octotools_states") | ||||||
|  |         self.base_path.mkdir(exist_ok=True) | ||||||
|  |         self.tools = {tool.name: tool for tool in tools}  # Store tools in a dictionary | ||||||
|  | 
 | ||||||
|  |         # Initialize agents | ||||||
|  |         self._init_agents() | ||||||
|  | 
 | ||||||
|  |         # Create conversation tracker and memory | ||||||
|  |         self.conversation = Conversation() | ||||||
|  |         self.memory = []  # Store the trajectory | ||||||
|  | 
 | ||||||
|  |     def _init_agents(self) -> None: | ||||||
|  |         """Initialize all agents with their specific roles and prompts.""" | ||||||
|  |         # Planner agent | ||||||
|  |         self.planner = Agent( | ||||||
|  |             agent_name="OctoTools-Planner", | ||||||
|  |             system_prompt=self._get_planner_prompt(), | ||||||
|  |             model_name=self.model_name, | ||||||
|  |             max_loops=3,   | ||||||
|  |             saved_state_path=str(self.base_path / "planner.json"), | ||||||
|  |             verbose=True,   | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Verifier agent | ||||||
|  |         self.verifier = Agent( | ||||||
|  |             agent_name="OctoTools-Verifier", | ||||||
|  |             system_prompt=self._get_verifier_prompt(), | ||||||
|  |             model_name=self.model_name, | ||||||
|  |             max_loops=1,   | ||||||
|  |             saved_state_path=str(self.base_path / "verifier.json"), | ||||||
|  |             verbose=True, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Summarizer agent | ||||||
|  |         self.summarizer = Agent( | ||||||
|  |             agent_name="OctoTools-Summarizer", | ||||||
|  |             system_prompt=self._get_summarizer_prompt(), | ||||||
|  |             model_name=self.model_name, | ||||||
|  |             max_loops=1,  | ||||||
|  |             saved_state_path=str(self.base_path / "summarizer.json"), | ||||||
|  |             verbose=True, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def _get_planner_prompt(self) -> str: | ||||||
|  |         """Get the prompt for the planner agent (Improved with few-shot examples).""" | ||||||
|  |         tool_descriptions = "\n".join( | ||||||
|  |             [ | ||||||
|  |                 f"- {tool_name}: {self.tools[tool_name].description}" | ||||||
|  |                 for tool_name in self.tools | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         return f"""You are the Planner in the OctoTools framework. Your role is to analyze the user's query, | ||||||
|  |         identify required skills, suggest relevant tools, and plan the steps to solve the problem. | ||||||
|  | 
 | ||||||
|  |         1. **Analyze the user's query:** Understand the requirements and identify the necessary skills and potentially relevant tools. | ||||||
|  |         2. **Perform high-level planning:**  Create a rough outline of how tools might be used to solve the problem. | ||||||
|  |         3. **Perform low-level planning (action prediction):**  At each step, select the best tool to use and formulate a specific sub-goal for that tool, considering the current context. | ||||||
|  | 
 | ||||||
|  |         Available Tools: | ||||||
|  |         {tool_descriptions} | ||||||
|  | 
 | ||||||
|  |         Output your response in JSON format.  Here are examples for different stages: | ||||||
|  | 
 | ||||||
|  |         **Query Analysis (High-Level Planning):** | ||||||
|  |         Example Input: | ||||||
|  |         Query: "What is the capital of France?" | ||||||
|  | 
 | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         {{ | ||||||
|  |             "summary": "The user is asking for the capital of France.", | ||||||
|  |             "required_skills": ["knowledge retrieval"], | ||||||
|  |             "relevant_tools": ["Web_Search_Tool"] | ||||||
|  |         }} | ||||||
|  |         ``` | ||||||
|  | 
 | ||||||
|  |         **Action Prediction (Low-Level Planning):** | ||||||
|  |         Example Input: | ||||||
|  |         Context: {{ "query": "What is the capital of France?", "available_tools": ["Web_Search_Tool"] }} | ||||||
|  | 
 | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         {{ | ||||||
|  |             "justification": "The Web_Search_Tool can be used to directly find the capital of France.", | ||||||
|  |             "context": {{}}, | ||||||
|  |             "sub_goal": "Search the web for 'capital of France'.", | ||||||
|  |             "tool_name": "Web_Search_Tool" | ||||||
|  |         }} | ||||||
|  |         ``` | ||||||
|  |         Another Example: | ||||||
|  |         Context: {{"query": "How many objects are in the image?", "available_tools": ["Image_Captioner_Tool", "Object_Detector_Tool"], "image": "objects.png"}} | ||||||
|  |          | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         {{ | ||||||
|  |             "justification": "First, get a general description of the image to understand the context.", | ||||||
|  |             "context": {{ "image": "objects.png" }}, | ||||||
|  |             "sub_goal": "Generate a description of the image.", | ||||||
|  |             "tool_name": "Image_Captioner_Tool" | ||||||
|  |         }} | ||||||
|  |         ``` | ||||||
|  |          | ||||||
|  |         Example for Finding Square Root: | ||||||
|  |         Context: {{"query": "What is the square root of the number of objects in the image?", "available_tools": ["Object_Detector_Tool", "Python_Calculator_Tool"], "image": "objects.png", "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"]}} | ||||||
|  |          | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         {{ | ||||||
|  |             "justification": "We have detected 4 objects in the image. Now we need to find the square root of 4.", | ||||||
|  |             "context": {{}}, | ||||||
|  |             "sub_goal": "Calculate the square root of 4", | ||||||
|  |             "tool_name": "Python_Calculator_Tool" | ||||||
|  |         }} | ||||||
|  |         ``` | ||||||
|  |          | ||||||
|  |         Your output MUST be a single, valid JSON object with the following keys: | ||||||
|  |             - justification (string): Your reasoning. | ||||||
|  |             - context (dict):  A dictionary containing relevant information. | ||||||
|  |             - sub_goal (string): The specific instruction for the tool. | ||||||
|  |             - tool_name (string): The EXACT name of the tool to use. | ||||||
|  | 
 | ||||||
|  |             Do NOT include any text outside of the JSON object. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |     def _get_verifier_prompt(self) -> str: | ||||||
|  |         """Get the prompt for the verifier agent (Improved with few-shot examples).""" | ||||||
|  |         return """You are the Context Verifier in the OctoTools framework. Your role is to analyze the current context | ||||||
|  |         and memory to determine if the problem is solved, if there are any inconsistencies, or if further steps are needed. | ||||||
|  | 
 | ||||||
|  |         Output your response in JSON format: | ||||||
|  |          | ||||||
|  |         Expected output structure: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "completeness": "Indicate whether the query is fully, partially, or not answered.", | ||||||
|  |             "inconsistencies": "List any inconsistencies found in the context or memory.", | ||||||
|  |             "verification_needs": "List any information that needs further verification.", | ||||||
|  |             "ambiguities": "List any ambiguities found in the context or memory.", | ||||||
|  |             "stop_signal": true/false | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  | 
 | ||||||
|  |         Example Input: | ||||||
|  |         Context: { "last_result": { "result": "Caption: The image shows a cat." } } | ||||||
|  |         Memory: [ { "component": "Action Predictor", "result": { "tool_name": "Image_Captioner_Tool" } } ] | ||||||
|  | 
 | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "completeness": "partial", | ||||||
|  |             "inconsistencies": [], | ||||||
|  |             "verification_needs": ["Object detection to confirm the presence of a cat."], | ||||||
|  |             "ambiguities": [], | ||||||
|  |             "stop_signal": false | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  | 
 | ||||||
|  |         Another Example: | ||||||
|  |         Context: { "last_result": { "result": ["Detected object: cat"] } } | ||||||
|  |         Memory:  [ { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } } ] | ||||||
|  |          | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "completeness": "yes", | ||||||
|  |             "inconsistencies": [], | ||||||
|  |             "verification_needs": [], | ||||||
|  |             "ambiguities": [], | ||||||
|  |             "stop_signal": true | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  |          | ||||||
|  |         Square Root Example: | ||||||
|  |         Context: {  | ||||||
|  |             "query": "What is the square root of the number of objects in the image?",  | ||||||
|  |             "image": "example.png", | ||||||
|  |             "Object_Detector_Tool_result": ["object1", "object2", "object3", "object4"], | ||||||
|  |             "Python_Calculator_Tool_result": "Result of 4**0.5 is 2.0" | ||||||
|  |         } | ||||||
|  |         Memory: [ | ||||||
|  |             { "component": "Action Predictor", "result": { "tool_name": "Object_Detector_Tool" } }, | ||||||
|  |             { "component": "Action Predictor", "result": { "tool_name": "Python_Calculator_Tool" } } | ||||||
|  |         ] | ||||||
|  |          | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "completeness": "yes", | ||||||
|  |             "inconsistencies": [], | ||||||
|  |             "verification_needs": [], | ||||||
|  |             "ambiguities": [], | ||||||
|  |             "stop_signal": true | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |     def _get_summarizer_prompt(self) -> str: | ||||||
|  |         """Get the prompt for the summarizer agent (Improved with few-shot examples).""" | ||||||
|  |         return """You are the Solution Summarizer in the OctoTools framework.  Your role is to synthesize the final | ||||||
|  |         answer to the user's query based on the complete trajectory of actions and results. | ||||||
|  | 
 | ||||||
|  |         Output your response in JSON format: | ||||||
|  | 
 | ||||||
|  |         Expected output structure: | ||||||
|  |          ```json | ||||||
|  |         { | ||||||
|  |             "final_answer": "Provide a clear and concise answer to the original query." | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  |         Example Input: | ||||||
|  |         Memory: [ | ||||||
|  |             {"component": "Query Analyzer", "result": {"summary": "Find the capital of France."}}, | ||||||
|  |             {"component": "Action Predictor", "result": {"tool_name": "Web_Search_Tool"}}, | ||||||
|  |             {"component": "Tool Execution", "result": {"result": "The capital of France is Paris."}} | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "final_answer": "The capital of France is Paris." | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  |          | ||||||
|  |         Square Root Example: | ||||||
|  |         Memory: [ | ||||||
|  |             {"component": "Query Analyzer", "result": {"summary": "Find the square root of the number of objects in the image."}}, | ||||||
|  |             {"component": "Action Predictor", "result": {"tool_name": "Object_Detector_Tool", "sub_goal": "Detect objects in the image"}}, | ||||||
|  |             {"component": "Tool Execution", "result": {"result": ["object1", "object2", "object3", "object4"]}}, | ||||||
|  |             {"component": "Action Predictor", "result": {"tool_name": "Python_Calculator_Tool", "sub_goal": "Calculate the square root of 4"}}, | ||||||
|  |             {"component": "Tool Execution", "result": {"result": "Result of 4**0.5 is 2.0"}} | ||||||
|  |         ] | ||||||
|  |          | ||||||
|  |         Example Output: | ||||||
|  |         ```json | ||||||
|  |         { | ||||||
|  |             "final_answer": "The square root of the number of objects in the image is 2.0. There are 4 objects in the image, and the square root of 4 is 2.0." | ||||||
|  |         } | ||||||
|  |         ``` | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |     def _safely_parse_json(self, json_str: str) -> Dict[str, Any]: | ||||||
|  |         """Safely parse JSON, handling errors and using recursive descent.""" | ||||||
|  |         try: | ||||||
|  |             return json.loads(json_str) | ||||||
|  |         except json.JSONDecodeError: | ||||||
|  |             logger.warning(f"JSONDecodeError: Attempting to extract JSON from: {json_str}") | ||||||
|  |             try: | ||||||
|  |                 # More robust JSON extraction with recursive descent | ||||||
|  |                 def extract_json(s): | ||||||
|  |                     stack = [] | ||||||
|  |                     start = -1 | ||||||
|  |                     for i, c in enumerate(s): | ||||||
|  |                         if c == '{': | ||||||
|  |                             if not stack: | ||||||
|  |                                 start = i | ||||||
|  |                             stack.append(c) | ||||||
|  |                         elif c == '}': | ||||||
|  |                             if stack: | ||||||
|  |                                 stack.pop() | ||||||
|  |                                 if not stack and start != -1: | ||||||
|  |                                     return s[start:i+1] | ||||||
|  |                     return None | ||||||
|  | 
 | ||||||
|  |                 extracted_json = extract_json(json_str) | ||||||
|  |                 if extracted_json: | ||||||
|  |                     logger.info(f"Extracted JSON: {extracted_json}") | ||||||
|  |                     return json.loads(extracted_json) | ||||||
|  |                 else: | ||||||
|  |                     logger.error("Failed to extract JSON using recursive descent.") | ||||||
|  |                     return {"error": "Failed to parse JSON", "content": json_str} | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.exception(f"Error during JSON extraction: {e}") | ||||||
|  |                 return {"error": "Failed to parse JSON", "content": json_str} | ||||||
|  | 
 | ||||||
|  |     def _execute_tool(self, tool_name: str, context: Dict[str, Any]) -> Dict[str, Any]: | ||||||
|  |         """Executes a tool based on its name and provided context.""" | ||||||
|  |         if tool_name not in self.tools: | ||||||
|  |             return {"error": f"Tool '{tool_name}' not found."} | ||||||
|  | 
 | ||||||
|  |         tool = self.tools[tool_name] | ||||||
|  |         try: | ||||||
|  |             # For Python Calculator tool, handle object counts from Object Detector | ||||||
|  |             if tool_name == "Python_Calculator_Tool": | ||||||
|  |                 # Check for object detector results | ||||||
|  |                 object_detector_result = context.get("Object_Detector_Tool_result") | ||||||
|  |                 if object_detector_result and isinstance(object_detector_result, list): | ||||||
|  |                     # Calculate the number of objects | ||||||
|  |                     num_objects = len(object_detector_result) | ||||||
|  |                     # If sub_goal doesn't already contain an expression, create one | ||||||
|  |                     if "sub_goal" in context and "Calculate the square root" in context["sub_goal"]: | ||||||
|  |                         context["expression"] = f"{num_objects}**0.5" | ||||||
|  |                     elif "expression" not in context: | ||||||
|  |                         # Default to square root if no expression is specified | ||||||
|  |                         context["expression"] = f"{num_objects}**0.5" | ||||||
|  |                  | ||||||
|  |             # Filter context: only pass expected inputs to the tool | ||||||
|  |             valid_inputs = { | ||||||
|  |                 k: v for k, v in context.items() if k in tool.metadata.get("input_types", {}) | ||||||
|  |             } | ||||||
|  |             result = tool.execute(**valid_inputs) | ||||||
|  |             return {"result": result} | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.exception(f"Error executing tool {tool_name}: {e}") | ||||||
|  |             return {"error": str(e)} | ||||||
|  | 
 | ||||||
|  |     def _run_agent(self, agent: Agent, input_prompt: str) -> Dict[str, Any]: | ||||||
|  |         """Runs a swarms agent, handling output and JSON parsing.""" | ||||||
|  |         try: | ||||||
|  |             # Construct the full input, including the system prompt | ||||||
|  |             full_input = f"{agent.system_prompt}\n\n{input_prompt}" | ||||||
|  | 
 | ||||||
|  |             # Run the agent and capture the output | ||||||
|  |             agent_response = agent.run(full_input) | ||||||
|  | 
 | ||||||
|  |             logger.info(f"DEBUG: Raw agent response: {agent_response}") | ||||||
|  | 
 | ||||||
|  |             # Extract the LLM's response (remove conversation history, etc.) | ||||||
|  |             response_text = agent_response # Assuming direct return | ||||||
|  | 
 | ||||||
|  |             # Try to parse the response as JSON | ||||||
|  |             parsed_response = self._safely_parse_json(response_text) | ||||||
|  | 
 | ||||||
|  |             return parsed_response | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.exception(f"Error running agent {agent.agent_name}: {e}") | ||||||
|  |             return {"error": f"Agent {agent.agent_name} failed: {str(e)}"} | ||||||
|  | 
 | ||||||
|  |     def run(self, query: str, image: Optional[str] = None) -> Dict[str, Any]: | ||||||
|  |         """Execute the task through the multi-agent workflow.""" | ||||||
|  |         logger.info(f"Starting task: {query}") | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             # Step 1: Query Analysis (High-Level Planning) | ||||||
|  |             planner_input = ( | ||||||
|  |                 f"Analyze the following query and determine the necessary skills and" | ||||||
|  |                 f" relevant tools: {query}" | ||||||
|  |             ) | ||||||
|  |             query_analysis = self._run_agent(self.planner, planner_input) | ||||||
|  | 
 | ||||||
|  |             if "error" in query_analysis: | ||||||
|  |                 return { | ||||||
|  |                     "error": f"Planner query analysis failed: {query_analysis['error']}", | ||||||
|  |                     "trajectory": self.memory, | ||||||
|  |                     "conversation": self.conversation.return_history_as_string(), | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |             self.memory.append( | ||||||
|  |                 {"step": 0, "component": "Query Analyzer", "result": query_analysis} | ||||||
|  |             ) | ||||||
|  |             self.conversation.add( | ||||||
|  |                 role=self.planner.agent_name, content=json.dumps(query_analysis) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             # Initialize context with the query and image (if provided) | ||||||
|  |             context = {"query": query} | ||||||
|  |             if image: | ||||||
|  |                 context["image"] = image | ||||||
|  |                  | ||||||
|  |             # Add available tools to context | ||||||
|  |             if "relevant_tools" in query_analysis: | ||||||
|  |                 context["available_tools"] = query_analysis["relevant_tools"] | ||||||
|  |             else: | ||||||
|  |                 # If no relevant tools specified, make all tools available | ||||||
|  |                 context["available_tools"] = list(self.tools.keys()) | ||||||
|  | 
 | ||||||
|  |             step_count = 1 | ||||||
|  | 
 | ||||||
|  |             # Step 2: Iterative Action-Execution Loop | ||||||
|  |             while step_count <= self.max_iterations: | ||||||
|  |                 logger.info(f"Starting iteration {step_count} of {self.max_iterations}") | ||||||
|  |                  | ||||||
|  |                 # Step 2a: Action Prediction (Low-Level Planning) | ||||||
|  |                 action_planner_input = ( | ||||||
|  |                     f"Current Context: {json.dumps(context)}\nAvailable Tools:" | ||||||
|  |                     f" {', '.join(context.get('available_tools', list(self.tools.keys())))}\nPlan the" | ||||||
|  |                     " next step." | ||||||
|  |                 ) | ||||||
|  |                 action = self._run_agent(self.planner, action_planner_input) | ||||||
|  |                 if "error" in action: | ||||||
|  |                     logger.error(f"Error in action prediction: {action['error']}") | ||||||
|  |                     return { | ||||||
|  |                         "error": f"Planner action prediction failed: {action['error']}", | ||||||
|  |                         "trajectory": self.memory, | ||||||
|  |                         "conversation": self.conversation.return_history_as_string() | ||||||
|  |                     } | ||||||
|  |                 self.memory.append( | ||||||
|  |                     {"step": step_count, "component": "Action Predictor", "result": action} | ||||||
|  |                 ) | ||||||
|  |                 self.conversation.add(role=self.planner.agent_name, content=json.dumps(action)) | ||||||
|  | 
 | ||||||
|  |                 # Input Validation for Action (Relaxed) | ||||||
|  |                 if not isinstance(action, dict) or "tool_name" not in action or "sub_goal" not in action: | ||||||
|  |                     error_msg = ( | ||||||
|  |                         "Action prediction did not return required fields (tool_name," | ||||||
|  |                         " sub_goal) or was not a dictionary." | ||||||
|  |                     ) | ||||||
|  |                     logger.error(error_msg) | ||||||
|  |                     self.memory.append( | ||||||
|  |                         {"step": step_count, "component": "Error", "result": error_msg} | ||||||
|  |                     ) | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |                 # Step 2b: Execute Tool | ||||||
|  |                 tool_execution_context = { | ||||||
|  |                     **context, | ||||||
|  |                     **action.get("context", {}),  # Add any additional context | ||||||
|  |                     "sub_goal": action["sub_goal"],  # Pass sub_goal to tool | ||||||
|  |                 } | ||||||
|  |                  | ||||||
|  |                 tool_result = self._execute_tool(action["tool_name"], tool_execution_context) | ||||||
|  |                  | ||||||
|  |                 self.memory.append( | ||||||
|  |                     { | ||||||
|  |                         "step": step_count, | ||||||
|  |                         "component": "Tool Execution", | ||||||
|  |                         "result": tool_result, | ||||||
|  |                     } | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 # Step 2c: Context Update - Store result with a descriptive key | ||||||
|  |                 if "result" in tool_result: | ||||||
|  |                     context[f"{action['tool_name']}_result"] = tool_result["result"] | ||||||
|  |                 if "error" in tool_result: | ||||||
|  |                     context[f"{action['tool_name']}_error"] = tool_result["error"] | ||||||
|  | 
 | ||||||
|  |                 # Step 2d: Context Verification | ||||||
|  |                 verifier_input = ( | ||||||
|  |                     f"Current Context: {json.dumps(context)}\nMemory:" | ||||||
|  |                     f" {json.dumps(self.memory)}\nQuery: {query}" | ||||||
|  |                 ) | ||||||
|  |                 verification = self._run_agent(self.verifier, verifier_input) | ||||||
|  |                 if "error" in verification: | ||||||
|  |                     return { | ||||||
|  |                         "error": f"Verifier failed: {verification['error']}", | ||||||
|  |                         "trajectory": self.memory, | ||||||
|  |                         "conversation": self.conversation.return_history_as_string(), | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                 self.memory.append( | ||||||
|  |                     { | ||||||
|  |                         "step": step_count, | ||||||
|  |                         "component": "Context Verifier", | ||||||
|  |                         "result": verification, | ||||||
|  |                     } | ||||||
|  |                 ) | ||||||
|  |                 self.conversation.add(role=self.verifier.agent_name, content=json.dumps(verification)) | ||||||
|  | 
 | ||||||
|  |                 # Check for stop signal from Verifier | ||||||
|  |                 if verification.get("stop_signal") is True: | ||||||
|  |                     logger.info("Received stop signal from verifier. Stopping iterations.") | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |                 # Safety mechanism - if we've executed the same tool multiple times | ||||||
|  |                 same_tool_count = sum( | ||||||
|  |                     1 for m in self.memory  | ||||||
|  |                     if m.get("component") == "Action Predictor"  | ||||||
|  |                     and m.get("result", {}).get("tool_name") == action.get("tool_name") | ||||||
|  |                 ) | ||||||
|  |                  | ||||||
|  |                 if same_tool_count > 3: | ||||||
|  |                     logger.warning(f"Tool {action.get('tool_name')} used more than 3 times. Forcing stop.") | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|  |                 step_count += 1 | ||||||
|  | 
 | ||||||
|  |             # Step 3: Solution Summarization | ||||||
|  |             summarizer_input = f"Complete Trajectory: {json.dumps(self.memory)}\nOriginal Query: {query}" | ||||||
|  | 
 | ||||||
|  |             summarization = self._run_agent(self.summarizer, summarizer_input) | ||||||
|  |             if "error" in summarization: | ||||||
|  |                 return { | ||||||
|  |                     "error": f"Summarizer failed: {summarization['error']}", | ||||||
|  |                     "trajectory": self.memory, | ||||||
|  |                     "conversation": self.conversation.return_history_as_string() | ||||||
|  |                 } | ||||||
|  |             self.conversation.add(role=self.summarizer.agent_name, content=json.dumps(summarization)) | ||||||
|  | 
 | ||||||
|  |             return { | ||||||
|  |                 "final_answer": summarization.get("final_answer", "No answer found."), | ||||||
|  |                 "trajectory": self.memory, | ||||||
|  |                 "conversation": self.conversation.return_history_as_string(), | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.exception(f"Unexpected error in run method: {e}")  # More detailed | ||||||
|  |             return { | ||||||
|  |                 "error": str(e), | ||||||
|  |                 "trajectory": self.memory, | ||||||
|  |                 "conversation": self.conversation.return_history_as_string(), | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |     def save_state(self) -> None: | ||||||
|  |         """Save the current state of all agents.""" | ||||||
|  |         for agent in [self.planner, self.verifier, self.summarizer]: | ||||||
|  |             try: | ||||||
|  |                 agent.save_state() | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error saving state for {agent.agent_name}: {str(e)}") | ||||||
|  | 
 | ||||||
|  |     def load_state(self) -> None: | ||||||
|  |         """Load the saved state of all agents.""" | ||||||
|  |         for agent in [self.planner, self.verifier, self.summarizer]: | ||||||
|  |             try: | ||||||
|  |                 agent.load_state() | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error(f"Error loading state for {agent.agent_name}: {str(e)}") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # --- Example Usage --- | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Define dummy tool functions (replace with actual implementations) | ||||||
|  | def image_captioner_execute(image: str, prompt: str = "Describe the image", **kwargs) -> str: | ||||||
|  |     """Dummy image captioner.""" | ||||||
|  |     print(f"image_captioner_execute called with image: {image}, prompt: {prompt}") | ||||||
|  |     return f"Caption for {image}: A descriptive caption (dummy)."  # Simplified | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def object_detector_execute(image: str, labels: List[str] = [], **kwargs) -> List[str]: | ||||||
|  |     """Dummy object detector, handles missing labels gracefully.""" | ||||||
|  |     print(f"object_detector_execute called with image: {image}, labels: {labels}") | ||||||
|  |     if not labels: | ||||||
|  |         return ["object1", "object2", "object3", "object4"]  # Return default objects if no labels | ||||||
|  |     return [f"Detected {label}" for label in labels]  # Simplified | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def web_search_execute(query: str, **kwargs) -> str: | ||||||
|  |     """Dummy web search.""" | ||||||
|  |     print(f"web_search_execute called with query: {query}") | ||||||
|  |     return f"Search results for '{query}'..."  # Simplified | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def python_calculator_execute(expression: str, **kwargs) -> str: | ||||||
|  |     """Python calculator (using math module).""" | ||||||
|  |     print(f"python_calculator_execute called with: {expression}") | ||||||
|  |     try: | ||||||
|  |         # Safely evaluate only simple expressions involving numbers and basic operations | ||||||
|  |         if re.match(r"^[0-9+\-*/().\s]+$", expression): | ||||||
|  |             result = eval(expression, {"__builtins__": {}, "math": math}) | ||||||
|  |             return f"Result of {expression} is {result}" | ||||||
|  |         else: | ||||||
|  |             return "Error: Invalid expression for calculator." | ||||||
|  |     except Exception as e: | ||||||
|  |         return f"Error: {e}" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Create utility function to get default tools | ||||||
|  | def get_default_tools() -> List[Tool]: | ||||||
|  |     """Returns a list of default tools that can be used with OctoToolsSwarm.""" | ||||||
|  |     image_captioner = Tool( | ||||||
|  |         name="Image_Captioner_Tool", | ||||||
|  |         description="Generates a caption for an image.", | ||||||
|  |         metadata={ | ||||||
|  |             "input_types": {"image": "str", "prompt": "str"}, | ||||||
|  |             "output_type": "str", | ||||||
|  |             "limitations": "May struggle with complex scenes or ambiguous objects.", | ||||||
|  |             "best_practices": "Use with clear, well-lit images. Provide specific prompts for better results.", | ||||||
|  |         }, | ||||||
|  |         execute_func=image_captioner_execute, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     object_detector = Tool( | ||||||
|  |         name="Object_Detector_Tool", | ||||||
|  |         description="Detects objects in an image.", | ||||||
|  |         metadata={ | ||||||
|  |             "input_types": {"image": "str", "labels": "list"}, | ||||||
|  |             "output_type": "list", | ||||||
|  |             "limitations": "Accuracy depends on the quality of the image and the clarity of the objects.", | ||||||
|  |             "best_practices": "Provide a list of specific object labels to detect. Use high-resolution images.", | ||||||
|  |         }, | ||||||
|  |         execute_func=object_detector_execute, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     web_search = Tool( | ||||||
|  |         name="Web_Search_Tool", | ||||||
|  |         description="Performs a web search.", | ||||||
|  |         metadata={ | ||||||
|  |             "input_types": {"query": "str"},  | ||||||
|  |             "output_type": "str", | ||||||
|  |             "limitations": "May not find specific or niche information.", | ||||||
|  |             "best_practices": "Use specific and descriptive keywords for better results.", | ||||||
|  |         }, | ||||||
|  |         execute_func=web_search_execute, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     calculator = Tool( | ||||||
|  |         name="Python_Calculator_Tool", | ||||||
|  |         description="Evaluates a Python expression.", | ||||||
|  |         metadata={ | ||||||
|  |             "input_types": {"expression": "str"},  | ||||||
|  |             "output_type": "str", | ||||||
|  |             "limitations": "Cannot handle complex mathematical functions or libraries.", | ||||||
|  |             "best_practices": "Use for basic arithmetic and simple calculations.", | ||||||
|  |         }, | ||||||
|  |         execute_func=python_calculator_execute, | ||||||
|  |     ) | ||||||
|  |      | ||||||
|  |     return [image_captioner, object_detector, web_search, calculator] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Only execute the example when this script is run directly | ||||||
|  | # if __name__ == "__main__": | ||||||
|  | #     print("Running OctoToolsSwarm example...") | ||||||
|  |      | ||||||
|  | #     # Create an OctoToolsSwarm agent with default tools | ||||||
|  | #     tools = get_default_tools() | ||||||
|  | #     agent = OctoToolsSwarm(tools=tools) | ||||||
|  | 
 | ||||||
|  | #     # Example query | ||||||
|  | #     query = "What is the square root of the number of objects in this image?" | ||||||
|  |      | ||||||
|  | #     # Create a dummy image file for testing if it doesn't exist | ||||||
|  | #     image_path = "example.png" | ||||||
|  | #     if not os.path.exists(image_path): | ||||||
|  | #         with open(image_path, "w") as f: | ||||||
|  | #             f.write("Dummy image content") | ||||||
|  | #         print(f"Created dummy image file: {image_path}") | ||||||
|  |      | ||||||
|  | #     # Run the agent | ||||||
|  | #     result = agent.run(query, image=image_path) | ||||||
|  | 
 | ||||||
|  | #     # Display results | ||||||
|  | #     print("\n=== FINAL ANSWER ===") | ||||||
|  | #     print(result["final_answer"]) | ||||||
|  |      | ||||||
|  | #     print("\n=== TRAJECTORY SUMMARY ===") | ||||||
|  | #     for step in result["trajectory"]: | ||||||
|  | #         print(f"Step {step.get('step', 'N/A')}: {step.get('component', 'Unknown')}") | ||||||
|  |      | ||||||
|  | #     print("\nOctoToolsSwarm example completed.") | ||||||
| @ -1,580 +0,0 @@ | |||||||
| """ |  | ||||||
| TalkHier: A hierarchical multi-agent framework for content generation and refinement. |  | ||||||
| Implements structured communication and evaluation protocols. |  | ||||||
| """ |  | ||||||
| 
 |  | ||||||
| import json |  | ||||||
| import logging |  | ||||||
| from dataclasses import dataclass |  | ||||||
| from datetime import datetime |  | ||||||
| from enum import Enum |  | ||||||
| from pathlib import Path |  | ||||||
| from typing import Any, Dict, List, Optional, Union |  | ||||||
| 
 |  | ||||||
| from swarms import Agent |  | ||||||
| from swarms.structs.conversation import Conversation |  | ||||||
| 
 |  | ||||||
| logging.basicConfig(level=logging.INFO) |  | ||||||
| logger = logging.getLogger(__name__) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class AgentRole(Enum): |  | ||||||
|     """Defines the possible roles for agents in the system.""" |  | ||||||
| 
 |  | ||||||
|     SUPERVISOR = "supervisor" |  | ||||||
|     GENERATOR = "generator" |  | ||||||
|     EVALUATOR = "evaluator" |  | ||||||
|     REVISOR = "revisor" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class CommunicationEvent: |  | ||||||
|     """Represents a structured communication event between agents.""" |  | ||||||
| 
 |  | ||||||
|     message: str |  | ||||||
|     background: Optional[str] = None |  | ||||||
|     intermediate_output: Optional[Dict[str, Any]] = None |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class TalkHier: |  | ||||||
|     """ |  | ||||||
|     A hierarchical multi-agent system for content generation and refinement. |  | ||||||
| 
 |  | ||||||
|     Implements the TalkHier framework with structured communication protocols |  | ||||||
|     and hierarchical refinement processes. |  | ||||||
| 
 |  | ||||||
|     Attributes: |  | ||||||
|         max_iterations: Maximum number of refinement iterations |  | ||||||
|         quality_threshold: Minimum score required for content acceptance |  | ||||||
|         model_name: Name of the LLM model to use |  | ||||||
|         base_path: Path for saving agent states |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         max_iterations: int = 3, |  | ||||||
|         quality_threshold: float = 0.8, |  | ||||||
|         model_name: str = "gpt-4", |  | ||||||
|         base_path: Optional[str] = None, |  | ||||||
|         return_string: bool = False, |  | ||||||
|     ): |  | ||||||
|         """Initialize the TalkHier system.""" |  | ||||||
|         self.max_iterations = max_iterations |  | ||||||
|         self.quality_threshold = quality_threshold |  | ||||||
|         self.model_name = model_name |  | ||||||
|         self.return_string = return_string |  | ||||||
|         self.base_path = ( |  | ||||||
|             Path(base_path) if base_path else Path("./agent_states") |  | ||||||
|         ) |  | ||||||
|         self.base_path.mkdir(exist_ok=True) |  | ||||||
| 
 |  | ||||||
|         # Initialize agents |  | ||||||
|         self._init_agents() |  | ||||||
| 
 |  | ||||||
|         # Create conversation |  | ||||||
|         self.conversation = Conversation() |  | ||||||
| 
 |  | ||||||
|     def _safely_parse_json(self, json_str: str) -> Dict[str, Any]: |  | ||||||
|         """ |  | ||||||
|         Safely parse JSON string, handling various formats and potential errors. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             json_str: String containing JSON data |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             Parsed dictionary |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             # Try direct JSON parsing |  | ||||||
|             return json.loads(json_str) |  | ||||||
|         except json.JSONDecodeError: |  | ||||||
|             try: |  | ||||||
|                 # Try extracting JSON from potential text wrapper |  | ||||||
|                 import re |  | ||||||
| 
 |  | ||||||
|                 json_match = re.search(r"\{.*\}", json_str, re.DOTALL) |  | ||||||
|                 if json_match: |  | ||||||
|                     return json.loads(json_match.group()) |  | ||||||
|                 # Try extracting from markdown code blocks |  | ||||||
|                 code_block_match = re.search( |  | ||||||
|                     r"```(?:json)?\s*(\{.*?\})\s*```", |  | ||||||
|                     json_str, |  | ||||||
|                     re.DOTALL, |  | ||||||
|                 ) |  | ||||||
|                 if code_block_match: |  | ||||||
|                     return json.loads(code_block_match.group(1)) |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.warning(f"Failed to extract JSON: {str(e)}") |  | ||||||
| 
 |  | ||||||
|             # Fallback: create structured dict from text |  | ||||||
|             return { |  | ||||||
|                 "content": json_str, |  | ||||||
|                 "metadata": { |  | ||||||
|                     "parsed": False, |  | ||||||
|                     "timestamp": str(datetime.now()), |  | ||||||
|                 }, |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|     def _init_agents(self) -> None: |  | ||||||
|         """Initialize all agents with their specific roles and prompts.""" |  | ||||||
|         # Main supervisor agent |  | ||||||
|         self.main_supervisor = Agent( |  | ||||||
|             agent_name="Main-Supervisor", |  | ||||||
|             system_prompt=self._get_supervisor_prompt(), |  | ||||||
|             model_name=self.model_name, |  | ||||||
|             max_loops=1, |  | ||||||
|             saved_state_path=str( |  | ||||||
|                 self.base_path / "main_supervisor.json" |  | ||||||
|             ), |  | ||||||
|             verbose=True, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # Generator agent |  | ||||||
|         self.generator = Agent( |  | ||||||
|             agent_name="Content-Generator", |  | ||||||
|             system_prompt=self._get_generator_prompt(), |  | ||||||
|             model_name=self.model_name, |  | ||||||
|             max_loops=1, |  | ||||||
|             saved_state_path=str(self.base_path / "generator.json"), |  | ||||||
|             verbose=True, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # Evaluators |  | ||||||
|         self.evaluators = [ |  | ||||||
|             Agent( |  | ||||||
|                 agent_name=f"Evaluator-{i}", |  | ||||||
|                 system_prompt=self._get_evaluator_prompt(i), |  | ||||||
|                 model_name=self.model_name, |  | ||||||
|                 max_loops=1, |  | ||||||
|                 saved_state_path=str( |  | ||||||
|                     self.base_path / f"evaluator_{i}.json" |  | ||||||
|                 ), |  | ||||||
|                 verbose=True, |  | ||||||
|             ) |  | ||||||
|             for i in range(3) |  | ||||||
|         ] |  | ||||||
| 
 |  | ||||||
|         # Revisor agent |  | ||||||
|         self.revisor = Agent( |  | ||||||
|             agent_name="Content-Revisor", |  | ||||||
|             system_prompt=self._get_revisor_prompt(), |  | ||||||
|             model_name=self.model_name, |  | ||||||
|             max_loops=1, |  | ||||||
|             saved_state_path=str(self.base_path / "revisor.json"), |  | ||||||
|             verbose=True, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     def _get_supervisor_prompt(self) -> str: |  | ||||||
|         """Get the prompt for the supervisor agent.""" |  | ||||||
|         return """You are a Supervisor agent responsible for orchestrating the content generation process. Your role is to analyze tasks, develop strategies, and coordinate other agents effectively. |  | ||||||
| 
 |  | ||||||
| You must carefully analyze each task to understand: |  | ||||||
| - The core objectives and requirements |  | ||||||
| - Target audience and their needs |  | ||||||
| - Complexity level and scope |  | ||||||
| - Any constraints or special considerations |  | ||||||
| 
 |  | ||||||
| Based on your analysis, develop a clear strategy that: |  | ||||||
| - Breaks down the task into manageable steps |  | ||||||
| - Identifies which agents are best suited for each step |  | ||||||
| - Anticipates potential challenges |  | ||||||
| - Sets clear success criteria |  | ||||||
| 
 |  | ||||||
| Output all responses in strict JSON format: |  | ||||||
| { |  | ||||||
|     "thoughts": { |  | ||||||
|         "task_analysis": "Detailed analysis of requirements, audience, scope, and constraints", |  | ||||||
|         "strategy": "Step-by-step plan including agent allocation and success metrics", |  | ||||||
|         "concerns": "Potential challenges, edge cases, and mitigation strategies" |  | ||||||
|     }, |  | ||||||
|     "next_action": { |  | ||||||
|         "agent": "Specific agent to engage (Generator, Evaluator, or Revisor)", |  | ||||||
|         "instruction": "Detailed instructions including context, requirements, and expected output" |  | ||||||
|     } |  | ||||||
| }""" |  | ||||||
| 
 |  | ||||||
|     def _get_generator_prompt(self) -> str: |  | ||||||
|         """Get the prompt for the generator agent.""" |  | ||||||
|         return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience. |  | ||||||
| 
 |  | ||||||
| When generating content: |  | ||||||
| - Thoroughly research and fact-check all information |  | ||||||
| - Structure content logically with clear flow |  | ||||||
| - Use appropriate tone and language for the target audience |  | ||||||
| - Include relevant examples and explanations |  | ||||||
| - Ensure content is original and plagiarism-free |  | ||||||
| - Consider SEO best practices where applicable |  | ||||||
| 
 |  | ||||||
| Output all responses in strict JSON format: |  | ||||||
| { |  | ||||||
|     "content": { |  | ||||||
|         "main_body": "The complete generated content with proper formatting and structure", |  | ||||||
|         "metadata": { |  | ||||||
|             "word_count": "Accurate word count of main body", |  | ||||||
|             "target_audience": "Detailed audience description", |  | ||||||
|             "key_points": ["List of main points covered"], |  | ||||||
|             "sources": ["List of reference sources if applicable"], |  | ||||||
|             "readability_level": "Estimated reading level", |  | ||||||
|             "tone": "Description of content tone" |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| }""" |  | ||||||
| 
 |  | ||||||
|     def _get_evaluator_prompt(self, evaluator_id: int) -> str: |  | ||||||
|         """Get the prompt for an evaluator agent.""" |  | ||||||
|         return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive. |  | ||||||
| 
 |  | ||||||
| Evaluate content across multiple dimensions: |  | ||||||
| - Accuracy: factual correctness, source reliability |  | ||||||
| - Clarity: readability, organization, flow |  | ||||||
| - Coherence: logical consistency, argument structure |  | ||||||
| - Engagement: interest level, relevance |  | ||||||
| - Completeness: topic coverage, depth |  | ||||||
| - Technical quality: grammar, spelling, formatting |  | ||||||
| - Audience alignment: appropriate level and tone |  | ||||||
| 
 |  | ||||||
| Output all responses in strict JSON format: |  | ||||||
| {{ |  | ||||||
|     "scores": {{ |  | ||||||
|         "overall": "0.0-1.0 composite score", |  | ||||||
|         "categories": {{ |  | ||||||
|             "accuracy": "0.0-1.0 score with evidence", |  | ||||||
|             "clarity": "0.0-1.0 score with examples", |  | ||||||
|             "coherence": "0.0-1.0 score with analysis", |  | ||||||
|             "engagement": "0.0-1.0 score with justification", |  | ||||||
|             "completeness": "0.0-1.0 score with gaps identified", |  | ||||||
|             "technical_quality": "0.0-1.0 score with issues noted", |  | ||||||
|             "audience_alignment": "0.0-1.0 score with reasoning" |  | ||||||
|         }} |  | ||||||
|     }}, |  | ||||||
|     "feedback": [ |  | ||||||
|         "Specific, actionable improvement suggestions", |  | ||||||
|         "Examples of issues found", |  | ||||||
|         "Recommendations for enhancement" |  | ||||||
|     ], |  | ||||||
|     "strengths": ["Notable positive aspects"], |  | ||||||
|     "weaknesses": ["Areas needing improvement"] |  | ||||||
| }}""" |  | ||||||
| 
 |  | ||||||
|     def _get_revisor_prompt(self) -> str: |  | ||||||
|         """Get the prompt for the revisor agent.""" |  | ||||||
|         return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose. |  | ||||||
| 
 |  | ||||||
| When revising content: |  | ||||||
| - Address all evaluator feedback systematically |  | ||||||
| - Maintain consistency in tone and style |  | ||||||
| - Preserve accurate information |  | ||||||
| - Enhance clarity and flow |  | ||||||
| - Fix technical issues |  | ||||||
| - Optimize for target audience |  | ||||||
| - Track all changes made |  | ||||||
| 
 |  | ||||||
| Output all responses in strict JSON format: |  | ||||||
| { |  | ||||||
|     "revised_content": { |  | ||||||
|         "main_body": "Complete revised content incorporating all improvements", |  | ||||||
|         "metadata": { |  | ||||||
|             "word_count": "Updated word count", |  | ||||||
|             "changes_made": [ |  | ||||||
|                 "Detailed list of specific changes and improvements", |  | ||||||
|                 "Reasoning for each major revision", |  | ||||||
|                 "Feedback points addressed" |  | ||||||
|             ], |  | ||||||
|             "improvement_summary": "Overview of main enhancements", |  | ||||||
|             "preserved_elements": ["Key elements maintained from original"], |  | ||||||
|             "revision_approach": "Strategy used for revisions" |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| }""" |  | ||||||
| 
 |  | ||||||
|     def _evaluate_content( |  | ||||||
|         self, content: Union[str, Dict] |  | ||||||
|     ) -> Dict[str, Any]: |  | ||||||
|         """ |  | ||||||
|         Coordinate the evaluation of content across multiple evaluators. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             content: Content to evaluate (string or dict) |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             Combined evaluation results |  | ||||||
|         """ |  | ||||||
|         try: |  | ||||||
|             # Ensure content is in correct format |  | ||||||
|             content_dict = ( |  | ||||||
|                 self._safely_parse_json(content) |  | ||||||
|                 if isinstance(content, str) |  | ||||||
|                 else content |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             # Collect evaluations |  | ||||||
|             evaluations = [] |  | ||||||
|             for evaluator in self.evaluators: |  | ||||||
|                 try: |  | ||||||
|                     eval_response = evaluator.run( |  | ||||||
|                         json.dumps(content_dict) |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|                     self.conversation.add( |  | ||||||
|                         role=evaluator.agent_name, |  | ||||||
|                         content=eval_response, |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|                     eval_data = self._safely_parse_json(eval_response) |  | ||||||
|                     evaluations.append(eval_data) |  | ||||||
|                 except Exception as e: |  | ||||||
|                     logger.warning(f"Evaluator error: {str(e)}") |  | ||||||
|                     evaluations.append( |  | ||||||
|                         self._get_fallback_evaluation() |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|             # Aggregate results |  | ||||||
|             return self._aggregate_evaluations(evaluations) |  | ||||||
| 
 |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.error(f"Evaluation error: {str(e)}") |  | ||||||
|             return self._get_fallback_evaluation() |  | ||||||
| 
 |  | ||||||
|     def _get_fallback_evaluation(self) -> Dict[str, Any]: |  | ||||||
|         """Get a safe fallback evaluation result.""" |  | ||||||
|         return { |  | ||||||
|             "scores": { |  | ||||||
|                 "overall": 0.5, |  | ||||||
|                 "categories": { |  | ||||||
|                     "accuracy": 0.5, |  | ||||||
|                     "clarity": 0.5, |  | ||||||
|                     "coherence": 0.5, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|             "feedback": ["Evaluation failed"], |  | ||||||
|             "metadata": { |  | ||||||
|                 "timestamp": str(datetime.now()), |  | ||||||
|                 "is_fallback": True, |  | ||||||
|             }, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     def _aggregate_evaluations( |  | ||||||
|         self, evaluations: List[Dict[str, Any]] |  | ||||||
|     ) -> Dict[str, Any]: |  | ||||||
|         """ |  | ||||||
|         Aggregate multiple evaluation results into a single evaluation. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             evaluations: List of evaluation results |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             Combined evaluation |  | ||||||
|         """ |  | ||||||
|         # Calculate average scores |  | ||||||
|         overall_scores = [] |  | ||||||
|         accuracy_scores = [] |  | ||||||
|         clarity_scores = [] |  | ||||||
|         coherence_scores = [] |  | ||||||
|         all_feedback = [] |  | ||||||
| 
 |  | ||||||
|         for eval_data in evaluations: |  | ||||||
|             try: |  | ||||||
|                 scores = eval_data.get("scores", {}) |  | ||||||
|                 overall_scores.append(scores.get("overall", 0.5)) |  | ||||||
| 
 |  | ||||||
|                 categories = scores.get("categories", {}) |  | ||||||
|                 accuracy_scores.append( |  | ||||||
|                     categories.get("accuracy", 0.5) |  | ||||||
|                 ) |  | ||||||
|                 clarity_scores.append(categories.get("clarity", 0.5)) |  | ||||||
|                 coherence_scores.append( |  | ||||||
|                     categories.get("coherence", 0.5) |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 all_feedback.extend(eval_data.get("feedback", [])) |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.warning( |  | ||||||
|                     f"Error aggregating evaluation: {str(e)}" |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|         def safe_mean(scores: List[float]) -> float: |  | ||||||
|             return sum(scores) / len(scores) if scores else 0.5 |  | ||||||
| 
 |  | ||||||
|         return { |  | ||||||
|             "scores": { |  | ||||||
|                 "overall": safe_mean(overall_scores), |  | ||||||
|                 "categories": { |  | ||||||
|                     "accuracy": safe_mean(accuracy_scores), |  | ||||||
|                     "clarity": safe_mean(clarity_scores), |  | ||||||
|                     "coherence": safe_mean(coherence_scores), |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|             "feedback": list(set(all_feedback)),  # Remove duplicates |  | ||||||
|             "metadata": { |  | ||||||
|                 "evaluator_count": len(evaluations), |  | ||||||
|                 "timestamp": str(datetime.now()), |  | ||||||
|             }, |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     def run(self, task: str) -> Dict[str, Any]: |  | ||||||
|         """ |  | ||||||
|         Generate and iteratively refine content based on the given task. |  | ||||||
| 
 |  | ||||||
|         Args: |  | ||||||
|             task: Content generation task description |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             Dictionary containing final content and metadata |  | ||||||
|         """ |  | ||||||
|         logger.info(f"Starting content generation for task: {task}") |  | ||||||
| 
 |  | ||||||
|         try: |  | ||||||
|             # Get initial direction from supervisor |  | ||||||
|             supervisor_response = self.main_supervisor.run(task) |  | ||||||
| 
 |  | ||||||
|             self.conversation.add( |  | ||||||
|                 role=self.main_supervisor.agent_name, |  | ||||||
|                 content=supervisor_response, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             supervisor_data = self._safely_parse_json( |  | ||||||
|                 supervisor_response |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             # Generate initial content |  | ||||||
|             generator_response = self.generator.run( |  | ||||||
|                 json.dumps(supervisor_data.get("next_action", {})) |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             self.conversation.add( |  | ||||||
|                 role=self.generator.agent_name, |  | ||||||
|                 content=generator_response, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             current_content = self._safely_parse_json( |  | ||||||
|                 generator_response |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|             for iteration in range(self.max_iterations): |  | ||||||
|                 logger.info(f"Starting iteration {iteration + 1}") |  | ||||||
| 
 |  | ||||||
|                 # Evaluate current content |  | ||||||
|                 evaluation = self._evaluate_content(current_content) |  | ||||||
| 
 |  | ||||||
|                 # Check if quality threshold is met |  | ||||||
|                 if ( |  | ||||||
|                     evaluation["scores"]["overall"] |  | ||||||
|                     >= self.quality_threshold |  | ||||||
|                 ): |  | ||||||
|                     logger.info( |  | ||||||
|                         "Quality threshold met, returning content" |  | ||||||
|                     ) |  | ||||||
|                     return { |  | ||||||
|                         "content": current_content.get( |  | ||||||
|                             "content", {} |  | ||||||
|                         ).get("main_body", ""), |  | ||||||
|                         "final_score": evaluation["scores"][ |  | ||||||
|                             "overall" |  | ||||||
|                         ], |  | ||||||
|                         "iterations": iteration + 1, |  | ||||||
|                         "metadata": { |  | ||||||
|                             "content_metadata": current_content.get( |  | ||||||
|                                 "content", {} |  | ||||||
|                             ).get("metadata", {}), |  | ||||||
|                             "evaluation": evaluation, |  | ||||||
|                         }, |  | ||||||
|                     } |  | ||||||
| 
 |  | ||||||
|                 # Revise content if needed |  | ||||||
|                 revision_input = { |  | ||||||
|                     "content": current_content, |  | ||||||
|                     "evaluation": evaluation, |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 revision_response = self.revisor.run( |  | ||||||
|                     json.dumps(revision_input) |  | ||||||
|                 ) |  | ||||||
|                 current_content = self._safely_parse_json( |  | ||||||
|                     revision_response |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 self.conversation.add( |  | ||||||
|                     role=self.revisor.agent_name, |  | ||||||
|                     content=revision_response, |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|             logger.warning( |  | ||||||
|                 "Max iterations reached without meeting quality threshold" |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.error(f"Error in generate_and_refine: {str(e)}") |  | ||||||
|             current_content = { |  | ||||||
|                 "content": {"main_body": f"Error: {str(e)}"} |  | ||||||
|             } |  | ||||||
|             evaluation = self._get_fallback_evaluation() |  | ||||||
| 
 |  | ||||||
|         if self.return_string: |  | ||||||
|             return self.conversation.return_history_as_string() |  | ||||||
|         else: |  | ||||||
|             return { |  | ||||||
|                 "content": current_content.get("content", {}).get( |  | ||||||
|                     "main_body", "" |  | ||||||
|                 ), |  | ||||||
|                 "final_score": evaluation["scores"]["overall"], |  | ||||||
|                 "iterations": self.max_iterations, |  | ||||||
|                 "metadata": { |  | ||||||
|                     "content_metadata": current_content.get( |  | ||||||
|                         "content", {} |  | ||||||
|                     ).get("metadata", {}), |  | ||||||
|                     "evaluation": evaluation, |  | ||||||
|                     "error": "Max iterations reached", |  | ||||||
|                 }, |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|     def save_state(self) -> None: |  | ||||||
|         """Save the current state of all agents.""" |  | ||||||
|         for agent in [ |  | ||||||
|             self.main_supervisor, |  | ||||||
|             self.generator, |  | ||||||
|             *self.evaluators, |  | ||||||
|             self.revisor, |  | ||||||
|         ]: |  | ||||||
|             try: |  | ||||||
|                 agent.save_state() |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.error( |  | ||||||
|                     f"Error saving state for {agent.agent_name}: {str(e)}" |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|     def load_state(self) -> None: |  | ||||||
|         """Load the saved state of all agents.""" |  | ||||||
|         for agent in [ |  | ||||||
|             self.main_supervisor, |  | ||||||
|             self.generator, |  | ||||||
|             *self.evaluators, |  | ||||||
|             self.revisor, |  | ||||||
|         ]: |  | ||||||
|             try: |  | ||||||
|                 agent.load_state() |  | ||||||
|             except Exception as e: |  | ||||||
|                 logger.error( |  | ||||||
|                     f"Error loading state for {agent.agent_name}: {str(e)}" |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     # Example usage |  | ||||||
|     try: |  | ||||||
|         talkhier = TalkHier( |  | ||||||
|             max_iterations=1, |  | ||||||
|             quality_threshold=0.8, |  | ||||||
|             model_name="gpt-4o", |  | ||||||
|             return_string=True, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         task = "Write a comprehensive explanation of quantum computing for beginners" |  | ||||||
|         result = talkhier.run(task) |  | ||||||
|         print(result) |  | ||||||
| 
 |  | ||||||
|         # print(f"Final content: {result['content']}") |  | ||||||
|         # print(f"Quality score: {result['final_score']}") |  | ||||||
|         # print(f"Iterations: {result['iterations']}") |  | ||||||
| 
 |  | ||||||
|     except Exception as e: |  | ||||||
|         logger.error(f"Error in main execution: {str(e)}") |  | ||||||
					Loading…
					
					
				
		Reference in new issue