fixed reflexion agent death spiral

pull/1266/head
Steve-Dusty 2 days ago
parent c317eb98e3
commit 748037a12d

@ -1,4 +1,5 @@
from typing import List, Dict, Any, Tuple
import re
import time
from datetime import datetime
@ -7,6 +8,10 @@ from swarms.structs.conversation import Conversation
from loguru import logger
# Configuration constants for ReflexionAgent
EARLY_TERMINATION_THRESHOLD = 0.8 # Lower than 0.9 for more realistic termination
DEFAULT_SCORE = 0.6 # Higher than 0.5 to increase chance of early termination
SCORE_IMPROVEMENT_THRESHOLD = 0.05 # Minimum improvement to continue iterating
# Define Reflexion prompt with detailed instructions
REFLEXION_PROMPT = """
@ -298,6 +303,73 @@ Focus on extracting lasting insights that will be valuable for improving future
f"Initialized {self.agent_name} with model {self.model_name}"
)
def _extract_score_robust(self, evaluation: str) -> float:
"""
Robustly extract a score from evaluation response using multiple strategies.
Args:
evaluation: The evaluation text from the evaluator agent
Returns:
float: Extracted score between 0.0 and 1.0, or DEFAULT_SCORE if extraction fails
"""
# Strategy 1: Look for "final score: X/10" or "overall score: X" (existing pattern)
# Handle markdown formatting by removing asterisks
eval_clean = evaluation.replace('*', '').lower()
score_patterns = [
r"(?:final|overall)\s+score:?\s*(\d+(?:\.\d+)?)",
r"score:?\s*(\d+(?:\.\d+)?)\s*/\s*10",
r"(?:rating|grade):?\s*(\d+(?:\.\d+)?)\s*/\s*10",
r"(?:rating|grade):?\s*(\d+(?:\.\d+)?)",
]
for pattern in score_patterns:
matches = re.findall(pattern, eval_clean)
if matches:
try:
score = float(matches[-1])
# Normalize to 0-1 range if needed
normalized = score / 10.0 if score > 1.0 else score
return max(0.0, min(1.0, normalized))
except (ValueError, IndexError):
continue
# Strategy 2: Look for any number between 0-10 with context
context_patterns = [
r"(\d+(?:\.\d+)?)\s*/\s*10",
r"(\d+(?:\.\d+)?)\s*out of\s*10",
]
for pattern in context_patterns:
matches = re.findall(pattern, eval_clean)
if matches:
try:
score = float(matches[-1])
return max(0.0, min(1.0, score / 10.0))
except (ValueError, IndexError):
continue
# Strategy 3: Sentiment analysis fallback
positive_keywords = ['excellent', 'great', 'strong', 'good', 'well done', 'impressive', 'thorough', 'comprehensive']
negative_keywords = ['poor', 'weak', 'lacking', 'insufficient', 'needs improvement', 'unclear', 'incomplete']
eval_lower = evaluation.lower()
positive_count = sum(1 for kw in positive_keywords if kw in eval_lower)
negative_count = sum(1 for kw in negative_keywords if kw in eval_lower)
if positive_count > negative_count and positive_count > 0:
logger.debug(f"Score extraction via sentiment: positive ({positive_count} keywords)")
return 0.75 # Likely good
elif negative_count > positive_count and negative_count > 0:
logger.debug(f"Score extraction via sentiment: negative ({negative_count} keywords)")
return 0.4 # Likely poor
# Default fallback
logger.warning(f"Could not extract score from evaluation, using default: {DEFAULT_SCORE}")
logger.debug(f"Evaluation text (first 200 chars): {evaluation[:200]}")
return DEFAULT_SCORE
def act(
self,
task: str,
@ -364,24 +436,10 @@ Evaluate this response thoroughly according to the criteria in your instructions
evaluation = self.evaluator.run(task=prompt)
# Extract numerical score from evaluation (in a production system, you'd want a more
# robust parsing method here, potentially using structured output)
try:
# Look for a final score in the format "Final Score: X/10" or similar
import re
score_matches = re.findall(
r"(?:final|overall)\s+score:?\s*(\d+(?:\.\d+)?)",
evaluation.lower(),
)
score = float(score_matches[-1]) if score_matches else 5.0
# Normalize to 0-1 range
normalized_score = score / 10.0
except Exception as e:
logger.error(f"Failed to extract score: {e}")
normalized_score = 0.5 # Default mid-range score
# Use robust score extraction with multiple fallback strategies
normalized_score = self._extract_score_robust(evaluation)
logger.debug(
logger.info(
f"Evaluation complete. Score: {normalized_score:.2f}"
)
@ -542,16 +600,20 @@ Based on the original response, evaluation, and reflection, provide an improved
all_results = []
for task_idx, task in enumerate(tasks):
logger.info(f"Processing task {task_idx+1}/{len(tasks)}")
logger.info(
f"\n{'='*60}\nProcessing task {task_idx+1}/{len(tasks)}\n{'='*60}"
)
logger.info(f"Task: {task[:100]}...")
iterations = []
best_response = None
best_score = -1
prev_score = -1
# Run through multiple iterations of reflection
for iteration in range(self.max_loops):
logger.debug(
f"Starting iteration {iteration+1}/{self.max_loops}"
logger.info(
f"\n--- Iteration {iteration+1}/{self.max_loops} ---"
)
# In first iteration, generate new response
@ -581,13 +643,36 @@ Based on the original response, evaluation, and reflection, provide an improved
best_response = step_result["response"]
best_score = step_result["score"]
# If score is very high, we can stop early
if step_result["score"] > 0.9:
logger.debug(
f"Score {step_result['score']} exceeds threshold. Stopping early."
current_score = step_result["score"]
logger.info(
f"Iteration {iteration+1} complete | Score: {current_score:.2f} | Best: {best_score:.2f}"
)
# Early termination condition 1: Score is high enough
if current_score >= EARLY_TERMINATION_THRESHOLD:
logger.info(
f"✓ High score achieved ({current_score:.2f} >= {EARLY_TERMINATION_THRESHOLD}). Stopping early."
)
break
# Early termination condition 2: Score not improving
if iteration > 0 and (current_score - prev_score) < SCORE_IMPROVEMENT_THRESHOLD:
logger.info(
f"✓ Score improvement minimal ({current_score:.2f} vs {prev_score:.2f}). Stopping early."
)
break
prev_score = current_score
# Summary logging
iterations_used = len(iterations)
logger.info(
f"\n{'='*60}\n"
f"Task complete | Iterations used: {iterations_used}/{self.max_loops} | "
f"Best score: {best_score:.2f}\n"
f"{'='*60}"
)
# Add to conversation history (simplified)
self.conversation.add("user", task)
self.conversation.add("assistant", best_response)

Loading…
Cancel
Save