diff --git a/swarms/swarms.py b/swarms/swarms.py index 9270e42a..2be56179 100644 --- a/swarms/swarms.py +++ b/swarms/swarms.py @@ -20,32 +20,25 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %( # TODO: Off class HierarchicalSwarm: - def __init__(self, - model_id: str = None, - openai_api_key="", - use_vectorstore=True, - embedding_size: int = None, - use_async=True, - human_in_the_loop=True, - model_type: str = None, - boss_prompt: str = None, - worker_prompt:str = None, - temperature=None, - max_iterations=None, - log_level: str = 'INFO' - ): - #openai_api_key: the openai key. Default is empty - if not model_id: - logging.error("Model ID is not provided") - raise ValueError("Model ID is required") - if not openai_api_key: - logging.error("OpenAI key is not provided") - raise ValueError("OpenAI API key is required") + def __init__( + self, + model_id: Optional[str] = None, + openai_api_key: Optional[str] = "", + use_vectorstore: Optional[bool] = True, + embedding_size: Optional[int] = None, + use_async: Optional[bool] = True, + human_in_the_loop: Optional[bool] = True, + model_type: Optional[str] = None, + boss_prompt: Optional[str] = None, + worker_prompt: Optional[str] = None, + temperature: Optional[float] = None, + max_iterations: Optional[int] = None, + logging_enabled: Optional[bool] = True): self.model_id = model_id self.openai_api_key = openai_api_key self.use_vectorstore = use_vectorstore - + self.use_async = use_async self.human_in_the_loop = human_in_the_loop self.model_type = model_type @@ -56,10 +49,15 @@ class HierarchicalSwarm: self.temperature = temperature self.max_iterations = max_iterations + self.logging_enabled = logging_enabled + self.logger = logging.getLogger() + if not logging_enabled: + self.logger.disabled = True - def initialize_llm(self, llm_class): + + def initialize_llm(self, llm_class: Optional = None): """ Init LLM @@ -219,7 +217,12 @@ class HierarchicalSwarm: return None # usage-# usage- -def swarm(api_key="", objective="", model_type="", model_id=""): +def swarm( + api_key: Optional[str]="", + objective: Optional[str]="", + model_type: Optional[str]="", + model_id: Optional[str]="" + ): """ Run the swarm with the given API key and objective. @@ -238,7 +241,7 @@ def swarm(api_key="", objective="", model_type="", model_id=""): logging.error("Invalid objective") raise ValueError("A valid objective is required") try: - swarms = HierarchicalSwarm(api_key, model_id=model_type, use_async=False, model_type=model_type) # Turn off async + swarms = HierarchicalSwarm(api_key, model_id=model_type, use_async=False, model_type=model_type, logging_enabled=logging_enabled) # Turn off async result = swarms.run(objective) if result is None: logging.error("Failed to run swarms")