Merge pull request #783 from harshalmore31/talk_hier

feat: Implement TalkHier framework for hierarchical multi-agent content generation
pull/785/head
Kye Gomez 2 months ago committed by GitHub
commit 49b65e4d07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,682 @@
"""
TalkHier: A hierarchical multi-agent framework for content generation and refinement.
Implements structured communication and evaluation protocols.
"""
import json
import logging
from dataclasses import dataclass, asdict
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from concurrent.futures import ThreadPoolExecutor
from swarms import Agent
from swarms.structs.conversation import Conversation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgentRole(Enum):
"""Defines the possible roles for agents in the system."""
SUPERVISOR = "supervisor"
GENERATOR = "generator"
EVALUATOR = "evaluator"
REVISOR = "revisor"
@dataclass
class CommunicationEvent:
"""Represents a structured communication event between agents."""
message: str
background: Optional[str] = None
intermediate_output: Optional[Dict[str, Any]] = None
sender: str = ""
receiver: str = ""
timestamp: str = str(datetime.now())
class TalkHier:
"""
A hierarchical multi-agent system for content generation and refinement.
Implements the TalkHier framework with structured communication protocols
and hierarchical refinement processes.
Attributes:
max_iterations: Maximum number of refinement iterations
quality_threshold: Minimum score required for content acceptance
model_name: Name of the LLM model to use
base_path: Path for saving agent states
"""
def __init__(
self,
max_iterations: int = 3,
quality_threshold: float = 0.8,
model_name: str = "gpt-4",
base_path: Optional[str] = None,
return_string: bool = False,
):
"""Initialize the TalkHier system."""
self.max_iterations = max_iterations
self.quality_threshold = quality_threshold
self.model_name = model_name
self.return_string = return_string
self.base_path = (
Path(base_path) if base_path else Path("./agent_states")
)
self.base_path.mkdir(exist_ok=True)
# Initialize agents
self._init_agents()
# Create conversation
self.conversation = Conversation()
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
"""
Safely parse JSON string, handling various formats and potential errors.
Args:
json_str: String containing JSON data
Returns:
Parsed dictionary
"""
try:
# Try direct JSON parsing
return json.loads(json_str)
except json.JSONDecodeError:
try:
# Try extracting JSON from potential text wrapper
import re
json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
if (json_match):
return json.loads(json_match.group())
# Try extracting from markdown code blocks
code_block_match = re.search(
r"```(?:json)?\s*(\{.*?\})\s*```",
json_str,
re.DOTALL,
)
if (code_block_match):
return json.loads(code_block_match.group(1))
except Exception as e:
logger.warning(f"Failed to extract JSON: {str(e)}")
# Fallback: create structured dict from text
return {
"content": json_str,
"metadata": {
"parsed": False,
"timestamp": str(datetime.now()),
},
}
def _get_criteria_generator_prompt(self) -> str:
"""Get the prompt for the criteria generator agent."""
return """You are a Criteria Generator agent responsible for creating task-specific evaluation criteria.
Analyze the task and generate appropriate evaluation criteria based on:
- Task type and complexity
- Required domain knowledge
- Target audience expectations
- Quality requirements
Output all responses in strict JSON format:
{
"criteria": {
"criterion_name": {
"description": "Detailed description of what this criterion measures",
"importance": "Weight from 0.0-1.0 indicating importance",
"evaluation_guide": "Guidelines for how to evaluate this criterion"
}
},
"metadata": {
"task_type": "Classification of the task type",
"complexity_level": "Assessment of task complexity",
"domain_focus": "Primary domain or field of the task"
}
}"""
def _init_agents(self) -> None:
"""Initialize all agents with their specific roles and prompts."""
# Main supervisor agent
self.main_supervisor = Agent(
agent_name="Main-Supervisor",
system_prompt=self._get_supervisor_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(
self.base_path / "main_supervisor.json"
),
verbose=True,
)
# Generator agent
self.generator = Agent(
agent_name="Content-Generator",
system_prompt=self._get_generator_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "generator.json"),
verbose=True,
)
# Criteria Generator agent
self.criteria_generator = Agent(
agent_name="Criteria-Generator",
system_prompt=self._get_criteria_generator_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "criteria_generator.json"),
verbose=True,
)
# Evaluators without criteria (will be set during run)
self.evaluators = []
for i in range(3):
self.evaluators.append(
Agent(
agent_name=f"Evaluator-{i}",
system_prompt=self._get_evaluator_prompt(i),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(
self.base_path / f"evaluator_{i}.json"
),
verbose=True,
)
)
# Revisor agent
self.revisor = Agent(
agent_name="Content-Revisor",
system_prompt=self._get_revisor_prompt(),
model_name=self.model_name,
max_loops=1,
saved_state_path=str(self.base_path / "revisor.json"),
verbose=True,
)
def _generate_dynamic_criteria(self, task: str) -> Dict[str, str]:
"""
Generate dynamic evaluation criteria based on the task.
Args:
task: Content generation task description
Returns:
Dictionary containing dynamic evaluation criteria
"""
# Example dynamic criteria generation logic
if "technical" in task.lower():
return {
"accuracy": "Technical correctness and source reliability",
"clarity": "Readability and logical structure",
"depth": "Comprehensive coverage of technical details",
"engagement": "Interest level and relevance to the audience",
"technical_quality": "Grammar, spelling, and formatting",
}
else:
return {
"accuracy": "Factual correctness and source reliability",
"clarity": "Readability and logical structure",
"coherence": "Logical consistency and argument structure",
"engagement": "Interest level and relevance to the audience",
"completeness": "Coverage of the topic and depth",
"technical_quality": "Grammar, spelling, and formatting",
}
def _get_supervisor_prompt(self) -> str:
"""Get the prompt for the supervisor agent."""
return """You are a Supervisor agent responsible for orchestrating the content generation process and selecting the best evaluation criteria.
You must:
1. Analyze tasks and develop strategies
2. Review multiple evaluator feedback
3. Select the most appropriate evaluation based on:
- Completeness of criteria
- Relevance to task
- Quality of feedback
4. Provide clear instructions for content revision
Output all responses in strict JSON format:
{
"thoughts": {
"task_analysis": "Analysis of requirements, audience, scope",
"strategy": "Step-by-step plan and success metrics",
"evaluation_selection": {
"chosen_evaluator": "ID of selected evaluator",
"reasoning": "Why this evaluation was chosen",
"key_criteria": ["List of most important criteria"]
}
},
"next_action": {
"agent": "Next agent to engage",
"instruction": "Detailed instructions with context"
}
}"""
def _get_generator_prompt(self) -> str:
"""Get the prompt for the generator agent."""
return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
When generating content:
- Thoroughly research and fact-check all information
- Structure content logically with clear flow
- Use appropriate tone and language for the target audience
- Include relevant examples and explanations
- Ensure content is original and plagiarism-free
- Consider SEO best practices where applicable
Output all responses in strict JSON format:
{
"content": {
"main_body": "The complete generated content with proper formatting and structure",
"metadata": {
"word_count": "Accurate word count of main body",
"target_audience": "Detailed audience description",
"key_points": ["List of main points covered"],
"sources": ["List of reference sources if applicable"],
"readability_level": "Estimated reading level",
"tone": "Description of content tone"
}
}
}"""
def _get_evaluator_prompt(self, evaluator_id: int) -> str:
"""Get the base prompt for an evaluator agent."""
return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
When receiving content to evaluate:
1. First analyze the task description to determine appropriate evaluation criteria
2. Generate specific criteria based on task requirements
3. Evaluate content against these criteria
4. Provide detailed feedback for each criterion
Output all responses in strict JSON format:
{{
"generated_criteria": {{
"criteria_name": "description of what this criterion measures",
// Add more criteria based on task analysis
}},
"scores": {{
"overall": "0.0-1.0 composite score",
"categories": {{
// Scores for each generated criterion
"criterion_name": "0.0-1.0 score with evidence"
}}
}},
"feedback": [
"Specific, actionable improvement suggestions per criterion"
],
"strengths": ["Notable positive aspects"],
"weaknesses": ["Areas needing improvement"]
}}"""
def _get_revisor_prompt(self) -> str:
"""Get the prompt for the revisor agent."""
return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
When revising content:
- Address all evaluator feedback systematically
- Maintain consistency in tone and style
- Preserve accurate information
- Enhance clarity and flow
- Fix technical issues
- Optimize for target audience
- Track all changes made
Output all responses in strict JSON format:
{
"revised_content": {
"main_body": "Complete revised content incorporating all improvements",
"metadata": {
"word_count": "Updated word count",
"changes_made": [
"Detailed list of specific changes and improvements",
"Reasoning for each major revision",
"Feedback points addressed"
],
"improvement_summary": "Overview of main enhancements",
"preserved_elements": ["Key elements maintained from original"],
"revision_approach": "Strategy used for revisions"
}
}
}"""
def _generate_criteria_for_task(self, task: str) -> Dict[str, Any]:
"""Generate evaluation criteria for the given task."""
try:
criteria_input = {
"task": task,
"instruction": "Generate specific evaluation criteria for this task."
}
criteria_response = self.criteria_generator.run(json.dumps(criteria_input))
self.conversation.add(
role="Criteria-Generator",
content=criteria_response
)
return self._safely_parse_json(criteria_response)
except Exception as e:
logger.error(f"Error generating criteria: {str(e)}")
return {"criteria": {}}
def _create_comm_event(self, sender: Agent, receiver: Agent, response: Dict) -> CommunicationEvent:
"""Create a structured communication event between agents."""
return CommunicationEvent(
message=response.get("message", ""),
background=response.get("background", ""),
intermediate_output=response.get("intermediate_output", {}),
sender=sender.agent_name,
receiver=receiver.agent_name,
)
def _evaluate_content(self, content: Union[str, Dict], task: str) -> Dict[str, Any]:
"""Coordinate evaluation process with parallel evaluator execution."""
try:
content_dict = self._safely_parse_json(content) if isinstance(content, str) else content
criteria_data = self._generate_criteria_for_task(task)
def run_evaluator(evaluator, eval_input):
response = evaluator.run(json.dumps(eval_input))
return {
"evaluator_id": evaluator.agent_name,
"evaluation": self._safely_parse_json(response)
}
eval_inputs = [{
"task": task,
"content": content_dict,
"criteria": criteria_data.get("criteria", {})
} for _ in self.evaluators]
with ThreadPoolExecutor() as executor:
evaluations = list(executor.map(
lambda x: run_evaluator(*x),
zip(self.evaluators, eval_inputs)
))
supervisor_input = {
"evaluations": evaluations,
"task": task,
"instruction": "Synthesize feedback"
}
supervisor_response = self.main_supervisor.run(json.dumps(supervisor_input))
aggregated_eval = self._safely_parse_json(supervisor_response)
# Track communication
comm_event = self._create_comm_event(
self.main_supervisor,
self.revisor,
aggregated_eval
)
self.conversation.add(
role="Communication",
content=json.dumps(asdict(comm_event))
)
return aggregated_eval
except Exception as e:
logger.error(f"Evaluation error: {str(e)}")
return self._get_fallback_evaluation()
def _get_fallback_evaluation(self) -> Dict[str, Any]:
"""Get a safe fallback evaluation result."""
return {
"scores": {
"overall": 0.5,
"categories": {
"accuracy": 0.5,
"clarity": 0.5,
"coherence": 0.5,
},
},
"feedback": ["Evaluation failed"],
"metadata": {
"timestamp": str(datetime.now()),
"is_fallback": True,
},
}
def _aggregate_evaluations(
self, evaluations: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Aggregate multiple evaluation results into a single evaluation."""
try:
# Collect all unique criteria from evaluators
all_criteria = set()
for eval_data in evaluations:
categories = eval_data.get("scores", {}).get("categories", {})
all_criteria.update(categories.keys())
# Initialize score aggregation
aggregated_scores = {criterion: [] for criterion in all_criteria}
overall_scores = []
all_feedback = []
# Collect scores and feedback
for eval_data in evaluations:
scores = eval_data.get("scores", {})
overall_scores.append(scores.get("overall", 0.5))
categories = scores.get("categories", {})
for criterion in all_criteria:
if criterion in categories:
aggregated_scores[criterion].append(
categories.get(criterion, 0.5)
)
all_feedback.extend(eval_data.get("feedback", []))
# Calculate means
def safe_mean(scores: List[float]) -> float:
return sum(scores) / len(scores) if scores else 0.5
return {
"scores": {
"overall": safe_mean(overall_scores),
"categories": {
criterion: safe_mean(scores)
for criterion, scores in aggregated_scores.items()
},
},
"feedback": list(set(all_feedback)),
"metadata": {
"evaluator_count": len(evaluations),
"criteria_used": list(all_criteria),
"timestamp": str(datetime.now()),
},
}
except Exception as e:
logger.error(f"Error in evaluation aggregation: {str(e)}")
return self._get_fallback_evaluation()
def _evaluate_and_revise(
self, content: Union[str, Dict], task: str
) -> Dict[str, Any]:
"""Coordinate evaluation and revision process."""
try:
# Get evaluations and supervisor selection
evaluation_result = self._evaluate_content(content, task)
# Extract selected evaluation and supervisor reasoning
selected_evaluation = evaluation_result.get("selected_evaluation", {})
supervisor_reasoning = evaluation_result.get("supervisor_reasoning", {})
# Prepare revision input with selected evaluation
revision_input = {
"content": content,
"evaluation": selected_evaluation,
"supervisor_feedback": supervisor_reasoning,
"instruction": "Revise the content based on the selected evaluation feedback"
}
# Get revision from content generator
revision_response = self.generator.run(json.dumps(revision_input))
revised_content = self._safely_parse_json(revision_response)
return {
"content": revised_content,
"evaluation": evaluation_result
}
except Exception as e:
logger.error(f"Evaluation and revision error: {str(e)}")
return {
"content": content,
"evaluation": self._get_fallback_evaluation()
}
def run(self, task: str) -> Dict[str, Any]:
"""
Generate and iteratively refine content based on the given task.
Args:
task: Content generation task description
Returns:
Dictionary containing final content and metadata
"""
logger.info(f"Starting content generation for task: {task}")
try:
# Get initial direction from supervisor
supervisor_response = self.main_supervisor.run(task)
self.conversation.add(
role=self.main_supervisor.agent_name,
content=supervisor_response,
)
supervisor_data = self._safely_parse_json(
supervisor_response
)
# Generate initial content
generator_response = self.generator.run(
json.dumps(supervisor_data.get("next_action", {}))
)
self.conversation.add(
role=self.generator.agent_name,
content=generator_response,
)
current_content = self._safely_parse_json(
generator_response
)
for iteration in range(self.max_iterations):
logger.info(f"Starting iteration {iteration + 1}")
# Evaluate and revise content
result = self._evaluate_and_revise(current_content, task)
evaluation = result["evaluation"]
current_content = result["content"]
# Check if quality threshold is met
selected_eval = evaluation.get("selected_evaluation", {})
overall_score = selected_eval.get("scores", {}).get("overall", 0.0)
if overall_score >= self.quality_threshold:
logger.info("Quality threshold met, returning content")
return {
"content": current_content.get("content", {}).get("main_body", ""),
"final_score": overall_score,
"iterations": iteration + 1,
"metadata": {
"content_metadata": current_content.get("content", {}).get("metadata", {}),
"evaluation": evaluation,
},
}
# Add to conversation history
self.conversation.add(
role=self.generator.agent_name,
content=json.dumps(current_content),
)
logger.warning(
"Max iterations reached without meeting quality threshold"
)
except Exception as e:
logger.error(f"Error in generate_and_refine: {str(e)}")
current_content = {
"content": {"main_body": f"Error: {str(e)}"}
}
evaluation = self._get_fallback_evaluation()
if self.return_string:
return self.conversation.return_history_as_string()
else:
return {
"content": current_content.get("content", {}).get(
"main_body", ""
),
"final_score": evaluation["scores"]["overall"],
"iterations": self.max_iterations,
"metadata": {
"content_metadata": current_content.get(
"content", {}
).get("metadata", {}),
"evaluation": evaluation,
"error": "Max iterations reached",
},
}
def save_state(self) -> None:
"""Save the current state of all agents."""
for agent in [
self.main_supervisor,
self.generator,
*self.evaluators,
self.revisor,
]:
try:
agent.save_state()
except Exception as e:
logger.error(
f"Error saving state for {agent.agent_name}: {str(e)}"
)
def load_state(self) -> None:
"""Load the saved state of all agents."""
for agent in [
self.main_supervisor,
self.generator,
*self.evaluators,
self.revisor,
]:
try:
agent.load_state()
except Exception as e:
logger.error(
f"Error loading state for {agent.agent_name}: {str(e)}"
)
if __name__ == "__main__":
try:
talkhier = TalkHier(
max_iterations=1,
quality_threshold=0.8,
model_name="gpt-4o",
return_string=False,
)
# Ask for user input
task = input("Enter the content generation task description: ")
result = talkhier.run(task)
except Exception as e:
logger.error(f"Error in main execution: {str(e)}")
Loading…
Cancel
Save