Merge pull request #783 from harshalmore31/talk_hier
feat: Implement TalkHier framework for hierarchical multi-agent content generationpull/785/head
commit
49b65e4d07
@ -0,0 +1,682 @@
|
|||||||
|
"""
|
||||||
|
TalkHier: A hierarchical multi-agent framework for content generation and refinement.
|
||||||
|
Implements structured communication and evaluation protocols.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRole(Enum):
|
||||||
|
"""Defines the possible roles for agents in the system."""
|
||||||
|
|
||||||
|
SUPERVISOR = "supervisor"
|
||||||
|
GENERATOR = "generator"
|
||||||
|
EVALUATOR = "evaluator"
|
||||||
|
REVISOR = "revisor"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommunicationEvent:
|
||||||
|
"""Represents a structured communication event between agents."""
|
||||||
|
message: str
|
||||||
|
background: Optional[str] = None
|
||||||
|
intermediate_output: Optional[Dict[str, Any]] = None
|
||||||
|
sender: str = ""
|
||||||
|
receiver: str = ""
|
||||||
|
timestamp: str = str(datetime.now())
|
||||||
|
|
||||||
|
|
||||||
|
class TalkHier:
|
||||||
|
"""
|
||||||
|
A hierarchical multi-agent system for content generation and refinement.
|
||||||
|
|
||||||
|
Implements the TalkHier framework with structured communication protocols
|
||||||
|
and hierarchical refinement processes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
max_iterations: Maximum number of refinement iterations
|
||||||
|
quality_threshold: Minimum score required for content acceptance
|
||||||
|
model_name: Name of the LLM model to use
|
||||||
|
base_path: Path for saving agent states
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_iterations: int = 3,
|
||||||
|
quality_threshold: float = 0.8,
|
||||||
|
model_name: str = "gpt-4",
|
||||||
|
base_path: Optional[str] = None,
|
||||||
|
return_string: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the TalkHier system."""
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
self.quality_threshold = quality_threshold
|
||||||
|
self.model_name = model_name
|
||||||
|
self.return_string = return_string
|
||||||
|
self.base_path = (
|
||||||
|
Path(base_path) if base_path else Path("./agent_states")
|
||||||
|
)
|
||||||
|
self.base_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize agents
|
||||||
|
self._init_agents()
|
||||||
|
|
||||||
|
# Create conversation
|
||||||
|
self.conversation = Conversation()
|
||||||
|
|
||||||
|
def _safely_parse_json(self, json_str: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Safely parse JSON string, handling various formats and potential errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str: String containing JSON data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed dictionary
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try direct JSON parsing
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
try:
|
||||||
|
# Try extracting JSON from potential text wrapper
|
||||||
|
import re
|
||||||
|
|
||||||
|
json_match = re.search(r"\{.*\}", json_str, re.DOTALL)
|
||||||
|
if (json_match):
|
||||||
|
return json.loads(json_match.group())
|
||||||
|
# Try extracting from markdown code blocks
|
||||||
|
code_block_match = re.search(
|
||||||
|
r"```(?:json)?\s*(\{.*?\})\s*```",
|
||||||
|
json_str,
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
if (code_block_match):
|
||||||
|
return json.loads(code_block_match.group(1))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to extract JSON: {str(e)}")
|
||||||
|
|
||||||
|
# Fallback: create structured dict from text
|
||||||
|
return {
|
||||||
|
"content": json_str,
|
||||||
|
"metadata": {
|
||||||
|
"parsed": False,
|
||||||
|
"timestamp": str(datetime.now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_criteria_generator_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the criteria generator agent."""
|
||||||
|
return """You are a Criteria Generator agent responsible for creating task-specific evaluation criteria.
|
||||||
|
Analyze the task and generate appropriate evaluation criteria based on:
|
||||||
|
- Task type and complexity
|
||||||
|
- Required domain knowledge
|
||||||
|
- Target audience expectations
|
||||||
|
- Quality requirements
|
||||||
|
|
||||||
|
Output all responses in strict JSON format:
|
||||||
|
{
|
||||||
|
"criteria": {
|
||||||
|
"criterion_name": {
|
||||||
|
"description": "Detailed description of what this criterion measures",
|
||||||
|
"importance": "Weight from 0.0-1.0 indicating importance",
|
||||||
|
"evaluation_guide": "Guidelines for how to evaluate this criterion"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"task_type": "Classification of the task type",
|
||||||
|
"complexity_level": "Assessment of task complexity",
|
||||||
|
"domain_focus": "Primary domain or field of the task"
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def _init_agents(self) -> None:
|
||||||
|
"""Initialize all agents with their specific roles and prompts."""
|
||||||
|
# Main supervisor agent
|
||||||
|
self.main_supervisor = Agent(
|
||||||
|
agent_name="Main-Supervisor",
|
||||||
|
system_prompt=self._get_supervisor_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(
|
||||||
|
self.base_path / "main_supervisor.json"
|
||||||
|
),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generator agent
|
||||||
|
self.generator = Agent(
|
||||||
|
agent_name="Content-Generator",
|
||||||
|
system_prompt=self._get_generator_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(self.base_path / "generator.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Criteria Generator agent
|
||||||
|
self.criteria_generator = Agent(
|
||||||
|
agent_name="Criteria-Generator",
|
||||||
|
system_prompt=self._get_criteria_generator_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(self.base_path / "criteria_generator.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluators without criteria (will be set during run)
|
||||||
|
self.evaluators = []
|
||||||
|
for i in range(3):
|
||||||
|
self.evaluators.append(
|
||||||
|
Agent(
|
||||||
|
agent_name=f"Evaluator-{i}",
|
||||||
|
system_prompt=self._get_evaluator_prompt(i),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(
|
||||||
|
self.base_path / f"evaluator_{i}.json"
|
||||||
|
),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Revisor agent
|
||||||
|
self.revisor = Agent(
|
||||||
|
agent_name="Content-Revisor",
|
||||||
|
system_prompt=self._get_revisor_prompt(),
|
||||||
|
model_name=self.model_name,
|
||||||
|
max_loops=1,
|
||||||
|
saved_state_path=str(self.base_path / "revisor.json"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_dynamic_criteria(self, task: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Generate dynamic evaluation criteria based on the task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Content generation task description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing dynamic evaluation criteria
|
||||||
|
"""
|
||||||
|
# Example dynamic criteria generation logic
|
||||||
|
if "technical" in task.lower():
|
||||||
|
return {
|
||||||
|
"accuracy": "Technical correctness and source reliability",
|
||||||
|
"clarity": "Readability and logical structure",
|
||||||
|
"depth": "Comprehensive coverage of technical details",
|
||||||
|
"engagement": "Interest level and relevance to the audience",
|
||||||
|
"technical_quality": "Grammar, spelling, and formatting",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"accuracy": "Factual correctness and source reliability",
|
||||||
|
"clarity": "Readability and logical structure",
|
||||||
|
"coherence": "Logical consistency and argument structure",
|
||||||
|
"engagement": "Interest level and relevance to the audience",
|
||||||
|
"completeness": "Coverage of the topic and depth",
|
||||||
|
"technical_quality": "Grammar, spelling, and formatting",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_supervisor_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the supervisor agent."""
|
||||||
|
return """You are a Supervisor agent responsible for orchestrating the content generation process and selecting the best evaluation criteria.
|
||||||
|
|
||||||
|
You must:
|
||||||
|
1. Analyze tasks and develop strategies
|
||||||
|
2. Review multiple evaluator feedback
|
||||||
|
3. Select the most appropriate evaluation based on:
|
||||||
|
- Completeness of criteria
|
||||||
|
- Relevance to task
|
||||||
|
- Quality of feedback
|
||||||
|
4. Provide clear instructions for content revision
|
||||||
|
|
||||||
|
Output all responses in strict JSON format:
|
||||||
|
{
|
||||||
|
"thoughts": {
|
||||||
|
"task_analysis": "Analysis of requirements, audience, scope",
|
||||||
|
"strategy": "Step-by-step plan and success metrics",
|
||||||
|
"evaluation_selection": {
|
||||||
|
"chosen_evaluator": "ID of selected evaluator",
|
||||||
|
"reasoning": "Why this evaluation was chosen",
|
||||||
|
"key_criteria": ["List of most important criteria"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"next_action": {
|
||||||
|
"agent": "Next agent to engage",
|
||||||
|
"instruction": "Detailed instructions with context"
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def _get_generator_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the generator agent."""
|
||||||
|
return """You are a Generator agent responsible for creating high-quality, original content. Your role is to produce content that is engaging, informative, and tailored to the target audience.
|
||||||
|
|
||||||
|
When generating content:
|
||||||
|
- Thoroughly research and fact-check all information
|
||||||
|
- Structure content logically with clear flow
|
||||||
|
- Use appropriate tone and language for the target audience
|
||||||
|
- Include relevant examples and explanations
|
||||||
|
- Ensure content is original and plagiarism-free
|
||||||
|
- Consider SEO best practices where applicable
|
||||||
|
|
||||||
|
Output all responses in strict JSON format:
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"main_body": "The complete generated content with proper formatting and structure",
|
||||||
|
"metadata": {
|
||||||
|
"word_count": "Accurate word count of main body",
|
||||||
|
"target_audience": "Detailed audience description",
|
||||||
|
"key_points": ["List of main points covered"],
|
||||||
|
"sources": ["List of reference sources if applicable"],
|
||||||
|
"readability_level": "Estimated reading level",
|
||||||
|
"tone": "Description of content tone"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def _get_evaluator_prompt(self, evaluator_id: int) -> str:
|
||||||
|
"""Get the base prompt for an evaluator agent."""
|
||||||
|
return f"""You are Evaluator {evaluator_id}, responsible for critically assessing content quality. Your evaluation must be thorough, objective, and constructive.
|
||||||
|
|
||||||
|
When receiving content to evaluate:
|
||||||
|
1. First analyze the task description to determine appropriate evaluation criteria
|
||||||
|
2. Generate specific criteria based on task requirements
|
||||||
|
3. Evaluate content against these criteria
|
||||||
|
4. Provide detailed feedback for each criterion
|
||||||
|
|
||||||
|
Output all responses in strict JSON format:
|
||||||
|
{{
|
||||||
|
"generated_criteria": {{
|
||||||
|
"criteria_name": "description of what this criterion measures",
|
||||||
|
// Add more criteria based on task analysis
|
||||||
|
}},
|
||||||
|
"scores": {{
|
||||||
|
"overall": "0.0-1.0 composite score",
|
||||||
|
"categories": {{
|
||||||
|
// Scores for each generated criterion
|
||||||
|
"criterion_name": "0.0-1.0 score with evidence"
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
"feedback": [
|
||||||
|
"Specific, actionable improvement suggestions per criterion"
|
||||||
|
],
|
||||||
|
"strengths": ["Notable positive aspects"],
|
||||||
|
"weaknesses": ["Areas needing improvement"]
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
def _get_revisor_prompt(self) -> str:
|
||||||
|
"""Get the prompt for the revisor agent."""
|
||||||
|
return """You are a Revisor agent responsible for improving content based on evaluator feedback. Your role is to enhance content while maintaining its core message and purpose.
|
||||||
|
|
||||||
|
When revising content:
|
||||||
|
- Address all evaluator feedback systematically
|
||||||
|
- Maintain consistency in tone and style
|
||||||
|
- Preserve accurate information
|
||||||
|
- Enhance clarity and flow
|
||||||
|
- Fix technical issues
|
||||||
|
- Optimize for target audience
|
||||||
|
- Track all changes made
|
||||||
|
|
||||||
|
Output all responses in strict JSON format:
|
||||||
|
{
|
||||||
|
"revised_content": {
|
||||||
|
"main_body": "Complete revised content incorporating all improvements",
|
||||||
|
"metadata": {
|
||||||
|
"word_count": "Updated word count",
|
||||||
|
"changes_made": [
|
||||||
|
"Detailed list of specific changes and improvements",
|
||||||
|
"Reasoning for each major revision",
|
||||||
|
"Feedback points addressed"
|
||||||
|
],
|
||||||
|
"improvement_summary": "Overview of main enhancements",
|
||||||
|
"preserved_elements": ["Key elements maintained from original"],
|
||||||
|
"revision_approach": "Strategy used for revisions"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def _generate_criteria_for_task(self, task: str) -> Dict[str, Any]:
|
||||||
|
"""Generate evaluation criteria for the given task."""
|
||||||
|
try:
|
||||||
|
criteria_input = {
|
||||||
|
"task": task,
|
||||||
|
"instruction": "Generate specific evaluation criteria for this task."
|
||||||
|
}
|
||||||
|
|
||||||
|
criteria_response = self.criteria_generator.run(json.dumps(criteria_input))
|
||||||
|
self.conversation.add(
|
||||||
|
role="Criteria-Generator",
|
||||||
|
content=criteria_response
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._safely_parse_json(criteria_response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating criteria: {str(e)}")
|
||||||
|
return {"criteria": {}}
|
||||||
|
|
||||||
|
def _create_comm_event(self, sender: Agent, receiver: Agent, response: Dict) -> CommunicationEvent:
|
||||||
|
"""Create a structured communication event between agents."""
|
||||||
|
return CommunicationEvent(
|
||||||
|
message=response.get("message", ""),
|
||||||
|
background=response.get("background", ""),
|
||||||
|
intermediate_output=response.get("intermediate_output", {}),
|
||||||
|
sender=sender.agent_name,
|
||||||
|
receiver=receiver.agent_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _evaluate_content(self, content: Union[str, Dict], task: str) -> Dict[str, Any]:
|
||||||
|
"""Coordinate evaluation process with parallel evaluator execution."""
|
||||||
|
try:
|
||||||
|
content_dict = self._safely_parse_json(content) if isinstance(content, str) else content
|
||||||
|
criteria_data = self._generate_criteria_for_task(task)
|
||||||
|
|
||||||
|
def run_evaluator(evaluator, eval_input):
|
||||||
|
response = evaluator.run(json.dumps(eval_input))
|
||||||
|
return {
|
||||||
|
"evaluator_id": evaluator.agent_name,
|
||||||
|
"evaluation": self._safely_parse_json(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_inputs = [{
|
||||||
|
"task": task,
|
||||||
|
"content": content_dict,
|
||||||
|
"criteria": criteria_data.get("criteria", {})
|
||||||
|
} for _ in self.evaluators]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
evaluations = list(executor.map(
|
||||||
|
lambda x: run_evaluator(*x),
|
||||||
|
zip(self.evaluators, eval_inputs)
|
||||||
|
))
|
||||||
|
|
||||||
|
supervisor_input = {
|
||||||
|
"evaluations": evaluations,
|
||||||
|
"task": task,
|
||||||
|
"instruction": "Synthesize feedback"
|
||||||
|
}
|
||||||
|
supervisor_response = self.main_supervisor.run(json.dumps(supervisor_input))
|
||||||
|
aggregated_eval = self._safely_parse_json(supervisor_response)
|
||||||
|
|
||||||
|
# Track communication
|
||||||
|
comm_event = self._create_comm_event(
|
||||||
|
self.main_supervisor,
|
||||||
|
self.revisor,
|
||||||
|
aggregated_eval
|
||||||
|
)
|
||||||
|
self.conversation.add(
|
||||||
|
role="Communication",
|
||||||
|
content=json.dumps(asdict(comm_event))
|
||||||
|
)
|
||||||
|
|
||||||
|
return aggregated_eval
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Evaluation error: {str(e)}")
|
||||||
|
return self._get_fallback_evaluation()
|
||||||
|
|
||||||
|
def _get_fallback_evaluation(self) -> Dict[str, Any]:
|
||||||
|
"""Get a safe fallback evaluation result."""
|
||||||
|
return {
|
||||||
|
"scores": {
|
||||||
|
"overall": 0.5,
|
||||||
|
"categories": {
|
||||||
|
"accuracy": 0.5,
|
||||||
|
"clarity": 0.5,
|
||||||
|
"coherence": 0.5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"feedback": ["Evaluation failed"],
|
||||||
|
"metadata": {
|
||||||
|
"timestamp": str(datetime.now()),
|
||||||
|
"is_fallback": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _aggregate_evaluations(
|
||||||
|
self, evaluations: List[Dict[str, Any]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Aggregate multiple evaluation results into a single evaluation."""
|
||||||
|
try:
|
||||||
|
# Collect all unique criteria from evaluators
|
||||||
|
all_criteria = set()
|
||||||
|
for eval_data in evaluations:
|
||||||
|
categories = eval_data.get("scores", {}).get("categories", {})
|
||||||
|
all_criteria.update(categories.keys())
|
||||||
|
|
||||||
|
# Initialize score aggregation
|
||||||
|
aggregated_scores = {criterion: [] for criterion in all_criteria}
|
||||||
|
overall_scores = []
|
||||||
|
all_feedback = []
|
||||||
|
|
||||||
|
# Collect scores and feedback
|
||||||
|
for eval_data in evaluations:
|
||||||
|
scores = eval_data.get("scores", {})
|
||||||
|
overall_scores.append(scores.get("overall", 0.5))
|
||||||
|
|
||||||
|
categories = scores.get("categories", {})
|
||||||
|
for criterion in all_criteria:
|
||||||
|
if criterion in categories:
|
||||||
|
aggregated_scores[criterion].append(
|
||||||
|
categories.get(criterion, 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_feedback.extend(eval_data.get("feedback", []))
|
||||||
|
|
||||||
|
# Calculate means
|
||||||
|
def safe_mean(scores: List[float]) -> float:
|
||||||
|
return sum(scores) / len(scores) if scores else 0.5
|
||||||
|
|
||||||
|
return {
|
||||||
|
"scores": {
|
||||||
|
"overall": safe_mean(overall_scores),
|
||||||
|
"categories": {
|
||||||
|
criterion: safe_mean(scores)
|
||||||
|
for criterion, scores in aggregated_scores.items()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"feedback": list(set(all_feedback)),
|
||||||
|
"metadata": {
|
||||||
|
"evaluator_count": len(evaluations),
|
||||||
|
"criteria_used": list(all_criteria),
|
||||||
|
"timestamp": str(datetime.now()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in evaluation aggregation: {str(e)}")
|
||||||
|
return self._get_fallback_evaluation()
|
||||||
|
|
||||||
|
def _evaluate_and_revise(
|
||||||
|
self, content: Union[str, Dict], task: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Coordinate evaluation and revision process."""
|
||||||
|
try:
|
||||||
|
# Get evaluations and supervisor selection
|
||||||
|
evaluation_result = self._evaluate_content(content, task)
|
||||||
|
|
||||||
|
# Extract selected evaluation and supervisor reasoning
|
||||||
|
selected_evaluation = evaluation_result.get("selected_evaluation", {})
|
||||||
|
supervisor_reasoning = evaluation_result.get("supervisor_reasoning", {})
|
||||||
|
|
||||||
|
# Prepare revision input with selected evaluation
|
||||||
|
revision_input = {
|
||||||
|
"content": content,
|
||||||
|
"evaluation": selected_evaluation,
|
||||||
|
"supervisor_feedback": supervisor_reasoning,
|
||||||
|
"instruction": "Revise the content based on the selected evaluation feedback"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get revision from content generator
|
||||||
|
revision_response = self.generator.run(json.dumps(revision_input))
|
||||||
|
revised_content = self._safely_parse_json(revision_response)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": revised_content,
|
||||||
|
"evaluation": evaluation_result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Evaluation and revision error: {str(e)}")
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"evaluation": self._get_fallback_evaluation()
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, task: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate and iteratively refine content based on the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Content generation task description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing final content and metadata
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting content generation for task: {task}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get initial direction from supervisor
|
||||||
|
supervisor_response = self.main_supervisor.run(task)
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.main_supervisor.agent_name,
|
||||||
|
content=supervisor_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
supervisor_data = self._safely_parse_json(
|
||||||
|
supervisor_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate initial content
|
||||||
|
generator_response = self.generator.run(
|
||||||
|
json.dumps(supervisor_data.get("next_action", {}))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.generator.agent_name,
|
||||||
|
content=generator_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_content = self._safely_parse_json(
|
||||||
|
generator_response
|
||||||
|
)
|
||||||
|
|
||||||
|
for iteration in range(self.max_iterations):
|
||||||
|
logger.info(f"Starting iteration {iteration + 1}")
|
||||||
|
|
||||||
|
# Evaluate and revise content
|
||||||
|
result = self._evaluate_and_revise(current_content, task)
|
||||||
|
evaluation = result["evaluation"]
|
||||||
|
current_content = result["content"]
|
||||||
|
|
||||||
|
# Check if quality threshold is met
|
||||||
|
selected_eval = evaluation.get("selected_evaluation", {})
|
||||||
|
overall_score = selected_eval.get("scores", {}).get("overall", 0.0)
|
||||||
|
|
||||||
|
if overall_score >= self.quality_threshold:
|
||||||
|
logger.info("Quality threshold met, returning content")
|
||||||
|
return {
|
||||||
|
"content": current_content.get("content", {}).get("main_body", ""),
|
||||||
|
"final_score": overall_score,
|
||||||
|
"iterations": iteration + 1,
|
||||||
|
"metadata": {
|
||||||
|
"content_metadata": current_content.get("content", {}).get("metadata", {}),
|
||||||
|
"evaluation": evaluation,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add to conversation history
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.generator.agent_name,
|
||||||
|
content=json.dumps(current_content),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Max iterations reached without meeting quality threshold"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in generate_and_refine: {str(e)}")
|
||||||
|
current_content = {
|
||||||
|
"content": {"main_body": f"Error: {str(e)}"}
|
||||||
|
}
|
||||||
|
evaluation = self._get_fallback_evaluation()
|
||||||
|
|
||||||
|
if self.return_string:
|
||||||
|
return self.conversation.return_history_as_string()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": current_content.get("content", {}).get(
|
||||||
|
"main_body", ""
|
||||||
|
),
|
||||||
|
"final_score": evaluation["scores"]["overall"],
|
||||||
|
"iterations": self.max_iterations,
|
||||||
|
"metadata": {
|
||||||
|
"content_metadata": current_content.get(
|
||||||
|
"content", {}
|
||||||
|
).get("metadata", {}),
|
||||||
|
"evaluation": evaluation,
|
||||||
|
"error": "Max iterations reached",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_state(self) -> None:
|
||||||
|
"""Save the current state of all agents."""
|
||||||
|
for agent in [
|
||||||
|
self.main_supervisor,
|
||||||
|
self.generator,
|
||||||
|
*self.evaluators,
|
||||||
|
self.revisor,
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
agent.save_state()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error saving state for {agent.agent_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state(self) -> None:
|
||||||
|
"""Load the saved state of all agents."""
|
||||||
|
for agent in [
|
||||||
|
self.main_supervisor,
|
||||||
|
self.generator,
|
||||||
|
*self.evaluators,
|
||||||
|
self.revisor,
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
agent.load_state()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error loading state for {agent.agent_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
talkhier = TalkHier(
|
||||||
|
max_iterations=1,
|
||||||
|
quality_threshold=0.8,
|
||||||
|
model_name="gpt-4o",
|
||||||
|
return_string=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ask for user input
|
||||||
|
task = input("Enter the content generation task description: ")
|
||||||
|
result = talkhier.run(task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main execution: {str(e)}")
|
Loading…
Reference in new issue