groupchat cleanup

pull/58/head
Kye 2 years ago
parent 602b479511
commit 275ac5be71

@ -0,0 +1,4 @@
from swarms.models import OpenAIChat
llm = OpenAIChat(openai_api_key="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")
out = llm("Hello, I am a robot and I like to talk about robots.")

@ -1 +1,3 @@
from swarms.swarms import GroupChatManager from swarms.swarms import GroupChatManager
from swarms.agents.base import AbstractAgent

@ -1,3 +1,4 @@
class AbsractAgent: class AbsractAgent:
def __init__( def __init__(
self, self,

@ -2,14 +2,14 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from swarms.workers.worker import Worker from swarms.agents.base import AbstractAgent
@dataclass @dataclass
class GroupChat: class GroupChat:
"""A group chat with multiple participants with a list of workers and a max number of rounds""" """A group chat with multiple participants with a list of workers and a max number of rounds"""
workers: List[Worker] workers: List[AbstractAgent]
messages: List[Dict] messages: List[Dict]
max_rounds: int = 10 max_rounds: int = 10
admin_name: str = "Admin" #admin worker admin_name: str = "Admin" #admin worker
@ -22,11 +22,11 @@ class GroupChat:
def reset(self): def reset(self):
self.messages.clear() self.messages.clear()
def worker_by_name(self, name: str) -> Worker: def worker_by_name(self, name: str) -> AbstractAgent:
"""Find the next speaker baed on the message""" """Find the next speaker baed on the message"""
return self.workers[self.worker_names.index(name)] return self.workers[self.worker_names.index(name)]
def next_worker(self, worker: Worker) -> Worker: def next_worker(self, worker: AbstractAgent) -> AbstractAgent:
"""Returns the next worker in the list""" """Returns the next worker in the list"""
return self.workers[ return self.workers[
(self.workers_names.index(worker.ai_name) + 1) % len(self.workers) (self.workers_names.index(worker.ai_name) + 1) % len(self.workers)
@ -45,8 +45,8 @@ class GroupChat:
def select_speaker( def select_speaker(
self, self,
last_speaker: Worker, last_speaker: AbstractAgent,
selector: Worker, selector: AbstractAgent,
): ):
"""Selects the next speaker""" """Selects the next speaker"""
selector.update_system_message(self.select_speaker_msg()) selector.update_system_message(self.select_speaker_msg())
@ -73,7 +73,7 @@ class GroupChat:
class GroupChatManager(Worker): class GroupChatManager(AbstractAgent):
def __init__( def __init__(
self, self,
groupchat: GroupChat, groupchat: GroupChat,
@ -92,7 +92,7 @@ class GroupChatManager(Worker):
**kwargs **kwargs
) )
self.register_reply( self.register_reply(
Worker, AbstractAgent,
GroupChatManager.run_chat, GroupChatManager.run_chat,
config=groupchat, config=groupchat,
reset_config=GroupChat.reset reset_config=GroupChat.reset
@ -101,7 +101,7 @@ class GroupChatManager(Worker):
def run( def run(
self, self,
messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None,
sender: Optional[Worker] = None, sender: Optional[AbstractAgent] = None,
config: Optional[GroupChat] = None, config: Optional[GroupChat] = None,
) -> Union[str, Dict, None]: ) -> Union[str, Dict, None]:
#run #run
@ -161,9 +161,9 @@ class GroupChatManager(Worker):
# model = GroupChatManager( # model = GroupChatManager(
# groupchat=GroupChat( # groupchat=GroupChat(
# workers=[ # workers=[
# Worker(name="A", system_message="I am worker A"), # AbstractAgent(name="A", system_message="I am worker A"),
# Worker(name="B", system_message="I am worker B"), # AbstractAgent(name="B", system_message="I am worker B"),
# Worker(name="C", system_message="I am worker C"), # AbstractAgent(name="C", system_message="I am worker C"),
# ] # ]
# ) # )
# ) # )

Loading…
Cancel
Save