parent
6d3d760fac
commit
63915cdc46
@ -1,151 +1,144 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Union
|
||||
import logging
|
||||
|
||||
from .. import Flow
|
||||
|
||||
from swarms.workers.worker import Worker
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupChat:
|
||||
"""A group chat with multiple participants with a list of workers and a max number of rounds"""
|
||||
"""A group chat class that contains a list of agents and the maximum number of rounds."""
|
||||
|
||||
workers: List[Worker]
|
||||
agents: List[Flow]
|
||||
messages: List[Dict]
|
||||
max_rounds: int = 10
|
||||
admin_name: str = "Admin" # admin worker
|
||||
max_round: int = 10
|
||||
admin_name: str = "Admin" # the name of the admin agent
|
||||
|
||||
@property
|
||||
def worker_names(self) -> List[str]:
|
||||
"""returns the names of the workers in the group chat"""
|
||||
return [worker.ai_name for worker in self.workers]
|
||||
def agent_names(self) -> List[str]:
|
||||
"""Return the names of the agents in the group chat."""
|
||||
return [agent.name for agent in self.agents]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the group chat."""
|
||||
self.messages.clear()
|
||||
|
||||
def worker_by_name(self, name: str) -> Worker:
|
||||
"""Find the next speaker baed on the message"""
|
||||
return self.workers[self.worker_names.index(name)]
|
||||
def agent_by_name(self, name: str) -> Flow:
|
||||
"""Find the next speaker based on the message."""
|
||||
return self.agents[self.agent_names.index(name)]
|
||||
|
||||
def next_worker(self, worker: Worker) -> Worker:
|
||||
"""Returns the next worker in the list"""
|
||||
return self.workers[
|
||||
(self.workers_names.index(worker.ai_name) + 1) % len(self.workers)
|
||||
]
|
||||
def next_agent(self, agent: Flow) -> Flow:
|
||||
"""Return the next agent in the list."""
|
||||
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
|
||||
|
||||
def select_speaker_msg(self):
|
||||
"""Return the message to select the next speaker"""
|
||||
"""Return the message for selecting the next speaker."""
|
||||
return f"""You are in a role play game. The following roles are available:
|
||||
{self._participant_roles()}.
|
||||
|
||||
return f"""
|
||||
You are in a role play game the following rules are available:
|
||||
{self.__participant_roles()}.
|
||||
Read the following conversation.
|
||||
Then select the next role from {self.agent_names} to play. Only return the role."""
|
||||
|
||||
Read the following conversation then select the next role from {self.worker_names}
|
||||
to play and only return the role
|
||||
"""
|
||||
|
||||
def select_speaker(
|
||||
self,
|
||||
last_speaker: Worker,
|
||||
selector: Worker,
|
||||
):
|
||||
"""Selects the next speaker"""
|
||||
def select_speaker(self, last_speaker: Flow, selector: Flow):
|
||||
"""Select the next speaker."""
|
||||
selector.update_system_message(self.select_speaker_msg())
|
||||
|
||||
final, name = selector.run(
|
||||
# Warn if GroupChat is underpopulated, without established changing behavior
|
||||
n_agents = len(self.agent_names)
|
||||
if n_agents < 3:
|
||||
logger.warning(
|
||||
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
|
||||
)
|
||||
|
||||
final, name = selector.generate_oai_reply(
|
||||
self.messages
|
||||
+ [
|
||||
{
|
||||
"role": "system",
|
||||
"context": f"Read the above conversation. Then select the next role from {self.worker_names} to play. Only return the role.",
|
||||
"content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
|
||||
}
|
||||
]
|
||||
)
|
||||
if not final:
|
||||
return self.next_worker(last_speaker)
|
||||
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
|
||||
return self.next_agent(last_speaker)
|
||||
try:
|
||||
return self.worker_by_name(name)
|
||||
return self.agent_by_name(name)
|
||||
except ValueError:
|
||||
return self.next_worker(last_speaker)
|
||||
return self.next_agent(last_speaker)
|
||||
|
||||
def _participant_roles(self):
|
||||
return "\n".join(
|
||||
[f"{worker.ai_name}: {worker.system_message}" for worker in self.workers]
|
||||
)
|
||||
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
|
||||
|
||||
|
||||
class GroupChatManager(Flow):
|
||||
"""(In preview) A chat manager agent that can manage a group chat of multiple agents."""
|
||||
|
||||
class GroupChatManager(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
groupchat: GroupChat,
|
||||
ai_name: Optional[str] = "chat_manager",
|
||||
name: Optional[str] = "chat_manager",
|
||||
# unlimited consecutive auto reply by default
|
||||
max_consecutive_auto_reply: Optional[int] = sys.maxsize,
|
||||
human_input_mode: Optional[str] = "NEVER",
|
||||
system_message: Optional[str] = "Group chat manager",
|
||||
system_message: Optional[str] = "Group chat manager.",
|
||||
# seed: Optional[int] = 4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
ai_name=ai_name,
|
||||
# max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
# human_input_mode=human_input_mode,
|
||||
# system_message=system_message,
|
||||
name=name,
|
||||
max_consecutive_auto_reply=max_consecutive_auto_reply,
|
||||
human_input_mode=human_input_mode,
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
||||
self.register_reply(
|
||||
Worker, GroupChatManager.run, config=groupchat, reset_config=GroupChat.reset
|
||||
)
|
||||
self.register_reply(Flow, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
|
||||
# self._random = random.Random(seed)
|
||||
|
||||
def run(
|
||||
def run_chat(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Worker] = None,
|
||||
sender: Optional[Flow] = None,
|
||||
config: Optional[GroupChat] = None,
|
||||
) -> Union[str, Dict, None]:
|
||||
# run
|
||||
"""Run a group chat."""
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
speaker = sender
|
||||
groupchat = config
|
||||
|
||||
for i in range(groupchat.max_rounds):
|
||||
for i in range(groupchat.max_round):
|
||||
# set the name to speaker's name if the role is not function
|
||||
if message["role"] != "function":
|
||||
message["name"] = speaker.ai_name
|
||||
|
||||
message["name"] = speaker.name
|
||||
groupchat.messages.append(message)
|
||||
|
||||
# broadcast the message to all workers except the speaker
|
||||
for worker in groupchat.workers:
|
||||
if worker != speaker:
|
||||
self.send(
|
||||
message,
|
||||
worker,
|
||||
request_reply=False,
|
||||
silent=True,
|
||||
)
|
||||
if i == groupchat.max_rounds - 1:
|
||||
# broadcast the message to all agents except the speaker
|
||||
for agent in groupchat.agents:
|
||||
if agent != speaker:
|
||||
self.send(message, agent, request_reply=False, silent=True)
|
||||
if i == groupchat.max_round - 1:
|
||||
# the last round
|
||||
break
|
||||
|
||||
try:
|
||||
# select next speaker
|
||||
# select the next speaker
|
||||
speaker = groupchat.select_speaker(speaker, self)
|
||||
# let the speaker speak
|
||||
reply = speaker.generate_reply(sender=self)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# let the admin speak if interrupted
|
||||
if groupchat.admin_name in groupchat.worker_names:
|
||||
# admin worker is a particpant
|
||||
speaker = groupchat.worker_by_name(groupchat.admin_name)
|
||||
# let the admin agent speak if interrupted
|
||||
if groupchat.admin_name in groupchat.agent_names:
|
||||
# admin agent is one of the participants
|
||||
speaker = groupchat.agent_by_name(groupchat.admin_name)
|
||||
reply = speaker.generate_reply(sender=self)
|
||||
else:
|
||||
# admin worker is not found in particpants
|
||||
# admin agent is not found in the participants
|
||||
raise
|
||||
if reply is None:
|
||||
break
|
||||
|
||||
# speaker sends message without requesting a reply
|
||||
# The speaker sends the message without requesting a reply
|
||||
speaker.send(reply, self, request_reply=False)
|
||||
message = self.last_message(speaker)
|
||||
message = self.last_messge(speaker)
|
||||
return True, None
|
||||
|
Loading…
Reference in new issue