From f06198fd649e724a911e12036ed252b84dc58e0b Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 19 May 2024 08:17:14 -0400 Subject: [PATCH] [CLEANUP] --- example.py | 5 +- pyproject.toml | 2 +- swarms/models/fire_function.py | 89 -------------- swarms/models/llama3_hosted.py | 2 +- swarms/models/speecht5.py | 172 --------------------------- swarms/structs/__init__.py | 13 +- swarms/structs/async_workflow.py | 3 +- swarms/structs/auto_swarm.py | 30 ++++- swarms/structs/company.py | 4 +- swarms/structs/debate.py | 3 +- swarms/structs/groupchat.py | 3 +- swarms/structs/hiearchical_swarm.py | 12 +- swarms/structs/message_pool.py | 3 +- swarms/structs/model_parallizer.py | 10 +- swarms/structs/rearrange.py | 61 +--------- swarms/structs/recursive_workflow.py | 23 +++- swarms/structs/sermon_swarm.py | 7 +- swarms/structs/task.py | 3 +- 18 files changed, 90 insertions(+), 355 deletions(-) delete mode 100644 swarms/models/fire_function.py delete mode 100644 swarms/models/speecht5.py diff --git a/example.py b/example.py index 1887ce63..ee0461d2 100644 --- a/example.py +++ b/example.py @@ -1,4 +1,5 @@ -from swarms import Agent, OpenAIChat +from swarms import Agent +from swarms.models.llama3_hosted import llama3Hosted # Initialize the agent @@ -7,7 +8,7 @@ agent = Agent( agent_description=( "Generate a transcript for a youtube video on what swarms" " are!" ), - llm=OpenAIChat(), + llm=llama3Hosted(), max_loops="auto", autosave=True, dashboard=False, diff --git a/pyproject.toml b/pyproject.toml index eea38356..47664b22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "5.0.3" +version = "5.0.4" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/models/fire_function.py b/swarms/models/fire_function.py deleted file mode 100644 index 88381888..00000000 --- a/swarms/models/fire_function.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -from typing import Any - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from swarms.models.base_llm import BaseLLM - - -class FireFunctionCaller(BaseLLM): - """ - A class that represents a caller for the FireFunction model. - - Args: - model_name (str): The name of the model to be used. - device (str): The device to be used. - function_spec (Any): The specification of the function. - max_tokens (int): The maximum number of tokens. - system_prompt (str): The system prompt. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Methods: - run(self, task: str, *args, **kwargs) -> None: Run the function with the given task and arguments. - - Examples: - >>> fire_function_caller = FireFunctionCaller() - >>> fire_function_caller.run("Add 2 and 3") - """ - - def __init__( - self, - model_name: str = "fireworks-ai/firefunction-v1", - device: str = "cuda", - function_spec: Any = None, - max_tokens: int = 3000, - system_prompt: str = "You are a helpful assistant with access to functions. Use them if required.", - *args, - **kwargs, - ): - super().__init__(model_name, device) - self.model_name = model_name - self.device = device - self.fucntion_spec = function_spec - self.max_tokens = max_tokens - self.system_prompt = system_prompt - - self.model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="auto", *args, **kwargs - ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - - self.functions = json.dumps(function_spec, indent=4) - - def run(self, task: str, *args, **kwargs): - """ - Run the function with the given task and arguments. - - Args: - task (str): The task to be performed. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None - """ - messages = [ - {"role": "functions", "content": self.functions}, - { - "role": "system", - "content": self.system_prompt, - }, - { - "role": "user", - "content": task, - }, - ] - - model_inputs = self.tokenizer.apply_chat_template( - messages, return_tensors="pt" - ).to(self.model.device) - - generated_ids = self.model.generate( - model_inputs, - max_new_tokens=self.max_tokens, - *args, - **kwargs, - ) - decoded = self.tokenizer.batch_decode(generated_ids) - print(decoded[0]) diff --git a/swarms/models/llama3_hosted.py b/swarms/models/llama3_hosted.py index 9b6e0d6b..0cc9862e 100644 --- a/swarms/models/llama3_hosted.py +++ b/swarms/models/llama3_hosted.py @@ -1,6 +1,6 @@ import requests import json -from swarms import BaseLLM +from swarms.models.base_llm import BaseLLM class llama3Hosted(BaseLLM): diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py deleted file mode 100644 index 5cd9bc9e..00000000 --- a/swarms/models/speecht5.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -SpeechT5 (TTS task) -SpeechT5 model fine-tuned for speech synthesis (text-to-speech) on LibriTTS. - -This model was introduced in SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. - -SpeechT5 was first released in this repository, original weights. The license used is MIT. - -Model Description -Motivated by the success of T5 (Text-To-Text Transfer Transformer) in pre-trained natural language processing models, we propose a unified-modal SpeechT5 framework that explores the encoder-decoder pre-training for self-supervised speech/text representation learning. The SpeechT5 framework consists of a shared encoder-decoder network and six modal-specific (speech/text) pre/post-nets. After preprocessing the input speech/text through the pre-nets, the shared encoder-decoder network models the sequence-to-sequence transformation, and then the post-nets generate the output in the speech/text modality based on the output of the decoder. - -Leveraging large-scale unlabeled speech and text data, we pre-train SpeechT5 to learn a unified-modal representation, hoping to improve the modeling capability for both speech and text. To align the textual and speech information into this unified semantic space, we propose a cross-modal vector quantization approach that randomly mixes up speech/text states with latent units as the interface between encoder and decoder. - -Extensive evaluations show the superiority of the proposed SpeechT5 framework on a wide variety of spoken language processing tasks, including automatic speech recognition, speech synthesis, speech translation, voice conversion, speech enhancement, and speaker identification. - -Developed by: Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. -Shared by [optional]: Matthijs Hollemans -Model type: text-to-speech -Language(s) (NLP): [More Information Needed] -License: MIT -Finetuned from model [optional]: [More Information Needed] -Model Sources [optional] -Repository: [https://github.com/microsoft/SpeechT5/] -Paper: [https://arxiv.org/pdf/2110.07205.pdf] -Blog Post: [https://huggingface.co/blog/speecht5] -Demo: [https://huggingface.co/spaces/Matthijs/speecht5-tts-demo] - -""" - -import soundfile as sf -import torch -from datasets import load_dataset -from transformers import ( - SpeechT5ForTextToSpeech, - SpeechT5HifiGan, - SpeechT5Processor, - pipeline, -) - - -class SpeechT5: - """ - SpeechT5Wrapper - - - Args: - model_name (str, optional): Model name or path. Defaults to "microsoft/speecht5_tts". - vocoder_name (str, optional): Vocoder name or path. Defaults to "microsoft/speecht5_hifigan". - dataset_name (str, optional): Dataset name or path. Defaults to "Matthijs/cmu-arctic-xvectors". - - Attributes: - model_name (str): Model name or path. - vocoder_name (str): Vocoder name or path. - dataset_name (str): Dataset name or path. - processor (SpeechT5Processor): Processor for the SpeechT5 model. - model (SpeechT5ForTextToSpeech): SpeechT5 model. - vocoder (SpeechT5HifiGan): SpeechT5 vocoder. - embeddings_dataset (datasets.Dataset): Dataset containing speaker embeddings. - - Methods - __call__: Synthesize speech from text. - save_speech: Save speech to a file. - set_model: Change the model. - set_vocoder: Change the vocoder. - set_embeddings_dataset: Change the embeddings dataset. - get_sampling_rate: Get the sampling rate of the model. - print_model_details: Print details of the model. - quick_synthesize: Customize pipeline method for quick synthesis. - change_dataset_split: Change dataset split (train, validation, test). - load_custom_embedding: Load a custom speaker embedding (xvector) for the text. - - Usage: - >>> speechT5 = SpeechT5Wrapper() - >>> result = speechT5("Hello, how are you?") - >>> speechT5.save_speech(result) - >>> print("Speech saved successfully!") - - - - """ - - def __init__( - self, - model_name="microsoft/speecht5_tts", - vocoder_name="microsoft/speecht5_hifigan", - dataset_name="Matthijs/cmu-arctic-xvectors", - ): - self.model_name = model_name - self.vocoder_name = vocoder_name - self.dataset_name = dataset_name - self.processor = SpeechT5Processor.from_pretrained(self.model_name) - self.model = SpeechT5ForTextToSpeech.from_pretrained( - self.model_name - ) - self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - self.embeddings_dataset = load_dataset( - self.dataset_name, split="validation" - ) - - def __call__(self, text: str, speaker_id: float = 7306): - """Call the model on some text and return the speech.""" - speaker_embedding = torch.tensor( - self.embeddings_dataset[speaker_id]["xvector"] - ).unsqueeze(0) - inputs = self.processor(text=text, return_tensors="pt") - speech = self.model.generate_speech( - inputs["input_ids"], - speaker_embedding, - vocoder=self.vocoder, - ) - return speech - - def save_speech(self, speech, filename="speech.wav"): - """Save Speech to a file.""" - sf.write(filename, speech.numpy(), samplerate=16000) - - def set_model(self, model_name: str): - """Set the model to a new model.""" - self.model_name = model_name - self.processor = SpeechT5Processor.from_pretrained(self.model_name) - self.model = SpeechT5ForTextToSpeech.from_pretrained( - self.model_name - ) - - def set_vocoder(self, vocoder_name): - """Set the vocoder to a new vocoder.""" - self.vocoder_name = vocoder_name - self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - - def set_embeddings_dataset(self, dataset_name): - """Set the embeddings dataset to a new dataset.""" - self.dataset_name = dataset_name - self.embeddings_dataset = load_dataset( - self.dataset_name, split="validation" - ) - - # Feature 1: Get sampling rate - def get_sampling_rate(self): - """Get sampling rate of the model.""" - return 16000 - - # Feature 2: Print details of the model - def print_model_details(self): - """Print details of the model.""" - print(f"Model Name: {self.model_name}") - print(f"Vocoder Name: {self.vocoder_name}") - - # Feature 3: Customize pipeline method for quick synthesis - def quick_synthesize(self, text): - """Customize pipeline method for quick synthesis.""" - synthesiser = pipeline("text-to-speech", self.model_name) - speech = synthesiser(text) - return speech - - # Feature 4: Change dataset split (train, validation, test) - def change_dataset_split(self, split="train"): - """Change dataset split (train, validation, test).""" - self.embeddings_dataset = load_dataset( - self.dataset_name, split=split - ) - - # Feature 5: Load a custom speaker embedding (xvector) for the text - def load_custom_embedding(self, xvector): - """Load a custom speaker embedding (xvector) for the text.""" - return torch.tensor(xvector).unsqueeze(0) - - -# if __name__ == "__main__": -# speechT5 = SpeechT5Wrapper() -# result = speechT5("Hello, how are you?") -# speechT5.save_speech(result) -# print("Speech saved successfully!") diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index 138afe26..7106d8ce 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -12,6 +12,7 @@ from swarms.structs.block_wrapper import block from swarms.structs.concurrent_workflow import ConcurrentWorkflow from swarms.structs.conversation import Conversation from swarms.structs.groupchat import GroupChat, GroupChatManager +from swarms.structs.hiearchical_swarm import HiearchicalSwarm from swarms.structs.majority_voting import ( MajorityVoting, majority_voting, @@ -19,6 +20,7 @@ from swarms.structs.majority_voting import ( parse_code_completion, ) from swarms.structs.message import Message +from swarms.structs.message_pool import MessagePool from swarms.structs.model_parallizer import ModelParallelizer from swarms.structs.multi_agent_collab import MultiAgentCollaboration from swarms.structs.multi_process_workflow import ( @@ -28,7 +30,9 @@ from swarms.structs.multi_threaded_workflow import ( MultiThreadedWorkflow, ) from swarms.structs.plan import Plan +from swarms.structs.rearrange import AgentRearrange, rearrange from swarms.structs.recursive_workflow import RecursiveWorkflow +from swarms.structs.round_robin import RoundRobinSwarm from swarms.structs.schemas import ( Artifact, ArtifactUpload, @@ -75,16 +79,12 @@ from swarms.structs.utils import ( find_token_in_text, parse_tasks, ) -from swarms.structs.rearrange import AgentRearrange, rearrange - from swarms.structs.yaml_model import ( - get_type_name, + YamlModel, create_yaml_schema_from_dict, + get_type_name, pydantic_type_to_yaml_schema, - YamlModel, ) -from swarms.structs.message_pool import MessagePool -from swarms.structs.round_robin import RoundRobinSwarm __all__ = [ "Agent", @@ -158,4 +158,5 @@ __all__ = [ "MessagePool", "rearrange", "RoundRobinSwarm", + "HiearchicalSwarm", ] diff --git a/swarms/structs/async_workflow.py b/swarms/structs/async_workflow.py index 9ac9018a..b307bf61 100644 --- a/swarms/structs/async_workflow.py +++ b/swarms/structs/async_workflow.py @@ -5,10 +5,11 @@ from typing import Any, Callable, List, Optional from swarms.structs.agent import Agent from swarms.structs.task import Task from swarms.utils.logger import logger +from swarms.structs.base_swarm import BaseSwarm @dataclass -class AsyncWorkflow: +class AsyncWorkflow(BaseSwarm): """ Represents an asynchronous workflow to run tasks. diff --git a/swarms/structs/auto_swarm.py b/swarms/structs/auto_swarm.py index c8f05b08..c7061ba4 100644 --- a/swarms/structs/auto_swarm.py +++ b/swarms/structs/auto_swarm.py @@ -116,6 +116,7 @@ class AutoSwarm(BaseSwarm): custom_preprocess: Optional[Callable] = None, custom_postprocess: Optional[Callable] = None, custom_router: Optional[Callable] = None, + max_loops: int = 1, *args, **kwargs, ): @@ -126,6 +127,8 @@ class AutoSwarm(BaseSwarm): self.custom_params = custom_params self.custom_preprocess = custom_preprocess self.custom_postprocess = custom_postprocess + self.custom_router = custom_router + self.max_loops = max_loops self.router = AutoSwarmRouter( name=name, description=description, @@ -141,7 +144,32 @@ class AutoSwarm(BaseSwarm): def run(self, task: str = None, *args, **kwargs): """Run the swarm simulation.""" try: - return self.router.run(task, *args, **kwargs) + loop = 0 + + while loop < self.max_loops: + if self.custom_preprocess: + # If custom preprocess function is provided then run it + logger.info("Running custom preprocess function.") + task, args, kwargs = self.custom_preprocess( + task, args, kwargs + ) + + if self.custom_router: + # If custom router function is provided then use it to route the task + logger.info("Running custom router function.") + out = self.custom_router(self, task, *args, **kwargs) + + else: + out = self.router.run(task, *args, **kwargs) + + if self.custom_postprocess: + # If custom postprocess function is provided then run it + out = self.custom_postprocess(out) + + # LOOP + loop += 1 + + return out except Exception as e: logger.error( f"Error: {e} try optimizing the inputs and try again." diff --git a/swarms/structs/company.py b/swarms/structs/company.py index 3f304891..db18b857 100644 --- a/swarms/structs/company.py +++ b/swarms/structs/company.py @@ -4,10 +4,10 @@ from typing import Dict, List, Optional, Union from swarms.structs.agent import Agent from swarms.structs.conversation import Conversation from swarms.utils.logger import logger - +from swarms.structs.base_swarm import BaseSwarm @dataclass -class Company: +class Company(BaseSwarm): """ Represents a company with a hierarchical organizational structure. """ diff --git a/swarms/structs/debate.py b/swarms/structs/debate.py index 5a80265a..0e3a915a 100644 --- a/swarms/structs/debate.py +++ b/swarms/structs/debate.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import List from swarms.structs.agent import Agent +from swarms.structs.base_swarm import BaseSwarm NAME_LIST = [ "Affirmative side", @@ -26,7 +27,7 @@ class DebatePlayer(Agent): super().__init__(llm=llm, agent_name=name, *args, **kwargs) -class Debate: +class Debate(BaseSwarm): """Create a debate Args: diff --git a/swarms/structs/groupchat.py b/swarms/structs/groupchat.py index 501e15db..dbf4e78f 100644 --- a/swarms/structs/groupchat.py +++ b/swarms/structs/groupchat.py @@ -3,10 +3,11 @@ from typing import List from swarms.structs.conversation import Conversation from swarms.utils.loguru_logger import logger from swarms.structs.agent import Agent +from swarms.structs.base_swarm import BaseSwarm @dataclass -class GroupChat: +class GroupChat(BaseSwarm): """ A group chat class that contains a list of agents and the maximum number of rounds. diff --git a/swarms/structs/hiearchical_swarm.py b/swarms/structs/hiearchical_swarm.py index 3814f781..454e45a5 100644 --- a/swarms/structs/hiearchical_swarm.py +++ b/swarms/structs/hiearchical_swarm.py @@ -9,7 +9,6 @@ from swarms.utils.loguru_logger import logger class HiearchicalSwarm(BaseSwarm): - @beartype def __init__( self, @@ -25,16 +24,15 @@ class HiearchicalSwarm(BaseSwarm): self.agents = agents self.max_loops = max_loops self.long_term_memory_system = long_term_memory_system - + # Set the director to max_one loop - self.director.max_loops = 1 - + if self.director.max_loops > 1: + self.director.max_loops = 1 + # Set the long term memory system of every agent to long term memory system if long_term_memory_system is True: for agent in agents: agent.long_term_memory = long_term_memory_system - - def parse_function_activate_agent( self, json_data: str = None, *args, **kwargs @@ -115,7 +113,7 @@ class HiearchicalSwarm(BaseSwarm): """ try: loop = 0 - + # While the loop is less than max loops while loop < self.max_loops: # Run the director diff --git a/swarms/structs/message_pool.py b/swarms/structs/message_pool.py index 40010bee..c30251ae 100644 --- a/swarms/structs/message_pool.py +++ b/swarms/structs/message_pool.py @@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Sequence, Union from swarms.structs.agent import Agent from swarms.utils.loguru_logger import logger +from swarms.structs.base_swarm import BaseSwarm def _hash(input: str): @@ -42,7 +43,7 @@ def msg_hash( ) -class MessagePool: +class MessagePool(BaseSwarm): """ A class representing a message pool for agents in a swarm. diff --git a/swarms/structs/model_parallizer.py b/swarms/structs/model_parallizer.py index b3c75b09..76e2fe55 100644 --- a/swarms/structs/model_parallizer.py +++ b/swarms/structs/model_parallizer.py @@ -19,12 +19,12 @@ class ModelParallelizer: Args: llms (List[Callable]): A list of language models. retry_attempts (int): The number of retry attempts. - iters (int): The number of iterations to run the task. + max_loops (int): The number of iterations to run the task. Attributes: llms (List[Callable]): A list of language models. retry_attempts (int): The number of retry attempts. - iters (int): The number of iterations to run the task. + max_loops (int): The number of iterations to run the task. last_responses (List[str]): The last responses from the language models. task_history (List[str]): The task history. @@ -52,20 +52,20 @@ class ModelParallelizer: self, llms: List[Callable] = None, retry_attempts: int = 3, - iters: int = None, + max_loops: int = None, *args, **kwargs, ): self.llms = llms self.retry_attempts = retry_attempts - self.iters = iters + self.max_loops = max_loops self.last_responses = None self.task_history = [] def run(self, task: str): """Run the task string""" try: - for i in range(self.iters): + for i in range(self.max_loops): with ThreadPoolExecutor() as executor: responses = executor.map( lambda llm: llm(task), self.llms diff --git a/swarms/structs/rearrange.py b/swarms/structs/rearrange.py index e685d552..47c1d21c 100644 --- a/swarms/structs/rearrange.py +++ b/swarms/structs/rearrange.py @@ -1,10 +1,10 @@ -from typing import List +from typing import Callable, Dict, List, Optional + from swarms.memory.base_vectordb import BaseVectorDatabase +from swarms.structs.agent import Agent from swarms.structs.base_swarm import BaseSwarm -from swarms.utils.loguru_logger import logger -from typing import Optional, Callable, Dict from swarms.structs.omni_agent_types import Agent -from swarms.structs.agent import Agent +from swarms.utils.loguru_logger import logger class AgentRearrange(BaseSwarm): @@ -283,56 +283,3 @@ def rearrange( agents=agents, flow=flow, *args, **kwargs ) return agent_system.run(task, *args, **kwargs) - - -# # Initialize the director agent -# director = Agent( -# agent_name="Director", -# system_prompt="Directs the tasks for the workers", -# llm=Anthropic(), -# max_loops=1, -# dashboard=False, -# streaming_on=True, -# verbose=True, -# stopping_token="", -# state_save_file_type="json", -# saved_state_path="director.json", -# ) - -# # Initialize worker 1 -# worker1 = Agent( -# agent_name="Worker1", -# system_prompt="Generates a transcript for a youtube video on what swarms are", -# llm=Anthropic(), -# max_loops=1, -# dashboard=False, -# streaming_on=True, -# verbose=True, -# stopping_token="", -# state_save_file_type="json", -# saved_state_path="worker1.json", -# ) - -# # Initialize worker 2 -# worker2 = Agent( -# agent_name="Worker2", -# system_prompt="Summarizes the transcript generated by Worker1", -# llm=Anthropic(), -# max_loops=1, -# dashboard=False, -# streaming_on=True, -# verbose=True, -# stopping_token="", -# state_save_file_type="json", -# saved_state_path="worker2.json", -# ) - - -# flow = "Director -> Worker1 -> Worker2" -# agent_system = AgentRearrange( -# agents=[director, worker1, worker2], flow=flow -# ) -# # Run the system -# output = agent_system.run( -# "Create a format to express and communicate swarms of llms in a structured manner for youtube" -# ) diff --git a/swarms/structs/recursive_workflow.py b/swarms/structs/recursive_workflow.py index 60c471a5..4f0eb400 100644 --- a/swarms/structs/recursive_workflow.py +++ b/swarms/structs/recursive_workflow.py @@ -34,6 +34,9 @@ class RecursiveWorkflow(BaseStructure): self, stop_token: str = "", stopping_conditions: callable = None, + max_loops: int = 1, + *args, + **kwargs, ): self.stop_token = stop_token self.stopping_conditions = stopping_conditions @@ -75,12 +78,20 @@ class RecursiveWorkflow(BaseStructure): None """ try: - for task in self.task_pool: - while True: - result = task.run() - if result is not None and self.stop_token in result: - break - print(f"{result}") + loop = 0 + while loop < self.max_loops: + for task in self.task_pool: + while True: + result = task.run() + if ( + result is not None + and self.stop_token in result + ): + break + print(f"{result}") + loop += 1 + + return result except Exception as error: logger.warning(f"[ERROR][RecursiveWorkflow] {error}") raise error diff --git a/swarms/structs/sermon_swarm.py b/swarms/structs/sermon_swarm.py index 59522b9a..fad468d4 100644 --- a/swarms/structs/sermon_swarm.py +++ b/swarms/structs/sermon_swarm.py @@ -50,7 +50,9 @@ class SermonSwarm(BaseSwarm): agent.add_message_to_memory(sermon) # Then run the agents - for _ in range(self.max_loops): + loop = 0 + # for _ in range(self.max_loops): + while loop < self.max_loops: for agent in self.agents: preach = agent.run(task, *args, **kwargs) @@ -63,3 +65,6 @@ class SermonSwarm(BaseSwarm): elif self.stop_condition in preach: break + + loop += 1 + return preach diff --git a/swarms/structs/task.py b/swarms/structs/task.py index bcdf42c6..e146ce87 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Union from swarms.structs.agent import Agent from swarms.structs.conversation import Conversation from swarms.utils.logger import logger +from swarms.structs.omni_agent_types import AgentType @dataclass @@ -51,7 +52,7 @@ class Task: """ - agent: Union[Callable, Agent] + agent: Union[Callable, Agent, AgentType] = None description: str = None result: Any = None history: List[Any] = field(default_factory=list)