|
|
|
@ -1,13 +1,15 @@
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import List, Dict, Optional
|
|
|
|
|
|
|
|
|
|
from swarms.prompts.agent_judge_prompt import AGENT_JUDGE_PROMPT
|
|
|
|
|
|
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
|
|
|
|
|
|
from swarms.structs.conversation import Conversation
|
|
|
|
|
|
|
|
|
|
from swarms.utils.any_to_str import any_to_str
|
|
|
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentJudge:
|
|
|
|
|
"""
|
|
|
|
|
A class to represent an agent judge that processes tasks and generates responses.
|
|
|
|
@ -19,13 +21,7 @@ class AgentJudge:
|
|
|
|
|
conversation (Conversation): An instance of the Conversation class to manage conversation history.
|
|
|
|
|
max_loops (int): The maximum number of iterations to run the tasks.
|
|
|
|
|
agent (Agent): An instance of the Agent class that performs the task execution.
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
step(tasks: List[str]) -> str:
|
|
|
|
|
Processes a list of tasks and returns the agent's response.
|
|
|
|
|
|
|
|
|
|
run(tasks: List[str]) -> List[str]:
|
|
|
|
|
Executes the tasks in a loop, updating context and collecting responses.
|
|
|
|
|
evaluation_criteria (Dict[str, float]): Dictionary of evaluation criteria and their weights.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
@ -34,6 +30,7 @@ class AgentJudge:
|
|
|
|
|
system_prompt: str = AGENT_JUDGE_PROMPT,
|
|
|
|
|
model_name: str = "openai/o1",
|
|
|
|
|
max_loops: int = 1,
|
|
|
|
|
evaluation_criteria: Optional[Dict[str, float]] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Initializes the AgentJudge with the specified parameters.
|
|
|
|
@ -43,17 +40,29 @@ class AgentJudge:
|
|
|
|
|
system_prompt (str): The system prompt for the agent.
|
|
|
|
|
model_name (str): The model name used for generating responses.
|
|
|
|
|
max_loops (int): The maximum number of iterations to run the tasks.
|
|
|
|
|
evaluation_criteria (Optional[Dict[str, float]]): Dictionary of evaluation criteria
|
|
|
|
|
and their weights. Keys are criteria names, values are weights.
|
|
|
|
|
Example: {"correctness": 0.4, "efficiency": 0.3, "clarity": 0.3}
|
|
|
|
|
"""
|
|
|
|
|
self.agent_name = agent_name
|
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.conversation = Conversation(time_enabled=False)
|
|
|
|
|
self.max_loops = max_loops
|
|
|
|
|
|
|
|
|
|
self.evaluation_criteria = evaluation_criteria or {}
|
|
|
|
|
|
|
|
|
|
# Enhance system prompt with evaluation criteria if provided
|
|
|
|
|
enhanced_prompt = system_prompt
|
|
|
|
|
if self.evaluation_criteria:
|
|
|
|
|
criteria_str = "\n\nEvaluation Criteria:\n"
|
|
|
|
|
for criterion, weight in self.evaluation_criteria.items():
|
|
|
|
|
criteria_str += f"- {criterion}: weight = {weight}\n"
|
|
|
|
|
enhanced_prompt += criteria_str
|
|
|
|
|
|
|
|
|
|
self.agent = Agent(
|
|
|
|
|
agent_name=agent_name,
|
|
|
|
|
agent_description="You're the agent judge",
|
|
|
|
|
system_prompt=AGENT_JUDGE_PROMPT,
|
|
|
|
|
system_prompt=enhanced_prompt,
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
max_loops=1,
|
|
|
|
|
)
|
|
|
|
@ -70,14 +79,22 @@ class AgentJudge:
|
|
|
|
|
"""
|
|
|
|
|
prompt = any_to_str(tasks)
|
|
|
|
|
logger.debug(f"Running step with prompt: {prompt}")
|
|
|
|
|
|
|
|
|
|
print(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_instruction = "Evaluate the following output or outputs"
|
|
|
|
|
if self.evaluation_criteria:
|
|
|
|
|
criteria_names = list(self.evaluation_criteria.keys())
|
|
|
|
|
if len(criteria_names) == 1:
|
|
|
|
|
task_instruction += f" based on {criteria_names[0]}"
|
|
|
|
|
else:
|
|
|
|
|
formatted_criteria = ", ".join(criteria_names[:-1]) + f" and {criteria_names[-1]}"
|
|
|
|
|
task_instruction += f" based on the criteria: {formatted_criteria}"
|
|
|
|
|
|
|
|
|
|
response = self.agent.run(
|
|
|
|
|
task=f"Evaluate the following output or outputs: {prompt}"
|
|
|
|
|
task=f"{task_instruction}: {prompt}"
|
|
|
|
|
)
|
|
|
|
|
logger.debug(f"Received response: {response}")
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Received response: {response}")
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def run(self, tasks: List[str]) -> List[str]:
|
|
|
|
@ -112,8 +129,7 @@ class AgentJudge:
|
|
|
|
|
|
|
|
|
|
# Update context for next iteration
|
|
|
|
|
context = current_response
|
|
|
|
|
|
|
|
|
|
# Add to conversation history
|
|
|
|
|
logger.debug("Added message to conversation history.")
|
|
|
|
|
|
|
|
|
|
return responses
|
|
|
|
|
return responses
|