You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/swarms/groupchat.py

156 lines
4.6 KiB

import logging
from dataclasses import dataclass
from typing import Dict, List
from swarms.structs.agent import Agent
logger = logging.getLogger(__name__)
@dataclass
class GroupChat:
"""
A group chat class that contains a list of agents and the maximum number of rounds.
Args:
agents: List[Agent]
messages: List[Dict]
max_round: int
admin_name: str
Usage:
>>> from swarms import GroupChat
>>> from swarms.structs.agent import Agent
>>> agents = Agent()
"""
agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin" # the name of the admin agent
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
return [agent.name for agent in self.agents]
def reset(self):
"""Reset the group chat."""
self.messages.clear()
def agent_by_name(self, name: str) -> Agent:
"""Find an agent whose name is contained within the given 'name' string."""
for agent in self.agents:
if agent.name in name:
return agent
raise ValueError(
f"No agent found with a name contained in '{name}'."
)
def next_agent(self, agent: Agent) -> Agent:
"""Return the next agent in the list."""
return self.agents[
(self.agent_names.index(agent.name) + 1)
% len(self.agents)
]
def select_speaker_msg(self):
"""Return the message for selecting the next speaker."""
return f"""
You are in a role play game. The following roles are available:
{self._participant_roles()}.
Read the following conversation.
Then select the next role from {self.agent_names} to play. Only return the role.
"""
def select_speaker(self, last_speaker: Agent, selector: Agent):
"""Select the next speaker."""
selector.update_system_message(self.select_speaker_msg())
# Warn if GroupChat is underpopulated, without established changing behavior
n_agents = len(self.agent_names)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents."
" Direct communication would be more efficient."
)
name = selector.generate_reply(
self.format_history(
self.messages
+ [
{
"role": "system",
"content": (
"Read the above conversation. Then"
" select the next most suitable role"
f" from {self.agent_names} to play. Only"
" return the role."
),
}
]
)
)
try:
return self.agent_by_name(name["content"])
except ValueError:
return self.next_agent(last_speaker)
def _participant_roles(self):
return "\n".join(
[
f"{agent.name}: {agent.system_message}"
for agent in self.agents
]
)
def format_history(self, messages: List[Dict]) -> str:
formatted_messages = []
for message in messages:
formatted_message = (
f"'{message['role']}:{message['content']}"
)
formatted_messages.append(formatted_message)
return "\n".join(formatted_messages)
class GroupChatManager:
"""
GroupChatManager
Args:
groupchat: GroupChat
selector: Agent
Usage:
>>> from swarms import GroupChatManager
>>> from swarms.structs.agent import Agent
>>> agents = Agent()
>>> output = GroupChatManager(agents, lambda x: x)
"""
def __init__(self, groupchat: GroupChat, selector: Agent):
self.groupchat = groupchat
self.selector = selector
def __call__(self, task: str):
self.groupchat.messages.append(
{"role": self.selector.name, "content": task}
)
for i in range(self.groupchat.max_round):
speaker = self.groupchat.select_speaker(
last_speaker=self.selector, selector=self.selector
)
reply = speaker.generate_reply(
self.groupchat.format_history(self.groupchat.messages)
)
self.groupchat.messages.append(reply)
print(reply)
if i == self.groupchat.max_round - 1:
break
return reply