You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
356 lines
10 KiB
356 lines
10 KiB
import concurrent.futures
|
|
from datetime import datetime
|
|
from typing import Callable, List
|
|
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
|
|
class AgentResponse(BaseModel):
|
|
agent_name: str
|
|
role: str
|
|
message: str
|
|
timestamp: datetime = Field(default_factory=datetime.now)
|
|
turn_number: int
|
|
preceding_context: List[str] = Field(default_factory=list)
|
|
|
|
|
|
class ChatTurn(BaseModel):
|
|
turn_number: int
|
|
responses: List[AgentResponse]
|
|
task: str
|
|
timestamp: datetime = Field(default_factory=datetime.now)
|
|
|
|
|
|
class ChatHistory(BaseModel):
|
|
turns: List[ChatTurn]
|
|
total_messages: int
|
|
name: str
|
|
description: str
|
|
start_time: datetime = Field(default_factory=datetime.now)
|
|
|
|
|
|
SpeakerFunction = Callable[[List[str], "Agent"], bool]
|
|
|
|
|
|
def round_robin(history: List[str], agent: Agent) -> bool:
|
|
"""
|
|
Round robin speaker function.
|
|
Each agent speaks in turn, in a circular order.
|
|
"""
|
|
return True
|
|
|
|
|
|
def expertise_based(history: List[str], agent: Agent) -> bool:
|
|
"""
|
|
Expertise based speaker function.
|
|
An agent speaks if their system prompt is in the last message.
|
|
"""
|
|
return (
|
|
agent.system_prompt.lower() in history[-1].lower()
|
|
if history
|
|
else True
|
|
)
|
|
|
|
|
|
def random_selection(history: List[str], agent: Agent) -> bool:
|
|
"""
|
|
Random selection speaker function.
|
|
An agent speaks randomly.
|
|
"""
|
|
import random
|
|
|
|
return random.choice([True, False])
|
|
|
|
|
|
def custom_speaker(history: List[str], agent: Agent) -> bool:
|
|
"""
|
|
Custom speaker function with complex logic.
|
|
|
|
Args:
|
|
history: Previous conversation messages
|
|
agent: Current agent being evaluated
|
|
|
|
Returns:
|
|
bool: Whether agent should speak
|
|
"""
|
|
# No history - let everyone speak
|
|
if not history:
|
|
return True
|
|
|
|
last_message = history[-1].lower()
|
|
|
|
# Check for agent expertise keywords
|
|
expertise_relevant = any(
|
|
keyword in last_message
|
|
for keyword in agent.description.lower().split()
|
|
)
|
|
|
|
# Check for direct mentions
|
|
mentioned = agent.agent_name.lower() in last_message
|
|
|
|
# Check if agent hasn't spoken recently
|
|
not_recent_speaker = not any(
|
|
agent.agent_name in msg for msg in history[-3:]
|
|
)
|
|
|
|
return expertise_relevant or mentioned or not_recent_speaker
|
|
|
|
|
|
def most_recent(history: List[str], agent: Agent) -> bool:
|
|
"""
|
|
Most recent speaker function.
|
|
An agent speaks if they are the last speaker.
|
|
"""
|
|
return (
|
|
agent.agent_name == history[-1].split(":")[0].strip()
|
|
if history
|
|
else True
|
|
)
|
|
|
|
|
|
class GroupChat:
|
|
"""
|
|
GroupChat class to enable multiple agents to communicate in a synchronous group chat.
|
|
Each agent is aware of all other agents, every message exchanged, and the social context.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "GroupChat",
|
|
description: str = "A group chat for multiple agents",
|
|
agents: List[Agent] = [],
|
|
speaker_fn: SpeakerFunction = round_robin,
|
|
max_loops: int = 10,
|
|
):
|
|
"""
|
|
Initialize the GroupChat.
|
|
|
|
Args:
|
|
name (str): Name of the group chat.
|
|
description (str): Description of the purpose of the group chat.
|
|
agents (List[Agent]): A list of agents participating in the chat.
|
|
speaker_fn (SpeakerFunction): The function to determine which agent should speak next.
|
|
max_loops (int): Maximum number of turns in the chat.
|
|
"""
|
|
self.name = name
|
|
self.description = description
|
|
self.agents = agents
|
|
self.speaker_fn = speaker_fn
|
|
self.max_loops = max_loops
|
|
self.chat_history = ChatHistory(
|
|
turns=[],
|
|
total_messages=0,
|
|
name=name,
|
|
description=description,
|
|
)
|
|
|
|
def _get_response_sync(
|
|
self, agent: Agent, prompt: str, turn_number: int
|
|
) -> AgentResponse:
|
|
"""
|
|
Get the response from an agent synchronously.
|
|
|
|
Args:
|
|
agent (Agent): The agent responding.
|
|
prompt (str): The message triggering the response.
|
|
turn_number (int): The current turn number.
|
|
|
|
Returns:
|
|
AgentResponse: The agent's response captured in a structured format.
|
|
"""
|
|
try:
|
|
# Provide the agent with information about the chat and other agents
|
|
chat_info = f"Chat Name: {self.name}\nChat Description: {self.description}\nAgents in Chat: {[a.agent_name for a in self.agents]}"
|
|
context = f"""You are {agent.agent_name}
|
|
Conversation History:
|
|
\n{chat_info}
|
|
Other agents: {[a.agent_name for a in self.agents if a != agent]}
|
|
Previous messages: {self.get_full_chat_history()}
|
|
""" # Updated line
|
|
|
|
message = agent.run(context + prompt)
|
|
return AgentResponse(
|
|
agent_name=agent.name,
|
|
role=agent.system_prompt,
|
|
message=message,
|
|
turn_number=turn_number,
|
|
preceding_context=self.get_recent_messages(3),
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error from {agent.name}: {e}")
|
|
return AgentResponse(
|
|
agent_name=agent.name,
|
|
role=agent.system_prompt,
|
|
message=f"Error generating response: {str(e)}",
|
|
turn_number=turn_number,
|
|
preceding_context=[],
|
|
)
|
|
|
|
def get_full_chat_history(self) -> str:
|
|
"""
|
|
Get the full chat history formatted for agent context.
|
|
|
|
Returns:
|
|
str: The full chat history with sender names.
|
|
"""
|
|
messages = []
|
|
for turn in self.chat_history.turns:
|
|
for response in turn.responses:
|
|
messages.append(
|
|
f"{response.agent_name}: {response.message}"
|
|
)
|
|
return "\n".join(messages)
|
|
|
|
def get_recent_messages(self, n: int = 3) -> List[str]:
|
|
"""
|
|
Get the most recent messages in the chat.
|
|
|
|
Args:
|
|
n (int): The number of recent messages to retrieve.
|
|
|
|
Returns:
|
|
List[str]: The most recent messages in the chat.
|
|
"""
|
|
messages = []
|
|
for turn in self.chat_history.turns[-n:]:
|
|
for response in turn.responses:
|
|
messages.append(
|
|
f"{response.agent_name}: {response.message}"
|
|
)
|
|
return messages
|
|
|
|
def run(self, task: str) -> ChatHistory:
|
|
"""
|
|
Run the group chat.
|
|
|
|
Args:
|
|
task (str): The initial message to start the chat.
|
|
|
|
Returns:
|
|
ChatHistory: The history of the chat.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
f"Starting chat '{self.name}' with task: {task}"
|
|
)
|
|
|
|
for turn in range(self.max_loops):
|
|
current_turn = ChatTurn(
|
|
turn_number=turn, responses=[], task=task
|
|
)
|
|
|
|
for agent in self.agents:
|
|
if self.speaker_fn(
|
|
self.get_recent_messages(), agent
|
|
):
|
|
response = self._get_response_sync(
|
|
agent, task, turn
|
|
)
|
|
current_turn.responses.append(response)
|
|
self.chat_history.total_messages += 1
|
|
logger.debug(
|
|
f"Turn {turn}, {agent.name} responded"
|
|
)
|
|
|
|
self.chat_history.turns.append(current_turn)
|
|
|
|
return self.chat_history
|
|
except Exception as e:
|
|
logger.error(f"Error in chat: {e}")
|
|
raise e
|
|
|
|
def batched_run(self, tasks: List[str], *args, **kwargs):
|
|
"""
|
|
Run the group chat with a batch of tasks.
|
|
|
|
Args:
|
|
tasks (List[str]): The list of tasks to run in the chat.
|
|
|
|
Returns:
|
|
List[ChatHistory]: The history of each chat.
|
|
"""
|
|
return [self.run(task, *args, **kwargs) for task in tasks]
|
|
|
|
def concurrent_run(self, tasks: List[str], *args, **kwargs):
|
|
"""
|
|
Run the group chat with a batch of tasks concurrently using a thread pool.
|
|
|
|
Args:
|
|
tasks (List[str]): The list of tasks to run in the chat.
|
|
|
|
Returns:
|
|
List[ChatHistory]: The history of each chat.
|
|
"""
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
return list(
|
|
executor.map(
|
|
lambda task: self.run(task, *args, **kwargs),
|
|
tasks,
|
|
)
|
|
)
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
|
|
# load_dotenv()
|
|
|
|
# # Get the OpenAI API key from the environment variable
|
|
# api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
# # Create an instance of the OpenAIChat class
|
|
# model = OpenAIChat(
|
|
# openai_api_key=api_key,
|
|
# model_name="gpt-4o-mini",
|
|
# temperature=0.1,
|
|
# )
|
|
|
|
# # Example agents
|
|
# agent1 = Agent(
|
|
# agent_name="Financial-Analysis-Agent",
|
|
# system_prompt="You are a financial analyst specializing in investment strategies.",
|
|
# llm=model,
|
|
# max_loops=1,
|
|
# autosave=False,
|
|
# dashboard=False,
|
|
# verbose=True,
|
|
# dynamic_temperature_enabled=True,
|
|
# user_name="swarms_corp",
|
|
# retry_attempts=1,
|
|
# context_length=200000,
|
|
# output_type="string",
|
|
# streaming_on=False,
|
|
# )
|
|
|
|
# agent2 = Agent(
|
|
# agent_name="Tax-Adviser-Agent",
|
|
# system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.",
|
|
# llm=model,
|
|
# max_loops=1,
|
|
# autosave=False,
|
|
# dashboard=False,
|
|
# verbose=True,
|
|
# dynamic_temperature_enabled=True,
|
|
# user_name="swarms_corp",
|
|
# retry_attempts=1,
|
|
# context_length=200000,
|
|
# output_type="string",
|
|
# streaming_on=False,
|
|
# )
|
|
|
|
# agents = [agent1, agent2]
|
|
|
|
# chat = GroupChat(
|
|
# name="Investment Advisory",
|
|
# description="Financial and tax analysis group",
|
|
# agents=agents,
|
|
# speaker_fn=expertise_based,
|
|
# )
|
|
|
|
# history = chat.run(
|
|
# "How to optimize tax strategy for investments?"
|
|
# )
|
|
# print(history.model_dump_json(indent=2))
|