From 31856154bf5206ff216932fa7e21cd7efe9f3da6 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 25 Jul 2023 17:46:08 -0400 Subject: [PATCH] clean up --- swarms/swarms.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/swarms/swarms.py b/swarms/swarms.py index 38698c05..43a61df8 100644 --- a/swarms/swarms.py +++ b/swarms/swarms.py @@ -21,12 +21,16 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %( # TODO: Add RLHF Data collection, ask user how the swarm is performing class HierarchicalSwarm: - def __init__(self, openai_api_key="", use_vectorstore=True, use_async=True, human_in_the_loop=True): + def __init__(self, model_id: str = None, openai_api_key="", use_vectorstore=True, use_async=True, human_in_the_loop=True, model_type: str = None): #openai_api_key: the openai key. Default is empty + if not model_id: + logging.error("Model ID is not provided") + raise ValueError("Model ID is required") if not openai_api_key: logging.error("OpenAI key is not provided") raise ValueError("OpenAI API key is required") + self.model_id = model_id self.openai_api_key = openai_api_key self.use_vectorstore = use_vectorstore self.use_async = use_async @@ -44,6 +48,8 @@ class HierarchicalSwarm: # Initialize language model if self.llm_class == OpenAI: return llm_class(openai_api_key=self.openai_api_key, temperature=temperature) + elif self.model_type == "huggingface": + return HuggingFaceLLM(model_id=self.model_id, temperature=temperature) else: return self.llm_class(model_id="gpt-2", temperature=temperature) except Exception as e: @@ -199,7 +205,7 @@ class HierarchicalSwarm: return None # usage-# usage- -def swarm(api_key="", objective=""): +def swarm(api_key="", objective="", model_type=""): """ Run the swarm with the given API key and objective. @@ -218,7 +224,7 @@ def swarm(api_key="", objective=""): logging.error("Invalid objective") raise ValueError("A valid objective is required") try: - swarms = HierarchicalSwarm(api_key, use_async=False) # Turn off async + swarms = HierarchicalSwarm(api_key, use_async=False, model_type=model_type) # Turn off async result = swarms.run(objective) if result is None: logging.error("Failed to run swarms")