From 8b4d593de51ecae841cb9820f343011e341be366 Mon Sep 17 00:00:00 2001
From: harshalmore31 <86048671+harshalmore31@users.noreply.github.com>
Date: Thu, 20 Feb 2025 23:58:43 +0530
Subject: [PATCH] Enhanced Talk hier :  A hierarchical multi-agent framework

---
 swarms/structs/Talk_Hier.py | 682 ++++++++++++++++++++++++++++++++++++
 1 file changed, 682 insertions(+)
 create mode 100644 swarms/structs/Talk_Hier.py

diff --git a/swarms/structs/Talk_Hier.py b/swarms/structs/Talk_Hier.py
new file mode 100644
index 00000000..23048ce8
--- /dev/null
+++ b/swarms/structs/Talk_Hier.py
@@ -0,0 +1,682 @@
+"""
+TalkHier: A hierarchical multi-agent framework for content generation and refinement.
+Implements structured communication and evaluation protocols.
+"""
+
+import json
+import logging
+from dataclasses import dataclass, asdict
+from datetime import datetime
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+from concurrent.futures import ThreadPoolExecutor
+
+from swarms import Agent
+from swarms.structs.conversation import Conversation
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class AgentRole(Enum):
+    """Defines the possible roles for agents in the system."""
+
+    SUPERVISOR = "supervisor"
+    GENERATOR = "generator"
+    EVALUATOR = "evaluator"
+    REVISOR = "revisor"
+
+
+@dataclass
+class CommunicationEvent:
+    """Represents a structured communication event between agents."""
+    message: str 
+    background: Optional[str] = None
+    intermediate_output: Optional[Dict[str, Any]] = None
+    sender: str = ""
+    receiver: str = ""
+    timestamp: str = str(datetime.now())
+
+
+class TalkHier:
+    """
+    A hierarchical multi-agent system for content generation and refinement.
+
+    Implements the TalkHier framework with structured communication protocols
+    and hierarchical refinement processes.
+
+    Attributes:
+        max_iterations: Maximum number of refinement iterations
+        quality_threshold: Minimum score required for content acceptance
+        model_name: Name of the LLM model to use
+        base_path: Path for saving agent states
+    """
+
+    def __init__(
+        self,
+        max_iterations: int = 3,
+        quality_threshold: float = 0.8,
+        model_name: str = "gpt-4",
+        base_path: Optional[str] = None,
+        return_string: bool = False,
+    ):
+        """Initialize the TalkHier system."""
+        self.max_iterations = max_iterations
+        self.quality_threshold = quality_threshold
+        self.model_name = model_name
+        self.return_string = return_string
+        self.base_path = (
+            Path(base_path) if base_path else Path("./agent_states")
+        )
+        self.base_path.mkdir(exist_ok=True)
+
+        # Initialize agents
+        self._init_agents()
+
+        # Create conversation
+        self.conversation = Conversation()
+
+    def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
+        """
+        Safely parse JSON string, handling various formats and potential errors.
+
+        Args:
+            json_str: String containing JSON data
+
+        Returns:
+            Parsed dictionary
+        """
+        try:
+            # Try direct JSON parsing
+            return json.loads(json_str)
+        except json.JSONDecodeError:
+            try:
+                # Try extracting JSON from potential text wrapper
+                import re
+
+                json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
+                if (json_match):
+                    return json.loads(json_match.group())
+                # Try extracting from markdown code blocks
+                code_block_match = re.search(
+                    r"```(?:json)?\s*(\{.*?\})\s*```",
+                    json_str,
+                    re.DOTALL,
+                )
+                if (code_block_match):
+                    return json.loads(code_block_match.group(1))
+            except Exception as e:
+                logger.warning(f"Failed to extract JSON: {str(e)}")
+
+            # Fallback: create structured dict from text
+            return {
+                "content": json_str,
+                "metadata": {
+                    "parsed": False,
+                    "timestamp": str(datetime.now()),
+                },
+            }
+
+    def _get_criteria_generator_prompt(self) -> str:
+        """Get the prompt for the criteria generator agent."""
+        return """You are a Criteria Generator agent responsible for creating task-specific evaluation criteria.
+Analyze the task and generate appropriate evaluation criteria based on:
+- Task type and complexity
+- Required domain knowledge
+- Target audience expectations
+- Quality requirements
+
+Output all responses in strict JSON format:
+{
+    "criteria": {
+        "criterion_name": {
+            "description": "Detailed description of what this criterion measures",
+            "importance": "Weight from 0.0-1.0 indicating importance",
+            "evaluation_guide": "Guidelines for how to evaluate this criterion"
+        }
+    },
+    "metadata": {
+        "task_type": "Classification of the task type",
+        "complexity_level": "Assessment of task complexity",
+        "domain_focus": "Primary domain or field of the task"
+    }
+}"""
+
+    def _init_agents(self) -> None:
+        """Initialize all agents with their specific roles and prompts."""
+        # Main supervisor agent
+        self.main_supervisor = Agent(
+            agent_name="Main-Supervisor",
+            system_prompt=self._get_supervisor_prompt(),
+            model_name=self.model_name,
+            max_loops=1,
+            saved_state_path=str(
+                self.base_path / "main_supervisor.json"
+            ),
+            verbose=True,
+        )
+
+        # Generator agent
+        self.generator = Agent(
+            agent_name="Content-Generator",
+            system_prompt=self._get_generator_prompt(),
+            model_name=self.model_name,
+            max_loops=1,
+            saved_state_path=str(self.base_path / "generator.json"),
+            verbose=True,
+        )
+
+        # Criteria Generator agent
+        self.criteria_generator = Agent(
+            agent_name="Criteria-Generator",
+            system_prompt=self._get_criteria_generator_prompt(),
+            model_name=self.model_name,
+            max_loops=1,
+            saved_state_path=str(self.base_path / "criteria_generator.json"),
+            verbose=True,
+        )
+
+        # Evaluators without criteria (will be set during run)
+        self.evaluators = []
+        for i in range(3):
+            self.evaluators.append(
+                Agent(
+                    agent_name=f"Evaluator-{i}",
+                    system_prompt=self._get_evaluator_prompt(i),
+                    model_name=self.model_name,
+                    max_loops=1,
+                    saved_state_path=str(
+                        self.base_path / f"evaluator_{i}.json"
+                    ),
+                    verbose=True,
+                )
+            )
+
+        # Revisor agent
+        self.revisor = Agent(
+            agent_name="Content-Revisor",
+            system_prompt=self._get_revisor_prompt(),
+            model_name=self.model_name,
+            max_loops=1,
+            saved_state_path=str(self.base_path / "revisor.json"),
+            verbose=True,
+        )
+
+    def _generate_dynamic_criteria(self, task: str) -> Dict[str, str]:
+        """
+        Generate dynamic evaluation criteria based on the task.
+
+        Args:
+            task: Content generation task description
+
+        Returns:
+            Dictionary containing dynamic evaluation criteria
+        """
+        # Example dynamic criteria generation logic
+        if "technical" in task.lower():
+            return {
+                "accuracy": "Technical correctness and source reliability",
+                "clarity": "Readability and logical structure",
+                "depth": "Comprehensive coverage of technical details",
+                "engagement": "Interest level and relevance to the audience",
+                "technical_quality": "Grammar, spelling, and formatting",
+            }
+        else:
+            return {
+                "accuracy": "Factual correctness and source reliability",
+                "clarity": "Readability and logical structure",
+                "coherence": "Logical consistency and argument structure",
+                "engagement": "Interest level and relevance to the audience",
+                "completeness": "Coverage of the topic and depth",
+                "technical_quality": "Grammar, spelling, and formatting",
+            }
+
+    def _get_supervisor_prompt(self) -> str:
+        """Get the prompt for the supervisor agent."""
+        return """You are a Supervisor agent responsible for orchestrating the content generation process and selecting the best evaluation criteria.
+
+You must:
+1. Analyze tasks and develop strategies
+2. Review multiple evaluator feedback
+3. Select the most appropriate evaluation based on:
+   - Completeness of criteria
+   - Relevance to task
+   - Quality of feedback
+4. Provide clear instructions for content revision
+
+Output all responses in strict JSON format:
+{
+    "thoughts": {
+        "task_analysis": "Analysis of requirements, audience, scope",
+        "strategy": "Step-by-step plan and success metrics",
+        "evaluation_selection": {
+            "chosen_evaluator": "ID of selected evaluator",
+            "reasoning": "Why this evaluation was chosen",
+            "key_criteria": ["List of most important criteria"]
+        }
+    },
+    "next_action": {
+        "agent": "Next agent to engage",
+        "instruction": "Detailed instructions with context"
+    }
+}"""
+
+    def _get_generator_prompt(self) -> str:
+        """Get the prompt for the generator agent."""
+        return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
+
+When generating content:
+- Thoroughly research and fact-check all information
+- Structure content logically with clear flow
+- Use appropriate tone and language for the target audience
+- Include relevant examples and explanations
+- Ensure content is original and plagiarism-free
+- Consider SEO best practices where applicable
+
+Output all responses in strict JSON format:
+{
+    "content": {
+        "main_body": "The complete generated content with proper formatting and structure",
+        "metadata": {
+            "word_count": "Accurate word count of main body",
+            "target_audience": "Detailed audience description",
+            "key_points": ["List of main points covered"],
+            "sources": ["List of reference sources if applicable"],
+            "readability_level": "Estimated reading level",
+            "tone": "Description of content tone"
+        }
+    }
+}"""
+
+    def _get_evaluator_prompt(self, evaluator_id: int) -> str:
+        """Get the base prompt for an evaluator agent."""
+        return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
+
+When receiving content to evaluate:
+1. First analyze the task description to determine appropriate evaluation criteria
+2. Generate specific criteria based on task requirements
+3. Evaluate content against these criteria
+4. Provide detailed feedback for each criterion
+
+Output all responses in strict JSON format:
+{{
+    "generated_criteria": {{
+        "criteria_name": "description of what this criterion measures",
+        // Add more criteria based on task analysis
+    }},
+    "scores": {{
+        "overall": "0.0-1.0 composite score",
+        "categories": {{
+            // Scores for each generated criterion
+            "criterion_name": "0.0-1.0 score with evidence"
+        }}
+    }},
+    "feedback": [
+        "Specific, actionable improvement suggestions per criterion"
+    ],
+    "strengths": ["Notable positive aspects"],
+    "weaknesses": ["Areas needing improvement"]
+}}"""
+
+    def _get_revisor_prompt(self) -> str:
+        """Get the prompt for the revisor agent."""
+        return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
+
+When revising content:
+- Address all evaluator feedback systematically
+- Maintain consistency in tone and style
+- Preserve accurate information
+- Enhance clarity and flow
+- Fix technical issues
+- Optimize for target audience
+- Track all changes made
+
+Output all responses in strict JSON format:
+{
+    "revised_content": {
+        "main_body": "Complete revised content incorporating all improvements",
+        "metadata": {
+            "word_count": "Updated word count",
+            "changes_made": [
+                "Detailed list of specific changes and improvements",
+                "Reasoning for each major revision",
+                "Feedback points addressed"
+            ],
+            "improvement_summary": "Overview of main enhancements",
+            "preserved_elements": ["Key elements maintained from original"],
+            "revision_approach": "Strategy used for revisions"
+        }
+    }
+}"""
+
+    def _generate_criteria_for_task(self, task: str) -> Dict[str, Any]:
+        """Generate evaluation criteria for the given task."""
+        try:
+            criteria_input = {
+                "task": task,
+                "instruction": "Generate specific evaluation criteria for this task."
+            }
+            
+            criteria_response = self.criteria_generator.run(json.dumps(criteria_input))
+            self.conversation.add(
+                role="Criteria-Generator",
+                content=criteria_response
+            )
+            
+            return self._safely_parse_json(criteria_response)
+        except Exception as e:
+            logger.error(f"Error generating criteria: {str(e)}")
+            return {"criteria": {}}
+
+    def _create_comm_event(self, sender: Agent, receiver: Agent, response: Dict) -> CommunicationEvent:
+        """Create a structured communication event between agents."""
+        return CommunicationEvent(
+            message=response.get("message", ""),
+            background=response.get("background", ""),
+            intermediate_output=response.get("intermediate_output", {}),
+            sender=sender.agent_name,
+            receiver=receiver.agent_name,
+        )
+
+    def _evaluate_content(self, content: Union[str, Dict], task: str) -> Dict[str, Any]:
+        """Coordinate evaluation process with parallel evaluator execution."""
+        try:
+            content_dict = self._safely_parse_json(content) if isinstance(content, str) else content
+            criteria_data = self._generate_criteria_for_task(task)
+
+            def run_evaluator(evaluator, eval_input):
+                response = evaluator.run(json.dumps(eval_input))
+                return {
+                    "evaluator_id": evaluator.agent_name, 
+                    "evaluation": self._safely_parse_json(response)
+                }
+
+            eval_inputs = [{
+                "task": task,
+                "content": content_dict,
+                "criteria": criteria_data.get("criteria", {})
+            } for _ in self.evaluators]
+
+            with ThreadPoolExecutor() as executor:
+                evaluations = list(executor.map(
+                    lambda x: run_evaluator(*x), 
+                    zip(self.evaluators, eval_inputs)
+                ))
+
+            supervisor_input = {
+                "evaluations": evaluations,
+                "task": task,
+                "instruction": "Synthesize feedback"
+            }
+            supervisor_response = self.main_supervisor.run(json.dumps(supervisor_input))
+            aggregated_eval = self._safely_parse_json(supervisor_response)
+
+            # Track communication
+            comm_event = self._create_comm_event(
+                self.main_supervisor, 
+                self.revisor, 
+                aggregated_eval
+            )
+            self.conversation.add(
+                role="Communication",
+                content=json.dumps(asdict(comm_event))
+            )
+
+            return aggregated_eval
+
+        except Exception as e:
+            logger.error(f"Evaluation error: {str(e)}")
+            return self._get_fallback_evaluation()
+
+    def _get_fallback_evaluation(self) -> Dict[str, Any]:
+        """Get a safe fallback evaluation result."""
+        return {
+            "scores": {
+                "overall": 0.5,
+                "categories": {
+                    "accuracy": 0.5,
+                    "clarity": 0.5,
+                    "coherence": 0.5,
+                },
+            },
+            "feedback": ["Evaluation failed"],
+            "metadata": {
+                "timestamp": str(datetime.now()),
+                "is_fallback": True,
+            },
+        }
+
+    def _aggregate_evaluations(
+        self, evaluations: List[Dict[str, Any]]
+    ) -> Dict[str, Any]:
+        """Aggregate multiple evaluation results into a single evaluation."""
+        try:
+            # Collect all unique criteria from evaluators
+            all_criteria = set()
+            for eval_data in evaluations:
+                categories = eval_data.get("scores", {}).get("categories", {})
+                all_criteria.update(categories.keys())
+
+            # Initialize score aggregation
+            aggregated_scores = {criterion: [] for criterion in all_criteria}
+            overall_scores = []
+            all_feedback = []
+
+            # Collect scores and feedback
+            for eval_data in evaluations:
+                scores = eval_data.get("scores", {})
+                overall_scores.append(scores.get("overall", 0.5))
+                
+                categories = scores.get("categories", {})
+                for criterion in all_criteria:
+                    if criterion in categories:
+                        aggregated_scores[criterion].append(
+                            categories.get(criterion, 0.5)
+                        )
+
+                all_feedback.extend(eval_data.get("feedback", []))
+
+            # Calculate means
+            def safe_mean(scores: List[float]) -> float:
+                return sum(scores) / len(scores) if scores else 0.5
+
+            return {
+                "scores": {
+                    "overall": safe_mean(overall_scores),
+                    "categories": {
+                        criterion: safe_mean(scores)
+                        for criterion, scores in aggregated_scores.items()
+                    },
+                },
+                "feedback": list(set(all_feedback)),
+                "metadata": {
+                    "evaluator_count": len(evaluations),
+                    "criteria_used": list(all_criteria),
+                    "timestamp": str(datetime.now()),
+                },
+            }
+
+        except Exception as e:
+            logger.error(f"Error in evaluation aggregation: {str(e)}")
+            return self._get_fallback_evaluation()
+
+    def _evaluate_and_revise(
+        self, content: Union[str, Dict], task: str
+    ) -> Dict[str, Any]:
+        """Coordinate evaluation and revision process."""
+        try:
+            # Get evaluations and supervisor selection
+            evaluation_result = self._evaluate_content(content, task)
+            
+            # Extract selected evaluation and supervisor reasoning
+            selected_evaluation = evaluation_result.get("selected_evaluation", {})
+            supervisor_reasoning = evaluation_result.get("supervisor_reasoning", {})
+            
+            # Prepare revision input with selected evaluation
+            revision_input = {
+                "content": content,
+                "evaluation": selected_evaluation,
+                "supervisor_feedback": supervisor_reasoning,
+                "instruction": "Revise the content based on the selected evaluation feedback"
+            }
+
+            # Get revision from content generator
+            revision_response = self.generator.run(json.dumps(revision_input))
+            revised_content = self._safely_parse_json(revision_response)
+
+            return {
+                "content": revised_content,
+                "evaluation": evaluation_result
+            }
+        except Exception as e:
+            logger.error(f"Evaluation and revision error: {str(e)}")
+            return {
+                "content": content,
+                "evaluation": self._get_fallback_evaluation()
+            }
+
+    def run(self, task: str) -> Dict[str, Any]:
+        """
+        Generate and iteratively refine content based on the given task.
+
+        Args:
+            task: Content generation task description
+
+        Returns:
+            Dictionary containing final content and metadata
+        """
+        logger.info(f"Starting content generation for task: {task}")
+
+        try:
+            # Get initial direction from supervisor
+            supervisor_response = self.main_supervisor.run(task)
+
+            self.conversation.add(
+                role=self.main_supervisor.agent_name,
+                content=supervisor_response,
+            )
+
+            supervisor_data = self._safely_parse_json(
+                supervisor_response
+            )
+
+            # Generate initial content
+            generator_response = self.generator.run(
+                json.dumps(supervisor_data.get("next_action", {}))
+            )
+
+            self.conversation.add(
+                role=self.generator.agent_name,
+                content=generator_response,
+            )
+
+            current_content = self._safely_parse_json(
+                generator_response
+            )
+
+            for iteration in range(self.max_iterations):
+                logger.info(f"Starting iteration {iteration + 1}")
+
+                # Evaluate and revise content
+                result = self._evaluate_and_revise(current_content, task)
+                evaluation = result["evaluation"]
+                current_content = result["content"]
+
+                # Check if quality threshold is met
+                selected_eval = evaluation.get("selected_evaluation", {})
+                overall_score = selected_eval.get("scores", {}).get("overall", 0.0)
+                
+                if overall_score >= self.quality_threshold:
+                    logger.info("Quality threshold met, returning content")
+                    return {
+                        "content": current_content.get("content", {}).get("main_body", ""),
+                        "final_score": overall_score,
+                        "iterations": iteration + 1,
+                        "metadata": {
+                            "content_metadata": current_content.get("content", {}).get("metadata", {}),
+                            "evaluation": evaluation,
+                        },
+                    }
+
+                # Add to conversation history
+                self.conversation.add(
+                    role=self.generator.agent_name,
+                    content=json.dumps(current_content),
+                )
+
+            logger.warning(
+                "Max iterations reached without meeting quality threshold"
+            )
+
+        except Exception as e:
+            logger.error(f"Error in generate_and_refine: {str(e)}")
+            current_content = {
+                "content": {"main_body": f"Error: {str(e)}"}
+            }
+            evaluation = self._get_fallback_evaluation()
+
+        if self.return_string:
+            return self.conversation.return_history_as_string()
+        else:
+            return {
+                "content": current_content.get("content", {}).get(
+                    "main_body", ""
+                ),
+                "final_score": evaluation["scores"]["overall"],
+                "iterations": self.max_iterations,
+                "metadata": {
+                    "content_metadata": current_content.get(
+                        "content", {}
+                    ).get("metadata", {}),
+                    "evaluation": evaluation,
+                    "error": "Max iterations reached",
+                },
+            }
+
+    def save_state(self) -> None:
+        """Save the current state of all agents."""
+        for agent in [
+            self.main_supervisor,
+            self.generator,
+            *self.evaluators,
+            self.revisor,
+        ]:
+            try:
+                agent.save_state()
+            except Exception as e:
+                logger.error(
+                    f"Error saving state for {agent.agent_name}: {str(e)}"
+                )
+
+    def load_state(self) -> None:
+        """Load the saved state of all agents."""
+        for agent in [
+            self.main_supervisor,
+            self.generator,
+            *self.evaluators,
+            self.revisor,
+        ]:
+            try:
+                agent.load_state()
+            except Exception as e:
+                logger.error(
+                    f"Error loading state for {agent.agent_name}: {str(e)}"
+                )
+
+
+if __name__ == "__main__":
+    try:
+        talkhier = TalkHier(
+            max_iterations=1,
+            quality_threshold=0.8,
+            model_name="gpt-4o",
+            return_string=False,
+        )
+
+        # Ask for user input
+        task = input("Enter the content generation task description: ")
+        result = talkhier.run(task)
+
+    except Exception as e:
+        logger.error(f"Error in main execution: {str(e)}")
\ No newline at end of file