Initiaal push

Former-commit-id: 3d3c01799a7f8d552496b10b8ec34249af69f735
jojo-group-chat
Sashin 1 year ago
parent 70f5d34369
commit 02f495219e

@ -1,8 +1,9 @@
import logging
from dataclasses import dataclass
from typing import Dict, List
from swarms.structs.flow import Flow
import sys
from typing import Dict, List, Optional, Union
import logging
from swarms import Flow, OpenAI
logger = logging.getLogger(__name__)
@ -38,13 +39,11 @@ class GroupChat:
def select_speaker_msg(self):
"""Return the message for selecting the next speaker."""
return f"""
You are in a role play game. The following roles are available:
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
Read the following conversation.
Then select the next role from {self.agent_names} to play. Only return the role.
"""
Then select the next role from {self.agent_names} to play. Only return the role."""
def select_speaker(self, last_speaker: Flow, selector: Flow):
"""Select the next speaker."""
@ -58,51 +57,45 @@ class GroupChat:
)
name = selector.generate_reply(
self.format_history(
self.messages
self.format_history(self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next most suitable role from {self.agent_names} to play. Only return the role.",
}
]
)
])
)
try:
return self.agent_by_name(name["content"])
return self.agent_by_name(name['content'])
except ValueError:
return self.next_agent(last_speaker)
def _participant_roles(self):
return "\n".join(
[f"{agent.name}: {agent.system_message}" for agent in self.agents]
)
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
def format_history(self, messages: List[Dict]) -> str:
formatted_messages = []
for message in messages:
formatted_message = f"'{message['role']}:{message['content']}"
formatted_messages.append(formatted_message)
return "\n".join(formatted_messages)
return '\n'.join(formatted_messages)
class GroupChatManager:
def __init__(self, groupchat: GroupChat, selector: Flow):
self.groupchat = groupchat
self.selector = selector
def __call__(self, task: str):
self.groupchat.messages.append({"role": self.selector.name, "content": task})
self.groupchat.messages.append({'role':self.selector.name, 'content': task})
for i in range(self.groupchat.max_round):
speaker = self.groupchat.select_speaker(
last_speaker=self.selector, selector=self.selector
)
reply = speaker.generate_reply(
self.groupchat.format_history(self.groupchat.messages)
)
speaker = self.groupchat.select_speaker(last_speaker=self.selector, selector=self.selector)
reply = speaker.generate_reply(self.groupchat.format_history(self.groupchat.messages))
self.groupchat.messages.append(reply)
print(reply)
if i == self.groupchat.max_round - 1:
break
return reply

Loading…
Cancel
Save