diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index 6bbe0898..0d9ffff3 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -1,8 +1,9 @@ -import logging from dataclasses import dataclass -from typing import Dict, List -from swarms.structs.flow import Flow +import sys +from typing import Dict, List, Optional, Union +import logging +from swarms import Flow, OpenAI logger = logging.getLogger(__name__) @@ -38,13 +39,11 @@ class GroupChat: def select_speaker_msg(self): """Return the message for selecting the next speaker.""" - return f""" - You are in a role play game. The following roles are available: - {self._participant_roles()}. + return f"""You are in a role play game. The following roles are available: +{self._participant_roles()}. - Read the following conversation. - Then select the next role from {self.agent_names} to play. Only return the role. - """ +Read the following conversation. +Then select the next role from {self.agent_names} to play. Only return the role.""" def select_speaker(self, last_speaker: Flow, selector: Flow): """Select the next speaker.""" @@ -58,51 +57,45 @@ class GroupChat: ) name = selector.generate_reply( - self.format_history( - self.messages - + [ - { - "role": "system", - "content": f"Read the above conversation. Then select the next most suitable role from {self.agent_names} to play. Only return the role.", - } - ] - ) + self.format_history(self.messages + + [ + { + "role": "system", + "content": f"Read the above conversation. Then select the next most suitable role from {self.agent_names} to play. Only return the role.", + } + ]) ) try: - return self.agent_by_name(name["content"]) + return self.agent_by_name(name['content']) except ValueError: return self.next_agent(last_speaker) def _participant_roles(self): - return "\n".join( - [f"{agent.name}: {agent.system_message}" for agent in self.agents] - ) + return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) def format_history(self, messages: List[Dict]) -> str: formatted_messages = [] for message in messages: formatted_message = f"'{message['role']}:{message['content']}" formatted_messages.append(formatted_message) - return "\n".join(formatted_messages) - + return '\n'.join(formatted_messages) class GroupChatManager: def __init__(self, groupchat: GroupChat, selector: Flow): self.groupchat = groupchat self.selector = selector + + def __call__(self, task: str): - self.groupchat.messages.append({"role": self.selector.name, "content": task}) + self.groupchat.messages.append({'role':self.selector.name, 'content': task}) for i in range(self.groupchat.max_round): - speaker = self.groupchat.select_speaker( - last_speaker=self.selector, selector=self.selector - ) - reply = speaker.generate_reply( - self.groupchat.format_history(self.groupchat.messages) - ) + speaker = self.groupchat.select_speaker(last_speaker=self.selector, selector=self.selector) + reply = speaker.generate_reply(self.groupchat.format_history(self.groupchat.messages)) self.groupchat.messages.append(reply) print(reply) if i == self.groupchat.max_round - 1: break return reply +