diff --git a/dialogue_simulator.py b/dialogue_simulator.py index ec3ebcd3..f5a5401f 100644 --- a/dialogue_simulator.py +++ b/dialogue_simulator.py @@ -3,17 +3,13 @@ from swarms import DialogueSimulator, Worker worker1 = Worker(ai_name="Plinus", openai_api_key="") worker2 = Worker(ai_name="Optimus Prime", openai_api_key="") -collab = DialogueSimulator([worker1, worker2], DialogueSimulator.select_next_speaker) -collab.reset() -collab.inject(name=worker1.ai_name, "what is your name") - -# collab.start("My name is Plinus and I am a worker", "How are you?") - -max_iters = 6 -n = 0 - -while n < max_iters: - name, message = simulator.step() - print(f"({name}): {message}") - print("\n") - n += 1 \ No newline at end of file +collab = DialogueSimulator( + [worker1, worker2], + DialogueSimulator.select_next_speaker +) + +collab.run( + max_iters = 4, + name = "plinus", + message = "how can we enable multi agent collaboration", +) \ No newline at end of file diff --git a/swarms/swarms/dialogue_simulator.py b/swarms/swarms/dialogue_simulator.py index 9b7ea1f4..b216bb96 100644 --- a/swarms/swarms/dialogue_simulator.py +++ b/swarms/swarms/dialogue_simulator.py @@ -1,57 +1,24 @@ -from typing import List, Callable +from typing import List from swarms.workers.worker import Worker - class DialogueSimulator: - def __init__( - self, - agents: List[Worker], - selection_func: Callable[[int, List[Worker]], int], - ): + def __init__(self, agents: List[Worker]): self.agents = agents - self._step = 0 - self.select_next_speaker = selection_func - - def reset(self): - for agent in self.agents: - agent.reset() - - def start(self, name: str, message: str): - #init conv with a message from name - prompt = f"Name {name} and message: {message}" - - for agent in self.agents: - agent.run(prompt) - - #increment time - self._step += 1 - - def inject(self, name: str, message: str): - for agent in self.agents: - agent.receieve(name, message) - - self._step += 1 - - def step(self) -> tuple[str, str]: - #choose next speaker - speaker_idx = self.select_next_speaker( - self._step, - self.agents - ) - speaker = self.agents[speaker_idx] - - #2. next speaker ends message - message = speaker.run() - #everyone receives messages - for receiver in self.agents: - receiver.receive(speaker.name, message) - - #increment time - self._step += 1 + def run(self, max_iters: int, name: str = None, message: str = None): + step = 0 + if name and message: + prompt = f"Name {name} and message: {message}" + for agent in self.agents: + agent.run(prompt) + step += 1 - return speaker.name, message - - def select_next_speaker(step: int, agents) -> int: - idx = (step) % len(agents) - return idx \ No newline at end of file + while step < max_iters: + speaker_idx = step % len(self.agents) + speaker = self.agents[speaker_idx] + speaker_message = speaker.run() + for receiver in self.agents: + receiver.receive(speaker.name, speaker_message) + print(f"({speaker.name}): {speaker_message}") + print("\n") + step += 1 \ No newline at end of file