diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..1ce589ae --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +# Use an official Python runtime as a parent image +FROM python:3.9-slim + +# Set environment variables to make Python output unbuffered and disable the PIP cache +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 +ENV PIP_NO_CACHE_DIR off +ENV PIP_DISABLE_PIP_VERSION_CHECK on +ENV PIP_DEFAULT_TIMEOUT 100 + +# Set the working directory in the container +WORKDIR /usr/src/app + +# Copy the current directory contents into the container at /usr/src/app +COPY . . + +# Install Poetry +RUN pip install poetry + +# Disable virtualenv creation by poetry and install dependencies +RUN poetry config virtualenvs.create false +RUN poetry install --no-interaction --no-ansi + +# Install the 'swarms' package if it's not included in the poetry.lock +RUN pip install swarms + +# Assuming tests require pytest to run +RUN pip install pytest + +# Run pytest on all tests in the tests directory +CMD find ./tests -name '*.py' -exec pytest {} + diff --git a/demos/ui_software_demo.py b/demos/ui_software_demo.py index 6271d96e..d322f71b 100644 --- a/demos/ui_software_demo.py +++ b/demos/ui_software_demo.py @@ -2,4 +2,4 @@ Autonomous swarm that optimizes UI autonomously GPT4Vision ->> GPT4 ->> UI -""" \ No newline at end of file +""" diff --git a/swarms/__init__.py b/swarms/__init__.py index 71481e16..f45f876f 100644 --- a/swarms/__init__.py +++ b/swarms/__init__.py @@ -9,6 +9,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from swarms.agents import * from swarms.swarms import * from swarms.structs import * -from swarms.models import * +from swarms.models import * from swarms.chunkers import * from swarms.workers import * diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py index cd3aa221..52afb476 100644 --- a/swarms/agents/__init__.py +++ b/swarms/agents/__init__.py @@ -8,6 +8,7 @@ from swarms.agents.registry import Registry # from swarms.agents.idea_to_image_agent import Idea2Image from swarms.agents.simple_agent import SimpleAgent + """Agent Infrastructure, models, memory, utils, tools""" __all__ = [ diff --git a/swarms/agents/agent.py b/swarms/agents/agent.py index c16dd780..bad9d3bb 100644 --- a/swarms/agents/agent.py +++ b/swarms/agents/agent.py @@ -8,7 +8,8 @@ from langchain.chains.llm import LLMChain from langchain.chat_models.base import BaseChatModel from langchain.memory import ChatMessageHistory from langchain.prompts.chat import ( - BaseChatPromptTemplate,) + BaseChatPromptTemplate, +) from langchain.schema import ( BaseChatMessageHistory, Document, @@ -70,12 +71,14 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] send_token_limit: int = 4196 def construct_full_prompt(self, goals: List[str]) -> str: - prompt_start = ("Your decisions must always be made independently " - "without seeking user assistance.\n" - "Play to your strengths as an LLM and pursue simple " - "strategies with no legal complications.\n" - "If you have completed all your tasks, make sure to " - 'use the "finish" command.') + prompt_start = ( + "Your decisions must always be made independently " + "without seeking user assistance.\n" + "Play to your strengths as an LLM and pursue simple " + "strategies with no legal complications.\n" + "If you have completed all your tasks, make sure to " + 'use the "finish" command.' + ) # Construct full prompt full_prompt = ( f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" @@ -87,23 +90,25 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] return full_prompt def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - base_prompt = SystemMessage( - content=self.construct_full_prompt(kwargs["goals"])) + base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"])) time_prompt = SystemMessage( - content=f"The current time and date is {time.strftime('%c')}") - used_tokens = self.token_counter( - base_prompt.content) + self.token_counter(time_prompt.content) + content=f"The current time and date is {time.strftime('%c')}" + ) + used_tokens = self.token_counter(base_prompt.content) + self.token_counter( + time_prompt.content + ) memory: VectorStoreRetriever = kwargs["memory"] previous_messages = kwargs["messages"] - relevant_docs = memory.get_relevant_documents( - str(previous_messages[-10:])) + relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) relevant_memory = [d.page_content for d in relevant_docs] relevant_memory_tokens = sum( - [self.token_counter(doc) for doc in relevant_memory]) + [self.token_counter(doc) for doc in relevant_memory] + ) while used_tokens + relevant_memory_tokens > 2500: relevant_memory = relevant_memory[:-1] relevant_memory_tokens = sum( - [self.token_counter(doc) for doc in relevant_memory]) + [self.token_counter(doc) for doc in relevant_memory] + ) content_format = ( f"This reminds you of these events from your past:\n{relevant_memory}\n\n" ) @@ -141,23 +146,13 @@ class PromptGenerator: self.performance_evaluation: List[str] = [] self.response_format = { "thoughts": { - "text": - "thought", - "reasoning": - "reasoning", - "plan": - "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": - "constructive self-criticism", - "speak": - "thoughts summary to say to user", - }, - "command": { - "name": "command name", - "args": { - "arg name": "value" - } + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", }, + "command": {"name": "command name", "args": {"arg name": "value"}}, } def add_constraint(self, constraint: str) -> None: @@ -195,9 +190,7 @@ class PromptGenerator: """ self.performance_evaluation.append(evaluation) - def _generate_numbered_list(self, - items: list, - item_type: str = "list") -> str: + def _generate_numbered_list(self, items: list, item_type: str = "list") -> str: """ Generate a numbered list from given items based on the item_type. @@ -215,11 +208,16 @@ class PromptGenerator: for i, item in enumerate(items) ] finish_description = ( - "use this to signal that you have finished all your objectives") - finish_args = ('"response": "final response to let ' - 'people know you have finished your objectives"') - finish_string = (f"{len(items) + 1}. {FINISH_NAME}: " - f"{finish_description}, args: {finish_args}") + "use this to signal that you have finished all your objectives" + ) + finish_args = ( + '"response": "final response to let ' + 'people know you have finished your objectives"' + ) + finish_string = ( + f"{len(items) + 1}. {FINISH_NAME}: " + f"{finish_description}, args: {finish_args}" + ) return "\n".join(command_strings + [finish_string]) else: return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) @@ -240,7 +238,8 @@ class PromptGenerator: f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" "You should only respond in JSON format as described below " f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads") + "\nEnsure the response can be parsed by Python json.loads" + ) return prompt_string @@ -261,11 +260,13 @@ def get_prompt(tools: List[BaseTool]) -> str: prompt_generator.add_constraint( "~16000 word limit for short term memory. " "Your short term memory is short, " - "so immediately save important information to files.") + "so immediately save important information to files." + ) prompt_generator.add_constraint( "If you are unsure how you previously did something " "or want to recall past events, " - "thinking about similar events will help you remember.") + "thinking about similar events will help you remember." + ) prompt_generator.add_constraint("No user assistance") prompt_generator.add_constraint( 'Exclusively use the commands listed in double quotes e.g. "command name"' @@ -277,23 +278,29 @@ def get_prompt(tools: List[BaseTool]) -> str: # Add resources to the PromptGenerator object prompt_generator.add_resource( - "Internet access for searches and information gathering.") + "Internet access for searches and information gathering." + ) prompt_generator.add_resource("Long Term memory management.") prompt_generator.add_resource( - "GPT-3.5 powered Agents for delegation of simple tasks.") + "GPT-3.5 powered Agents for delegation of simple tasks." + ) prompt_generator.add_resource("File output.") # Add performance evaluations to the PromptGenerator object prompt_generator.add_performance_evaluation( "Continuously review and analyze your actions " - "to ensure you are performing to the best of your abilities.") + "to ensure you are performing to the best of your abilities." + ) prompt_generator.add_performance_evaluation( - "Constructively self-criticize your big-picture behavior constantly.") + "Constructively self-criticize your big-picture behavior constantly." + ) prompt_generator.add_performance_evaluation( - "Reflect on past decisions and strategies to refine your approach.") + "Reflect on past decisions and strategies to refine your approach." + ) prompt_generator.add_performance_evaluation( "Every command has a cost, so be smart and efficient. " - "Aim to complete tasks in the least number of steps.") + "Aim to complete tasks in the least number of steps." + ) # Generate the prompt string prompt_string = prompt_generator.generate_prompt_string() @@ -364,8 +371,10 @@ class AutoGPT: ) def run(self, goals: List[str]) -> str: - user_input = ("Determine which next command to use, " - "and respond using the format specified above:") + user_input = ( + "Determine which next command to use, " + "and respond using the format specified above:" + ) # Interaction Loop loop_count = 0 while True: @@ -382,10 +391,8 @@ class AutoGPT: # Print Assistant thoughts print(assistant_reply) - self.chat_history_memory.add_message( - HumanMessage(content=user_input)) - self.chat_history_memory.add_message( - AIMessage(content=assistant_reply)) + self.chat_history_memory.add_message(HumanMessage(content=user_input)) + self.chat_history_memory.add_message(AIMessage(content=assistant_reply)) # Get command name and arguments action = self.output_parser.parse(assistant_reply) @@ -411,7 +418,8 @@ class AutoGPT: result = ( f"Unknown command '{action.name}'. " "Please refer to the 'COMMANDS' list for available " - "commands and only respond in the specified JSON format.") + "commands and only respond in the specified JSON format." + ) memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} " if self.feedback_tool is not None: diff --git a/swarms/agents/aot.py b/swarms/agents/aot.py index 123f5591..b36fb43c 100644 --- a/swarms/agents/aot.py +++ b/swarms/agents/aot.py @@ -4,13 +4,13 @@ import time import openai_model -logging.basicConfig(level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) class OpenAI: - def __init__( self, api_key, @@ -68,13 +68,16 @@ class OpenAI: temperature=temperature, ) with open("openai.logs", "a") as log_file: - log_file.write("\n" + "-----------" + "\n" + "Prompt : " + - prompt + "\n") + log_file.write( + "\n" + "-----------" + "\n" + "Prompt : " + prompt + "\n" + ) return response except openai_model.error.RateLimitError as e: sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30) - print(f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" - " OPENAI_RATE_TIMEOUT") + print( + f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" + " OPENAI_RATE_TIMEOUT" + ) time.sleep(sleep_duratoin) def openai_choice2text_handler(self, choice): @@ -97,16 +100,11 @@ class OpenAI: else: response = self.run(prompt, 300, 0.5, k) thoughts = [ - self.openai_choice2text_handler(choice) - for choice in response.choices + self.openai_choice2text_handler(choice) for choice in response.choices ] return thoughts - def generate_thoughts(self, - state, - k, - initial_prompt, - rejected_solutions=None): + def generate_thoughts(self, state, k, initial_prompt, rejected_solutions=None): if isinstance(state, str): pass else: @@ -179,8 +177,7 @@ class OpenAI: """ response = self.run(prompt, 10, 1) try: - value_text = self.openai_choice2text_handler( - response.choices[0]) + value_text = self.openai_choice2text_handler(response.choices[0]) # print(f'state: {value_text}') value = float(value_text) print(f"Evaluated Thought Value: {value}") @@ -190,12 +187,10 @@ class OpenAI: return state_values else: - raise ValueError( - "Invalid evaluation strategy. Choose 'value' or 'vote'.") + raise ValueError("Invalid evaluation strategy. Choose 'value' or 'vote'.") class AoTAgent: - def __init__( self, num_thoughts: int = None, @@ -227,8 +222,7 @@ class AoTAgent: return None best_state, _ = max(self.output, key=lambda x: x[1]) - solution = self.model.generate_solution(self.initial_prompt, - best_state) + solution = self.model.generate_solution(self.initial_prompt, best_state) print(f"Solution is {solution}") return solution if solution else best_state except Exception as error: @@ -245,8 +239,11 @@ class AoTAgent: for next_state in thoughts: state_value = self.evaluated_thoughts[next_state] if state_value > self.value_threshold: - child = ((state, next_state) if isinstance(state, str) else - (*state, next_state)) + child = ( + (state, next_state) + if isinstance(state, str) + else (*state, next_state) + ) self.dfs(child, step + 1) # backtracking @@ -256,14 +253,17 @@ class AoTAgent: continue def generate_and_filter_thoughts(self, state): - thoughts = self.model.generate_thoughts(state, self.num_thoughts, - self.initial_prompt) + thoughts = self.model.generate_thoughts( + state, self.num_thoughts, self.initial_prompt + ) self.evaluated_thoughts = self.model.evaluate_states( - thoughts, self.initial_prompt) + thoughts, self.initial_prompt + ) filtered_thoughts = [ - thought for thought in thoughts + thought + for thought in thoughts if self.evaluated_thoughts[thought] >= self.pruning_threshold ] print(f"filtered_thoughts: {filtered_thoughts}") diff --git a/swarms/agents/browser_agent.py b/swarms/agents/browser_agent.py index 3a274468..02c4ef0d 100644 --- a/swarms/agents/browser_agent.py +++ b/swarms/agents/browser_agent.py @@ -38,8 +38,7 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None): if not os.path.exists("agents"): os.makedirs("agents") - if os.path.exists( - f"agents/{agent_name}.py") and config.environment != "local": + if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local": if not _is_blank_agent(agent_name=agent_name): raise Exception(f"Agent with name {agent_name} already exists") driver = get_driver( # noqa: F841 @@ -55,10 +54,12 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None): print( "\033[34mYou have the Python debugger open, you can run commands in it like you" - " would in a normal Python shell.\033[0m") + " would in a normal Python shell.\033[0m" + ) print( "\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and" - " press enter.\033[0m") + " press enter.\033[0m" + ) breakpoint() @@ -78,13 +79,12 @@ def extract_domain_from_url(url: str): class AutotabChromeDriver(uc.Chrome): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def find_element_with_retry(self, - by=By.ID, - value: Optional[str] = None) -> WebElement: + def find_element_with_retry( + self, by=By.ID, value: Optional[str] = None + ) -> WebElement: try: return super().find_element(by, value) except Exception as e: @@ -102,8 +102,11 @@ def open_plugin(driver: AutotabChromeDriver): def open_plugin_and_login(driver: AutotabChromeDriver): if config.autotab_api_key is not None: - backend_url = ("http://localhost:8000" if config.environment == "local" - else "https://api.autotab.com") + backend_url = ( + "http://localhost:8000" + if config.environment == "local" + else "https://api.autotab.com" + ) driver.get(f"{backend_url}/auth/signin-api-key-page") response = requests.post( f"{backend_url}/auth/signin-api-key", @@ -116,7 +119,8 @@ def open_plugin_and_login(driver: AutotabChromeDriver): else: raise Exception( f"Error {response.status_code} from backend while logging you in" - f" with your API key: {response.text}") + f" with your API key: {response.text}" + ) cookie["name"] = cookie["key"] del cookie["key"] driver.add_cookie(cookie) @@ -126,21 +130,26 @@ def open_plugin_and_login(driver: AutotabChromeDriver): else: print("No autotab API key found, heading to autotab.com to sign up") - url = ("http://localhost:3000/dashboard" if config.environment - == "local" else "https://autotab.com/dashboard") + url = ( + "http://localhost:3000/dashboard" + if config.environment == "local" + else "https://autotab.com/dashboard" + ) driver.get(url) time.sleep(0.5) open_plugin(driver) -def get_driver(autotab_ext_path: Optional[str] = None, - record_mode: bool = False) -> AutotabChromeDriver: +def get_driver( + autotab_ext_path: Optional[str] = None, record_mode: bool = False +) -> AutotabChromeDriver: options = webdriver.ChromeOptions() options.add_argument("--no-sandbox") # Necessary for running options.add_argument( "--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" - " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36") + " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + ) options.add_argument("--enable-webgl") options.add_argument("--enable-3d-apis") options.add_argument("--enable-clipboard-read-write") @@ -229,8 +238,7 @@ class Config(BaseModel): return cls( autotab_api_key=autotab_api_key, credentials=_credentials, - google_credentials=GoogleCredentials( - credentials=google_credentials), + google_credentials=GoogleCredentials(credentials=google_credentials), chrome_binary_location=config.get("chrome_binary_location"), environment=config.get("environment", "prod"), ) @@ -248,9 +256,9 @@ def is_signed_in_to_google(driver): return len([c for c in cookies if c["name"] == "SAPISID"]) != 0 -def google_login(driver, - credentials: Optional[SiteCredentials] = None, - navigate: bool = True): +def google_login( + driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True +): print("Logging in to Google") if navigate: driver.get("https://accounts.google.com/") @@ -282,7 +290,8 @@ def google_login(driver, email_input.send_keys(credentials.email) email_input.send_keys(Keys.ENTER) WebDriverWait(driver, 10).until( - EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))) + EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']")) + ) password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input.send_keys(credentials.password) @@ -305,20 +314,21 @@ def google_login(driver, cookies = driver.get_cookies() cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"] google_cookies = [ - cookie for cookie in cookies - if cookie["domain"] in [".google.com", "accounts.google.com"] and - cookie["name"] in cookie_names + cookie + for cookie in cookies + if cookie["domain"] in [".google.com", "accounts.google.com"] + and cookie["name"] in cookie_names ] with open("google_cookies.json", "w") as f: json.dump(google_cookies, f) # Log back in login_button = driver.find_element( - By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']") + By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']" + ) login_button.click() time.sleep(1) - password_input = driver.find_element(By.CSS_SELECTOR, - "[type='password']") + password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input.send_keys(credentials.password) password_input.send_keys(Keys.ENTER) @@ -333,7 +343,8 @@ def login(driver, url: str): login_url = credentials.login_url if credentials.login_with_google_account: google_credentials = config.google_credentials.credentials[ - credentials.login_with_google_account] + credentials.login_with_google_account + ] _login_with_google(driver, login_url, google_credentials) else: _login(driver, login_url, credentials=credentials) @@ -360,15 +371,16 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials): driver.get(url) WebDriverWait(driver, 10).until( - EC.presence_of_element_located((By.TAG_NAME, "body"))) + EC.presence_of_element_located((By.TAG_NAME, "body")) + ) main_window = driver.current_window_handle xpath = ( "//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with" - " Google') or contains(@title, 'Sign in with Google')]") + " Google') or contains(@title, 'Sign in with Google')]" + ) - WebDriverWait(driver, - 10).until(EC.presence_of_element_located((By.XPATH, xpath))) + WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath))) driver.find_element( By.XPATH, xpath, @@ -376,8 +388,8 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials): driver.switch_to.window(driver.window_handles[-1]) driver.find_element( - By.XPATH, - f"//*[contains(text(), '{google_credentials.email}')]").click() + By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]" + ).click() driver.switch_to.window(main_window) @@ -430,11 +442,8 @@ def should_update(): # Parse the XML file root = ET.fromstring(xml_content) - namespaces = { - "ns": "http://www.google.com/update2/response" - } # add namespaces - xml_version = root.find(".//ns:app/ns:updatecheck", - namespaces).get("version") + namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces + xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version") # Load the local JSON file with open("src/extension/autotab/manifest.json", "r") as f: diff --git a/swarms/agents/hf_agents.py b/swarms/agents/hf_agents.py index e13d3462..4e186e3a 100644 --- a/swarms/agents/hf_agents.py +++ b/swarms/agents/hf_agents.py @@ -56,24 +56,23 @@ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ def get_remote_tools(organization="huggingface-tools"): if is_offline_mode(): - logger.info( - "You are in offline mode, so remote tools are not available.") + logger.info("You are in offline mode, so remote tools are not available.") return {} spaces = list_spaces(author=organization) tools = {} for space_info in spaces: repo_id = space_info.id - resolved_config_file = hf_hub_download(repo_id, - TOOL_CONFIG_FILE, - repo_type="space") + resolved_config_file = hf_hub_download( + repo_id, TOOL_CONFIG_FILE, repo_type="space" + ) with open(resolved_config_file, encoding="utf-8") as reader: config = json.load(reader) task = repo_id.split("/")[-1] - tools[config["name"]] = PreTool(task=task, - description=config["description"], - repo_id=repo_id) + tools[config["name"]] = PreTool( + task=task, description=config["description"], repo_id=repo_id + ) return tools @@ -93,7 +92,8 @@ def _setup_default_tools(): tool_class = getattr(tools_module, tool_class_name) description = tool_class.description HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool( - task=task_name, description=description, repo_id=None) + task=task_name, description=description, repo_id=None + ) if not is_offline_mode(): for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: @@ -197,19 +197,18 @@ class Agent: one of the default tools, that default tool will be overridden. """ - def __init__(self, - chat_prompt_template=None, - run_prompt_template=None, - additional_tools=None): + def __init__( + self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None + ): _setup_default_tools() agent_name = self.__class__.__name__ - self.chat_prompt_template = download_prompt(chat_prompt_template, - agent_name, - mode="chat") - self.run_prompt_template = download_prompt(run_prompt_template, - agent_name, - mode="run") + self.chat_prompt_template = download_prompt( + chat_prompt_template, agent_name, mode="chat" + ) + self.run_prompt_template = download_prompt( + run_prompt_template, agent_name, mode="run" + ) self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() self.log = print if additional_tools is not None: @@ -225,16 +224,17 @@ class Agent: } self._toolbox.update(additional_tools) if len(replacements) > 1: - names = "\n".join( - [f"- {n}: {t}" for n, t in replacements.items()]) + names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()]) logger.warning( "The following tools have been replaced by the ones provided in" - f" `additional_tools`:\n{names}.") + f" `additional_tools`:\n{names}." + ) elif len(replacements) == 1: name = list(replacements.keys())[0] logger.warning( f"{name} has been replaced by {replacements[name]} as provided in" - " `additional_tools`.") + " `additional_tools`." + ) self.prepare_for_new_chat() @@ -244,20 +244,17 @@ class Agent: return self._toolbox def format_prompt(self, task, chat_mode=False): - description = "\n".join([ - f"- {name}: {tool.description}" - for name, tool in self.toolbox.items() - ]) + description = "\n".join( + [f"- {name}: {tool.description}" for name, tool in self.toolbox.items()] + ) if chat_mode: if self.chat_history is None: - prompt = self.chat_prompt_template.replace( - "<>", description) + prompt = self.chat_prompt_template.replace("<>", description) else: prompt = self.chat_history prompt += CHAT_MESSAGE_PROMPT.replace("<>", task) else: - prompt = self.run_prompt_template.replace("<>", - description) + prompt = self.run_prompt_template.replace("<>", description) prompt = prompt.replace("<>", task) return prompt @@ -306,19 +303,14 @@ class Agent: if not return_code: self.log("\n\n==Result==") self.cached_tools = resolve_tools( - code, - self.toolbox, - remote=remote, - cached_tools=self.cached_tools) + code, self.toolbox, remote=remote, cached_tools=self.cached_tools + ) self.chat_state.update(kwargs) - return evaluate(code, - self.cached_tools, - self.chat_state, - chat_mode=True) + return evaluate( + code, self.cached_tools, self.chat_state, chat_mode=True + ) else: - tool_code = get_tool_creation_code(code, - self.toolbox, - remote=remote) + tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) return f"{tool_code}\n{code}" def prepare_for_new_chat(self): @@ -360,15 +352,12 @@ class Agent: self.log(f"\n\n==Code generated by the agent==\n{code}") if not return_code: self.log("\n\n==Result==") - self.cached_tools = resolve_tools(code, - self.toolbox, - remote=remote, - cached_tools=self.cached_tools) + self.cached_tools = resolve_tools( + code, self.toolbox, remote=remote, cached_tools=self.cached_tools + ) return evaluate(code, self.cached_tools, state=kwargs.copy()) else: - tool_code = get_tool_creation_code(code, - self.toolbox, - remote=remote) + tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) return f"{tool_code}\n{code}" def generate_one(self, prompt, stop): @@ -428,7 +417,8 @@ class HFAgent(Agent): ): if not is_openai_available(): raise ImportError( - "Using `OpenAiAgent` requires `openai`: `pip install openai`.") + "Using `OpenAiAgent` requires `openai`: `pip install openai`." + ) if api_key is None: api_key = os.environ.get("OPENAI_API_KEY", None) @@ -436,7 +426,8 @@ class HFAgent(Agent): raise ValueError( "You need an openai key to use `OpenAIAgent`. You can get one here: Get" " one here https://openai.com/api/`. If you have one, set it in your" - " env with `os.environ['OPENAI_API_KEY'] = xxx.") + " env with `os.environ['OPENAI_API_KEY'] = xxx." + ) else: openai.api_key = api_key self.model = model @@ -461,10 +452,7 @@ class HFAgent(Agent): def _chat_generate(self, prompt, stop): result = openai.ChatCompletion.create( model=self.model, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], temperature=0, stop=stop, ) @@ -542,7 +530,8 @@ class AzureOpenAI(Agent): ): if not is_openai_available(): raise ImportError( - "Using `OpenAiAgent` requires `openai`: `pip install openai`.") + "Using `OpenAiAgent` requires `openai`: `pip install openai`." + ) self.deployment_id = deployment_id openai.api_type = "azure" @@ -552,7 +541,8 @@ class AzureOpenAI(Agent): raise ValueError( "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have" " one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] =" - " xxx.") + " xxx." + ) else: openai.api_key = api_key if resource_name is None: @@ -561,7 +551,8 @@ class AzureOpenAI(Agent): raise ValueError( "You need a resource_name to use `AzureOpenAIAgent`. If you have one," " set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] =" - " xxx.") + " xxx." + ) else: openai.api_base = f"https://{resource_name}.openai.azure.com" openai.api_version = api_version @@ -591,10 +582,7 @@ class AzureOpenAI(Agent): def _chat_generate(self, prompt, stop): result = openai.ChatCompletion.create( engine=self.deployment_id, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], temperature=0, stop=stop, ) diff --git a/swarms/agents/meta_prompter.py b/swarms/agents/meta_prompter.py index f744e38e..aeee9878 100644 --- a/swarms/agents/meta_prompter.py +++ b/swarms/agents/meta_prompter.py @@ -88,8 +88,9 @@ class MetaPrompterAgent: Assistant: """ - prompt = PromptTemplate(input_variables=["history", "human_input"], - template=template) + prompt = PromptTemplate( + input_variables=["history", "human_input"], template=template + ) self.chain = LLMChain( llm=self.llm(), @@ -101,15 +102,13 @@ class MetaPrompterAgent: def get_chat_history(self, chain_memory): """Get Chat History from the memory""" memory_key = chain_memory.memory_key - chat_history = chain_memory.load_memory_variables( - memory_key)[memory_key] + chat_history = chain_memory.load_memory_variables(memory_key)[memory_key] return chat_history def get_new_instructions(self, meta_output): """Get New Instructions from the meta_output""" delimiter = "Instructions: " - new_instructions = meta_output[meta_output.find(delimiter) + - len(delimiter):] + new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :] return new_instructions def run(self, task: str): @@ -150,7 +149,8 @@ class MetaPrompterAgent: meta_chain = self.initialize_meta_chain() meta_output = meta_chain.predict( - chat_history=self.get_chat_history(chain.memory)) + chat_history=self.get_chat_history(chain.memory) + ) print(f"Feedback: {meta_output}") self.instructions = self.get_new_instructions(meta_output) diff --git a/swarms/agents/multi_modal_visual_agent.py b/swarms/agents/multi_modal_visual_agent.py index 72b6c50e..34780594 100644 --- a/swarms/agents/multi_modal_visual_agent.py +++ b/swarms/agents/multi_modal_visual_agent.py @@ -150,7 +150,6 @@ def seed_everything(seed): def prompts(name, description): - def decorator(func): func.name = name func.description = description @@ -172,12 +171,9 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): kernel = np.multiply(kernel_h, np.transpose(kernel_w)) kernel[steps:-steps, steps:-steps] = 1 - kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, - steps - 1] - kernel[:steps, - -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)] - kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, - steps - 1] + kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, steps - 1] + kernel[:steps, -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)] + kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, steps - 1] kernel[-steps:, -steps:] = kernel[-steps:, -steps:] / kernel[-steps, -steps] kernel = np.expand_dims(kernel, 2) kernel = np.repeat(kernel, 3, 2) @@ -211,12 +207,12 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): kernel[steps:-steps, :steps] = left kernel[steps:-steps, -steps:] = right - pt_gt_img = easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] - gaussian_gt_img = (kernel * gt_img_array + (1 - kernel) * pt_gt_img - ) # gt img with blur img + pt_gt_img = easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] + gaussian_gt_img = ( + kernel * gt_img_array + (1 - kernel) * pt_gt_img + ) # gt img with blur img gaussian_gt_img = gaussian_gt_img.astype(np.int64) - easy_img[pos_h:pos_h + old_size[1], - pos_w:pos_w + old_size[0]] = gaussian_gt_img + easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] = gaussian_gt_img gaussian_img = Image.fromarray(easy_img) return gaussian_img @@ -256,7 +252,6 @@ def get_new_image_name(org_img_name, func_name="update"): class InstructPix2Pix: - def __init__(self, device): print(f"Initializing InstructPix2Pix to {device}") self.device = device @@ -265,102 +260,110 @@ class InstructPix2Pix: self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ).to(device) self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) @prompts( name="Instruct Image Using Text", - description= - ("useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. "), + description=( + "useful when you want to the style of the image to be like the text. " + "like: make it look like a painting. or make it like a robot. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the text. " + ), ) def inference(self, inputs): """Change style of image.""" print("===>Starting InstructPix2Pix Inference") image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) - image = self.pipe(text, - image=original_image, - num_inference_steps=40, - image_guidance_scale=1.2).images[0] + image = self.pipe( + text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 + ).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) print( f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}") + f" {text}, Output Image: {updated_image_path}" + ) return updated_image_path class Text2Image: - def __init__(self, device): print(f"Initializing Text2Image to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) + "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype + ) self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image From User Input Text", - description= - ("useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. "), + description=( + "useful when you want to generate an image from a user input text and save" + " it to a file. like: generate an image of an object or something, or" + " generate an image that includes some objects. The input to this tool" + " should be a string, representing the text used to generate image. " + ), ) def inference(self, text): image_filename = os.path.join("image", f"{str(uuid.uuid4())[:8]}.png") prompt = text + ", " + self.a_prompt image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0] image.save(image_filename) - print(f"\nProcessed Text2Image, Input Text: {text}, Output Image:" - f" {image_filename}") + print( + f"\nProcessed Text2Image, Input Text: {text}, Output Image:" + f" {image_filename}" + ) return image_filename class ImageCaptioning: - def __init__(self, device): print(f"Initializing ImageCaptioning to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-image-captioning-base") + "Salesforce/blip-image-captioning-base" + ) self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", - torch_dtype=self.torch_dtype).to(self.device) + "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype + ).to(self.device) @prompts( name="Get Photo Description", - description= - ("useful when you want to know what is inside the photo. receives image_path" - " as input. The input to this tool should be a string, representing the" - " image_path. "), + description=( + "useful when you want to know what is inside the photo. receives image_path" + " as input. The input to this tool should be a string, representing the" + " image_path. " + ), ) def inference(self, image_path): - inputs = self.processor(Image.open(image_path), - return_tensors="pt").to(self.device, - self.torch_dtype) + inputs = self.processor(Image.open(image_path), return_tensors="pt").to( + self.device, self.torch_dtype + ) out = self.model.generate(**inputs) captions = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text:" - f" {captions}") + f" {captions}" + ) return captions class Image2Canny: - def __init__(self, device): print("Initializing Image2Canny") self.low_threshold = 100 @@ -368,11 +371,12 @@ class Image2Canny: @prompts( name="Edge Detection On Image", - description= - ("useful when you want to detect the edge of the image. like: detect the" - " edges of this image, or canny detection on image, or perform edge" - " detection on this image, or detect the canny image of this image. The" - " input to this tool should be a string, representing the image_path"), + description=( + "useful when you want to detect the edge of the image. like: detect the" + " edges of this image, or canny detection on image, or perform edge" + " detection on this image, or detect the canny image of this image. The" + " input to this tool should be a string, representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) @@ -383,13 +387,14 @@ class Image2Canny: canny = Image.fromarray(canny) updated_image_path = get_new_image_name(inputs, func_name="edge") canny.save(updated_image_path) - print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text:" + f" {updated_image_path}" + ) return updated_image_path class CannyText2Image: - def __init__(self, device): print(f"Initializing CannyText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -401,31 +406,36 @@ class CannyText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Canny Image", - description= - ("useful when you want to generate a new real image from both the user" - " description and a canny image. like: generate a real image of a object or" - " something from this canny image, or generate a new real image of a object" - " or something from this edge image. The input to this tool should be a" - " comma separated string of two, representing the image_path and the user" - " description. "), + description=( + "useful when you want to generate a new real image from both the user" + " description and a canny image. like: generate a real image of a object or" + " something from this canny image, or generate a new real image of a object" + " or something from this edge image. The input to this tool should be a" + " comma separated string of two, representing the image_path and the user" + " description. " + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -438,77 +448,83 @@ class CannyText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="canny2image") + updated_image_path = get_new_image_name(image_path, func_name="canny2image") image.save(updated_image_path) print( f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text:" - f" {instruct_text}, Output Text: {updated_image_path}") + f" {instruct_text}, Output Text: {updated_image_path}" + ) return updated_image_path class Image2Line: - def __init__(self, device): print("Initializing Image2Line") self.detector = MLSDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Line Detection On Image", - description= - ("useful when you want to detect the straight line of the image. like:" - " detect the straight lines of this image, or straight line detection on" - " image, or perform straight line detection on this image, or detect the" - " straight line image of this image. The input to this tool should be a" - " string, representing the image_path"), + description=( + "useful when you want to detect the straight line of the image. like:" + " detect the straight lines of this image, or straight line detection on" + " image, or perform straight line detection on this image, or detect the" + " straight line image of this image. The input to this tool should be a" + " string, representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) mlsd = self.detector(image) updated_image_path = get_new_image_name(inputs, func_name="line-of") mlsd.save(updated_image_path) - print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Line, Input Image: {inputs}, Output Line:" + f" {updated_image_path}" + ) return updated_image_path class LineText2Image: - def __init__(self, device): print(f"Initializing LineText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-mlsd", - torch_dtype=self.torch_dtype) + "fusing/stable-diffusion-v1-5-controlnet-mlsd", torch_dtype=self.torch_dtype + ) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Line Image", - description= - ("useful when you want to generate a new real image from both the user" - " description and a straight line image. like: generate a real image of a" - " object or something from this straight line image, or generate a new real" - " image of a object or something from this straight lines. The input to" - " this tool should be a comma separated string of two, representing the" - " image_path and the user description. "), + description=( + "useful when you want to generate a new real image from both the user" + " description and a straight line image. like: generate a real image of a" + " object or something from this straight line image, or generate a new real" + " image of a object or something from this straight lines. The input to" + " this tool should be a comma separated string of two, representing the" + " image_path and the user description. " + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -521,78 +537,83 @@ class LineText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="line2image") + updated_image_path = get_new_image_name(image_path, func_name="line2image") image.save(updated_image_path) print( f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text:" - f" {instruct_text}, Output Text: {updated_image_path}") + f" {instruct_text}, Output Text: {updated_image_path}" + ) return updated_image_path class Image2Hed: - def __init__(self, device): print("Initializing Image2Hed") self.detector = HEDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Hed Detection On Image", - description= - ("useful when you want to detect the soft hed boundary of the image. like:" - " detect the soft hed boundary of this image, or hed boundary detection on" - " image, or perform hed boundary detection on this image, or detect soft" - " hed boundary image of this image. The input to this tool should be a" - " string, representing the image_path"), + description=( + "useful when you want to detect the soft hed boundary of the image. like:" + " detect the soft hed boundary of this image, or hed boundary detection on" + " image, or perform hed boundary detection on this image, or detect soft" + " hed boundary image of this image. The input to this tool should be a" + " string, representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) hed = self.detector(image) - updated_image_path = get_new_image_name(inputs, - func_name="hed-boundary") + updated_image_path = get_new_image_name(inputs, func_name="hed-boundary") hed.save(updated_image_path) - print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed:" + f" {updated_image_path}" + ) return updated_image_path class HedText2Image: - def __init__(self, device): print(f"Initializing HedText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-hed", - torch_dtype=self.torch_dtype) + "fusing/stable-diffusion-v1-5-controlnet-hed", torch_dtype=self.torch_dtype + ) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Soft Hed Boundary Image", - description= - ("useful when you want to generate a new real image from both the user" - " description and a soft hed boundary image. like: generate a real image of" - " a object or something from this soft hed boundary image, or generate a" - " new real image of a object or something from this hed boundary. The input" - " to this tool should be a comma separated string of two, representing the" - " image_path and the user description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and a soft hed boundary image. like: generate a real image of" + " a object or something from this soft hed boundary image, or generate a" + " new real image of a object or something from this hed boundary. The input" + " to this tool should be a comma separated string of two, representing the" + " image_path and the user description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -605,27 +626,28 @@ class HedText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="hed2image") + updated_image_path = get_new_image_name(image_path, func_name="hed2image") image.save(updated_image_path) - print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + print( + f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text:" + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class Image2Scribble: - def __init__(self, device): print("Initializing Image2Scribble") self.detector = HEDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Sketch Detection On Image", - description= - ("useful when you want to generate a scribble of the image. like: generate a" - " scribble of this image, or generate a sketch from this image, detect the" - " sketch from this image. The input to this tool should be a string," - " representing the image_path"), + description=( + "useful when you want to generate a scribble of the image. like: generate a" + " scribble of this image, or generate a sketch from this image, detect the" + " sketch from this image. The input to this tool should be a string," + " representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) @@ -634,12 +656,12 @@ class Image2Scribble: scribble.save(updated_image_path) print( f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble:" - f" {updated_image_path}") + f" {updated_image_path}" + ) return updated_image_path class ScribbleText2Image: - def __init__(self, device): print(f"Initializing ScribbleText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -651,29 +673,34 @@ class ScribbleText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Sketch Image", - description= - ("useful when you want to generate a new real image from both the user" - " description and a scribble image or a sketch image. The input to this" - " tool should be a comma separated string of two, representing the" - " image_path and the user description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and a scribble image or a sketch image. The input to this" + " tool should be a comma separated string of two, representing the" + " image_path and the user description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -686,41 +713,41 @@ class ScribbleText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="scribble2image") + updated_image_path = get_new_image_name(image_path, func_name="scribble2image") image.save(updated_image_path) print( f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class Image2Pose: - def __init__(self, device): print("Initializing Image2Pose") - self.detector = OpenposeDetector.from_pretrained( - "lllyasviel/ControlNet") + self.detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Pose Detection On Image", - description= - ("useful when you want to detect the human pose of the image. like: generate" - " human poses of this image, or generate a pose image from this image. The" - " input to this tool should be a string, representing the image_path"), + description=( + "useful when you want to detect the human pose of the image. like: generate" + " human poses of this image, or generate a pose image from this image. The" + " input to this tool should be a string, representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) pose = self.detector(image) updated_image_path = get_new_image_name(inputs, func_name="human-pose") pose.save(updated_image_path) - print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose:" + f" {updated_image_path}" + ) return updated_image_path class PoseText2Image: - def __init__(self, device): print(f"Initializing PoseText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -732,11 +759,13 @@ class PoseText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.num_inference_steps = 20 self.seed = -1 @@ -744,20 +773,23 @@ class PoseText2Image: self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality") + " fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Pose Image", - description= - ("useful when you want to generate a new real image from both the user" - " description and a human pose image. like: generate a real image of a" - " human from this human pose image, or generate a new real image of a human" - " from this pose. The input to this tool should be a comma separated string" - " of two, representing the image_path and the user description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and a human pose image. like: generate a real image of a" + " human from this human pose image, or generate a new real image of a human" + " from this pose. The input to this tool should be a comma separated string" + " of two, representing the image_path and the user description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -770,52 +802,56 @@ class PoseText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="pose2image") + updated_image_path = get_new_image_name(image_path, func_name="pose2image") image.save(updated_image_path) print( f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class SegText2Image: - def __init__(self, device): print(f"Initializing SegText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-seg", - torch_dtype=self.torch_dtype) + "fusing/stable-diffusion-v1-5-controlnet-seg", torch_dtype=self.torch_dtype + ) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality") + " fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Segmentations", - description= - ("useful when you want to generate a new real image from both the user" - " description and segmentations. like: generate a real image of a object or" - " something from this segmentation image, or generate a new real image of a" - " object or something from these segmentations. The input to this tool" - " should be a comma separated string of two, representing the image_path" - " and the user description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and segmentations. like: generate a real image of a object or" + " something from this segmentation image, or generate a new real image of a" + " object or something from these segmentations. The input to this tool" + " should be a comma separated string of two, representing the image_path" + " and the user description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -828,27 +864,28 @@ class SegText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="segment2image") + updated_image_path = get_new_image_name(image_path, func_name="segment2image") image.save(updated_image_path) - print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + print( + f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text:" + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class Image2Depth: - def __init__(self, device): print("Initializing Image2Depth") self.depth_estimator = pipeline("depth-estimation") @prompts( name="Predict Depth On Image", - description= - ("useful when you want to detect depth of the image. like: generate the" - " depth from this image, or detect the depth map on this image, or predict" - " the depth for this image. The input to this tool should be a string," - " representing the image_path"), + description=( + "useful when you want to detect depth of the image. like: generate the" + " depth from this image, or detect the depth map on this image, or predict" + " the depth for this image. The input to this tool should be a string," + " representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) @@ -859,13 +896,14 @@ class Image2Depth: depth = Image.fromarray(depth) updated_image_path = get_new_image_name(inputs, func_name="depth") depth.save(updated_image_path) - print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth:" + f" {updated_image_path}" + ) return updated_image_path class DepthText2Image: - def __init__(self, device): print(f"Initializing DepthText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -877,31 +915,36 @@ class DepthText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality") + " fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Depth", - description= - ("useful when you want to generate a new real image from both the user" - " description and depth image. like: generate a real image of a object or" - " something from this depth image, or generate a new real image of a object" - " or something from the depth map. The input to this tool should be a comma" - " separated string of two, representing the image_path and the user" - " description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and depth image. like: generate a real image of a object or" + " something from this depth image, or generate a new real image of a object" + " or something from the depth map. The input to this tool should be a comma" + " separated string of two, representing the image_path and the user" + " description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -914,29 +957,30 @@ class DepthText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="depth2image") + updated_image_path = get_new_image_name(image_path, func_name="depth2image") image.save(updated_image_path) print( f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class Image2Normal: - def __init__(self, device): print("Initializing Image2Normal") - self.depth_estimator = pipeline("depth-estimation", - model="Intel/dpt-hybrid-midas") + self.depth_estimator = pipeline( + "depth-estimation", model="Intel/dpt-hybrid-midas" + ) self.bg_threhold = 0.4 @prompts( name="Predict Normal Map On Image", - description= - ("useful when you want to detect norm map of the image. like: generate" - " normal map from this image, or predict normal map of this image. The" - " input to this tool should be a string, representing the image_path"), + description=( + "useful when you want to detect norm map of the image. like: generate" + " normal map from this image, or predict normal map of this image. The" + " input to this tool should be a string, representing the image_path" + ), ) def inference(self, inputs): image = Image.open(inputs) @@ -952,19 +996,20 @@ class Image2Normal: y[image_depth < self.bg_threhold] = 0 z = np.ones_like(x) * np.pi * 2.0 image = np.stack([x, y, z], axis=2) - image /= np.sum(image**2.0, axis=2, keepdims=True)**0.5 + image /= np.sum(image**2.0, axis=2, keepdims=True) ** 0.5 image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) image = Image.fromarray(image) image = image.resize(original_size) updated_image_path = get_new_image_name(inputs, func_name="normal-map") image.save(updated_image_path) - print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth:" - f" {updated_image_path}") + print( + f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth:" + f" {updated_image_path}" + ) return updated_image_path class NormalText2Image: - def __init__(self, device): print(f"Initializing NormalText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -976,31 +1021,36 @@ class NormalText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality") + " fewer digits, cropped, worst quality, low quality" + ) @prompts( name="Generate Image Condition On Normal Map", - description= - ("useful when you want to generate a new real image from both the user" - " description and normal map. like: generate a real image of a object or" - " something from this normal map, or generate a new real image of a object" - " or something from the normal map. The input to this tool should be a" - " comma separated string of two, representing the image_path and the user" - " description"), + description=( + "useful when you want to generate a new real image from both the user" + " description and normal map. like: generate a real image of a object or" + " something from this normal map, or generate a new real image of a object" + " or something from the normal map. The input to this tool should be a" + " comma separated string of two, representing the image_path and the user" + " description" + ), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -1013,53 +1063,50 @@ class NormalText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="normal2image") + updated_image_path = get_new_image_name(image_path, func_name="normal2image") image.save(updated_image_path) print( f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}") + f" {instruct_text}, Output Image: {updated_image_path}" + ) return updated_image_path class VisualQuestionAnswering: - def __init__(self, device): print(f"Initializing VisualQuestionAnswering to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") self.model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", - torch_dtype=self.torch_dtype).to(self.device) + "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype + ).to(self.device) @prompts( name="Answer Question About The Image", - description= - ("useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + description=( + "useful when you need an answer for a question based on an image. like:" + " what is the background color of the last image, how many cats in this" + " figure, what is in this figure. The input to this tool should be a comma" + " separated string of two, representing the image_path and the question" ), ) def inference(self, inputs): - image_path, question = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + image_path, question = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) raw_image = Image.open(image_path).convert("RGB") - inputs = self.processor(raw_image, question, - return_tensors="pt").to(self.device, - self.torch_dtype) + inputs = self.processor(raw_image, question, return_tensors="pt").to( + self.device, self.torch_dtype + ) out = self.model.generate(**inputs) answer = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}") + f" Question: {question}, Output Answer: {answer}" + ) return answer class Segmenting: - def __init__(self, device): print(f"Inintializing Segmentation to {device}") self.device = device @@ -1104,8 +1151,7 @@ class Segmenting: h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255 - image = cv2.addWeighted(image, 0.7, mask_image.astype("uint8"), - transparency, 0) + image = cv2.addWeighted(image, 0.7, mask_image.astype("uint8"), transparency, 0) return image @@ -1113,12 +1159,10 @@ class Segmenting: x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch( - plt.Rectangle((x0, y0), - w, - h, - edgecolor="green", - facecolor=(0, 0, 0, 0), - lw=2)) + plt.Rectangle( + (x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2 + ) + ) ax.text(x0, y0, label) def get_mask_with_boxes(self, image_pil, image, boxes_filt): @@ -1131,7 +1175,8 @@ class Segmenting: boxes_filt = boxes_filt.cpu() transformed_boxes = self.sam_predictor.transform.apply_boxes_torch( - boxes_filt, image.shape[:2]).to(self.device) + boxes_filt, image.shape[:2] + ).to(self.device) masks, _, _ = self.sam_predictor.predict_torch( point_coords=None, @@ -1141,8 +1186,7 @@ class Segmenting: ) return masks - def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, - pred_phrases): + def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam_predictor.set_image(image) @@ -1152,13 +1196,11 @@ class Segmenting: # draw output image for mask in masks: - image = self.show_mask(mask[0].cpu().numpy(), - image, - random_color=True, - transparency=0.3) + image = self.show_mask( + mask[0].cpu().numpy(), image, random_color=True, transparency=0.3 + ) - updated_image_path = get_new_image_name(image_path, - func_name="segmentation") + updated_image_path = get_new_image_name(image_path, func_name="segmentation") new_image = Image.fromarray(image) new_image.save(updated_image_path) @@ -1170,8 +1212,9 @@ class Segmenting: with torch.cuda.amp.autocast(): self.sam_predictor.set_image(img) - def show_points(self, coords: np.ndarray, labels: np.ndarray, - image: np.ndarray) -> np.ndarray: + def show_points( + self, coords: np.ndarray, labels: np.ndarray, image: np.ndarray + ) -> np.ndarray: """Visualize points on top of an image. Args: @@ -1185,17 +1228,13 @@ class Segmenting: pos_points = coords[labels == 1] neg_points = coords[labels == 0] for p in pos_points: - image = cv2.circle(image, - p.astype(int), - radius=3, - color=(0, 255, 0), - thickness=-1) + image = cv2.circle( + image, p.astype(int), radius=3, color=(0, 255, 0), thickness=-1 + ) for p in neg_points: - image = cv2.circle(image, - p.astype(int), - radius=3, - color=(255, 0, 0), - thickness=-1) + image = cv2.circle( + image, p.astype(int), radius=3, color=(255, 0, 0), thickness=-1 + ) return image def segment_image_with_click(self, img, is_positive: bool): @@ -1213,17 +1252,13 @@ class Segmenting: multimask_output=False, ) - img = self.show_mask(masks[0], - img, - random_color=False, - transparency=0.3) + img = self.show_mask(masks[0], img, random_color=False, transparency=0.3) img = self.show_points(input_point, input_label, img) return img - def segment_image_with_coordinate(self, img, is_positive: bool, - coordinate: tuple): + def segment_image_with_coordinate(self, img, is_positive: bool, coordinate: tuple): """ Args: img (numpy.ndarray): the given image, shape: H x W x 3. @@ -1254,10 +1289,7 @@ class Segmenting: multimask_output=False, ) - img = self.show_mask(masks[0], - img, - random_color=False, - transparency=0.3) + img = self.show_mask(masks[0], img, random_color=False, transparency=0.3) img = self.show_points(input_point, input_label, img) @@ -1269,12 +1301,13 @@ class Segmenting: @prompts( name="Segment the Image", - description= - ("useful when you want to segment all the part of the image, but not segment" - " a certain object.like: segment all the object in this image, or generate" - " segmentations on this image, or segment the image,or perform segmentation" - " on this image, or segment all the object in this image.The input to this" - " tool should be a string, representing the image_path"), + description=( + "useful when you want to segment all the part of the image, but not segment" + " a certain object.like: segment all the object in this image, or generate" + " segmentations on this image, or segment the image,or perform segmentation" + " on this image, or segment all the object in this image.The input to this" + " tool should be a string, representing the image_path" + ), ) def inference_all(self, image_path): image = cv2.imread(image_path) @@ -1295,26 +1328,19 @@ class Segmenting: img[:, :, i] = color_mask[i] ax.imshow(np.dstack((img, m))) - updated_image_path = get_new_image_name(image_path, - func_name="segment-image") + updated_image_path = get_new_image_name(image_path, func_name="segment-image") plt.axis("off") - plt.savefig(updated_image_path, - bbox_inches="tight", - dpi=300, - pad_inches=0.0) + plt.savefig(updated_image_path, bbox_inches="tight", dpi=300, pad_inches=0.0) return updated_image_path class Text2Box: - def __init__(self, device): print(f"Initializing ObjectDetection to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 - self.model_checkpoint_path = os.path.join("checkpoints", - "groundingdino") - self.model_config_path = os.path.join("checkpoints", - "grounding_config.py") + self.model_checkpoint_path = os.path.join("checkpoints", "groundingdino") + self.model_config_path = os.path.join("checkpoints", "grounding_config.py") self.download_parameters() self.box_threshold = 0.3 self.text_threshold = 0.25 @@ -1332,11 +1358,13 @@ class Text2Box: # load image image_pil = Image.open(image_path).convert("RGB") # load image - transform = T.Compose([ - T.RandomResize([512], max_size=1333), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) + transform = T.Compose( + [ + T.RandomResize([512], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image @@ -1345,8 +1373,9 @@ class Text2Box: args.device = self.device model = build_model(args) checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu") - load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), - strict=False) + load_res = model.load_state_dict( + clean_state_dict(checkpoint["model"]), strict=False + ) print(load_res) _ = model.eval() return model @@ -1377,11 +1406,11 @@ class Text2Box: # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): - pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, - tokenized, tokenlizer) + pred_phrase = get_phrases_from_posmap( + logit > self.text_threshold, tokenized, tokenlizer + ) if with_logits: - pred_phrases.append(pred_phrase + - f"({str(logit.max().item())[:4]})") + pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) @@ -1391,8 +1420,7 @@ class Text2Box: H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] - assert len(boxes) == len( - labels), "boxes and labels must have same length" + assert len(boxes) == len(labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) @@ -1430,11 +1458,12 @@ class Text2Box: @prompts( name="Detect the Give Object", - description= - ("useful when you only want to detect or find out given objects in the" - " pictureThe input to this tool should be a comma separated string of two," - " representing the image_path, the text description of the object to be" - " found"), + description=( + "useful when you only want to detect or find out given objects in the" + " pictureThe input to this tool should be a comma separated string of two," + " representing the image_path, the text description of the object to be" + " found" + ), ) def inference(self, inputs): image_path, det_prompt = inputs.split(",") @@ -1452,18 +1481,19 @@ class Text2Box: image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0] - updated_image_path = get_new_image_name(image_path, - func_name="detect-something") + updated_image_path = get_new_image_name( + image_path, func_name="detect-something" + ) updated_image = image_with_box.resize(size) updated_image.save(updated_image_path) print( f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be" - f" Detect {det_prompt}, Output Image: {updated_image_path}") + f" Detect {det_prompt}, Output Image: {updated_image_path}" + ) return updated_image_path class Inpainting: - def __init__(self, device): self.device = device self.revision = "fp16" if "cuda" in self.device else None @@ -1474,16 +1504,13 @@ class Inpainting: revision=self.revision, torch_dtype=self.torch_dtype, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker"), + "CompVis/stable-diffusion-safety-checker" + ), ).to(device) - def __call__(self, - prompt, - image, - mask_image, - height=512, - width=512, - num_inference_steps=50): + def __call__( + self, prompt, image, mask_image, height=512, width=512, num_inference_steps=50 + ): update_image = self.inpaint( prompt=prompt, image=image.resize((width, height)), @@ -1506,27 +1533,29 @@ class InfinityOutPainting: self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) def get_BLIP_vqa(self, image, question): - inputs = self.ImageVQA.processor(image, question, - return_tensors="pt").to( - self.ImageVQA.device, - self.ImageVQA.torch_dtype) + inputs = self.ImageVQA.processor(image, question, return_tensors="pt").to( + self.ImageVQA.device, self.ImageVQA.torch_dtype + ) out = self.ImageVQA.model.generate(**inputs) - answer = self.ImageVQA.processor.decode(out[0], - skip_special_tokens=True) + answer = self.ImageVQA.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed VisualQuestionAnswering, Input Question: {question}, Output" - f" Answer: {answer}") + f" Answer: {answer}" + ) return answer def get_BLIP_caption(self, image): inputs = self.ImageCaption.processor(image, return_tensors="pt").to( - self.ImageCaption.device, self.ImageCaption.torch_dtype) + self.ImageCaption.device, self.ImageCaption.torch_dtype + ) out = self.ImageCaption.model.generate(**inputs) BLIP_caption = self.ImageCaption.processor.decode( - out[0], skip_special_tokens=True) + out[0], skip_special_tokens=True + ) return BLIP_caption def check_prompt(self, prompt): @@ -1540,7 +1569,8 @@ class InfinityOutPainting: def get_imagine_caption(self, image, imagine): BLIP_caption = self.get_BLIP_caption(image) background_color = self.get_BLIP_vqa( - image, "what is the background color of this image") + image, "what is the background color of this image" + ) style = self.get_BLIP_vqa(image, "what is the style of this image") imagine_prompt = ( "let's pretend you are an excellent painter and now there is an incomplete" @@ -1548,47 +1578,54 @@ class InfinityOutPainting: " painting and describe ityou should consider the background color is" f" {background_color}, the style is {style}You should make the painting as" " vivid and realistic as possibleYou can not use words like painting or" - " pictureand you should use no more than 50 words to describe it") + " pictureand you should use no more than 50 words to describe it" + ) caption = self.llm(imagine_prompt) if imagine else BLIP_caption caption = self.check_prompt(caption) - print(f"BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}" - ) if imagine else print(f"Prompt: {caption}") + print( + f"BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}" + ) if imagine else print(f"Prompt: {caption}") return caption def resize_image(self, image, max_size=1000000, multiple=8): aspect_ratio = image.size[0] / image.size[1] new_width = int(math.sqrt(max_size * aspect_ratio)) new_height = int(new_width / aspect_ratio) - new_width, new_height = new_width - ( - new_width % multiple), new_height - (new_height % multiple) + new_width, new_height = new_width - (new_width % multiple), new_height - ( + new_height % multiple + ) return image.resize((new_width, new_height)) def dowhile(self, original_img, tosize, expand_ratio, imagine, usr_prompt): old_img = original_img while old_img.size != tosize: - prompt = (self.check_prompt(usr_prompt) if usr_prompt else - self.get_imagine_caption(old_img, imagine)) + prompt = ( + self.check_prompt(usr_prompt) + if usr_prompt + else self.get_imagine_caption(old_img, imagine) + ) crop_w = 15 if old_img.size[0] != tosize[0] else 0 crop_h = 15 if old_img.size[1] != tosize[1] else 0 old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h)) temp_canvas_size = ( expand_ratio * old_img.width - if expand_ratio * old_img.width < tosize[0] else tosize[0], + if expand_ratio * old_img.width < tosize[0] + else tosize[0], expand_ratio * old_img.height - if expand_ratio * old_img.height < tosize[1] else tosize[1], + if expand_ratio * old_img.height < tosize[1] + else tosize[1], ) - temp_canvas, temp_mask = Image.new("RGB", - temp_canvas_size, - color="white"), Image.new( - "L", - temp_canvas_size, - color="white") + temp_canvas, temp_mask = Image.new( + "RGB", temp_canvas_size, color="white" + ), Image.new("L", temp_canvas_size, color="white") x, y = (temp_canvas.width - old_img.width) // 2, ( - temp_canvas.height - old_img.height) // 2 + temp_canvas.height - old_img.height + ) // 2 temp_canvas.paste(old_img, (x, y)) temp_mask.paste(0, (x, y, x + old_img.width, y + old_img.height)) resized_temp_canvas, resized_temp_mask = self.resize_image( - temp_canvas), self.resize_image(temp_mask) + temp_canvas + ), self.resize_image(temp_mask) image = self.inpaint( prompt=prompt, image=resized_temp_canvas, @@ -1603,11 +1640,11 @@ class InfinityOutPainting: @prompts( name="Extend An Image", - description= - ("useful when you need to extend an image into a larger image.like: extend" - " the image into a resolution of 2048x1024, extend the image into" - " 2048x1024. The input to this tool should be a comma separated string of" - " two, representing the image_path and the resolution of widthxheight" + description=( + "useful when you need to extend an image into a larger image.like: extend" + " the image into a resolution of 2048x1024, extend the image into" + " 2048x1024. The input to this tool should be a comma separated string of" + " two, representing the image_path and the resolution of widthxheight" ), ) def inference(self, inputs): @@ -1617,12 +1654,12 @@ class InfinityOutPainting: image = Image.open(image_path) image = ImageOps.crop(image, (10, 10, 10, 10)) out_painted_image = self.dowhile(image, tosize, 4, True, False) - updated_image_path = get_new_image_name(image_path, - func_name="outpainting") + updated_image_path = get_new_image_name(image_path, func_name="outpainting") out_painted_image.save(updated_image_path) print( f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input" - f" Resolution: {resolution}, Output Image: {updated_image_path}") + f" Resolution: {resolution}, Output Image: {updated_image_path}" + ) return updated_image_path @@ -1641,20 +1678,22 @@ class ObjectSegmenting: " pictureaccording to the given textlike: segment the cat,or can you" " segment an obeject for meThe input to this tool should be a comma" " separated string of two, representing the image_path, the text" - " description of the object to be found"), + " description of the object to be found" + ), ) def inference(self, inputs): image_path, det_prompt = inputs.split(",") print(f"image_path={image_path}, text_prompt={det_prompt}") image_pil, image = self.grounding.load_image(image_path) - boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( - image, det_prompt) + boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt) updated_image_path = self.sam.segment_image_with_boxes( - image_pil, image_path, boxes_filt, pred_phrases) + image_pil, image_path, boxes_filt, pred_phrases + ) print( f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be" - f" Segment {det_prompt}, Output Image: {updated_image_path}") + f" Segment {det_prompt}, Output Image: {updated_image_path}" + ) return updated_image_path def merge_masks(self, masks): @@ -1685,7 +1724,8 @@ class ObjectSegmenting: image_pil, image = self.grounding.load_image(image_path) boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( - image, text_prompt) + image, text_prompt + ) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam.sam_predictor.set_image(image) @@ -1698,10 +1738,9 @@ class ObjectSegmenting: # draw output image for mask in masks: - image = self.sam.show_mask(mask[0].cpu().numpy(), - image, - random_color=True, - transparency=0.3) + image = self.sam.show_mask( + mask[0].cpu().numpy(), image, random_color=True, transparency=0.3 + ) Image.fromarray(merged_mask) @@ -1711,8 +1750,9 @@ class ObjectSegmenting: class ImageEditing: template_model = True - def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting, - Inpainting: Inpainting): + def __init__( + self, Text2Box: Text2Box, Segmenting: Segmenting, Inpainting: Inpainting + ): print("Initializing ImageEditing") self.sam = Segmenting self.grounding = Text2Box @@ -1725,7 +1765,8 @@ class ImageEditing: mask_array = np.zeros_like(mask, dtype=bool) for idx in true_indices: padded_slice = tuple( - slice(max(0, i - padding), i + padding + 1) for i in idx) + slice(max(0, i - padding), i + padding + 1) for i in idx + ) mask_array[padded_slice] = True new_mask = (mask_array * 255).astype(np.uint8) # new_mask @@ -1733,34 +1774,38 @@ class ImageEditing: @prompts( name="Remove Something From The Photo", - description= - ("useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. "), + description=( + "useful when you want to remove and object or something from the photo " + "from its description or location. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the object need to be removed. " + ), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",")[0], ",".join( - inputs.split(",")[1:]) + inputs.split(",")[1:] + ) return self.inference_replace_sam( - f"{image_path},{to_be_removed_txt},background") + f"{image_path},{to_be_removed_txt},background" + ) @prompts( name="Replace Something From The Photo", - description= - ("useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with "), + description=( + "useful when you want to replace an object from the object description or" + " location with another object from its description. The input to this tool" + " should be a comma separated string of three, representing the image_path," + " the object to be replaced, the object to be replaced with " + ), ) def inference_replace_sam(self, inputs): image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") - print( - f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") + print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") image_pil, image = self.grounding.load_image(image_path) boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( - image, to_be_replaced_txt) + image, to_be_replaced_txt + ) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam.sam_predictor.set_image(image) @@ -1772,16 +1817,19 @@ class ImageEditing: mask = self.pad_edge(mask, padding=20) # numpy mask_image = Image.fromarray(mask) - updated_image = self.inpaint(prompt=replace_with_txt, - image=image_pil, - mask_image=mask_image) - updated_image_path = get_new_image_name(image_path, - func_name="replace-something") + updated_image = self.inpaint( + prompt=replace_with_txt, image=image_pil, mask_image=mask_image + ) + updated_image_path = get_new_image_name( + image_path, func_name="replace-something" + ) updated_image = updated_image.resize(image_pil.size) updated_image.save(updated_image_path) - print(f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" - f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" - f" {updated_image_path}") + print( + f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" + f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" + f" {updated_image_path}" + ) return updated_image_path @@ -1803,9 +1851,10 @@ class BackgroundRemoving: @prompts( name="Remove the background", - description= - ("useful when you want to extract the object or remove the background," - "the input should be a string image_path"), + description=( + "useful when you want to extract the object or remove the background," + "the input should be a string image_path" + ), ) def inference(self, image_path): """ @@ -1819,8 +1868,9 @@ class BackgroundRemoving: mask = Image.fromarray(mask) image.putalpha(mask) - updated_image_path = get_new_image_name(image_path, - func_name="detect-something") + updated_image_path = get_new_image_name( + image_path, func_name="detect-something" + ) image.save(updated_image_path) return updated_image_path @@ -1843,7 +1893,6 @@ class BackgroundRemoving: class MultiModalVisualAgent: - def __init__( self, load_dict, @@ -1856,7 +1905,8 @@ class MultiModalVisualAgent: if "ImageCaptioning" not in load_dict: raise ValueError( "You have to load ImageCaptioning as a basic function for" - " MultiModalVisualAgent") + " MultiModalVisualAgent" + ) self.models = {} @@ -1866,18 +1916,17 @@ class MultiModalVisualAgent: for class_name, module in globals().items(): if getattr(module, "template_model", False): template_required_names = { - k for k in inspect.signature( - module.__init__).parameters.keys() if k != "self" + k + for k in inspect.signature(module.__init__).parameters.keys() + if k != "self" } - loaded_names = set( - [type(e).__name__ for e in self.models.values()]) + loaded_names = set([type(e).__name__ for e in self.models.values()]) if template_required_names.issubset(loaded_names): - self.models[class_name] = globals()[class_name](**{ - name: self.models[name] - for name in template_required_names - }) + self.models[class_name] = globals()[class_name]( + **{name: self.models[name] for name in template_required_names} + ) print(f"All the Available Functions: {self.models}") @@ -1887,13 +1936,13 @@ class MultiModalVisualAgent: if e.startswith("inference"): func = getattr(instance, e) self.tools.append( - Tool(name=func.name, - description=func.description, - func=func)) + Tool(name=func.name, description=func.description, func=func) + ) self.llm = OpenAI(temperature=0) - self.memory = ConversationBufferMemory(memory_key="chat_history", - output_key="output") + self.memory = ConversationBufferMemory( + memory_key="chat_history", output_key="output" + ) def init_agent(self, lang): self.memory.clear() @@ -1931,7 +1980,8 @@ class MultiModalVisualAgent: def run_text(self, text): self.agent.memory.buffer = cut_dialogue_history( - self.agent.memory.buffer, keep_last_n_words=500) + self.agent.memory.buffer, keep_last_n_words=500 + ) res = self.agent({"input": text.strip()}) res["output"] = res["output"].replace("\\", "/") @@ -1941,8 +1991,10 @@ class MultiModalVisualAgent: res["output"], ) - print(f"\nProcessed run_text, Input text: {text}\n" - f"Current Memory: {self.agent.memory.buffer}") + print( + f"\nProcessed run_text, Input text: {text}\n" + f"Current Memory: {self.agent.memory.buffer}" + ) return response @@ -1964,10 +2016,12 @@ class MultiModalVisualAgent: description = self.models["ImageCaptioning"].inference(image_filename) if lang == "Chinese": - Human_prompt = (f"\nHuman: 提供一张名为 {image_filename}的图片。它的描述是:" - f" {description}。 这些信息帮助你理解这个图像," - "但是你应该使用工具来完成下面的任务,而不是直接从我的描述中想象。" - ' 如果你明白了, 说 "收到". \n') + Human_prompt = ( + f"\nHuman: 提供一张名为 {image_filename}的图片。它的描述是:" + f" {description}。 这些信息帮助你理解这个图像," + "但是你应该使用工具来完成下面的任务,而不是直接从我的描述中想象。" + ' 如果你明白了, 说 "收到". \n' + ) AI_prompt = "收到。 " else: Human_prompt = ( @@ -1975,14 +2029,18 @@ class MultiModalVisualAgent: f" {description}. This information helps you to understand this image," " but you should use tools to finish following tasks, rather than" " directly imagine from my description. If you understand, say" - ' "Received". \n') + ' "Received". \n' + ) AI_prompt = "Received. " - self.agent.memory.buffer = (self.agent.memory.buffer + Human_prompt + - "AI: " + AI_prompt) + self.agent.memory.buffer = ( + self.agent.memory.buffer + Human_prompt + "AI: " + AI_prompt + ) - print(f"\nProcessed run_image, Input image: {image_filename}\n" - f"Current Memory: {self.agent.memory.buffer}") + print( + f"\nProcessed run_image, Input image: {image_filename}\n" + f"Current Memory: {self.agent.memory.buffer}" + ) return AI_prompt @@ -2029,10 +2087,7 @@ class MultiModalAgent: """ - def __init__(self, - load_dict, - temperature: int = 0.1, - language: str = "english"): + def __init__(self, load_dict, temperature: int = 0.1, language: str = "english"): self.load_dict = load_dict self.temperature = temperature self.langigage = language @@ -2068,10 +2123,7 @@ class MultiModalAgent: except Exception as error: return f"Error processing image: {str(error)}" - def chat(self, - msg: str = None, - language: str = "english", - streaming: bool = False): + def chat(self, msg: str = None, language: str = "english", streaming: bool = False): """ Run chat with the multi-modal agent diff --git a/swarms/agents/neural_architecture_search_worker.py b/swarms/agents/neural_architecture_search_worker.py index 3bfd8323..fd253b95 100644 --- a/swarms/agents/neural_architecture_search_worker.py +++ b/swarms/agents/neural_architecture_search_worker.py @@ -2,7 +2,6 @@ class Replicator: - def __init__( self, model_name, diff --git a/swarms/agents/omni_modal_agent.py b/swarms/agents/omni_modal_agent.py index b6fdfbdc..007a2219 100644 --- a/swarms/agents/omni_modal_agent.py +++ b/swarms/agents/omni_modal_agent.py @@ -3,20 +3,23 @@ from typing import Dict, List from langchain.base_language import BaseLanguageModel from langchain.tools.base import BaseTool from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import ( - load_response_generator,) + load_response_generator, +) from langchain_experimental.autonomous_agents.hugginggpt.task_executor import ( - TaskExecutor,) + TaskExecutor, +) from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( - load_chat_planner,) + load_chat_planner, +) from transformers import load_tool from swarms.agents.message import Message class Step: - - def __init__(self, task: str, id: int, dep: List[int], args: Dict[str, str], - tool: BaseTool): + def __init__( + self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool + ): self.task = task self.id = id self.dep = dep @@ -25,7 +28,6 @@ class Step: class Plan: - def __init__(self, steps: List[Step]): self.steps = steps @@ -71,7 +73,8 @@ class OmniModalAgent: print("Loading tools...") self.tools = [ - load_tool(tool_name) for tool_name in [ + load_tool(tool_name) + for tool_name in [ "document-question-answering", "image-captioning", "image-question-answering", @@ -96,15 +99,18 @@ class OmniModalAgent: def run(self, input: str) -> str: """Run the OmniAgent""" - plan = self.chat_planner.plan(inputs={ - "input": input, - "hf_tools": self.tools, - }) + plan = self.chat_planner.plan( + inputs={ + "input": input, + "hf_tools": self.tools, + } + ) self.task_executor = TaskExecutor(plan) self.task_executor.run() response = self.response_generator.generate( - {"task_execution": self.task_executor}) + {"task_execution": self.task_executor} + ) return response diff --git a/swarms/agents/profitpilot.py b/swarms/agents/profitpilot.py index a4ff13a5..6858dc72 100644 --- a/swarms/agents/profitpilot.py +++ b/swarms/agents/profitpilot.py @@ -145,12 +145,13 @@ def setup_knowledge_base(product_catalog: str = None): llm = OpenAI(temperature=0) embeddings = OpenAIEmbeddings() - docsearch = Chroma.from_texts(texts, - embeddings, - collection_name="product-knowledge-base") + docsearch = Chroma.from_texts( + texts, embeddings, collection_name="product-knowledge-base" + ) knowledge_base = RetrievalQA.from_chain_type( - llm=llm, chain_type="stuff", retriever=docsearch.as_retriever()) + llm=llm, chain_type="stuff", retriever=docsearch.as_retriever() + ) return knowledge_base @@ -162,8 +163,8 @@ def get_tools(product_catalog): Tool( name="ProductSearch", func=knowledge_base.run, - description= - ("useful for when you need to answer questions about product information" + description=( + "useful for when you need to answer questions about product information" ), ), # omnimodal agent @@ -193,7 +194,8 @@ class CustomPromptTemplateForTools(StringPromptTemplate): tools = self.tools_getter(kwargs["input"]) # Create a tools variable from the list of tools provided kwargs["tools"] = "\n".join( - [f"{tool.name}: {tool.description}" for tool in tools]) + [f"{tool.name}: {tool.description}" for tool in tools] + ) # Create a list of tool names for the tools provided kwargs["tool_names"] = ", ".join([tool.name for tool in tools]) return self.template.format(**kwargs) @@ -216,7 +218,8 @@ class SalesConvoOutputParser(AgentOutputParser): print("-------") if f"{self.ai_prefix}:" in text: return AgentFinish( - {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text) + {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text + ) regex = r"Action: (.*?)[\n]*Action Input: (.*)" match = re.search(regex, text) if not match: @@ -225,15 +228,15 @@ class SalesConvoOutputParser(AgentOutputParser): { "output": ( "I apologize, I was unable to find the answer to your question." - " Is there anything else I can help with?") + " Is there anything else I can help with?" + ) }, text, ) # raise OutputParserException(f"Could not parse LLM output: `{text}`") action = match.group(1) action_input = match.group(2) - return AgentAction(action.strip(), - action_input.strip(" ").strip('"'), text) + return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) @property def _type(self) -> str: @@ -261,11 +264,13 @@ class ProfitPilot(Chain, BaseModel): "2": ( "Qualification: Qualify the prospect by confirming if they are the right" " person to talk to regarding your product/service. Ensure that they have" - " the authority to make purchasing decisions."), + " the authority to make purchasing decisions." + ), "3": ( "Value proposition: Briefly explain how your product/service can benefit" " the prospect. Focus on the unique selling points and value proposition of" - " your product/service that sets it apart from competitors."), + " your product/service that sets it apart from competitors." + ), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs" " and pain points. Listen carefully to their responses and take notes." @@ -277,11 +282,13 @@ class ProfitPilot(Chain, BaseModel): "6": ( "Objection handling: Address any objections that the prospect may have" " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims."), + " testimonials to support your claims." + ), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has" - " been discussed and reiterate the benefits."), + " been discussed and reiterate the benefits." + ), } salesperson_name: str = "Ted Lasso" @@ -291,16 +298,19 @@ class ProfitPilot(Chain, BaseModel): "Sleep Haven is a premium mattress company that provides customers with the" " most comfortable and supportive sleeping experience possible. We offer a" " range of high-quality mattresses, pillows, and bedding accessories that are" - " designed to meet the unique needs of our customers.") + " designed to meet the unique needs of our customers." + ) company_values: str = ( "Our mission at Sleep Haven is to help people achieve a better night's sleep by" " providing them with the best possible sleep solutions. We believe that" " quality sleep is essential to overall health and well-being, and we are" " committed to helping our customers achieve optimal sleep by offering" - " exceptional products and customer service.") + " exceptional products and customer service." + ) conversation_purpose: str = ( "find out whether they are looking to achieve better sleep via buying a premier" - " mattress.") + " mattress." + ) conversation_type: str = "call" def retrieve_conversation_stage(self, key): @@ -326,7 +336,8 @@ class ProfitPilot(Chain, BaseModel): ) self.current_conversation_stage = self.retrieve_conversation_stage( - conversation_stage_id) + conversation_stage_id + ) print(f"Conversation Stage: {self.current_conversation_stage}") @@ -380,15 +391,13 @@ class ProfitPilot(Chain, BaseModel): return {} @classmethod - def from_llm(cls, - llm: BaseLLM, - verbose: bool = False, - **kwargs): # noqa: F821 + def from_llm(cls, llm: BaseLLM, verbose: bool = False, **kwargs): # noqa: F821 """Initialize the SalesGPT Controller.""" stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) sales_conversation_utterance_chain = SalesConversationChain.from_llm( - llm, verbose=verbose) + llm, verbose=verbose + ) if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False: sales_agent_executor = None @@ -421,8 +430,7 @@ class ProfitPilot(Chain, BaseModel): # WARNING: this output parser is NOT reliable yet # It makes assumptions about output from LLM which can break and throw an error - output_parser = SalesConvoOutputParser( - ai_prefix=kwargs["salesperson_name"]) + output_parser = SalesConvoOutputParser(ai_prefix=kwargs["salesperson_name"]) sales_agent_with_tools = LLMSingleActionAgent( llm_chain=llm_chain, @@ -433,12 +441,12 @@ class ProfitPilot(Chain, BaseModel): ) sales_agent_executor = AgentExecutor.from_agent_and_tools( - agent=sales_agent_with_tools, tools=tools, verbose=verbose) + agent=sales_agent_with_tools, tools=tools, verbose=verbose + ) return cls( stage_analyzer_chain=stage_analyzer_chain, - sales_conversation_utterance_chain= - sales_conversation_utterance_chain, + sales_conversation_utterance_chain=sales_conversation_utterance_chain, sales_agent_executor=sales_agent_executor, verbose=verbose, **kwargs, @@ -450,27 +458,32 @@ config = dict( salesperson_name="Ted Lasso", salesperson_role="Business Development Representative", company_name="Sleep Haven", - company_business= - ("Sleep Haven is a premium mattress company that provides customers with the" - " most comfortable and supportive sleeping experience possible. We offer a" - " range of high-quality mattresses, pillows, and bedding accessories that are" - " designed to meet the unique needs of our customers."), - company_values= - ("Our mission at Sleep Haven is to help people achieve a better night's sleep by" - " providing them with the best possible sleep solutions. We believe that" - " quality sleep is essential to overall health and well-being, and we are" - " committed to helping our customers achieve optimal sleep by offering" - " exceptional products and customer service."), - conversation_purpose= - ("find out whether they are looking to achieve better sleep via buying a premier" - " mattress."), + company_business=( + "Sleep Haven is a premium mattress company that provides customers with the" + " most comfortable and supportive sleeping experience possible. We offer a" + " range of high-quality mattresses, pillows, and bedding accessories that are" + " designed to meet the unique needs of our customers." + ), + company_values=( + "Our mission at Sleep Haven is to help people achieve a better night's sleep by" + " providing them with the best possible sleep solutions. We believe that" + " quality sleep is essential to overall health and well-being, and we are" + " committed to helping our customers achieve optimal sleep by offering" + " exceptional products and customer service." + ), + conversation_purpose=( + "find out whether they are looking to achieve better sleep via buying a premier" + " mattress." + ), conversation_history=[], conversation_type="call", conversation_stage=conversation_stages.get( "1", - ("Introduction: Start the conversation by introducing yourself and your" - " company. Be polite and respectful while keeping the tone of the" - " conversation professional."), + ( + "Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional." + ), ), use_tools=True, product_catalog="sample_product_catalog.txt", diff --git a/swarms/agents/refiner_agent.py b/swarms/agents/refiner_agent.py index 509484e3..2a1383e9 100644 --- a/swarms/agents/refiner_agent.py +++ b/swarms/agents/refiner_agent.py @@ -1,11 +1,9 @@ class PromptRefiner: - def __init__(self, system_prompt: str, llm): super().__init__() self.system_prompt = system_prompt self.llm = llm def run(self, task: str): - refine = self.llm( - f"System Prompt: {self.system_prompt} Current task: {task}") + refine = self.llm(f"System Prompt: {self.system_prompt} Current task: {task}") return refine diff --git a/swarms/agents/registry.py b/swarms/agents/registry.py index 5cf2c0d5..aa1f1375 100644 --- a/swarms/agents/registry.py +++ b/swarms/agents/registry.py @@ -10,7 +10,6 @@ class Registry(BaseModel): entries: Dict = {} def register(self, key: str): - def decorator(class_builder): self.entries[key] = class_builder return class_builder @@ -21,7 +20,8 @@ class Registry(BaseModel): if type not in self.entries: raise ValueError( f"{type} is not registered. Please register with the" - f' .register("{type}") method provided in {self.name} registry') + f' .register("{type}") method provided in {self.name} registry' + ) return self.entries[type](**kwargs) def get_all_entries(self): diff --git a/swarms/agents/simple_agent.py b/swarms/agents/simple_agent.py index 847cbc67..88327095 100644 --- a/swarms/agents/simple_agent.py +++ b/swarms/agents/simple_agent.py @@ -29,8 +29,7 @@ class SimpleAgent: def run(self, task: str) -> str: """Run method""" - metrics = print( - colored(f"Agent {self.name} is running task: {task}", "red")) + metrics = print(colored(f"Agent {self.name} is running task: {task}", "red")) print(metrics) response = self.flow.run(task) diff --git a/swarms/artifacts/base.py b/swarms/artifacts/base.py index 1357a86b..dac7a523 100644 --- a/swarms/artifacts/base.py +++ b/swarms/artifacts/base.py @@ -10,8 +10,9 @@ from marshmallow.exceptions import RegistryError @define class BaseArtifact(ABC): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - name: str = field(default=Factory(lambda self: self.id, takes_self=True), - kw_only=True) + name: str = field( + default=Factory(lambda self: self.id, takes_self=True), kw_only=True + ) value: any = field() type: str = field( default=Factory(lambda self: self.__class__.__name__, takes_self=True), @@ -53,8 +54,7 @@ class BaseArtifact(ABC): class_registry.register("ListArtifact", ListArtifactSchema) try: - return class_registry.get_class( - artifact_dict["type"])().load(artifact_dict) + return class_registry.get_class(artifact_dict["type"])().load(artifact_dict) except RegistryError: raise ValueError("Unsupported artifact type") diff --git a/swarms/artifacts/main.py b/swarms/artifacts/main.py index 8845ada3..4b240b22 100644 --- a/swarms/artifacts/main.py +++ b/swarms/artifacts/main.py @@ -15,7 +15,8 @@ class Artifact(BaseModel): artifact_id: StrictStr = Field(..., description="ID of the artifact") file_name: StrictStr = Field(..., description="Filename of the artifact") relative_path: Optional[StrictStr] = Field( - None, description="Relative path of the artifact") + None, description="Relative path of the artifact" + ) __properties = ["artifact_id", "file_name", "relative_path"] class Config: @@ -48,10 +49,12 @@ class Artifact(BaseModel): if not isinstance(obj, dict): return Artifact.parse_obj(obj) - _obj = Artifact.parse_obj({ - "artifact_id": obj.get("artifact_id"), - "file_name": obj.get("file_name"), - "relative_path": obj.get("relative_path"), - }) + _obj = Artifact.parse_obj( + { + "artifact_id": obj.get("artifact_id"), + "file_name": obj.get("file_name"), + "relative_path": obj.get("relative_path"), + } + ) return _obj diff --git a/swarms/chunkers/base.py b/swarms/chunkers/base.py index d243bd0d..0fabdcef 100644 --- a/swarms/chunkers/base.py +++ b/swarms/chunkers/base.py @@ -48,13 +48,15 @@ class BaseChunker(ABC): kw_only=True, ) tokenizer: OpenAITokenizer = field( - default=Factory(lambda: OpenAITokenizer( - model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), + default=Factory( + lambda: OpenAITokenizer( + model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL + ) + ), kw_only=True, ) max_tokens: int = field( - default=Factory(lambda self: self.tokenizer.max_tokens, - takes_self=True), + default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True), kw_only=True, ) @@ -64,9 +66,8 @@ class BaseChunker(ABC): return [TextArtifact(c) for c in self._chunk_recursively(text)] def _chunk_recursively( - self, - chunk: str, - current_separator: Optional[ChunkSeparator] = None) -> list[str]: + self, chunk: str, current_separator: Optional[ChunkSeparator] = None + ) -> list[str]: token_count = self.tokenizer.count_tokens(chunk) if token_count <= self.max_tokens: @@ -78,8 +79,7 @@ class BaseChunker(ABC): half_token_count = token_count // 2 if current_separator: - separators = self.separators[self.separators. - index(current_separator):] + separators = self.separators[self.separators.index(current_separator) :] else: separators = self.separators @@ -102,19 +102,26 @@ class BaseChunker(ABC): if separator.is_prefix: first_subchunk = separator.value + separator.value.join( - subchanks[:balance_index + 1]) + subchanks[: balance_index + 1] + ) second_subchunk = separator.value + separator.value.join( - subchanks[balance_index + 1:]) + subchanks[balance_index + 1 :] + ) else: - first_subchunk = (separator.value.join( - subchanks[:balance_index + 1]) + separator.value) + first_subchunk = ( + separator.value.join(subchanks[: balance_index + 1]) + + separator.value + ) second_subchunk = separator.value.join( - subchanks[balance_index + 1:]) + subchanks[balance_index + 1 :] + ) first_subchunk_rec = self._chunk_recursively( - first_subchunk.strip(), separator) + first_subchunk.strip(), separator + ) second_subchunk_rec = self._chunk_recursively( - second_subchunk.strip(), separator) + second_subchunk.strip(), separator + ) if first_subchunk_rec and second_subchunk_rec: return first_subchunk_rec + second_subchunk_rec diff --git a/swarms/chunkers/omni_chunker.py b/swarms/chunkers/omni_chunker.py index c4870e2b..70a11380 100644 --- a/swarms/chunkers/omni_chunker.py +++ b/swarms/chunkers/omni_chunker.py @@ -76,7 +76,8 @@ class OmniChunker: colored( f"Could not decode file with extension {file_extension}: {e}", "yellow", - )) + ) + ) return "" def chunk_content(self, content: str) -> List[str]: @@ -90,7 +91,7 @@ class OmniChunker: List[str]: The list of chunks. """ return [ - content[i:i + self.chunk_size] + content[i : i + self.chunk_size] for i in range(0, len(content), self.chunk_size) ] @@ -112,4 +113,5 @@ class OmniChunker: {self.metrics()} """, "cyan", - )) + ) + ) diff --git a/swarms/loaders/asana.py b/swarms/loaders/asana.py index 022b685b..dd14cff4 100644 --- a/swarms/loaders/asana.py +++ b/swarms/loaders/asana.py @@ -18,9 +18,9 @@ class AsanaReader(BaseReader): self.client = asana.Client.access_token(asana_token) - def load_data(self, - workspace_id: Optional[str] = None, - project_id: Optional[str] = None) -> List[Document]: + def load_data( + self, workspace_id: Optional[str] = None, project_id: Optional[str] = None + ) -> List[Document]: """Load data from the workspace. Args: @@ -31,20 +31,18 @@ class AsanaReader(BaseReader): """ if workspace_id is None and project_id is None: - raise ValueError( - "Either workspace_id or project_id must be provided") + raise ValueError("Either workspace_id or project_id must be provided") if workspace_id is not None and project_id is not None: raise ValueError( - "Only one of workspace_id or project_id should be provided") + "Only one of workspace_id or project_id should be provided" + ) results = [] if workspace_id is not None: - workspace_name = self.client.workspaces.find_by_id( - workspace_id)["name"] - projects = self.client.projects.find_all( - {"workspace": workspace_id}) + workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"] + projects = self.client.projects.find_all({"workspace": workspace_id}) # Case: Only project_id is provided else: # since we've handled the other cases, this means project_id is not None @@ -52,58 +50,54 @@ class AsanaReader(BaseReader): workspace_name = projects[0]["workspace"]["name"] for project in projects: - tasks = self.client.tasks.find_all({ - "project": - project["gid"], - "opt_fields": - "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", - }) + tasks = self.client.tasks.find_all( + { + "project": project["gid"], + "opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", + } + ) for task in tasks: - stories = self.client.tasks.stories(task["gid"], - opt_fields="type,text") - comments = "\n".join([ - story["text"] - for story in stories - if story.get("type") == "comment" and "text" in story - ]) + stories = self.client.tasks.stories(task["gid"], opt_fields="type,text") + comments = "\n".join( + [ + story["text"] + for story in stories + if story.get("type") == "comment" and "text" in story + ] + ) task_metadata = { - "task_id": - task.get("gid", ""), - "name": - task.get("name", ""), + "task_id": task.get("gid", ""), + "name": task.get("name", ""), "assignee": (task.get("assignee") or {}).get("name", ""), - "completed_on": - task.get("completed_at", ""), - "completed_by": (task.get("completed_by") or - {}).get("name", ""), - "project_name": - project.get("name", ""), + "completed_on": task.get("completed_at", ""), + "completed_by": (task.get("completed_by") or {}).get("name", ""), + "project_name": project.get("name", ""), "custom_fields": [ i["display_value"] for i in task.get("custom_fields") if task.get("custom_fields") is not None ], - "workspace_name": - workspace_name, - "url": - f"https://app.asana.com/0/{project['gid']}/{task['gid']}", + "workspace_name": workspace_name, + "url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}", } if task.get("followers") is not None: task_metadata["followers"] = [ - i.get("name") - for i in task.get("followers") - if "name" in i + i.get("name") for i in task.get("followers") if "name" in i ] else: task_metadata["followers"] = [] results.append( Document( - text=task.get("name", "") + " " + - task.get("notes", "") + " " + comments, + text=task.get("name", "") + + " " + + task.get("notes", "") + + " " + + comments, extra_info=task_metadata, - )) + ) + ) return results diff --git a/swarms/loaders/base.py b/swarms/loaders/base.py index 2d5c7cdb..afeeb231 100644 --- a/swarms/loaders/base.py +++ b/swarms/loaders/base.py @@ -47,8 +47,7 @@ class BaseComponent(BaseModel): # TODO: return type here not supported by current mypy version @classmethod - def from_dict(cls, data: Dict[str, Any], - **kwargs: Any) -> Self: # type: ignore + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore if isinstance(kwargs, dict): data.update(kwargs) @@ -119,10 +118,12 @@ class BaseNode(BaseComponent): class Config: allow_population_by_field_name = True - id_: str = Field(default_factory=lambda: str(uuid.uuid4()), - description="Unique ID of the node.") + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." + ) embedding: Optional[List[float]] = Field( - default=None, description="Embedding of the node.") + default=None, description="Embedding of the node." + ) """" metadata fields - injected as part of the text shown to LLMs as context @@ -137,8 +138,7 @@ class BaseNode(BaseComponent): ) excluded_embed_metadata_keys: List[str] = Field( default_factory=list, - description= - "Metadata keys that are excluded from text for the embed model.", + description="Metadata keys that are excluded from text for the embed model.", ) excluded_llm_metadata_keys: List[str] = Field( default_factory=list, @@ -156,8 +156,7 @@ class BaseNode(BaseComponent): """Get Object type.""" @abstractmethod - def get_content(self, - metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: """Get object content.""" @abstractmethod @@ -188,8 +187,7 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.SOURCE] if isinstance(relation, list): - raise ValueError( - "Source object must be a single RelatedNodeInfo object") + raise ValueError("Source object must be a single RelatedNodeInfo object") return relation @property @@ -200,8 +198,7 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.PREVIOUS] if not isinstance(relation, RelatedNodeInfo): - raise ValueError( - "Previous object must be a single RelatedNodeInfo object") + raise ValueError("Previous object must be a single RelatedNodeInfo object") return relation @property @@ -212,8 +209,7 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.NEXT] if not isinstance(relation, RelatedNodeInfo): - raise ValueError( - "Next object must be a single RelatedNodeInfo object") + raise ValueError("Next object must be a single RelatedNodeInfo object") return relation @property @@ -224,8 +220,7 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.PARENT] if not isinstance(relation, RelatedNodeInfo): - raise ValueError( - "Parent object must be a single RelatedNodeInfo object") + raise ValueError("Parent object must be a single RelatedNodeInfo object") return relation @property @@ -236,8 +231,7 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.CHILD] if not isinstance(relation, list): - raise ValueError( - "Child objects must be a list of RelatedNodeInfo objects.") + raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") return relation @property @@ -254,10 +248,12 @@ class BaseNode(BaseComponent): return self.metadata def __str__(self) -> str: - source_text_truncated = truncate_text(self.get_content().strip(), - TRUNCATE_LENGTH) - source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", - width=WRAP_WIDTH) + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) return f"Node ID: {self.node_id}\n{source_text_wrapped}" def get_embedding(self) -> List[float]: @@ -283,23 +279,28 @@ class BaseNode(BaseComponent): class TextNode(BaseNode): text: str = Field(default="", description="Text content of the node.") start_char_idx: Optional[int] = Field( - default=None, description="Start char index of the node.") + default=None, description="Start char index of the node." + ) end_char_idx: Optional[int] = Field( - default=None, description="End char index of the node.") + default=None, description="End char index of the node." + ) text_template: str = Field( default=DEFAULT_TEXT_NODE_TMPL, - description=("Template for how text is formatted, with {content} and " - "{metadata_str} placeholders."), + description=( + "Template for how text is formatted, with {content} and " + "{metadata_str} placeholders." + ), ) metadata_template: str = Field( default=DEFAULT_METADATA_TMPL, - description=("Template for how metadata is formatted, with {key} and " - "{value} placeholders."), + description=( + "Template for how metadata is formatted, with {key} and " + "{value} placeholders." + ), ) metadata_seperator: str = Field( default="\n", - description= - "Separator between metadata fields when converting to string.", + description="Separator between metadata fields when converting to string.", ) @classmethod @@ -313,7 +314,8 @@ class TextNode(BaseNode): metadata = values.get("metadata", {}) doc_identity = str(text) + str(metadata) values["hash"] = str( - sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) + sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() + ) return values @classmethod @@ -321,15 +323,15 @@ class TextNode(BaseNode): """Get Object type.""" return ObjectType.TEXT - def get_content(self, - metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: """Get object content.""" metadata_str = self.get_metadata_str(mode=metadata_mode).strip() if not metadata_str: return self.text - return self.text_template.format(content=self.text, - metadata_str=metadata_str).strip() + return self.text_template.format( + content=self.text, metadata_str=metadata_str + ).strip() def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: """Metadata info string.""" @@ -346,11 +348,13 @@ class TextNode(BaseNode): if key in usable_metadata_keys: usable_metadata_keys.remove(key) - return self.metadata_seperator.join([ - self.metadata_template.format(key=key, value=str(value)) - for key, value in self.metadata.items() - if key in usable_metadata_keys - ]) + return self.metadata_seperator.join( + [ + self.metadata_template.format(key=key, value=str(value)) + for key, value in self.metadata.items() + if key in usable_metadata_keys + ] + ) def set_content(self, value: str) -> None: """Set the content of the node.""" @@ -474,8 +478,7 @@ class NodeWithScore(BaseComponent): else: raise ValueError("Node must be a TextNode to get text.") - def get_content(self, - metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: return self.node.get_content(metadata_mode=metadata_mode) def get_embedding(self) -> List[float]: @@ -512,10 +515,12 @@ class Document(TextNode): return self.id_ def __str__(self) -> str: - source_text_truncated = truncate_text(self.get_content().strip(), - TRUNCATE_LENGTH) - source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", - width=WRAP_WIDTH) + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" def get_doc_id(self) -> str: @@ -531,27 +536,22 @@ class Document(TextNode): """Convert struct to Haystack document format.""" from haystack.schema import Document as HaystackDocument - return HaystackDocument(content=self.text, - meta=self.metadata, - embedding=self.embedding, - id=self.id_) + return HaystackDocument( + content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ + ) @classmethod def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": """Convert struct from Haystack document format.""" - return cls(text=doc.content, - metadata=doc.meta, - embedding=doc.embedding, - id_=doc.id) + return cls( + text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id + ) def to_embedchain_format(self) -> Dict[str, Any]: """Convert struct to EmbedChain document format.""" return { "doc_id": self.id_, - "data": { - "content": self.text, - "meta_data": self.metadata - }, + "data": {"content": self.text, "meta_data": self.metadata}, } @classmethod @@ -581,8 +581,7 @@ class Document(TextNode): return cls( text=doc._text, metadata={"additional_metadata": doc._additional_metadata}, - embedding=doc._embedding.tolist() - if doc._embedding is not None else None, + embedding=doc._embedding.tolist() if doc._embedding is not None else None, id_=doc._id, ) @@ -590,10 +589,7 @@ class Document(TextNode): def example(cls) -> "Document": return Document( text=SAMPLE_TEXT, - metadata={ - "filename": "README.md", - "category": "codebase" - }, + metadata={"filename": "README.md", "category": "codebase"}, ) @classmethod diff --git a/swarms/memory/base.py b/swarms/memory/base.py index 7c08af6f..7f71c4b9 100644 --- a/swarms/memory/base.py +++ b/swarms/memory/base.py @@ -30,25 +30,32 @@ class BaseVectorStore(ABC): embedding_driver: Any futures_executor: futures.Executor = field( - default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) - - def upsert_text_artifacts(self, - artifacts: dict[str, list[TextArtifact]], - meta: Optional[dict] = None, - **kwargs) -> None: - execute_futures_dict({ - namespace: - self.futures_executor.submit(self.upsert_text_artifact, a, - namespace, meta, **kwargs) - for namespace, artifact_list in artifacts.items() - for a in artifact_list - }) - - def upsert_text_artifact(self, - artifact: TextArtifact, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs) -> str: + default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True + ) + + def upsert_text_artifacts( + self, + artifacts: dict[str, list[TextArtifact]], + meta: Optional[dict] = None, + **kwargs + ) -> None: + execute_futures_dict( + { + namespace: self.futures_executor.submit( + self.upsert_text_artifact, a, namespace, meta, **kwargs + ) + for namespace, artifact_list in artifacts.items() + for a in artifact_list + } + ) + + def upsert_text_artifact( + self, + artifact: TextArtifact, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: if not meta: meta = {} @@ -59,37 +66,39 @@ class BaseVectorStore(ABC): else: vector = artifact.generate_embedding(self.embedding_driver) - return self.upsert_vector(vector, - vector_id=artifact.id, - namespace=namespace, - meta=meta, - **kwargs) - - def upsert_text(self, - string: str, - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs) -> str: - return self.upsert_vector(self.embedding_driver.embed_string(string), - vector_id=vector_id, - namespace=namespace, - meta=meta if meta else {}, - **kwargs) + return self.upsert_vector( + vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs + ) + + def upsert_text( + self, + string: str, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: + return self.upsert_vector( + self.embedding_driver.embed_string(string), + vector_id=vector_id, + namespace=namespace, + meta=meta if meta else {}, + **kwargs + ) @abstractmethod - def upsert_vector(self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs) -> str: + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: ... @abstractmethod - def load_entry(self, - vector_id: str, - namespace: Optional[str] = None) -> Entry: + def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: ... @abstractmethod @@ -97,10 +106,12 @@ class BaseVectorStore(ABC): ... @abstractmethod - def query(self, - query: str, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - **kwargs) -> list[QueryResult]: + def query( + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs + ) -> list[QueryResult]: ... diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index 080245fb..67ba4cb2 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -80,8 +80,10 @@ class Chroma(VectorStore): import chromadb import chromadb.config except ImportError: - raise ImportError("Could not import chromadb python package. " - "Please install it with `pip install chromadb`.") + raise ImportError( + "Could not import chromadb python package. " + "Please install it with `pip install chromadb`." + ) if client is not None: self._client_settings = client_settings @@ -92,7 +94,8 @@ class Chroma(VectorStore): # If client_settings is provided with persist_directory specified, # then it is "in-memory and persisting to disk" mode. client_settings.persist_directory = ( - persist_directory or client_settings.persist_directory) + persist_directory or client_settings.persist_directory + ) if client_settings.persist_directory is not None: # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") @@ -105,23 +108,25 @@ class Chroma(VectorStore): major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: _client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet",) + chroma_db_impl="duckdb+parquet", + ) else: - _client_settings = chromadb.config.Settings( - is_persistent=True) + _client_settings = chromadb.config.Settings(is_persistent=True) _client_settings.persist_directory = persist_directory else: _client_settings = chromadb.config.Settings() self._client_settings = _client_settings self._client = chromadb.Client(_client_settings) - self._persist_directory = (_client_settings.persist_directory or - persist_directory) + self._persist_directory = ( + _client_settings.persist_directory or persist_directory + ) self._embedding_function = embedding_function self._collection = self._client.get_or_create_collection( name=collection_name, embedding_function=self._embedding_function.embed_documents - if self._embedding_function is not None else None, + if self._embedding_function is not None + else None, metadata=collection_metadata, ) self.override_relevance_score_fn = relevance_score_fn @@ -144,8 +149,10 @@ class Chroma(VectorStore): try: import chromadb # noqa: F401 except ImportError: - raise ValueError("Could not import chromadb python package. " - "Please install it with `pip install chromadb`.") + raise ValueError( + "Could not import chromadb python package. " + "Please install it with `pip install chromadb`." + ) return self._collection.query( query_texts=query_texts, query_embeddings=query_embeddings, @@ -195,9 +202,9 @@ class Chroma(VectorStore): if non_empty_ids: metadatas = [metadatas[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids] - embeddings_with_metadatas = ([ - embeddings[idx] for idx in non_empty_ids - ] if embeddings else None) + embeddings_with_metadatas = ( + [embeddings[idx] for idx in non_empty_ids] if embeddings else None + ) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: self._collection.upsert( @@ -218,7 +225,8 @@ class Chroma(VectorStore): if empty_ids: texts_without_metadatas = [texts[j] for j in empty_ids] embeddings_without_metadatas = ( - [embeddings[j] for j in empty_ids] if embeddings else None) + [embeddings[j] for j in empty_ids] if embeddings else None + ) ids_without_metadatas = [ids[j] for j in empty_ids] self._collection.upsert( embeddings=embeddings_without_metadatas, @@ -250,9 +258,7 @@ class Chroma(VectorStore): Returns: List[Document]: List of documents most similar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, - k, - filter=filter) + docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector( @@ -375,7 +381,8 @@ class Chroma(VectorStore): raise ValueError( "No supported normalization function" f" for distance metric of type: {distance}." - "Consider providing relevance_score_fn to Chroma constructor.") + "Consider providing relevance_score_fn to Chroma constructor." + ) def max_marginal_relevance_search_by_vector( self, @@ -421,9 +428,7 @@ class Chroma(VectorStore): candidates = _results_to_docs(results) - selected_results = [ - r for i, r in enumerate(candidates) if i in mmr_selected - ] + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] return selected_results def max_marginal_relevance_search( @@ -518,8 +523,10 @@ class Chroma(VectorStore): It will also be called automatically when the object is destroyed. """ if self._persist_directory is None: - raise ValueError("You must specify a persist_directory on" - "creation to persist the collection.") + raise ValueError( + "You must specify a persist_directory on" + "creation to persist the collection." + ) import chromadb # Maintain backwards compatibility with chromadb < 0.4.0 @@ -536,8 +543,7 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], - documents: List[Document]) -> None: + def update_documents(self, ids: List[str], documents: List[Document]) -> None: """Update a document in the collection. Args: @@ -552,16 +558,17 @@ class Chroma(VectorStore): ) embeddings = self._embedding_function.embed_documents(text) - if hasattr(self._collection._client, - "max_batch_size"): # for Chroma 0.4.10 and above + if hasattr( + self._collection._client, "max_batch_size" + ): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( - api=self._collection._client, - ids=ids, - metadatas=metadata, - documents=text, - embeddings=embeddings, + api=self._collection._client, + ids=ids, + metadatas=metadata, + documents=text, + embeddings=embeddings, ): self._collection.update( ids=batch[0], @@ -621,15 +628,16 @@ class Chroma(VectorStore): ) if ids is None: ids = [str(uuid.uuid1()) for _ in texts] - if hasattr(chroma_collection._client, - "max_batch_size"): # for Chroma 0.4.10 and above + if hasattr( + chroma_collection._client, "max_batch_size" + ): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( - api=chroma_collection._client, - ids=ids, - metadatas=metadatas, - documents=texts, + api=chroma_collection._client, + ids=ids, + metadatas=metadatas, + documents=texts, ): chroma_collection.add_texts( texts=batch[3] if batch[3] else [], @@ -637,9 +645,7 @@ class Chroma(VectorStore): ids=batch[0], ) else: - chroma_collection.add_texts(texts=texts, - metadatas=metadatas, - ids=ids) + chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) return chroma_collection @classmethod diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index 9b183834..99d47368 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -19,7 +19,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: if X.shape[1] != Y.shape[1]: raise ValueError( f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}.") + f"and Y has shape {Y.shape}." + ) try: import simsimd as simd @@ -32,7 +33,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: except ImportError: logger.info( "Unable to import simsimd, defaulting to NumPy implementation. If you want " - "to use simsimd please install with `pip install simsimd`.") + "to use simsimd please install with `pip install simsimd`." + ) X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) # Ignore divide by zero errors run time warnings as those are handled below. diff --git a/swarms/memory/db.py b/swarms/memory/db.py index 8e6bad12..9f23b59f 100644 --- a/swarms/memory/db.py +++ b/swarms/memory/db.py @@ -27,7 +27,6 @@ class NotFoundException(Exception): class TaskDB(ABC): - async def create_task( self, input: Optional[str], @@ -68,9 +67,9 @@ class TaskDB(ABC): async def list_tasks(self) -> List[Task]: raise NotImplementedError - async def list_steps(self, - task_id: str, - status: Optional[Status] = None) -> List[Step]: + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: raise NotImplementedError @@ -137,8 +136,8 @@ class InMemoryTaskDB(TaskDB): async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: task = await self.get_task(task_id) artifact = next( - filter(lambda a: a.artifact_id == artifact_id, task.artifacts), - None) + filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None + ) if not artifact: raise NotFoundException("Artifact", artifact_id) return artifact @@ -151,9 +150,9 @@ class InMemoryTaskDB(TaskDB): step_id: Optional[str] = None, ) -> Artifact: artifact_id = str(uuid.uuid4()) - artifact = Artifact(artifact_id=artifact_id, - file_name=file_name, - relative_path=relative_path) + artifact = Artifact( + artifact_id=artifact_id, file_name=file_name, relative_path=relative_path + ) task = await self.get_task(task_id) task.artifacts.append(artifact) @@ -166,9 +165,9 @@ class InMemoryTaskDB(TaskDB): async def list_tasks(self) -> List[Task]: return [task for task in self._tasks.values()] - async def list_steps(self, - task_id: str, - status: Optional[Status] = None) -> List[Step]: + async def list_steps( + self, task_id: str, status: Optional[Status] = None + ) -> List[Step]: task = await self.get_task(task_id) steps = task.steps if status: diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py index 339c3596..da58c81c 100644 --- a/swarms/memory/ocean.py +++ b/swarms/memory/ocean.py @@ -63,7 +63,8 @@ class OceanDB: try: embedding_function = MultiModalEmbeddingFunction(modality=modality) collection = self.client.create_collection( - collection_name, embedding_function=embedding_function) + collection_name, embedding_function=embedding_function + ) return collection except Exception as e: logging.error(f"Failed to create collection. Error {e}") @@ -90,8 +91,7 @@ class OceanDB: try: return collection.add(documents=[document], ids=[id]) except Exception as e: - logging.error( - f"Failed to append document to the collection. Error {e}") + logging.error(f"Failed to append document to the collection. Error {e}") raise def add_documents(self, collection, documents: List[str], ids: List[str]): @@ -137,8 +137,7 @@ class OceanDB: the results of the query """ try: - results = collection.query(query_texts=query_texts, - n_results=n_results) + results = collection.query(query_texts=query_texts, n_results=n_results) return results except Exception as e: logging.error(f"Failed to query the collection. Error {e}") diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index 09534cac..bd768459 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -88,12 +88,12 @@ class PgVectorVectorStore(BaseVectorStore): create_engine_params: dict = field(factory=dict, kw_only=True) engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) - _model: any = field(default=Factory( - lambda self: self.default_vector_model(), takes_self=True)) + _model: any = field( + default=Factory(lambda self: self.default_vector_model(), takes_self=True) + ) @connection_string.validator - def validate_connection_string(self, _, - connection_string: Optional[str]) -> None: + def validate_connection_string(self, _, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. if self.engine is not None: return @@ -122,8 +122,9 @@ class PgVectorVectorStore(BaseVectorStore): If not, a connection string is used to create a new database connection here. """ if self.engine is None: - self.engine = create_engine(self.connection_string, - **self.create_engine_params) + self.engine = create_engine( + self.connection_string, **self.create_engine_params + ) def setup( self, @@ -141,12 +142,14 @@ class PgVectorVectorStore(BaseVectorStore): if create_schema: self._model.metadata.create_all(self.engine) - def upsert_vector(self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs) -> str: + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: """Inserts or updates a vector in the collection.""" with Session(self.engine) as session: obj = self._model( @@ -161,9 +164,9 @@ class PgVectorVectorStore(BaseVectorStore): return str(obj.id) - def load_entry(self, - vector_id: str, - namespace: Optional[str] = None) -> BaseVectorStore.Entry: + def load_entry( + self, vector_id: str, namespace: Optional[str] = None + ) -> BaseVectorStore.Entry: """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" with Session(self.engine) as session: result = session.get(self._model, vector_id) @@ -176,8 +179,8 @@ class PgVectorVectorStore(BaseVectorStore): ) def load_entries( - self, - namespace: Optional[str] = None) -> list[BaseVectorStore.Entry]: + self, namespace: Optional[str] = None + ) -> list[BaseVectorStore.Entry]: """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace. """ @@ -194,16 +197,19 @@ class PgVectorVectorStore(BaseVectorStore): vector=result.vector, namespace=result.namespace, meta=result.meta, - ) for result in results + ) + for result in results ] - def query(self, - query: str, - count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, - namespace: Optional[str] = None, - include_vectors: bool = False, - distance_metric: str = "cosine_distance", - **kwargs) -> list[BaseVectorStore.QueryResult]: + def query( + self, + query: str, + count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, + namespace: Optional[str] = None, + include_vectors: bool = False, + distance_metric: str = "cosine_distance", + **kwargs + ) -> list[BaseVectorStore.QueryResult]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace. """ @@ -239,7 +245,8 @@ class PgVectorVectorStore(BaseVectorStore): score=result[1], meta=result[0].meta, namespace=result[0].namespace, - ) for result in results + ) + for result in results ] def default_vector_model(self) -> any: diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 0269aa38..2374f12a 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -102,12 +102,14 @@ class PineconeVectorStoreStore(BaseVector): self.index = pinecone.Index(self.index_name) - def upsert_vector(self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs) -> str: + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: """Upsert vector""" vector_id = vector_id if vector_id else str_to_hash(str(vector)) @@ -118,12 +120,10 @@ class PineconeVectorStoreStore(BaseVector): return vector_id def load_entry( - self, - vector_id: str, - namespace: Optional[str] = None) -> Optional[BaseVector.Entry]: + self, vector_id: str, namespace: Optional[str] = None + ) -> Optional[BaseVector.Entry]: """Load entry""" - result = self.index.fetch(ids=[vector_id], - namespace=namespace).to_dict() + result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -138,8 +138,7 @@ class PineconeVectorStoreStore(BaseVector): else: return None - def load_entries(self, - namespace: Optional[str] = None) -> list[BaseVector.Entry]: + def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]: """Load entries""" # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: @@ -158,18 +157,20 @@ class PineconeVectorStoreStore(BaseVector): vector=r["values"], meta=r["metadata"], namespace=results["namespace"], - ) for r in results["matches"] + ) + for r in results["matches"] ] def query( - self, - query: str, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - # PineconeVectorStoreStorageDriver-specific params: - include_metadata=True, - **kwargs) -> list[BaseVector.QueryResult]: + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + # PineconeVectorStoreStorageDriver-specific params: + include_metadata=True, + **kwargs + ) -> list[BaseVector.QueryResult]: """Query vectors""" vector = self.embedding_driver.embed_string(query) @@ -189,14 +190,12 @@ class PineconeVectorStoreStore(BaseVector): score=r["score"], meta=r["metadata"], namespace=results["namespace"], - ) for r in results["matches"] + ) + for r in results["matches"] ] def create_index(self, name: str, **kwargs) -> None: """Create index""" - params = { - "name": name, - "dimension": self.embedding_driver.dimensions - } | kwargs + params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs pinecone.create_index(**params) diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index ce54208d..bbc71bc2 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -20,9 +20,9 @@ class Artifact(BaseModel): description="Id of the artifact", example="b225e278-8b4c-4f99-a696-8facf19f0e56", ) - file_name: str = Field(..., - description="Filename of the artifact", - example="main.py") + file_name: str = Field( + ..., description="Filename of the artifact", example="main.py" + ) relative_path: Optional[str] = Field( None, description="Relative path of the artifact in the agent's workspace", @@ -50,8 +50,7 @@ class StepInput(BaseModel): class StepOutput(BaseModel): __root__: Any = Field( ..., - description= - "Output that the task step has produced. Any value is allowed.", + description="Output that the task step has produced. Any value is allowed.", example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -82,9 +81,9 @@ class Task(TaskRequestBody): class StepRequestBody(BaseModel): - input: Optional[str] = Field(None, - description="Input prompt for the step.", - example="Washington") + input: Optional[str] = Field( + None, description="Input prompt for the step.", example="Washington" + ) additional_input: Optional[StepInput] = None @@ -105,19 +104,22 @@ class Step(StepRequestBody): description="The ID of the task step.", example="6bb1801a-fd80-45e8-899a-4dd723cc602e", ) - name: Optional[str] = Field(None, - description="The name of the task step.", - example="Write to file") + name: Optional[str] = Field( + None, description="The name of the task step.", example="Write to file" + ) status: Status = Field(..., description="The status of the task step.") output: Optional[str] = Field( None, description="Output of the task step.", - example= - ("I am going to use the write_to_file command and write Washington to a file" - " called output.txt best_score: best_score = equation_score idx_to_add = i @@ -56,8 +57,8 @@ def maximal_marginal_relevance( def filter_complex_metadata( documents: List[Document], *, - allowed_types: Tuple[Type, - ...] = (str, bool, int, float)) -> List[Document]: + allowed_types: Tuple[Type, ...] = (str, bool, int, float) +) -> List[Document]: """Filter out metadata types that are not supported for a vector store.""" updated_documents = [] for document in documents: diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index 634fa030..30ec22ce 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -41,24 +41,21 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: """Validate specified keyword args are mutually exclusive.""" def decorator(func: Callable) -> Callable: - @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: """Validate exactly one arg in each group is not None.""" counts = [ - sum(1 - for arg in arg_group - if kwargs.get(arg) is not None) + sum(1 for arg in arg_group if kwargs.get(arg) is not None) for arg_group in arg_groups ] invalid_groups = [i for i, count in enumerate(counts) if count != 1] if invalid_groups: - invalid_group_names = [ - ", ".join(arg_groups[i]) for i in invalid_groups - ] - raise ValueError("Exactly one argument in each of the following" - " groups must be defined:" - f" {', '.join(invalid_group_names)}") + invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] + raise ValueError( + "Exactly one argument in each of the following" + " groups must be defined:" + f" {', '.join(invalid_group_names)}" + ) return func(*args, **kwargs) return wrapper @@ -108,10 +105,9 @@ def mock_now(dt_value): # type: ignore datetime.datetime = real_datetime -def guard_import(module_name: str, - *, - pip_name: Optional[str] = None, - package: Optional[str] = None) -> Any: +def guard_import( + module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None +) -> Any: """Dynamically imports a module and raises a helpful exception if the module is not installed.""" try: @@ -119,7 +115,8 @@ def guard_import(module_name: str, except ImportError: raise ImportError( f"Could not import {module_name} python package. " - f"Please install it with `pip install {pip_name or module_name}`.") + f"Please install it with `pip install {pip_name or module_name}`." + ) return module @@ -135,19 +132,23 @@ def check_package_version( if lt_version is not None and imported_version >= parse(lt_version): raise ValueError( f"Expected {package} version to be < {lt_version}. Received " - f"{imported_version}.") + f"{imported_version}." + ) if lte_version is not None and imported_version > parse(lte_version): raise ValueError( f"Expected {package} version to be <= {lte_version}. Received " - f"{imported_version}.") + f"{imported_version}." + ) if gt_version is not None and imported_version <= parse(gt_version): raise ValueError( f"Expected {package} version to be > {gt_version}. Received " - f"{imported_version}.") + f"{imported_version}." + ) if gte_version is not None and imported_version < parse(gte_version): raise ValueError( f"Expected {package} version to be >= {gte_version}. Received " - f"{imported_version}.") + f"{imported_version}." + ) def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: @@ -179,17 +180,19 @@ def build_extra_kwargs( if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - warnings.warn(f"""WARNING! {field_name} is not default parameter. + warnings.warn( + f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""") + Please confirm that {field_name} is what you intended.""" + ) extra_kwargs[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection( - extra_kwargs.keys()) + invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter.") + "Instead they were passed in as part of `model_kwargs` parameter." + ) return extra_kwargs @@ -238,16 +241,17 @@ class _AnthropicCommon(BaseLanguageModel): def build_extra(cls, values: Dict) -> Dict: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) - values["model_kwargs"] = build_extra_kwargs(extra, values, - all_required_field_names) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "anthropic_api_key", - "ANTHROPIC_API_KEY")) + get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") + ) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( values, @@ -277,7 +281,8 @@ class _AnthropicCommon(BaseLanguageModel): except ImportError: raise ImportError( "Could not import anthropic python package. " - "Please it install it with `pip install anthropic`.") + "Please it install it with `pip install anthropic`." + ) return values @property @@ -300,8 +305,7 @@ class _AnthropicCommon(BaseLanguageModel): """Get the identifying parameters.""" return {**{}, **self._default_params} - def _get_anthropic_stop(self, - stop: Optional[List[str]] = None) -> List[str]: + def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: if not self.HUMAN_PROMPT or not self.AI_PROMPT: raise NameError("Please ensure the anthropic package is loaded") @@ -368,8 +372,7 @@ class Anthropic(LLM, _AnthropicCommon): return prompt # Already wrapped. # Guard against common errors in specifying wrong number of newlines. - corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, - prompt) + corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) if n_subs == 1: return corrected_prompt @@ -402,10 +405,9 @@ class Anthropic(LLM, _AnthropicCommon): """ if self.streaming: completion = "" - for chunk in self._stream(prompt=prompt, - stop=stop, - run_manager=run_manager, - **kwargs): + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): completion += chunk.text return completion @@ -431,10 +433,9 @@ class Anthropic(LLM, _AnthropicCommon): """Call out to Anthropic's completion endpoint asynchronously.""" if self.streaming: completion = "" - async for chunk in self._astream(prompt=prompt, - stop=stop, - run_manager=run_manager, - **kwargs): + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): completion += chunk.text return completion @@ -475,10 +476,8 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} for token in self.client.completions.create( - prompt=self._wrap_prompt(prompt), - stop_sequences=stop, - stream=True, - **params): + prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + ): chunk = GenerationChunk(text=token.completion) yield chunk if run_manager: @@ -510,10 +509,10 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} async for token in await self.async_client.completions.create( - prompt=self._wrap_prompt(prompt), - stop_sequences=stop, - stream=True, - **params, + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, ): chunk = GenerationChunk(text=token.completion) yield chunk diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py index d7052ef3..c2b4bfa5 100644 --- a/swarms/models/bioclip.py +++ b/swarms/models/bioclip.py @@ -97,8 +97,9 @@ class BioClip: self.preprocess_val, ) = open_clip.create_model_and_transforms(model_path) self.tokenizer = open_clip.get_tokenizer(model_path) - self.device = (torch.device("cuda") - if torch.cuda.is_available() else torch.device("cpu")) + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) self.model.to(self.device) self.model.eval() @@ -109,17 +110,18 @@ class BioClip: template: str = "this is a photo of ", context_length: int = 256, ): - image = torch.stack([self.preprocess_val(Image.open(img_path)) - ]).to(self.device) - texts = self.tokenizer([template + l for l in labels], - context_length=context_length).to(self.device) + image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device) + texts = self.tokenizer( + [template + l for l in labels], context_length=context_length + ).to(self.device) with torch.no_grad(): - image_features, text_features, logit_scale = self.model( - image, texts) - logits = ((logit_scale * - image_features @ text_features.t()).detach().softmax( - dim=-1)) + image_features, text_features, logit_scale = self.model(image, texts) + logits = ( + (logit_scale * image_features @ text_features.t()) + .detach() + .softmax(dim=-1) + ) sorted_indices = torch.argsort(logits, dim=-1, descending=True) logits = logits.cpu().numpy() sorted_indices = sorted_indices.cpu().numpy() @@ -137,8 +139,11 @@ class BioClip: fig, ax = plt.subplots(figsize=(5, 5)) ax.imshow(img) ax.axis("off") - title = (metadata["filename"] + "\n" + "\n".join( - [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()])) + title = ( + metadata["filename"] + + "\n" + + "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()]) + ) ax.set_title(title, fontsize=14) plt.tight_layout() plt.show() diff --git a/swarms/models/biogpt.py b/swarms/models/biogpt.py index ebec10b9..83c31e55 100644 --- a/swarms/models/biogpt.py +++ b/swarms/models/biogpt.py @@ -102,9 +102,9 @@ class BioGPT: list[dict]: A list of generated texts. """ set_seed(42) - generator = pipeline("text-generation", - model=self.model, - tokenizer=self.tokenizer) + generator = pipeline( + "text-generation", model=self.model, tokenizer=self.tokenizer + ) out = generator( text, max_length=self.max_length, @@ -149,11 +149,13 @@ class BioGPT: inputs = self.tokenizer(sentence, return_tensors="pt") set_seed(42) with torch.no_grad(): - beam_output = self.model.generate(**inputs, - min_length=self.min_length, - max_length=self.max_length, - num_beams=num_beams, - early_stopping=early_stopping) + beam_output = self.model.generate( + **inputs, + min_length=self.min_length, + max_length=self.max_length, + num_beams=num_beams, + early_stopping=early_stopping + ) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) # Feature 1: Set a new tokenizer and model diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 788bae62..c24f262d 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -124,10 +124,13 @@ class Dalle3: # Handling exceptions and printing the errors details print( colored( - (f"Error running Dalle3: {error} try optimizing your api key and" - " or try again"), + ( + f"Error running Dalle3: {error} try optimizing your api key and" + " or try again" + ), "red", - )) + ) + ) raise error def create_variations(self, img: str): @@ -154,19 +157,22 @@ class Dalle3: """ try: - response = self.client.images.create_variation(img=open(img, "rb"), - n=self.n, - size=self.size) + response = self.client.images.create_variation( + img=open(img, "rb"), n=self.n, size=self.size + ) img = response.data[0].url return img except (Exception, openai.OpenAIError) as error: print( colored( - (f"Error running Dalle3: {error} try optimizing your api key and" - " or try again"), + ( + f"Error running Dalle3: {error} try optimizing your api key and" + " or try again" + ), "red", - )) + ) + ) print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red")) raise error diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 8fc5b99a..0a60aaac 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -18,7 +18,6 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): """ def decorator(func): - @wraps(func) async def wrapper(*args, **kwargs): retries = max_retries @@ -29,9 +28,7 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): retries -= 1 if retries <= 0: raise - print( - f"Retry after exception: {e}, Attempts remaining: {retries}" - ) + print(f"Retry after exception: {e}, Attempts remaining: {retries}") await asyncio.sleep(delay) return wrapper @@ -65,8 +62,7 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.torch_dtype = torch.float16 if torch.cuda.is_available( - ) else torch.float32 + self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, @@ -123,14 +119,14 @@ class DistilWhisperModel: try: with torch.no_grad(): # Load the whole audio file, but process and transcribe it in chunks - audio_input = self.processor.audio_file_to_array( - audio_file_path) + audio_input = self.processor.audio_file_to_array(audio_file_path) sample_rate = audio_input.sampling_rate total_duration = len(audio_input.array) / sample_rate chunks = [ - audio_input.array[i:i + sample_rate * chunk_duration] - for i in range(0, len(audio_input.array), sample_rate * - chunk_duration) + audio_input.array[i : i + sample_rate * chunk_duration] + for i in range( + 0, len(audio_input.array), sample_rate * chunk_duration + ) ] print(colored("Starting real-time transcription...", "green")) @@ -143,22 +139,22 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = processed_inputs.input_values.to( - self.device) + processed_inputs = processed_inputs.input_values.to(self.device) # Generate transcription for the chunk logits = self.model.generate(processed_inputs) transcription = self.processor.batch_decode( - logits, skip_special_tokens=True)[0] + logits, skip_special_tokens=True + )[0] # Print the chunk's transcription print( - colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + - transcription) + colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + + transcription + ) # Wait for the chunk's duration to simulate real-time processing time.sleep(chunk_duration) except Exception as e: - print(colored(f"An error occurred during transcription: {e}", - "red")) + print(colored(f"An error occurred during transcription: {e}", "red")) diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index 370569fb..a2d6bc0a 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -11,8 +11,7 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the classes for image classification -with open(os.path.join(os.path.dirname(__file__), - "fast_vit_classes.json")) as f: +with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f: FASTVIT_IMAGENET_1K_CLASSES = json.load(f) @@ -22,8 +21,7 @@ class ClassificationResult(BaseModel): @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): - assert isinstance(v, int) or isinstance( - v, float), "must be integer or float" + assert isinstance(v, int) or isinstance(v, float), "must be integer or float" return v @@ -49,16 +47,16 @@ class FastViT: """ def __init__(self): - self.model = timm.create_model("hf_hub:timm/fastvit_s12.apple_in1k", - pretrained=True).to(DEVICE) + self.model = timm.create_model( + "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True + ).to(DEVICE) data_config = timm.data.resolve_model_data_config(self.model) - self.transforms = timm.data.create_transform(**data_config, - is_training=False) + self.transforms = timm.data.create_transform(**data_config, is_training=False) self.model.eval() - def __call__(self, - img: str, - confidence_threshold: float = 0.5) -> ClassificationResult: + def __call__( + self, img: str, confidence_threshold: float = 0.5 + ) -> ClassificationResult: """classifies the input image and returns the top k classes and their probabilities""" img = Image.open(img).convert("RGB") img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE) @@ -67,8 +65,9 @@ class FastViT: probabilities = torch.nn.functional.softmax(output, dim=1) # Get top k classes and their probabilities - top_probs, top_classes = torch.topk(probabilities, - k=FASTVIT_IMAGENET_1K_CLASSES) + top_probs, top_classes = torch.topk( + probabilities, k=FASTVIT_IMAGENET_1K_CLASSES + ) # Filter by confidence threshold mask = top_probs > confidence_threshold diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index 63108835..dd664f51 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -45,9 +45,9 @@ class Fuyu: self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.image_processor = FuyuImageProcessor() - self.processor = FuyuProcessor(image_processor=self.image_processor, - tokenizer=self.tokenizer, - **kwargs) + self.processor = FuyuProcessor( + image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs + ) self.model = FuyuForCausalLM.from_pretrained( pretrained_path, device_map=device_map, @@ -62,17 +62,15 @@ class Fuyu: def __call__(self, text: str, img: str): """Call the model with text and img paths""" image_pil = Image.open(img) - model_inputs = self.processor(text=text, - images=[image_pil], - device=self.device_map) + model_inputs = self.processor( + text=text, images=[image_pil], device=self.device_map + ) for k, v in model_inputs.items(): model_inputs[k] = v.to(self.device_map) - output = self.model.generate(**model_inputs, - max_new_tokens=self.max_new_tokens) - text = self.processor.batch_decode(output[:, -7:], - skip_special_tokens=True) + output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens) + text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True) return print(str(text)) def get_img_from_web(self, img_url: str): diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py index 251744e8..d1d5ce1f 100644 --- a/swarms/models/gpt4v.py +++ b/swarms/models/gpt4v.py @@ -69,7 +69,9 @@ class GPT4Vision: quality: str = "low" # Max tokens to use for the API request, the maximum might be 3,000 but we don't know max_tokens: int = 200 - client = OpenAI(api_key=openai_api_key,) + client = OpenAI( + api_key=openai_api_key, + ) dashboard: bool = True call_limit: int = 1 period_seconds: int = 60 @@ -88,8 +90,9 @@ class GPT4Vision: return base64.b64encode(image_file.read()).decode("utf-8") @sleep_and_retry - @limits(calls=call_limit, - period=period_seconds) # Rate limit of 10 calls per minute + @limits( + calls=call_limit, period=period_seconds + ) # Rate limit of 10 calls per minute def run(self, task: str, img: str): """ Run the GPT-4 Vision model @@ -105,22 +108,20 @@ class GPT4Vision: try: response = self.client.chat.completions.create( model="gpt-4-vision-preview", - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": task - }, - { - "type": "image_url", - "image_url": { - "url": str(img), + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": task}, + { + "type": "image_url", + "image_url": { + "url": str(img), + }, }, - }, - ], - }], + ], + } + ], max_tokens=self.max_tokens, ) @@ -160,22 +161,20 @@ class GPT4Vision: try: response = await self.client.chat.completions.create( model="gpt-4-vision-preview", - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": task - }, - { - "type": "image_url", - "image_url": { - "url": img, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": task}, + { + "type": "image_url", + "image_url": { + "url": img, + }, }, - }, - ], - }], + ], + } + ], max_tokens=self.max_tokens, ) @@ -190,14 +189,12 @@ class GPT4Vision: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run, task, img) - for task, img in tasks_images + executor.submit(self.run, task, img) for task, img in tasks_images ] results = [future.result() for future in futures] return results - async def run_batch_async(self, - tasks_images: List[Tuple[str, str]]) -> List[str]: + async def run_batch_async(self, tasks_images: List[Tuple[str, str]]) -> List[str]: """Process a batch of tasks and images asynchronously""" loop = asyncio.get_event_loop() futures = [ @@ -207,7 +204,8 @@ class GPT4Vision: return await asyncio.gather(*futures) async def run_batch_async_with_retries( - self, tasks_images: List[Tuple[str, str]]) -> List[str]: + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ @@ -231,7 +229,8 @@ class GPT4Vision: """, "green", - )) + ) + ) return dashboard def health_check(self): diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index a84cc960..9279fea4 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -47,8 +47,9 @@ class HuggingfaceLLM: **kwargs, ): self.logger = logging.getLogger(__name__) - self.device = (device if device else - ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -57,8 +58,9 @@ class HuggingfaceLLM: self.model, self.tokenizer = None, None if self.distributed: - assert (torch.cuda.device_count() > - 1), "You need more than 1 gpu for distributed processing" + assert ( + torch.cuda.device_count() > 1 + ), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -73,17 +75,17 @@ class HuggingfaceLLM: try: self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id, *args, **kwargs) + self.model_id, *args, **kwargs + ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, *args, **kwargs) + self.model_id, quantization_config=bnb_config, *args, **kwargs + ) self.model # .to(self.device) except Exception as e: # self.logger.error(f"Failed to load the model or the tokenizer: {e}") # raise - print( - colored(f"Failed to load the model and or the tokenizer: {e}", - "red")) + print(colored(f"Failed to load the model and or the tokenizer: {e}", "red")) def print_error(self, error: str): """Print error""" @@ -95,18 +97,20 @@ class HuggingfaceLLM: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = (BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config else None) + bnb_config = ( + BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config + else None + ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, - quantization_config=bnb_config).to(self.device) + self.model_id, quantization_config=bnb_config + ).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error( - f"Failed to load the model or the tokenizer: {error}") + self.logger.error(f"Failed to load the model or the tokenizer: {error}") raise def run(self, task: str): @@ -127,8 +131,7 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) # self.log.start() @@ -137,36 +140,39 @@ class HuggingfaceLLM: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: print( colored( - (f"HuggingfaceLLM could not generate text because of error: {e}," - " try optimizing your arguments"), + ( + f"HuggingfaceLLM could not generate text because of error: {e}," + " try optimizing your arguments" + ), "red", - )) + ) + ) raise async def run_async(self, task: str, *args, **kwargs) -> str: @@ -210,8 +216,7 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) # self.log.start() @@ -220,26 +225,26 @@ class HuggingfaceLLM: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs @@ -300,7 +305,8 @@ class HuggingfaceLLM: """, "red", - )) + ) + ) print(dashboard) diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 41b8823d..73cb4991 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -65,8 +65,9 @@ class Idefics: torch_dtype=torch.bfloat16, max_length=100, ): - self.device = (device if device else - ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) self.model = IdeficsForVisionText2Text.from_pretrained( checkpoint, torch_dtype=torch_dtype, @@ -95,17 +96,21 @@ class Idefics: list A list of generated text strings. """ - inputs = (self.processor( - prompts, add_end_of_utterance_token=False, return_tensors="pt").to( - self.device) if batched_mode else self.processor( - prompts[0], return_tensors="pt").to(self.device)) + inputs = ( + self.processor( + prompts, add_end_of_utterance_token=False, return_tensors="pt" + ).to(self.device) + if batched_mode + else self.processor(prompts[0], return_tensors="pt").to(self.device) + ) exit_condition = self.processor.tokenizer( - "", add_special_tokens=False).input_ids + "", add_special_tokens=False + ).input_ids bad_words_ids = self.processor.tokenizer( - ["", "", "", add_special_tokens=False).input_ids + "", add_special_tokens=False + ).input_ids bad_words_ids = self.processor.tokenizer( - ["", "", " - 1), "You need more than 1 gpu for distributed processing" + assert ( + torch.cuda.device_count() > 1 + ), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -81,9 +83,8 @@ class JinaEmbeddings: try: self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, - quantization_config=bnb_config, - trust_remote_code=True) + self.model_id, quantization_config=bnb_config, trust_remote_code=True + ) self.model # .to(self.device) except Exception as e: @@ -96,8 +97,11 @@ class JinaEmbeddings: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = (BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config else None) + bnb_config = ( + BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config + else None + ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, @@ -108,8 +112,7 @@ class JinaEmbeddings: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error( - f"Failed to load the model or the tokenizer: {error}") + self.logger.error(f"Failed to load the model or the tokenizer: {error}") raise def run(self, task: str): diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index 9a1eafba..12d5638a 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -14,8 +14,11 @@ class Detections(BaseModel): @root_validator def check_length(cls, values): - assert (len(values.get("xyxy")) == len(values.get("class_id")) == len( - values.get("confidence"))), "All fields must have the same length." + assert ( + len(values.get("xyxy")) + == len(values.get("class_id")) + == len(values.get("confidence")) + ), "All fields must have the same length." return values @validator("xyxy", "class_id", "confidence", pre=True, each_item=True) @@ -36,9 +39,11 @@ class Kosmos2(BaseModel): @classmethod def initialize(cls): model = AutoModelForVision2Seq.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True) + "ydshieh/kosmos-2-patch14-224", trust_remote_code=True + ) processor = AutoProcessor.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True) + "ydshieh/kosmos-2-patch14-224", trust_remote_code=True + ) return cls(model=model, processor=processor) def __call__(self, img: str) -> Detections: @@ -46,12 +51,11 @@ class Kosmos2(BaseModel): prompt = "An image of" inputs = self.processor(text=prompt, images=image, return_tensors="pt") - outputs = self.model.generate(**inputs, - use_cache=True, - max_new_tokens=64) + outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64) - generated_text = self.processor.batch_decode( - outputs, skip_special_tokens=True)[0] + generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[ + 0 + ] # The actual processing of generated_text to entities would go here # For the purpose of this example, assume a mock function 'extract_entities' exists: @@ -62,8 +66,8 @@ class Kosmos2(BaseModel): return detections def extract_entities( - self, - text: str) -> List[Tuple[str, Tuple[float, float, float, float]]]: + self, text: str + ) -> List[Tuple[str, Tuple[float, float, float, float]]]: # Placeholder function for entity extraction # This should be replaced with the actual method of extracting entities return [] @@ -76,19 +80,19 @@ class Kosmos2(BaseModel): if not entities: return Detections.empty() - class_ids = [0] * len( - entities) # Replace with actual class ID extraction logic - xyxys = [( - e[1][0] * image.width, - e[1][1] * image.height, - e[1][2] * image.width, - e[1][3] * image.height, - ) for e in entities] + class_ids = [0] * len(entities) # Replace with actual class ID extraction logic + xyxys = [ + ( + e[1][0] * image.width, + e[1][1] * image.height, + e[1][2] * image.width, + e[1][3] * image.height, + ) + for e in entities + ] confidences = [1.0] * len(entities) # Placeholder confidence - return Detections(xyxy=xyxys, - class_id=class_ids, - confidence=confidences) + return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences) # Usage: diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index 402ad73d..596886f3 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -46,9 +46,11 @@ class Kosmos: model_name="ydshieh/kosmos-2-patch14-224", ): self.model = AutoModelForVision2Seq.from_pretrained( - model_name, trust_remote_code=True) - self.processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True) + model_name, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True + ) def get_image(self, url): """Image""" @@ -71,7 +73,8 @@ class Kosmos: skip_special_tokens=True, )[0] processed_text, entities = self.processor.post_process_generation( - generated_texts) + generated_texts + ) def __call__(self, prompt, image): """Run call""" @@ -90,7 +93,8 @@ class Kosmos: skip_special_tokens=True, )[0] processed_text, entities = self.processor.post_process_generation( - generated_texts) + generated_texts + ) # tasks def multimodal_grounding(self, phrase, image_url): @@ -141,10 +145,12 @@ class Kosmos: elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() - reverse_norm_mean = torch.tensor( - [0.48145466, 0.4578275, 0.40821073])[:, None, None] - reverse_norm_std = torch.tensor( - [0.26862954, 0.26130258, 0.27577711])[:, None, None] + reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ + :, None, None + ] + reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ + :, None, None + ] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height @@ -163,9 +169,9 @@ class Kosmos: # thickness of text text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) box_line = 3 - (c_width, text_height), _ = cv2.getTextSize("F", - cv2.FONT_HERSHEY_COMPLEX, - text_size, text_line) + (c_width, text_height), _ = cv2.getTextSize( + "F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + ) base_height = int(text_height * 0.675) text_offset_original = text_height - base_height text_spaces = 3 @@ -181,8 +187,9 @@ class Kosmos: # draw bbox # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) - new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), - (orig_x2, orig_y2), color, box_line) + new_image = cv2.rectangle( + new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line + ) l_o, r_o = ( box_line // 2 + box_line % 2, @@ -193,15 +200,19 @@ class Kosmos: y1 = orig_y1 - l_o if y1 < text_height + text_offset_original + 2 * text_spaces: - y1 = (orig_y1 + r_o + text_height + text_offset_original + - 2 * text_spaces) + y1 = ( + orig_y1 + + r_o + + text_height + + text_offset_original + + 2 * text_spaces + ) x1 = orig_x1 + r_o # add text background - (text_width, - text_height), _ = cv2.getTextSize(f" {entity_name}", - cv2.FONT_HERSHEY_COMPLEX, - text_size, text_line) + (text_width, text_height), _ = cv2.getTextSize( + f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + ) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( x1, y1 - (text_height + text_offset_original + 2 * text_spaces), @@ -211,19 +222,23 @@ class Kosmos: for prev_bbox in previous_bboxes: while is_overlapping( - (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), - prev_bbox): - text_bg_y1 += (text_height + text_offset_original + - 2 * text_spaces) - text_bg_y2 += (text_height + text_offset_original + - 2 * text_spaces) + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox + ): + text_bg_y1 += ( + text_height + text_offset_original + 2 * text_spaces + ) + text_bg_y2 += ( + text_height + text_offset_original + 2 * text_spaces + ) y1 += text_height + text_offset_original + 2 * text_spaces if text_bg_y2 >= image_h: text_bg_y1 = max( 0, - image_h - (text_height + text_offset_original + - 2 * text_spaces), + image_h + - ( + text_height + text_offset_original + 2 * text_spaces + ), ) text_bg_y2 = image_h y1 = image_h @@ -240,9 +255,9 @@ class Kosmos: # white bg_color = [255, 255, 255] new_image[i, j] = ( - alpha * new_image[i, j] + - (1 - alpha) * np.array(bg_color)).astype( - np.uint8) + alpha * new_image[i, j] + + (1 - alpha) * np.array(bg_color) + ).astype(np.uint8) cv2.putText( new_image, @@ -255,8 +270,7 @@ class Kosmos: cv2.LINE_AA, ) # previous_locations.append((x1, y1)) - previous_bboxes.append( - (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) + previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) if save_path: diff --git a/swarms/models/llava.py b/swarms/models/llava.py index 7f49ad4a..6f8019bc 100644 --- a/swarms/models/llava.py +++ b/swarms/models/llava.py @@ -48,8 +48,9 @@ class MultiModalLlava: revision=revision, ).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, - use_fast=True) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, use_fast=True + ) self.pipe = pipeline( "text-generation", model=self.model, diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index f14d9e39..7f48a0d6 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -49,8 +49,7 @@ class Mistral: # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": - raise ValueError( - "CUDA is not available. Please choose a different device.") + raise ValueError("CUDA is not available. Please choose a different device.") # Load the model and tokenizer self.model = None @@ -71,8 +70,7 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], - return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -89,8 +87,7 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], - return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index 9fb6c90b..035e2b54 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -26,10 +26,7 @@ class MPT7B: """ - def __init__(self, - model_name: str, - tokenizer_name: str, - max_tokens: int = 100): + def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100): # Loading model and tokenizer details self.model_name = model_name self.tokenizer_name = tokenizer_name @@ -40,9 +37,11 @@ class MPT7B: self.logger = logging.getLogger(__name__) config = AutoModelForCausalLM.from_pretrained( - model_name, trust_remote_code=True).config + model_name, trust_remote_code=True + ).config self.model = AutoModelForCausalLM.from_pretrained( - model_name, config=config, trust_remote_code=True) + model_name, config=config, trust_remote_code=True + ) # Initializing a text-generation pipeline self.pipe = pipeline( @@ -115,10 +114,9 @@ class MPT7B: """ with torch.autocast("cuda", dtype=torch.bfloat16): - return self.pipe(prompt, - max_new_tokens=self.max_tokens, - do_sample=True, - use_cache=True)[0]["generated_text"] + return self.pipe( + prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True + )[0]["generated_text"] async def generate_async(self, prompt: str) -> str: """Generate Async""" diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index 4de1d952..f156981c 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -41,10 +41,8 @@ class Nougat: self.min_length = min_length self.max_new_tokens = max_new_tokens - self.processor = NougatProcessor.from_pretrained( - self.model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained( - self.model_name_or_path) + self.processor = NougatProcessor.from_pretrained(self.model_name_or_path) + self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) @@ -65,10 +63,8 @@ class Nougat: max_new_tokens=self.max_new_tokens, ) - sequence = self.processor.batch_decode(outputs, - skip_special_tokens=True)[0] - sequence = self.processor.post_process_generation(sequence, - fix_markdown=False) + sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] + sequence = self.processor.post_process_generation(sequence, fix_markdown=False) out = print(sequence) return out @@ -76,7 +72,8 @@ class Nougat: def clean_nougat_output(raw_output): # Define the pattern to extract the relevant data daily_balance_pattern = ( - r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*") + r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*" + ) # Find all matches of the pattern matches = re.findall(daily_balance_pattern, raw_output) diff --git a/swarms/models/openai_assistant.py b/swarms/models/openai_assistant.py index 37b41191..6d0c518f 100644 --- a/swarms/models/openai_assistant.py +++ b/swarms/models/openai_assistant.py @@ -55,9 +55,9 @@ class OpenAIAssistant: return thread def add_message_to_thread(self, thread_id: str, message: str): - message = self.client.beta.threads.add_message(thread_id=thread_id, - role=self.user, - content=message) + message = self.client.beta.threads.add_message( + thread_id=thread_id, role=self.user, content=message + ) return message def run(self, task: str): @@ -67,7 +67,8 @@ class OpenAIAssistant: instructions=self.instructions, ) - out = self.client.beta.threads.runs.retrieve(thread_id=run.thread_id, - run_id=run.id) + out = self.client.beta.threads.runs.retrieve( + thread_id=run.thread_id, run_id=run.id + ) return out diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 8eeb009d..81dea550 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -28,10 +28,9 @@ from tenacity import ( from swarms.models.embeddings_base import Embeddings -def get_from_dict_or_env(values: dict, - key: str, - env_key: str, - default: Any = None) -> Any: +def get_from_dict_or_env( + values: dict, key: str, env_key: str, default: Any = None +) -> Any: import os return values.get(key) or os.getenv(env_key) or default @@ -44,8 +43,7 @@ def get_pydantic_field_names(cls: Any) -> Set[str]: logger = logging.getLogger(__name__) -def _create_retry_decorator( - embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: +def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: import llm min_seconds = 4 @@ -56,11 +54,13 @@ def _create_retry_decorator( reraise=True, stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=(retry_if_exception_type(llm.error.Timeout) | - retry_if_exception_type(llm.error.APIError) | - retry_if_exception_type(llm.error.APIConnectionError) | - retry_if_exception_type(llm.error.RateLimitError) | - retry_if_exception_type(llm.error.ServiceUnavailableError)), + retry=( + retry_if_exception_type(llm.error.Timeout) + | retry_if_exception_type(llm.error.APIError) + | retry_if_exception_type(llm.error.APIConnectionError) + | retry_if_exception_type(llm.error.RateLimitError) + | retry_if_exception_type(llm.error.ServiceUnavailableError) + ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -76,16 +76,17 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any: reraise=True, stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=(retry_if_exception_type(llm.error.Timeout) | - retry_if_exception_type(llm.error.APIError) | - retry_if_exception_type(llm.error.APIConnectionError) | - retry_if_exception_type(llm.error.RateLimitError) | - retry_if_exception_type(llm.error.ServiceUnavailableError)), + retry=( + retry_if_exception_type(llm.error.Timeout) + | retry_if_exception_type(llm.error.APIError) + | retry_if_exception_type(llm.error.APIConnectionError) + | retry_if_exception_type(llm.error.RateLimitError) + | retry_if_exception_type(llm.error.ServiceUnavailableError) + ), before_sleep=before_sleep_log(logger, logging.WARNING), ) def wrap(func: Callable) -> Callable: - async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: async for _ in async_retrying: return await func(*args, **kwargs) @@ -117,8 +118,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) -async def async_embed_with_retry(embeddings: OpenAIEmbeddings, - **kwargs: Any) -> Any: +async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" @_async_retry_decorator(embeddings) @@ -225,11 +225,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): warnings.warn( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""") + Please confirm that {field_name} is what you intended.""" + ) extra[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection( - extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " @@ -242,9 +242,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["openai_api_key"] = get_from_dict_or_env(values, - "openai_api_key", - "OPENAI_API_KEY") + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", @@ -284,8 +284,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): values["client"] = llm.Embedding except ImportError: - raise ImportError("Could not import openai python package. " - "Please install it with `pip install openai`.") + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) return values @property @@ -313,11 +315,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return openai_args def _get_len_safe_embeddings( - self, - texts: List[str], - *, - engine: str, - chunk_size: Optional[int] = None) -> List[List[float]]: + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + ) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken @@ -325,7 +324,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " - "Please install it with `pip install tiktoken`.") + "Please install it with `pip install tiktoken`." + ) tokens = [] indices = [] @@ -333,8 +333,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning( - "Warning: model not found. Using cl100k_base encoding.") + logger.warning("Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -348,7 +347,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j:j + self.embedding_ctx_length]) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -367,7 +366,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in _iter: response = embed_with_retry( self, - input=tokens[i:i + _chunk_size], + input=tokens[i : i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -385,11 +384,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): self, input="", **self._invocation_params, - )["data"][0]["embedding"] + )[ + "data" + ][0]["embedding"] else: - average = np.average(_result, - axis=0, - weights=num_tokens_in_batch[i]) + average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -397,11 +396,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb async def _aget_len_safe_embeddings( - self, - texts: List[str], - *, - engine: str, - chunk_size: Optional[int] = None) -> List[List[float]]: + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + ) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken @@ -409,7 +405,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " - "Please install it with `pip install tiktoken`.") + "Please install it with `pip install tiktoken`." + ) tokens = [] indices = [] @@ -417,8 +414,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning( - "Warning: model not found. Using cl100k_base encoding.") + logger.warning("Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -432,7 +428,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j:j + self.embedding_ctx_length]) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -440,7 +436,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(0, len(tokens), _chunk_size): response = await async_embed_with_retry( self, - input=tokens[i:i + _chunk_size], + input=tokens[i : i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -454,22 +450,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(len(texts)): _result = results[i] if len(_result) == 0: - average = (await async_embed_with_retry( - self, - input="", - **self._invocation_params, - ))["data"][0]["embedding"] + average = ( + await async_embed_with_retry( + self, + input="", + **self._invocation_params, + ) + )["data"][0]["embedding"] else: - average = np.average(_result, - axis=0, - weights=num_tokens_in_batch[i]) + average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings - def embed_documents(self, - texts: List[str], - chunk_size: Optional[int] = 0) -> List[List[float]]: + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. Args: @@ -485,9 +481,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return self._get_len_safe_embeddings(texts, engine=self.deployment) async def aembed_documents( - self, - texts: List[str], - chunk_size: Optional[int] = 0) -> List[List[float]]: + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. Args: @@ -500,8 +495,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return await self._aget_len_safe_embeddings(texts, - engine=self.deployment) + return await self._aget_len_safe_embeddings(texts, engine=self.deployment) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 128169a3..4b0cc91d 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -33,8 +33,9 @@ from langchain.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) -def update_token_usage(keys: Set[str], response: Dict[str, Any], - token_usage: Dict[str, Any]) -> None: +def update_token_usage( + keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] +) -> None: """Update token usage.""" _keys_to_use = keys.intersection(response["usage"]) for _key in _keys_to_use: @@ -45,42 +46,44 @@ def update_token_usage(keys: Set[str], response: Dict[str, Any], def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any],) -> GenerationChunk: + stream_response: Dict[str, Any], +) -> GenerationChunk: """Convert a stream response to a generation chunk.""" return GenerationChunk( text=stream_response["choices"][0]["text"], generation_info=dict( - finish_reason=stream_response["choices"][0].get( - "finish_reason", None), + finish_reason=stream_response["choices"][0].get("finish_reason", None), logprobs=stream_response["choices"][0].get("logprobs", None), ), ) -def _update_response(response: Dict[str, Any], - stream_response: Dict[str, Any]) -> None: +def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: """Update response from the stream response.""" response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( - "finish_reason", None) - response["choices"][0]["logprobs"] = stream_response["choices"][0][ - "logprobs"] + "finish_reason", None + ) + response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] def _streaming_response_template() -> Dict[str, Any]: return { - "choices": [{ - "text": "", - "finish_reason": None, - "logprobs": None, - }] + "choices": [ + { + "text": "", + "finish_reason": None, + "logprobs": None, + } + ] } def _create_retry_decorator( llm: Union[BaseOpenAI, OpenAIChat], - run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun]] = None, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, ) -> Callable[[Any], Any]: import openai @@ -91,9 +94,9 @@ def _create_retry_decorator( openai.error.RateLimitError, openai.error.ServiceUnavailableError, ] - return create_base_retry_decorator(error_types=errors, - max_retries=llm.max_retries, - run_manager=run_manager) + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) def completion_with_retry( @@ -203,8 +206,7 @@ class BaseOpenAI(BaseLLM): API but with different models. In those cases, in order to avoid erroring when tiktoken is called, you can specify a model name to use here.""" - def __new__(cls, - **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore + def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore """Initialize the OpenAI object.""" data.get("model_name", "") return super().__new__(cls) @@ -219,16 +221,17 @@ class BaseOpenAI(BaseLLM): """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) - values["model_kwargs"] = build_extra_kwargs(extra, values, - all_required_field_names) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["openai_api_key"] = get_from_dict_or_env(values, - "openai_api_key", - "OPENAI_API_KEY") + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", @@ -252,8 +255,10 @@ class BaseOpenAI(BaseLLM): values["client"] = openai.Completion except ImportError: - raise ImportError("Could not import openai python package. " - "Please install it with `pip install openai`.") + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) if values["streaming"] and values["n"] > 1: raise ValueError("Cannot stream results when n > 1.") if values["streaming"] and values["best_of"] > 1: @@ -290,10 +295,9 @@ class BaseOpenAI(BaseLLM): ) -> Iterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} self.get_sub_prompts(params, [prompt], stop) # this mutates params - for stream_resp in completion_with_retry(self, - prompt=prompt, - run_manager=run_manager, - **params): + for stream_resp in completion_with_retry( + self, prompt=prompt, run_manager=run_manager, **params + ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk if run_manager: @@ -302,7 +306,8 @@ class BaseOpenAI(BaseLLM): chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info else None, + if chunk.generation_info + else None, ) async def _astream( @@ -315,7 +320,8 @@ class BaseOpenAI(BaseLLM): params = {**self._invocation_params, **kwargs, "stream": True} self.get_sub_prompts(params, [prompt], stop) # this mutate params async for stream_resp in await acompletion_with_retry( - self, prompt=prompt, run_manager=run_manager, **params): + self, prompt=prompt, run_manager=run_manager, **params + ): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk if run_manager: @@ -324,7 +330,8 @@ class BaseOpenAI(BaseLLM): chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info else None, + if chunk.generation_info + else None, ) def _generate( @@ -360,32 +367,30 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError( - "Cannot stream results with multiple prompts.") + raise ValueError("Cannot stream results with multiple prompts.") generation: Optional[GenerationChunk] = None - for chunk in self._stream(_prompts[0], stop, run_manager, - **kwargs): + for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): if generation is None: generation = chunk else: generation += chunk assert generation is not None - choices.append({ - "text": - generation.text, - "finish_reason": - generation.generation_info.get("finish_reason") - if generation.generation_info else None, - "logprobs": - generation.generation_info.get("logprobs") - if generation.generation_info else None, - }) + choices.append( + { + "text": generation.text, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) else: - response = completion_with_retry(self, - prompt=_prompts, - run_manager=run_manager, - **params) + response = completion_with_retry( + self, prompt=_prompts, run_manager=run_manager, **params + ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -409,32 +414,32 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError( - "Cannot stream results with multiple prompts.") + raise ValueError("Cannot stream results with multiple prompts.") generation: Optional[GenerationChunk] = None - async for chunk in self._astream(_prompts[0], stop, run_manager, - **kwargs): + async for chunk in self._astream( + _prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: generation += chunk assert generation is not None - choices.append({ - "text": - generation.text, - "finish_reason": - generation.generation_info.get("finish_reason") - if generation.generation_info else None, - "logprobs": - generation.generation_info.get("logprobs") - if generation.generation_info else None, - }) + choices.append( + { + "text": generation.text, + "finish_reason": generation.generation_info.get("finish_reason") + if generation.generation_info + else None, + "logprobs": generation.generation_info.get("logprobs") + if generation.generation_info + else None, + } + ) else: - response = await acompletion_with_retry(self, - prompt=_prompts, - run_manager=run_manager, - **params) + response = await acompletion_with_retry( + self, prompt=_prompts, run_manager=run_manager, **params + ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -448,35 +453,39 @@ class BaseOpenAI(BaseLLM): """Get the sub prompts for llm call.""" if stop is not None: if "stop" in params: - raise ValueError( - "`stop` found in both the input and default params.") + raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: raise ValueError( - "max_tokens set to -1 not supported for multiple inputs.") + "max_tokens set to -1 not supported for multiple inputs." + ) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) sub_prompts = [ - prompts[i:i + self.batch_size] + prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size) ] return sub_prompts - def create_llm_result(self, choices: Any, prompts: List[str], - token_usage: Dict[str, int]) -> LLMResult: + def create_llm_result( + self, choices: Any, prompts: List[str], token_usage: Dict[str, int] + ) -> LLMResult: """Create the LLMResult from the choices and prompts.""" generations = [] for i, _ in enumerate(prompts): - sub_choices = choices[i * self.n:(i + 1) * self.n] - generations.append([ - Generation( - text=choice["text"], - generation_info=dict( - finish_reason=choice.get("finish_reason"), - logprobs=choice.get("logprobs"), - ), - ) for choice in sub_choices - ]) + sub_choices = choices[i * self.n : (i + 1) * self.n] + generations.append( + [ + Generation( + text=choice["text"], + generation_info=dict( + finish_reason=choice.get("finish_reason"), + logprobs=choice.get("logprobs"), + ), + ) + for choice in sub_choices + ] + ) llm_output = {"token_usage": token_usage, "model_name": self.model_name} return LLMResult(generations=generations, llm_output=llm_output) @@ -518,14 +527,14 @@ class BaseOpenAI(BaseLLM): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`.") + "Please install it with `pip install tiktoken`." + ) model_name = self.tiktoken_model_name or self.model_name try: enc = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning( - "Warning: model not found. Using cl100k_base encoding.") + logger.warning("Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" enc = tiktoken.get_encoding(model) @@ -587,7 +596,8 @@ class BaseOpenAI(BaseLLM): if context_size is None: raise ValueError( f"Unknown model: {modelname}. Please provide a valid OpenAI model name." - "Known models are: " + ", ".join(model_token_mapping.keys())) + "Known models are: " + ", ".join(model_token_mapping.keys()) + ) return context_size @@ -665,15 +675,14 @@ class AzureOpenAI(BaseOpenAI): "OPENAI_API_VERSION", ) values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", "azure") + values, "openai_api_type", "OPENAI_API_TYPE", "azure" + ) return values @property def _identifying_params(self) -> Mapping[str, Any]: return { - **{ - "deployment_name": self.deployment_name - }, + **{"deployment_name": self.deployment_name}, **super()._identifying_params, } @@ -738,9 +747,7 @@ class OpenAIChat(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = { - field.alias for field in cls.__fields__.values() - } + all_required_field_names = {field.alias for field in cls.__fields__.values()} extra = values.get("model_kwargs", {}) for field_name in list(values): @@ -754,8 +761,9 @@ class OpenAIChat(BaseLLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env(values, "openai_api_key", - "OPENAI_API_KEY") + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) openai_api_base = get_from_dict_or_env( values, "openai_api_base", @@ -768,10 +776,9 @@ class OpenAIChat(BaseLLM): "OPENAI_PROXY", default="", ) - openai_organization = get_from_dict_or_env(values, - "openai_organization", - "OPENAI_ORGANIZATION", - default="") + openai_organization = get_from_dict_or_env( + values, "openai_organization", "OPENAI_ORGANIZATION", default="" + ) try: import openai @@ -786,15 +793,18 @@ class OpenAIChat(BaseLLM): "https": openai_proxy, } # type: ignore[assignment] # noqa: E501 except ImportError: - raise ImportError("Could not import openai python package. " - "Please install it with `pip install openai`.") + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) try: values["client"] = openai.ChatCompletion except AttributeError: raise ValueError( "`openai` has no `ChatCompletion` attribute, this is likely " "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`.") + "with `pip install --upgrade openai`." + ) return values @property @@ -802,27 +812,18 @@ class OpenAIChat(BaseLLM): """Get the default parameters for calling OpenAI API.""" return self.model_kwargs - def _get_chat_params(self, - prompts: List[str], - stop: Optional[List[str]] = None) -> Tuple: + def _get_chat_params( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> Tuple: if len(prompts) > 1: raise ValueError( f"OpenAIChat currently only supports single prompt, got {prompts}" ) - messages = self.prefix_messages + [{ - "role": "user", - "content": prompts[0] - }] - params: Dict[str, Any] = { - **{ - "model": self.model_name - }, - **self._default_params - } + messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] + params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: if "stop" in params: - raise ValueError( - "`stop` found in both the input and default params.") + raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop if params.get("max_tokens") == -1: # for ChatGPT api, omitting max_tokens is equivalent to having no limit @@ -838,10 +839,9 @@ class OpenAIChat(BaseLLM): ) -> Iterator[GenerationChunk]: messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} - for stream_resp in completion_with_retry(self, - messages=messages, - run_manager=run_manager, - **params): + for stream_resp in completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): token = stream_resp["choices"][0]["delta"].get("content", "") chunk = GenerationChunk(text=token) yield chunk @@ -858,7 +858,8 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} async for stream_resp in await acompletion_with_retry( - self, messages=messages, run_manager=run_manager, **params): + self, messages=messages, run_manager=run_manager, **params + ): token = stream_resp["choices"][0]["delta"].get("content", "") chunk = GenerationChunk(text=token) yield chunk @@ -884,19 +885,17 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = completion_with_retry(self, - messages=messages, - run_manager=run_manager, - **params) + full_response = completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, } return LLMResult( - generations=[[ - Generation( - text=full_response["choices"][0]["message"]["content"]) - ]], + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], llm_output=llm_output, ) @@ -909,8 +908,7 @@ class OpenAIChat(BaseLLM): ) -> LLMResult: if self.streaming: generation: Optional[GenerationChunk] = None - async for chunk in self._astream(prompts[0], stop, run_manager, - **kwargs): + async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): if generation is None: generation = chunk else: @@ -920,19 +918,17 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = await acompletion_with_retry(self, - messages=messages, - run_manager=run_manager, - **params) + full_response = await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, } return LLMResult( - generations=[[ - Generation( - text=full_response["choices"][0]["message"]["content"]) - ]], + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], llm_output=llm_output, ) @@ -957,7 +953,8 @@ class OpenAIChat(BaseLLM): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`.") + "Please install it with `pip install tiktoken`." + ) enc = tiktoken.encoding_for_model(self.model_name) return enc.encode( diff --git a/swarms/models/openai_tokenizer.py b/swarms/models/openai_tokenizer.py index 26ec9221..9ff1fa08 100644 --- a/swarms/models/openai_tokenizer.py +++ b/swarms/models/openai_tokenizer.py @@ -71,15 +71,16 @@ class OpenAITokenizer(BaseTokenizer): @property def max_tokens(self) -> int: - tokens = next(v for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() - if self.model.startswith(k)) + tokens = next( + v + for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() + if self.model.startswith(k) + ) offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset - def count_tokens(self, - text: str | list, - model: Optional[str] = None) -> int: + def count_tokens(self, text: str | list, model: Optional[str] = None) -> int: """ Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -95,12 +96,12 @@ class OpenAITokenizer(BaseTokenizer): encoding = tiktoken.get_encoding("cl100k_base") if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", }: tokens_per_message = 3 tokens_per_name = 1 @@ -112,18 +113,21 @@ class OpenAITokenizer(BaseTokenizer): elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: logging.info( "gpt-3.5-turbo may update over time. Returning num tokens assuming" - " gpt-3.5-turbo-0613.") + " gpt-3.5-turbo-0613." + ) return self.count_tokens(text, model="gpt-3.5-turbo-0613") elif "gpt-4" in model: logging.info( "gpt-4 may update over time. Returning num tokens assuming" - " gpt-4-0613.") + " gpt-4-0613." + ) return self.count_tokens(text, model="gpt-4-0613") else: raise NotImplementedError( f"""token_count() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for - information on how messages are converted to tokens.""") + information on how messages are converted to tokens.""" + ) num_tokens = 0 @@ -140,5 +144,5 @@ class OpenAITokenizer(BaseTokenizer): return num_tokens else: return len( - self.encoding.encode(text, - allowed_special=set(self.stop_sequences))) + self.encoding.encode(text, allowed_special=set(self.stop_sequences)) + ) diff --git a/swarms/models/palm.py b/swarms/models/palm.py index c551c288..ec8aafd6 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -26,7 +26,8 @@ def _create_retry_decorator() -> Callable[[Any], Any]: except ImportError: raise ImportError( "Could not import google-api-core python package. " - "Please install it with `pip install google-api-core`.") + "Please install it with `pip install google-api-core`." + ) multiplier = 2 min_seconds = 1 @@ -36,15 +37,12 @@ def _create_retry_decorator() -> Callable[[Any], Any]: return retry( reraise=True, stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=multiplier, - min=min_seconds, - max=max_seconds), - retry=(retry_if_exception_type( - google.api_core.exceptions.ResourceExhausted) | - retry_if_exception_type( - google.api_core.exceptions.ServiceUnavailable) | - retry_if_exception_type( - google.api_core.exceptions.GoogleAPIError)), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) + | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -66,8 +64,7 @@ def _strip_erroneous_leading_spaces(text: str) -> str: The PaLM API will sometimes erroneously return a single leading space in all lines > 1. This function strips that space. """ - has_leading_space = all( - not line or line[0] == " " for line in text.split("\n")[1:]) + has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) if has_leading_space: return text.replace("\n ", "\n") else: @@ -100,8 +97,9 @@ class GooglePalm(BaseLLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" - google_api_key = get_from_dict_or_env(values, "google_api_key", - "GOOGLE_API_KEY") + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) try: import google.generativeai as genai @@ -109,12 +107,12 @@ class GooglePalm(BaseLLM, BaseModel): except ImportError: raise ImportError( "Could not import google-generativeai python package. " - "Please install it with `pip install google-generativeai`.") + "Please install it with `pip install google-generativeai`." + ) values["client"] = genai - if values["temperature"] is not None and not 0 <= values[ - "temperature"] <= 1: + if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: @@ -123,8 +121,7 @@ class GooglePalm(BaseLLM, BaseModel): if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") - if values["max_output_tokens"] is not None and values[ - "max_output_tokens"] <= 0: + if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: raise ValueError("max_output_tokens must be greater than zero") return values diff --git a/swarms/models/pegasus.py b/swarms/models/pegasus.py index c2571f72..e388d40c 100644 --- a/swarms/models/pegasus.py +++ b/swarms/models/pegasus.py @@ -33,10 +33,9 @@ class PegasusEmbedding: """ - def __init__(self, - modality: str, - multi_process: bool = False, - n_processes: int = 4): + def __init__( + self, modality: str, multi_process: bool = False, n_processes: int = 4 + ): self.modality = modality self.multi_process = multi_process self.n_processes = n_processes @@ -44,7 +43,8 @@ class PegasusEmbedding: self.pegasus = Pegasus(modality, multi_process, n_processes) except Exception as e: logging.error( - f"Failed to initialize Pegasus with modality: {modality}: {e}") + f"Failed to initialize Pegasus with modality: {modality}: {e}" + ) raise def embed(self, data: Union[str, list[str]]): diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index fbb7c066..7eb923b4 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -21,4 +21,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): return openai.Embedding.create( input=[text], model=model, - )["data"][0]["embedding"] + )["data"][ + 0 + ]["embedding"] diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index d1b476b9..e98036ac 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -90,17 +90,17 @@ class SpeechT5: self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - self.embeddings_dataset = load_dataset(self.dataset_name, - split="validation") + self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") def __call__(self, text: str, speaker_id: float = 7306): """Call the model on some text and return the speech.""" speaker_embedding = torch.tensor( - self.embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0) + self.embeddings_dataset[speaker_id]["xvector"] + ).unsqueeze(0) inputs = self.processor(text=text, return_tensors="pt") - speech = self.model.generate_speech(inputs["input_ids"], - speaker_embedding, - vocoder=self.vocoder) + speech = self.model.generate_speech( + inputs["input_ids"], speaker_embedding, vocoder=self.vocoder + ) return speech def save_speech(self, speech, filename="speech.wav"): @@ -121,8 +121,7 @@ class SpeechT5: def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" self.dataset_name = dataset_name - self.embeddings_dataset = load_dataset(self.dataset_name, - split="validation") + self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") # Feature 1: Get sampling rate def get_sampling_rate(self): diff --git a/swarms/models/timm.py b/swarms/models/timm.py index 5b17c76c..5d9b965a 100644 --- a/swarms/models/timm.py +++ b/swarms/models/timm.py @@ -50,8 +50,9 @@ class TimmModel: in_chans=model_info.in_chans, ) - def __call__(self, model_info: TimmModelInfo, - input_tensor: torch.Tensor) -> torch.Size: + def __call__( + self, model_info: TimmModelInfo, input_tensor: torch.Tensor + ) -> torch.Size: """ Create and run a model specified by `model_info` on `input_tensor`. diff --git a/swarms/models/trocr.py b/swarms/models/trocr.py index 1b9e72e7..f4a4156d 100644 --- a/swarms/models/trocr.py +++ b/swarms/models/trocr.py @@ -10,8 +10,9 @@ import requests class TrOCR: - - def __init__(self,): + def __init__( + self, + ): pass def __call__(self): diff --git a/swarms/models/vilt.py b/swarms/models/vilt.py index 4725a317..f95d265c 100644 --- a/swarms/models/vilt.py +++ b/swarms/models/vilt.py @@ -23,9 +23,11 @@ class Vilt: def __init__(self): self.processor = ViltProcessor.from_pretrained( - "dandelin/vilt-b32-finetuned-vqa") + "dandelin/vilt-b32-finetuned-vqa" + ) self.model = ViltForQuestionAnswering.from_pretrained( - "dandelin/vilt-b32-finetuned-vqa") + "dandelin/vilt-b32-finetuned-vqa" + ) def __call__(self, text: str, image_url: str): """ diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index 929fe10e..49ffb70d 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -33,8 +33,7 @@ class WizardLLMStoryTeller: def __init__( self, - model_id: - str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF", + model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF", device: str = None, max_length: int = 500, quantize: bool = False, @@ -45,8 +44,9 @@ class WizardLLMStoryTeller: decoding=False, ): self.logger = logging.getLogger(__name__) - self.device = (device if device else - ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -56,8 +56,9 @@ class WizardLLMStoryTeller: # self.log = Logging() if self.distributed: - assert (torch.cuda.device_count() > - 1), "You need more than 1 gpu for distributed processing" + assert ( + torch.cuda.device_count() > 1 + ), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -73,7 +74,8 @@ class WizardLLMStoryTeller: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config) + self.model_id, quantization_config=bnb_config + ) self.model # .to(self.device) except Exception as e: @@ -86,18 +88,20 @@ class WizardLLMStoryTeller: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = (BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config else None) + bnb_config = ( + BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config + else None + ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, - quantization_config=bnb_config).to(self.device) + self.model_id, quantization_config=bnb_config + ).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error( - f"Failed to load the model or the tokenizer: {error}") + self.logger.error(f"Failed to load the model or the tokenizer: {error}") raise def run(self, prompt_text: str): @@ -116,8 +120,9 @@ class WizardLLMStoryTeller: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -126,26 +131,26 @@ class WizardLLMStoryTeller: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) @@ -169,8 +174,9 @@ class WizardLLMStoryTeller: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -179,26 +185,26 @@ class WizardLLMStoryTeller: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index e3120e20..ebe107a2 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -44,8 +44,9 @@ class YarnMistral128: decoding=False, ): self.logger = logging.getLogger(__name__) - self.device = (device if device else - ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = ( + device if device else ("cuda" if torch.cuda.is_available() else "cpu") + ) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -55,8 +56,9 @@ class YarnMistral128: # self.log = Logging() if self.distributed: - assert (torch.cuda.device_count() > - 1), "You need more than 1 gpu for distributed processing" + assert ( + torch.cuda.device_count() > 1 + ), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -91,18 +93,20 @@ class YarnMistral128: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = (BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config else None) + bnb_config = ( + BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config + else None + ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, - quantization_config=bnb_config).to(self.device) + self.model_id, quantization_config=bnb_config + ).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error( - f"Failed to load the model or the tokenizer: {error}") + self.logger.error(f"Failed to load the model or the tokenizer: {error}") raise def run(self, prompt_text: str): @@ -121,8 +125,9 @@ class YarnMistral128: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -131,26 +136,26 @@ class YarnMistral128: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) @@ -197,8 +202,9 @@ class YarnMistral128: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, - return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -207,26 +213,26 @@ class YarnMistral128: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate(inputs, - max_length=len(inputs) + - 1, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=len(inputs) + 1, do_sample=True + ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode([output_tokens], - skip_special_tokens=True), + self.tokenizer.decode( + [output_tokens], skip_special_tokens=True + ), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate(inputs, - max_length=max_length, - do_sample=True) + outputs = self.model.generate( + inputs, max_length=max_length, do_sample=True + ) del inputs diff --git a/swarms/models/zephyr.py b/swarms/models/zephyr.py index 0ed23f19..4fca5211 100644 --- a/swarms/models/zephyr.py +++ b/swarms/models/zephyr.py @@ -28,8 +28,7 @@ class Zephyr: model_name: str = "HuggingFaceH4/zephyr-7b-alpha", tokenize: bool = False, add_generation_prompt: bool = True, - system_prompt: - str = "You are a friendly chatbot who always responds in the style of a pirate", + system_prompt: str = "You are a friendly chatbot who always responds in the style of a pirate", max_new_tokens: int = 300, temperature: float = 0.5, top_k: float = 50, diff --git a/swarms/prompts/agent_output_parser.py b/swarms/prompts/agent_output_parser.py index e00db22d..27f8ac24 100644 --- a/swarms/prompts/agent_output_parser.py +++ b/swarms/prompts/agent_output_parser.py @@ -24,8 +24,9 @@ class AgentOutputParser(BaseAgentOutputParser): @staticmethod def _preprocess_json_input(input_str: str) -> str: - corrected_str = re.sub(r'(? dict: diff --git a/swarms/prompts/agent_prompt.py b/swarms/prompts/agent_prompt.py index aa84ebf8..c4897193 100644 --- a/swarms/prompts/agent_prompt.py +++ b/swarms/prompts/agent_prompt.py @@ -13,23 +13,13 @@ class PromptGenerator: self.performance_evaluation: List[str] = [] self.response_format = { "thoughts": { - "text": - "thought", - "reasoning": - "reasoning", - "plan": - "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": - "constructive self-criticism", - "speak": - "thoughts summary to say to user", - }, - "command": { - "name": "command name", - "args": { - "arg name": "value" - } + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", }, + "command": {"name": "command name", "args": {"arg name": "value"}}, } def add_constraint(self, constraint: str) -> None: @@ -82,6 +72,7 @@ class PromptGenerator: f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" "You should only respond in JSON format as described below " f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads") + "\nEnsure the response can be parsed by Python json.loads" + ) return prompt_string diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index 3de5bcb2..8d145fc0 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -7,21 +7,25 @@ def generate_agent_role_prompt(agent): "Finance Agent": ( "You are a seasoned finance analyst AI assistant. Your primary goal is to" " compose comprehensive, astute, impartial, and methodically arranged" - " financial reports based on provided data and trends."), + " financial reports based on provided data and trends." + ), "Travel Agent": ( "You are a world-travelled AI tour guide assistant. Your main purpose is to" " draft engaging, insightful, unbiased, and well-structured travel reports" " on given locations, including history, attractions, and cultural" - " insights."), + " insights." + ), "Academic Research Agent": ( "You are an AI academic research assistant. Your primary responsibility is" " to create thorough, academically rigorous, unbiased, and systematically" " organized reports on a given research topic, following the standards of" - " scholarly work."), + " scholarly work." + ), "Default Agent": ( "You are an AI critical thinker research assistant. Your sole purpose is to" " write well written, critically acclaimed, objective and structured" - " reports on given text."), + " reports on given text." + ), } return prompts.get(agent, "No such agent") @@ -40,7 +44,8 @@ def generate_report_prompt(question, research_summary): " focus on the answer to the question, should be well structured, informative," " in depth, with facts and numbers if available, a minimum of 1,200 words and" " with markdown syntax and apa format. Write all source urls at the end of the" - " report in apa format") + " report in apa format" + ) def generate_search_queries_prompt(question): @@ -52,7 +57,8 @@ def generate_search_queries_prompt(question): return ( "Write 4 google search queries to search online that form an objective opinion" f' from the following: "{question}"You must respond with a list of strings in' - ' the following format: ["query 1", "query 2", "query 3", "query 4"]') + ' the following format: ["query 1", "query 2", "query 3", "query 4"]' + ) def generate_resource_report_prompt(question, research_summary): @@ -74,7 +80,8 @@ def generate_resource_report_prompt(question, research_summary): " significance of each source. Ensure that the report is well-structured," " informative, in-depth, and follows Markdown syntax. Include relevant facts," " figures, and numbers whenever available. The report should have a minimum" - " length of 1,200 words.") + " length of 1,200 words." + ) def generate_outline_report_prompt(question, research_summary): @@ -91,7 +98,8 @@ def generate_outline_report_prompt(question, research_summary): " research report, including the main sections, subsections, and key points to" " be covered. The research report should be detailed, informative, in-depth," " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" - " outline and ensure readability.") + " outline and ensure readability." + ) def generate_concepts_prompt(question, research_summary): @@ -106,7 +114,8 @@ def generate_concepts_prompt(question, research_summary): " main concepts to learn for a research report on the following question or" f' topic: "{question}". The outline should provide a well-structured' " frameworkYou must respond with a list of strings in the following format:" - ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]') + ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' + ) def generate_lesson_prompt(concept): @@ -122,7 +131,8 @@ def generate_lesson_prompt(concept): f"generate a comprehensive lesson about {concept} in Markdown syntax. This" f" should include the definitionof {concept}, its historical background and" " development, its applications or uses in differentfields, and notable events" - f" or facts related to {concept}.") + f" or facts related to {concept}." + ) return prompt diff --git a/swarms/prompts/base.py b/swarms/prompts/base.py index 8bb77236..54a0bc3f 100644 --- a/swarms/prompts/base.py +++ b/swarms/prompts/base.py @@ -11,9 +11,9 @@ if TYPE_CHECKING: from langchain.prompts.chat import ChatPromptTemplate -def get_buffer_string(messages: Sequence[BaseMessage], - human_prefix: str = "Human", - ai_prefix: str = "AI") -> str: +def get_buffer_string( + messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: """Convert sequence of Messages to strings and concatenate them into one string. Args: @@ -88,9 +88,9 @@ class BaseMessage(Serializable): class BaseMessageChunk(BaseMessage): - - def _merge_kwargs_dict(self, left: Dict[str, Any], - right: Dict[str, Any]) -> Dict[str, Any]: + def _merge_kwargs_dict( + self, left: Dict[str, Any], right: Dict[str, Any] + ) -> Dict[str, Any]: """Merge additional_kwargs from another BaseMessageChunk into this one.""" merged = left.copy() for k, v in right.items(): @@ -99,7 +99,8 @@ class BaseMessageChunk(BaseMessage): elif not isinstance(merged[k], type(v)): raise ValueError( f'additional_kwargs["{k}"] already exists in this message,' - " but with a different type.") + " but with a different type." + ) elif isinstance(merged[k], str): merged[k] += v elif isinstance(merged[k], dict): @@ -118,12 +119,15 @@ class BaseMessageChunk(BaseMessage): return self.__class__( content=self.content + other.content, additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs), + self.additional_kwargs, other.additional_kwargs + ), ) else: - raise TypeError('unsupported operand type(s) for +: "' - f"{self.__class__.__name__}" - f'" and "{other.__class__.__name__}"') + raise TypeError( + 'unsupported operand type(s) for +: "' + f"{self.__class__.__name__}" + f'" and "{other.__class__.__name__}"' + ) class HumanMessage(BaseMessage): diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index 5f48488f..b0330e24 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -66,10 +66,9 @@ class SystemMessage(Message): of input messages. """ - def __init__(self, - content: str, - role: str = "System", - additional_kwargs: Dict = None): + def __init__( + self, content: str, role: str = "System", additional_kwargs: Dict = None + ): super().__init__(content, role, additional_kwargs) def get_type(self) -> str: @@ -107,9 +106,9 @@ class ChatMessage(Message): return "chat" -def get_buffer_string(messages: Sequence[Message], - human_prefix: str = "Human", - ai_prefix: str = "AI") -> str: +def get_buffer_string( + messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: string_messages = [] for m in messages: message = f"{m.role}: {m.content}" diff --git a/swarms/prompts/debate.py b/swarms/prompts/debate.py index 5a6be762..a11c7af4 100644 --- a/swarms/prompts/debate.py +++ b/swarms/prompts/debate.py @@ -38,6 +38,7 @@ def debate_monitor(game_description, word_limit, character_names): return prompt -def generate_character_header(game_description, topic, character_name, - character_description): +def generate_character_header( + game_description, topic, character_name, character_description +): pass diff --git a/swarms/prompts/multi_modal_prompts.py b/swarms/prompts/multi_modal_prompts.py index dc2bccd5..b552b68d 100644 --- a/swarms/prompts/multi_modal_prompts.py +++ b/swarms/prompts/multi_modal_prompts.py @@ -1,6 +1,7 @@ ERROR_PROMPT = ( "An error has occurred for the following text: \n{promptedQuery} Please explain" - " this error.\n {e}") + " this error.\n {e}" +) IMAGE_PROMPT = """ provide a figure named {filename}. The description is: {description}. diff --git a/swarms/prompts/python.py b/swarms/prompts/python.py index cd34e9bd..9d1f4a1e 100644 --- a/swarms/prompts/python.py +++ b/swarms/prompts/python.py @@ -3,25 +3,30 @@ PY_REFLEXION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given your past function" " implementation, a series of unit tests, and a hint to change the implementation" " appropriately. Write your full implementation (restate the function" - " signature).\n\n-----") + " signature).\n\n-----" +) PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given a function implementation" " and a series of unit tests. Your goal is to write a few sentences to explain why" " your implementation is wrong as indicated by the tests. You will need this as a" " hint when you try again later. Only provide the few sentence description in your" - " answer, not the implementation.\n\n-----") + " answer, not the implementation.\n\n-----" +) USE_PYTHON_CODEBLOCK_INSTRUCTION = ( "Use a Python code block to write your response. For" - " example:\n```python\nprint('Hello world!')\n```") + " example:\n```python\nprint('Hello world!')\n```" +) PY_SIMPLE_CHAT_INSTRUCTION = ( "You are an AI that only responds with python code, NOT ENGLISH. You will be given" " a function signature and its docstring by the user. Write your full" - " implementation (restate the function signature).") + " implementation (restate the function signature)." +) PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( "You are an AI that only responds with only python code. You will be given a" " function signature and its docstring by the user. Write your full implementation" - " (restate the function signature).") + " (restate the function signature)." +) PY_REFLEXION_CHAT_INSTRUCTION = ( "You are an AI Python assistant. You will be given your past function" " implementation, a series of unit tests, and a hint to change the implementation" @@ -31,7 +36,8 @@ PY_REFLEXION_CHAT_INSTRUCTION_V2 = ( "You are an AI Python assistant. You will be given your previous implementation of" " a function, a series of unit tests results, and your self-reflection on your" " previous implementation. Write your full implementation (restate the function" - " signature).") + " signature)." +) PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: [previous impl]: ```python @@ -169,14 +175,16 @@ PY_SELF_REFLECTION_CHAT_INSTRUCTION = ( " implementation and a series of unit tests. Your goal is to write a few sentences" " to explain why your implementation is wrong as indicated by the tests. You will" " need this as a hint when you try again later. Only provide the few sentence" - " description in your answer, not the implementation.") + " description in your answer, not the implementation." +) PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( "You are a Python programming assistant. You will be given a function" " implementation and a series of unit test results. Your goal is to write a few" " sentences to explain why your implementation is wrong as indicated by the tests." " You will need this as guidance when you try again later. Only provide the few" " sentence description in your answer, not the implementation. You will be given a" - " few examples by the user.") + " few examples by the user." +) PY_SELF_REFLECTION_FEW_SHOT = """Example 1: [function impl]: ```python diff --git a/swarms/prompts/sales.py b/swarms/prompts/sales.py index 6660e084..4f04f7fc 100644 --- a/swarms/prompts/sales.py +++ b/swarms/prompts/sales.py @@ -3,29 +3,36 @@ conversation_stages = { "Introduction: Start the conversation by introducing yourself and your company." " Be polite and respectful while keeping the tone of the conversation" " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect."), + " greeting the reason why you are contacting the prospect." + ), "2": ( "Qualification: Qualify the prospect by confirming if they are the right person" " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions."), + " authority to make purchasing decisions." + ), "3": ( "Value proposition: Briefly explain how your product/service can benefit the" " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors."), + " product/service that sets it apart from competitors." + ), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes."), - "5": ("Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." - ), - "6": - ("Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims."), + " pain points. Listen carefully to their responses and take notes." + ), + "5": ( + "Solution presentation: Based on the prospect's needs, present your" + " product/service as the solution that can address their pain points." + ), + "6": ( + "Objection handling: Address any objections that the prospect may have" + " regarding your product/service. Be prepared to provide evidence or" + " testimonials to support your claims." + ), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits."), + " discussed and reiterate the benefits." + ), } SALES_AGENT_TOOLS_PROMPT = """ diff --git a/swarms/prompts/sales_prompts.py b/swarms/prompts/sales_prompts.py index ce5303b3..3f2b9f2b 100644 --- a/swarms/prompts/sales_prompts.py +++ b/swarms/prompts/sales_prompts.py @@ -49,27 +49,34 @@ conversation_stages = { "Introduction: Start the conversation by introducing yourself and your company." " Be polite and respectful while keeping the tone of the conversation" " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect."), + " greeting the reason why you are contacting the prospect." + ), "2": ( "Qualification: Qualify the prospect by confirming if they are the right person" " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions."), + " authority to make purchasing decisions." + ), "3": ( "Value proposition: Briefly explain how your product/service can benefit the" " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors."), + " product/service that sets it apart from competitors." + ), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes."), - "5": ("Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." - ), - "6": - ("Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims."), + " pain points. Listen carefully to their responses and take notes." + ), + "5": ( + "Solution presentation: Based on the prospect's needs, present your" + " product/service as the solution that can address their pain points." + ), + "6": ( + "Objection handling: Address any objections that the prospect may have" + " regarding your product/service. Be prepared to provide evidence or" + " testimonials to support your claims." + ), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits."), + " discussed and reiterate the benefits." + ), } diff --git a/swarms/schemas/typings.py b/swarms/schemas/typings.py index f59b16f7..2d848736 100644 --- a/swarms/schemas/typings.py +++ b/swarms/schemas/typings.py @@ -18,11 +18,13 @@ class ChatbotError(Exception): def __init__(self, *args: object) -> None: if SUPPORT_ADD_NOTES: - super().add_note(( - "Please check that the input is correct, or you can resolve this" - " issue by filing an issue"),) super().add_note( - "Project URL: https://github.com/acheong08/ChatGPT") + ( + "Please check that the input is correct, or you can resolve this" + " issue by filing an issue" + ), + ) + super().add_note("Project URL: https://github.com/acheong08/ChatGPT") super().__init__(*args) diff --git a/swarms/structs/document.py b/swarms/structs/document.py index 505df6ae..b87d3d91 100644 --- a/swarms/structs/document.py +++ b/swarms/structs/document.py @@ -63,8 +63,9 @@ class BaseDocumentTransformer(ABC): """ # noqa: E501 @abstractmethod - def transform_documents(self, documents: Sequence[Document], - **kwargs: Any) -> Sequence[Document]: + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: """Transform a list of documents. Args: @@ -74,8 +75,9 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ - async def atransform_documents(self, documents: Sequence[Document], - **kwargs: Any) -> Sequence[Document]: + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: @@ -85,4 +87,5 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents) + None, partial(self.transform_documents, **kwargs), documents + ) diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index a3633a2c..8d89fd89 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -100,7 +100,7 @@ class Flow: self, llm: Any, # template: str, - max_loops = 5, + max_loops=5, stopping_condition: Optional[Callable[[str], bool]] = None, loop_interval: int = 1, retry_attempts: int = 3, @@ -188,7 +188,8 @@ class Flow: value = self.llm.__dict__.get(name, "Unknown") params_str_list.append( - f" {name.capitalize().replace('_', ' ')}: {value}") + f" {name.capitalize().replace('_', ' ')}: {value}" + ) return "\n".join(params_str_list) @@ -196,7 +197,7 @@ class Flow: """ Take the history and truncate it to fit into the model context length """ - truncated_history = self.memory[-1][-self.context_length:] + truncated_history = self.memory[-1][-self.context_length :] self.memory[-1] = truncated_history def add_task_to_memory(self, task: str): @@ -246,7 +247,8 @@ class Flow: ---------------------------------------- """, "green", - )) + ) + ) # print(dashboard) @@ -256,17 +258,18 @@ class Flow: print(colored("Initializing Autonomous Agent...", "yellow")) # print(colored("Loading modules...", "yellow")) # print(colored("Modules loaded successfully.", "green")) - print(colored("Autonomous Agent Activated.", "cyan", - attrs=["bold"])) - print(colored("All systems operational. Executing task...", - "green")) + print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"])) + print(colored("All systems operational. Executing task...", "green")) except Exception as error: print( colored( - ("Error activating autonomous agent. Try optimizing your" - " parameters..."), + ( + "Error activating autonomous agent. Try optimizing your" + " parameters..." + ), "red", - )) + ) + ) print(error) def run(self, task: str, **kwargs): @@ -296,7 +299,7 @@ class Flow: loop_count = 0 # for i in range(self.max_loops): - while self.max_loops == 'auto' or loop_count < self.max_loops: + while self.max_loops == "auto" or loop_count < self.max_loops: loop_count += 1 print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print("\n") @@ -315,8 +318,7 @@ class Flow: while attempt < self.retry_attempts: try: response = self.llm( - task - **kwargs, + task**kwargs, ) if self.interactive: print(f"AI: {response}") @@ -344,7 +346,7 @@ class Flow: if self.return_history: return response, history - return response + return response async def arun(self, task: str, **kwargs): """ @@ -373,7 +375,7 @@ class Flow: loop_count = 0 # for i in range(self.max_loops): - while self.max_loops == 'auto' or loop_count < self.max_loops: + while self.max_loops == "auto" or loop_count < self.max_loops: loop_count += 1 print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print("\n") @@ -392,8 +394,7 @@ class Flow: while attempt < self.retry_attempts: try: response = self.llm( - task - **kwargs, + task**kwargs, ) if self.interactive: print(f"AI: {response}") @@ -421,7 +422,7 @@ class Flow: if self.return_history: return response, history - return response + return response def _run(self, **kwargs: Any) -> str: """Generate a result using the provided keyword args.""" @@ -460,9 +461,7 @@ class Flow: Args: tasks (List[str]): A list of tasks to run. """ - task_coroutines = [ - self.run_async(task, **kwargs) for task in tasks - ] + task_coroutines = [self.run_async(task, **kwargs) for task in tasks] completed_tasks = await asyncio.gather(*task_coroutines) return completed_tasks @@ -575,9 +574,7 @@ class Flow: import boto3 s3 = boto3.client("s3") - s3.put_object(Bucket=bucket_name, - Key=object_name, - Body=json.dumps(self.memory)) + s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory)) print(f"Backed up memory to S3: {bucket_name}/{object_name}") def analyze_feedback(self): @@ -681,7 +678,7 @@ class Flow: def get_llm_params(self): """ Extracts and returns the parameters of the llm object for serialization. - It assumes that the llm object has an __init__ method + It assumes that the llm object has an __init__ method with parameters that can be used to recreate it. """ if not hasattr(self.llm, "__init__"): @@ -697,8 +694,8 @@ class Flow: if hasattr(self.llm, name): value = getattr(self.llm, name) if isinstance( - value, - (str, int, float, bool, list, dict, tuple, type(None))): + value, (str, int, float, bool, list, dict, tuple, type(None)) + ): llm_params[name] = value else: llm_params[name] = str( @@ -758,10 +755,7 @@ class Flow: print(f"Flow state loaded from {file_path}") - def retry_on_failure(self, - function, - retries: int = 3, - retry_delay: int = 1): + def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1): """Retry wrapper for LLM calls.""" attempt = 0 while attempt < retries: diff --git a/swarms/structs/nonlinear_workflow.py b/swarms/structs/nonlinear_workflow.py index 140c0d7b..2357f614 100644 --- a/swarms/structs/nonlinear_workflow.py +++ b/swarms/structs/nonlinear_workflow.py @@ -8,10 +8,9 @@ class Task: Task is a unit of work that can be executed by an agent """ - def __init__(self, - id: str, - parents: List["Task"] = None, - children: List["Task"] = None): + def __init__( + self, id: str, parents: List["Task"] = None, children: List["Task"] = None + ): self.id = id self.parents = parents self.children = children @@ -80,8 +79,7 @@ class NonLinearWorkflow: for task in ordered_tasks: if task.can_execute: - future = self.executor.submit(self.agents.run, - task.task_string) + future = self.executor.submit(self.agents.run, task.task_string) futures_list[future] = task for future in as_completed(futures_list): @@ -97,8 +95,7 @@ class NonLinearWorkflow: def to_graph(self) -> Dict[str, set[str]]: """Convert the workflow to a graph""" graph = { - task.id: set(child.id for child in task.children) - for task in self.tasks + task.id: set(child.id for child in task.children) for task in self.tasks } return graph diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index 8dd5abbd..8c7d9760 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -61,12 +61,13 @@ class Task: if isinstance(self.flow, Flow): # Add a prompt to notify the Flow of the sequential workflow if "prompt" in self.kwargs: - self.kwargs["prompt"] += (f"\n\nPrevious output: {self.result}" - if self.result else "") + self.kwargs["prompt"] += ( + f"\n\nPrevious output: {self.result}" if self.result else "" + ) else: self.kwargs["prompt"] = f"Main task: {self.description}" + ( - f"\n\nPrevious output: {self.result}" - if self.result else "") + f"\n\nPrevious output: {self.result}" if self.result else "" + ) self.result = self.flow.run(*self.args, **self.kwargs) else: self.result = self.flow(*self.args, **self.kwargs) @@ -110,8 +111,7 @@ class SequentialWorkflow: restore_state_filepath: Optional[str] = None dashboard: bool = False - def add(self, task: str, flow: Union[Callable, Flow], *args, - **kwargs) -> None: + def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None: """ Add a task to the workflow. @@ -127,7 +127,8 @@ class SequentialWorkflow: # Append the task to the tasks list self.tasks.append( - Task(description=task, flow=flow, args=list(args), kwargs=kwargs)) + Task(description=task, flow=flow, args=list(args), kwargs=kwargs) + ) def reset_workflow(self) -> None: """Resets the workflow by clearing the results of each task.""" @@ -179,9 +180,8 @@ class SequentialWorkflow: raise ValueError(f"Task {task_description} not found in workflow.") def save_workflow_state( - self, - filepath: Optional[str] = "sequential_workflow_state.json", - **kwargs) -> None: + self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs + ) -> None: """ Saves the workflow state to a json file. @@ -202,13 +202,16 @@ class SequentialWorkflow: with open(filepath, "w") as f: # Saving the state as a json for simplicuty state = { - "tasks": [{ - "description": task.description, - "args": task.args, - "kwargs": task.kwargs, - "result": task.result, - "history": task.history, - } for task in self.tasks], + "tasks": [ + { + "description": task.description, + "args": task.args, + "kwargs": task.kwargs, + "result": task.result, + "history": task.history, + } + for task in self.tasks + ], "max_loops": self.max_loops, } json.dump(state, f, indent=4) @@ -220,7 +223,8 @@ class SequentialWorkflow: Sequential Workflow Initializing...""", "green", attrs=["bold", "underline"], - )) + ) + ) def workflow_dashboard(self, **kwargs) -> None: """ @@ -259,7 +263,8 @@ class SequentialWorkflow: """, "cyan", attrs=["bold", "underline"], - )) + ) + ) def workflow_shutdown(self, **kwargs) -> None: print( @@ -268,7 +273,8 @@ class SequentialWorkflow: Sequential Workflow Shutdown...""", "red", attrs=["bold", "underline"], - )) + ) + ) def add_objective_to_workflow(self, task: str, **kwargs) -> None: print( @@ -277,7 +283,8 @@ class SequentialWorkflow: Adding Objective to Workflow...""", "green", attrs=["bold", "underline"], - )) + ) + ) task = Task( description=task, @@ -342,12 +349,13 @@ class SequentialWorkflow: if "task" not in task.kwargs: raise ValueError( "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'") + f" execution in '{task.description}'" + ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") - task.result = task.flow.run(flow_task_arg, - *task.args, - **task.kwargs) + task.result = task.flow.run( + flow_task_arg, *task.args, **task.kwargs + ) else: # If it's not a Flow instance, call the flow directly task.result = task.flow(*task.args, **task.kwargs) @@ -365,17 +373,19 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state( - "sequential_workflow_state.json") + self.save_workflow_state("sequential_workflow_state.json") except Exception as e: print( colored( - (f"Error initializing the Sequential workflow: {e} try" - " optimizing your inputs like the flow class and task" - " description"), + ( + f"Error initializing the Sequential workflow: {e} try" + " optimizing your inputs like the flow class and task" + " description" + ), "red", attrs=["bold", "underline"], - )) + ) + ) async def arun(self) -> None: """ @@ -395,11 +405,13 @@ class SequentialWorkflow: if "task" not in task.kwargs: raise ValueError( "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'") + f" execution in '{task.description}'" + ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") task.result = await task.flow.arun( - flow_task_arg, *task.args, **task.kwargs) + flow_task_arg, *task.args, **task.kwargs + ) else: # If it's not a Flow instance, call the flow directly task.result = await task.flow(*task.args, **task.kwargs) @@ -417,5 +429,4 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state( - "sequential_workflow_state.json") + self.save_workflow_state("sequential_workflow_state.json") diff --git a/swarms/structs/task.py b/swarms/structs/task.py index 6824bf0e..80f95d4d 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -13,7 +13,6 @@ from swarms.artifacts.error_artifact import ErrorArtifact class BaseTask(ABC): - class State(Enum): PENDING = 1 EXECUTING = 2 @@ -34,15 +33,11 @@ class BaseTask(ABC): @property def parents(self) -> List[BaseTask]: - return [ - self.structure.find_task(parent_id) for parent_id in self.parent_ids - ] + return [self.structure.find_task(parent_id) for parent_id in self.parent_ids] @property def children(self) -> List[BaseTask]: - return [ - self.structure.find_task(child_id) for child_id in self.child_ids - ] + return [self.structure.find_task(child_id) for child_id in self.child_ids] def __rshift__(self, child: BaseTask) -> BaseTask: return self.add_child(child) @@ -123,7 +118,8 @@ class BaseTask(ABC): def can_execute(self) -> bool: return self.state == self.State.PENDING and all( - parent.is_finished() for parent in self.parents) + parent.is_finished() for parent in self.parents + ) def reset(self) -> BaseTask: self.state = self.State.PENDING @@ -136,10 +132,10 @@ class BaseTask(ABC): class Task(BaseModel): - input: Optional[StrictStr] = Field(None, - description="Input prompt for the task") + input: Optional[StrictStr] = Field(None, description="Input prompt for the task") additional_input: Optional[Any] = Field( - None, description="Input parameters for the task. Any value is allowed") + None, description="Input parameters for the task. Any value is allowed" + ) task_id: StrictStr = Field(..., description="ID of the task") class Config: diff --git a/swarms/structs/workflow.py b/swarms/structs/workflow.py index e4a841ed..762ee6cc 100644 --- a/swarms/structs/workflow.py +++ b/swarms/structs/workflow.py @@ -65,13 +65,11 @@ class Workflow: def context(self, task: Task) -> Dict[str, Any]: """Context in tasks""" return { - "parent_output": - task.parents[0].output - if task.parents and task.parents[0].output else None, - "parent": - task.parents[0] if task.parents else None, - "child": - task.children[0] if task.children else None, + "parent_output": task.parents[0].output + if task.parents and task.parents[0].output + else None, + "parent": task.parents[0] if task.parents else None, + "child": task.children[0] if task.children else None, } def __run_from_task(self, task: Optional[Task]) -> None: diff --git a/swarms/swarms/autoscaler.py b/swarms/swarms/autoscaler.py index d0aaa598..5f6bedde 100644 --- a/swarms/swarms/autoscaler.py +++ b/swarms/swarms/autoscaler.py @@ -87,8 +87,7 @@ class AutoScaler: while True: sleep(60) # check minute pending_tasks = self.task_queue.qsize() - active_agents = sum( - [1 for agent in self.agents_pool if agent.is_busy()]) + active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()]) if pending_tasks / len(self.agents_pool) > self.busy_threshold: self.scale_up() diff --git a/swarms/swarms/base.py b/swarms/swarms/base.py index 6d8e0163..e99c9b38 100644 --- a/swarms/swarms/base.py +++ b/swarms/swarms/base.py @@ -117,9 +117,7 @@ class AbstractSwarm(ABC): pass @abstractmethod - def broadcast(self, - message: str, - sender: Optional["AbstractWorker"] = None): + def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None): """Broadcast a message to all workers""" pass diff --git a/swarms/swarms/battle_royal.py b/swarms/swarms/battle_royal.py index 7b5c2a99..2a02186e 100644 --- a/swarms/swarms/battle_royal.py +++ b/swarms/swarms/battle_royal.py @@ -77,15 +77,19 @@ class BattleRoyalSwarm: # Check for clashes and handle them for i, worker1 in enumerate(self.workers): for j, worker2 in enumerate(self.workers): - if (i != j and worker1.is_within_proximity(worker2) and - set(worker1.teams) != set(worker2.teams)): + if ( + i != j + and worker1.is_within_proximity(worker2) + and set(worker1.teams) != set(worker2.teams) + ): winner, loser = self.clash(worker1, worker2, question) print(f"Worker {winner.id} won over Worker {loser.id}") def communicate(self, sender: Worker, reciever: Worker, message: str): """Communicate a message from one worker to another.""" if sender.is_within_proximity(reciever) or any( - team in sender.teams for team in reciever.teams): + team in sender.teams for team in reciever.teams + ): pass def clash(self, worker1: Worker, worker2: Worker, question: str): diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index 7f302318..fe842f0a 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -49,8 +49,9 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), - "cyan")) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + ) + ) def run_all(self, task): """Run the task on all LLMs""" @@ -73,15 +74,18 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), - "cyan")) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + ) + ) # New Features def save_responses_to_file(self, filename): """Save responses to file""" with open(filename, "w") as file: - table = [[f"LLM {i+1}", response] - for i, response in enumerate(self.last_responses)] + table = [ + [f"LLM {i+1}", response] + for i, response in enumerate(self.last_responses) + ] file.write(tabulate(table, headers=["LLM", "Response"])) @classmethod @@ -101,9 +105,11 @@ class GodMode: for i, task in enumerate(self.task_history): print(f"{i + 1}. {task}") print("\nLast Responses:") - table = [[f"LLM {i+1}", response] - for i, response in enumerate(self.last_responses)] + table = [ + [f"LLM {i+1}", response] for i, response in enumerate(self.last_responses) + ] print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), - "cyan")) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + ) + ) diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index 842ebac9..6be43a89 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -33,8 +33,7 @@ class GroupChat: def next_agent(self, agent: Flow) -> Flow: """Return the next agent in the list.""" - return self.agents[(self.agent_names.index(agent.name) + 1) % - len(self.agents)] + return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] def select_speaker_msg(self): """Return the message for selecting the next speaker.""" @@ -55,17 +54,24 @@ class GroupChat: if n_agents < 3: logger.warning( f"GroupChat is underpopulated with {n_agents} agents. Direct" - " communication would be more efficient.") + " communication would be more efficient." + ) name = selector.generate_reply( - self.format_history(self.messages + [{ - "role": - "system", - "content": - ("Read the above conversation. Then select the next most" - f" suitable role from {self.agent_names} to play. Only" - " return the role."), - }])) + self.format_history( + self.messages + + [ + { + "role": "system", + "content": ( + "Read the above conversation. Then select the next most" + f" suitable role from {self.agent_names} to play. Only" + " return the role." + ), + } + ] + ) + ) try: return self.agent_by_name(name["content"]) except ValueError: @@ -73,7 +79,8 @@ class GroupChat: def _participant_roles(self): return "\n".join( - [f"{agent.name}: {agent.system_message}" for agent in self.agents]) + [f"{agent.name}: {agent.system_message}" for agent in self.agents] + ) def format_history(self, messages: List[Dict]) -> str: formatted_messages = [] @@ -84,21 +91,19 @@ class GroupChat: class GroupChatManager: - def __init__(self, groupchat: GroupChat, selector: Flow): self.groupchat = groupchat self.selector = selector def __call__(self, task: str): - self.groupchat.messages.append({ - "role": self.selector.name, - "content": task - }) + self.groupchat.messages.append({"role": self.selector.name, "content": task}) for i in range(self.groupchat.max_round): - speaker = self.groupchat.select_speaker(last_speaker=self.selector, - selector=self.selector) + speaker = self.groupchat.select_speaker( + last_speaker=self.selector, selector=self.selector + ) reply = speaker.generate_reply( - self.groupchat.format_history(self.groupchat.messages)) + self.groupchat.format_history(self.groupchat.messages) + ) self.groupchat.messages.append(reply) print(reply) if i == self.groupchat.max_round - 1: diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/swarms/multi_agent_collab.py index a3b79d7f..9a5f27bc 100644 --- a/swarms/swarms/multi_agent_collab.py +++ b/swarms/swarms/multi_agent_collab.py @@ -5,16 +5,16 @@ from langchain.output_parsers import RegexParser # utils class BidOutputParser(RegexParser): - def get_format_instructions(self) -> str: return ( "Your response should be an integrater delimited by angled brackets like" - " this: ") + " this: " + ) -bid_parser = BidOutputParser(regex=r"<(\d+)>", - output_keys=["bid"], - default_output_key="bid") +bid_parser = BidOutputParser( + regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid" +) def select_next_speaker(step: int, agents, director) -> int: @@ -29,7 +29,6 @@ def select_next_speaker(step: int, agents, director) -> int: # main class MultiAgentCollaboration: - def __init__( self, agents, diff --git a/swarms/swarms/multi_agent_debate.py b/swarms/swarms/multi_agent_debate.py index 1c7ebdf9..4bba3619 100644 --- a/swarms/swarms/multi_agent_debate.py +++ b/swarms/swarms/multi_agent_debate.py @@ -46,6 +46,7 @@ class MultiAgentDebate: def format_results(self, results): formatted_results = "\n".join( - [f"Agent responded: {result['response']}" for result in results]) + [f"Agent responded: {result['response']}" for result in results] + ) return formatted_results diff --git a/swarms/swarms/orchestrate.py b/swarms/swarms/orchestrate.py index d47771ab..f522911b 100644 --- a/swarms/swarms/orchestrate.py +++ b/swarms/swarms/orchestrate.py @@ -111,8 +111,7 @@ class Orchestrator: self.chroma_client = chromadb.Client() - self.collection = self.chroma_client.create_collection( - name=collection_name) + self.collection = self.chroma_client.create_collection(name=collection_name) self.current_tasks = {} @@ -138,8 +137,9 @@ class Orchestrator: result = self.worker.run(task["content"]) # using the embed method to get the vector representation of the result - vector_representation = self.embed(result, self.api_key, - self.model_name) + vector_representation = self.embed( + result, self.api_key, self.model_name + ) self.collection.add( embeddings=[vector_representation], @@ -154,7 +154,8 @@ class Orchestrator: except Exception as error: logging.error( f"Failed to process task {id(task)} by agent {id(agent)}. Error:" - f" {error}") + f" {error}" + ) finally: with self.condition: self.agents.put(agent) @@ -162,7 +163,8 @@ class Orchestrator: def embed(self, input, api_key, model_name): openai = embedding_functions.OpenAIEmbeddingFunction( - api_key=api_key, model_name=model_name) + api_key=api_key, model_name=model_name + ) embedding = openai(input) return embedding @@ -173,13 +175,13 @@ class Orchestrator: try: # Query the vector database for documents created by the agents - results = self.collection.query(query_texts=[str(agent_id)], - n_results=10) + results = self.collection.query(query_texts=[str(agent_id)], n_results=10) return results except Exception as e: logging.error( - f"Failed to retrieve results from agent {agent_id}. Error {e}") + f"Failed to retrieve results from agent {agent_id}. Error {e}" + ) raise # @abstractmethod @@ -210,8 +212,7 @@ class Orchestrator: self.collection.add(documents=[result], ids=[str(id(result))]) except Exception as e: - logging.error( - f"Failed to append the agent output to database. Error: {e}") + logging.error(f"Failed to append the agent output to database. Error: {e}") raise def run(self, objective: str): @@ -224,8 +225,8 @@ class Orchestrator: self.task_queue.append(objective) results = [ - self.assign_task(agent_id, task) for agent_id, task in zip( - range(len(self.agents)), self.task_queue) + self.assign_task(agent_id, task) + for agent_id, task in zip(range(len(self.agents)), self.task_queue) ] for result in results: diff --git a/swarms/swarms/simple_swarm.py b/swarms/swarms/simple_swarm.py index a382c0d7..7e806215 100644 --- a/swarms/swarms/simple_swarm.py +++ b/swarms/swarms/simple_swarm.py @@ -2,7 +2,6 @@ from queue import Queue, PriorityQueue class SimpleSwarm: - def __init__( self, llm, diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index 270504aa..cf5450e6 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -8,7 +8,8 @@ import torch from langchain.agents import tool from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent from langchain.chains.qa_with_sources.loading import ( - BaseCombineDocumentsChain,) + BaseCombineDocumentsChain, +) from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.tools import BaseTool @@ -36,10 +37,9 @@ def pushd(new_dir): @tool -def process_csv(llm, - csv_file_path: str, - instructions: str, - output_path: Optional[str] = None) -> str: +def process_csv( + llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None +) -> str: """Process a CSV by with pandas in a limited REPL.\ Only use this after writing data to disk as a csv file.\ Any figures must be saved to disk to be viewed by the human.\ @@ -49,10 +49,7 @@ def process_csv(llm, df = pd.read_csv(csv_file_path) except Exception as e: return f"Error: {e}" - agent = create_pandas_dataframe_agent(llm, - df, - max_iterations=30, - verbose=False) + agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False) if output_path is not None: instructions += f" Save output to disk at {output_path}" try: @@ -82,8 +79,7 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = ( - phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -117,7 +113,8 @@ class WebpageQATool(BaseTool): "Browse a webpage and retrieve the information relevant to the question." ) text_splitter: RecursiveCharacterTextSplitter = Field( - default_factory=_get_text_splitter) + default_factory=_get_text_splitter + ) qa_chain: BaseCombineDocumentsChain def _run(self, url: str, question: str) -> str: @@ -128,12 +125,9 @@ class WebpageQATool(BaseTool): results = [] # TODO: Handle this with a MapReduceChain for i in range(0, len(web_docs), 4): - input_docs = web_docs[i:i + 4] + input_docs = web_docs[i : i + 4] window_result = self.qa_chain( - { - "input_documents": input_docs, - "question": question - }, + {"input_documents": input_docs, "question": question}, return_only_outputs=True, ) results.append(f"Response from window {i} - {window_result}") @@ -141,10 +135,7 @@ class WebpageQATool(BaseTool): Document(page_content="\n".join(results), metadata={"source": url}) ] return self.qa_chain( - { - "input_documents": results_docs, - "question": question - }, + {"input_documents": results_docs, "question": question}, return_only_outputs=True, ) @@ -180,17 +171,18 @@ def VQAinference(self, inputs): torch_dtype = torch.float16 if "cuda" in device else torch.float32 processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", torch_dtype=torch_dtype).to(device) + "Salesforce/blip-vqa-base", torch_dtype=torch_dtype + ).to(device) image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = processor(raw_image, question, - return_tensors="pt").to(device, torch_dtype) + inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype) out = model.generate(**inputs) answer = processor.decode(out[0], skip_special_tokens=True) logger.debug( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}") + f" Question: {question}, Output Answer: {answer}" + ) return answer diff --git a/swarms/tools/mm_models.py b/swarms/tools/mm_models.py index fd115bd6..58fe11e5 100644 --- a/swarms/tools/mm_models.py +++ b/swarms/tools/mm_models.py @@ -25,14 +25,13 @@ from swarms.utils.main import BaseHandler, get_new_image_name class MaskFormer: - def __init__(self, device): print("Initializing MaskFormer to %s" % device) self.device = device - self.processor = CLIPSegProcessor.from_pretrained( - "CIDAS/clipseg-rd64-refined") + self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained( - "CIDAS/clipseg-rd64-refined").to(device) + "CIDAS/clipseg-rd64-refined" + ).to(device) def inference(self, image_path, text): threshold = 0.5 @@ -40,10 +39,9 @@ class MaskFormer: padding = 20 original_image = Image.open(image_path) image = original_image.resize((512, 512)) - inputs = self.processor(text=text, - images=image, - padding="max_length", - return_tensors="pt").to(self.device) + inputs = self.processor( + text=text, images=image, padding="max_length", return_tensors="pt" + ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold @@ -54,7 +52,8 @@ class MaskFormer: mask_array = np.zeros_like(mask, dtype=bool) for idx in true_indices: padded_slice = tuple( - slice(max(0, i - padding), i + padding + 1) for i in idx) + slice(max(0, i - padding), i + padding + 1) for i in idx + ) mask_array[padded_slice] = True visual_mask = (mask_array * 255).astype(np.uint8) image_mask = Image.fromarray(visual_mask) @@ -62,7 +61,6 @@ class MaskFormer: class ImageEditing: - def __init__(self, device): print("Initializing ImageEditing to %s" % device) self.device = device @@ -77,24 +75,25 @@ class ImageEditing: @tool( name="Remove Something From The Photo", - description= - ("useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. "), + description=( + "useful when you want to remove and object or something from the photo " + "from its description or location. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the object need to be removed. " + ), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",") - return self.inference_replace( - f"{image_path},{to_be_removed_txt},background") + return self.inference_replace(f"{image_path},{to_be_removed_txt},background") @tool( name="Replace Something From The Photo", - description= - ("useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with "), + description=( + "useful when you want to replace an object from the object description or" + " location with another object from its description. The input to this tool" + " should be a comma separated string of three, representing the image_path," + " the object to be replaced, the object to be replaced with " + ), ) def inference_replace(self, inputs): image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") @@ -106,21 +105,22 @@ class ImageEditing: image=original_image.resize((512, 512)), mask_image=mask_image.resize((512, 512)), ).images[0] - updated_image_path = get_new_image_name(image_path, - func_name="replace-something") + updated_image_path = get_new_image_name( + image_path, func_name="replace-something" + ) updated_image = updated_image.resize(original_size) updated_image.save(updated_image_path) logger.debug( f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" - f" {updated_image_path}") + f" {updated_image_path}" + ) return updated_image_path class InstructPix2Pix: - def __init__(self, device): print("Initializing InstructPix2Pix to %s" % device) self.device = device @@ -131,56 +131,60 @@ class InstructPix2Pix: torch_dtype=self.torch_dtype, ).to(device) self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - self.pipe.scheduler.config) + self.pipe.scheduler.config + ) @tool( name="Instruct Image Using Text", - description= - ("useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. "), + description=( + "useful when you want to the style of the image to be like the text. " + "like: make it look like a painting. or make it like a robot. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the text. " + ), ) def inference(self, inputs): """Change style of image.""" logger.debug("===> Starting InstructPix2Pix Inference") image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) - image = self.pipe(text, - image=original_image, - num_inference_steps=40, - image_guidance_scale=1.2).images[0] + image = self.pipe( + text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 + ).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) logger.debug( f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}") + f" {text}, Output Image: {updated_image_path}" + ) return updated_image_path class Text2Image: - def __init__(self, device): print("Initializing Text2Image to %s" % device) self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) + "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype + ) self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality") + "fewer digits, cropped, worst quality, low quality" + ) @tool( name="Generate Image From User Input Text", - description= - ("useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. "), + description=( + "useful when you want to generate an image from a user input text and save" + " it to a file. like: generate an image of an object or something, or" + " generate an image that includes some objects. The input to this tool" + " should be a string, representing the text used to generate image. " + ), ) def inference(self, text): image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png") @@ -190,59 +194,59 @@ class Text2Image: logger.debug( f"\nProcessed Text2Image, Input Text: {text}, Output Image:" - f" {image_filename}") + f" {image_filename}" + ) return image_filename class VisualQuestionAnswering: - def __init__(self, device): print("Initializing VisualQuestionAnswering to %s" % device) self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") self.model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", - torch_dtype=self.torch_dtype).to(self.device) + "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype + ).to(self.device) @tool( name="Answer Question About The Image", - description= - ("useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + description=( + "useful when you need an answer for a question based on an image. like:" + " what is the background color of the last image, how many cats in this" + " figure, what is in this figure. The input to this tool should be a comma" + " separated string of two, representing the image_path and the question" ), ) def inference(self, inputs): image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = self.processor(raw_image, question, - return_tensors="pt").to(self.device, - self.torch_dtype) + inputs = self.processor(raw_image, question, return_tensors="pt").to( + self.device, self.torch_dtype + ) out = self.model.generate(**inputs) answer = self.processor.decode(out[0], skip_special_tokens=True) logger.debug( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}") + f" Question: {question}, Output Answer: {answer}" + ) return answer class ImageCaptioning(BaseHandler): - def __init__(self, device): print("Initializing ImageCaptioning to %s" % device) self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-image-captioning-base") + "Salesforce/blip-image-captioning-base" + ) self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", - torch_dtype=self.torch_dtype).to(self.device) + "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype + ).to(self.device) def handle(self, filename: str): img = Image.open(filename) @@ -254,13 +258,14 @@ class ImageCaptioning(BaseHandler): img.save(filename, "PNG") print(f"Resize image form {width}x{height} to {width_new}x{height_new}") - inputs = self.processor(Image.open(filename), - return_tensors="pt").to(self.device, - self.torch_dtype) + inputs = self.processor(Image.open(filename), return_tensors="pt").to( + self.device, self.torch_dtype + ) out = self.model.generate(**inputs) description = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" - f" {description}") + f" {description}" + ) return IMAGE_PROMPT.format(filename=filename, description=description) diff --git a/swarms/tools/stt.py b/swarms/tools/stt.py index da9d7f27..cfe3e656 100644 --- a/swarms/tools/stt.py +++ b/swarms/tools/stt.py @@ -9,7 +9,6 @@ from pytube import YouTube class SpeechToText: - def __init__( self, video_url, @@ -62,15 +61,14 @@ class SpeechToText: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model("large-v2", - device, - compute_type=compute_type) + model = whisperx.load_model("large-v2", device, compute_type=compute_type) audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) # 2. Align Whisper output 🔍 model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=device) + language_code=result["language"], device=device + ) result = whisperx.align( result["segments"], model_a, @@ -82,7 +80,8 @@ class SpeechToText: # 3. Assign speaker labels 🏷️ diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=device) + use_auth_token=self.hf_api_key, device=device + ) diarize_model(audio_file) try: @@ -99,7 +98,8 @@ class SpeechToText: # 2. Align Whisper output 🔍 model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=self.device) + language_code=result["language"], device=self.device + ) result = whisperx.align( result["segments"], @@ -112,7 +112,8 @@ class SpeechToText: # 3. Assign speaker labels 🏷️ diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=self.device) + use_auth_token=self.hf_api_key, device=self.device + ) diarize_model(audio_file) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index 29b0f5de..f7e85204 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -34,8 +34,9 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" -def _create_subset_model(name: str, model: BaseModel, - field_names: list) -> Type[BaseModel]: +def _create_subset_model( + name: str, model: BaseModel, field_names: list +) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} for field_name in field_names: @@ -51,11 +52,7 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return { - k: schema[k] - for k in valid_keys - if k not in ("run_manager", "callbacks") - } + return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} class _SchemaConfig: @@ -85,8 +82,9 @@ def create_schema_from_function( del inferred_model.__fields__["callbacks"] # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func) - return _create_subset_model(f"{model_name}Schema", inferred_model, - list(valid_properties)) + return _create_subset_model( + f"{model_name}Schema", inferred_model, list(valid_properties) + ) class ToolException(Exception): @@ -127,7 +125,8 @@ class ChildTool(BaseTool): "Expected annotation of 'Type[BaseModel]'" f" but got '{args_schema_type}'.\n" "Expected class looks like:\n" - f"{typehint_mandate}") + f"{typehint_mandate}" + ) name: str """The unique name of the tool that clearly communicates its purpose.""" @@ -148,8 +147,7 @@ class ChildTool(BaseTool): callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, - exclude=True) + callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """Deprecated. Please use callbacks instead.""" tags: Optional[List[str]] = None """Optional list of tags associated with the tool. Defaults to None @@ -164,8 +162,9 @@ class ChildTool(BaseTool): You can use these to eg identify a specific instance of a tool with its use case. """ - handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], - str]]] = False + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False """Handle the content of the ToolException thrown.""" class Config(Serializable.Config): @@ -245,9 +244,7 @@ class ChildTool(BaseTool): else: if input_args is not None: result = input_args.parse_obj(tool_input) - return { - k: v for k, v in result.dict().items() if k in tool_input - } + return {k: v for k, v in result.dict().items() if k in tool_input} return tool_input @root_validator() @@ -289,8 +286,7 @@ class ChildTool(BaseTool): *args, ) - def _to_args_and_kwargs(self, - tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -329,10 +325,7 @@ class ChildTool(BaseTool): # TODO: maybe also pass through run_manager is _run supports kwargs new_arg_supported = signature(self._run).parameters.get("run_manager") run_manager = callback_manager.on_tool_start( - { - "name": self.name, - "description": self.description - }, + {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, @@ -342,7 +335,9 @@ class ChildTool(BaseTool): tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) observation = ( self._run(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported else self._run(*tool_args, **tool_kwargs)) + if new_arg_supported + else self._run(*tool_args, **tool_kwargs) + ) except ToolException as e: if not self.handle_tool_error: run_manager.on_tool_error(e) @@ -359,20 +354,19 @@ class ChildTool(BaseTool): else: raise ValueError( "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}") - run_manager.on_tool_end(str(observation), - color="red", - name=self.name, - **kwargs) + f"or callable. Received: {self.handle_tool_error}" + ) + run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) return observation except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e else: - run_manager.on_tool_end(str(observation), - color=color, - name=self.name, - **kwargs) + run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) return observation async def arun( @@ -405,10 +399,7 @@ class ChildTool(BaseTool): ) new_arg_supported = signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( - { - "name": self.name, - "description": self.description - }, + {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, @@ -417,10 +408,11 @@ class ChildTool(BaseTool): try: # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = (await self._arun(*tool_args, - run_manager=run_manager, - **tool_kwargs) if new_arg_supported - else await self._arun(*tool_args, **tool_kwargs)) + observation = ( + await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else await self._arun(*tool_args, **tool_kwargs) + ) except ToolException as e: if not self.handle_tool_error: await run_manager.on_tool_error(e) @@ -437,20 +429,19 @@ class ChildTool(BaseTool): else: raise ValueError( "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}") - await run_manager.on_tool_end(str(observation), - color="red", - name=self.name, - **kwargs) + f"or callable. Received: {self.handle_tool_error}" + ) + await run_manager.on_tool_end( + str(observation), color="red", name=self.name, **kwargs + ) return observation except (Exception, KeyboardInterrupt) as e: await run_manager.on_tool_error(e) raise e else: - await run_manager.on_tool_end(str(observation), - color=color, - name=self.name, - **kwargs) + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) return observation def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: @@ -477,7 +468,8 @@ class Tool(BaseTool): if not self.coroutine: # If the tool does not implement async, fall back to default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs)) + None, partial(self.invoke, input, config, **kwargs) + ) return await super().ainvoke(input, config, **kwargs) @@ -492,8 +484,7 @@ class Tool(BaseTool): # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, - tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) # For backwards compatibility. The tool must be run with a single input @@ -512,13 +503,16 @@ class Tool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature( - self.func).parameters.get("callbacks") - return (self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) if new_argument_supported else self.func(*args, **kwargs)) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) raise NotImplementedError("Tool does not support sync") async def _arun( @@ -529,27 +523,31 @@ class Tool(BaseTool): ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature( - self.coroutine).parameters.get("callbacks") - return (await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) if new_argument_supported else await self.coroutine( - *args, **kwargs)) + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) else: return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), - *args) + None, partial(self._run, run_manager=run_manager, **kwargs), *args + ) # TODO: this is for backwards compatibility, remove in future - def __init__(self, name: str, func: Optional[Callable], description: str, - **kwargs: Any) -> None: + def __init__( + self, name: str, func: Optional[Callable], description: str, **kwargs: Any + ) -> None: """Initialize tool.""" - super(Tool, self).__init__(name=name, - func=func, - description=description, - **kwargs) + super(Tool, self).__init__( + name=name, func=func, description=description, **kwargs + ) @classmethod def from_function( @@ -559,8 +557,9 @@ class Tool(BaseTool): description: str, return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, - coroutine: Optional[Callable[..., Awaitable[ - Any]]] = None, # This is last for compatibility, but should be after func + coroutine: Optional[ + Callable[..., Awaitable[Any]] + ] = None, # This is last for compatibility, but should be after func **kwargs: Any, ) -> Tool: """Initialize tool from a function.""" @@ -598,7 +597,8 @@ class StructuredTool(BaseTool): if not self.coroutine: # If the tool does not implement async, fall back to default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs)) + None, partial(self.invoke, input, config, **kwargs) + ) return await super().ainvoke(input, config, **kwargs) @@ -617,13 +617,16 @@ class StructuredTool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature( - self.func).parameters.get("callbacks") - return (self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) if new_argument_supported else self.func(*args, **kwargs)) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) raise NotImplementedError("Tool does not support sync") async def _arun( @@ -634,14 +637,18 @@ class StructuredTool(BaseTool): ) -> str: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature( - self.coroutine).parameters.get("callbacks") - return (await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) if new_argument_supported else await self.coroutine( - *args, **kwargs)) + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) return await asyncio.get_running_loop().run_in_executor( None, partial(self._run, run_manager=run_manager, **kwargs), @@ -698,7 +705,8 @@ class StructuredTool(BaseTool): description = description or source_function.__doc__ if description is None: raise ValueError( - "Function must have a docstring if description not provided.") + "Function must have a docstring if description not provided." + ) # Description example: # search_api(query: str) - Searches the API for the query. @@ -706,8 +714,7 @@ class StructuredTool(BaseTool): description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", - source_function) + _args_schema = create_schema_from_function(f"{name}Schema", source_function) return cls( name=name, func=func, @@ -755,7 +762,6 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): runnable = dec_func @@ -763,13 +769,14 @@ def tool( if runnable.input_schema.schema().get("type") != "object": raise ValueError("Runnable must have an object schema.") - async def ainvoke_wrapper(callbacks: Optional[Callbacks] = None, - **kwargs: Any) -> Any: - return await runnable.ainvoke(kwargs, - {"callbacks": callbacks}) + async def ainvoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: + return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) - def invoke_wrapper(callbacks: Optional[Callbacks] = None, - **kwargs: Any) -> Any: + def invoke_wrapper( + callbacks: Optional[Callbacks] = None, **kwargs: Any + ) -> Any: return runnable.invoke(kwargs, {"callbacks": callbacks}) coroutine = ainvoke_wrapper @@ -802,7 +809,8 @@ def tool( if func.__doc__ is None: raise ValueError( "Function must have a docstring if " - "description not provided and infer_schema is False.") + "description not provided and infer_schema is False." + ) return Tool( name=tool_name, func=func, @@ -813,8 +821,7 @@ def tool( return _make_tool - if len(args) == 2 and isinstance(args[0], str) and isinstance( - args[1], Runnable): + if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): return _make_with_name(args[0])(args[1]) elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name diff --git a/swarms/tools/tool_registry.py b/swarms/tools/tool_registry.py index 3354646a..5aa544e9 100644 --- a/swarms/tools/tool_registry.py +++ b/swarms/tools/tool_registry.py @@ -6,7 +6,6 @@ FuncToolBuilder = Callable[[], ToolBuilder] class ToolsRegistry: - def __init__(self) -> None: self.tools: Dict[str, FuncToolBuilder] = {} @@ -19,7 +18,8 @@ class ToolsRegistry: if isinstance(ret, tool): return ret raise ValueError( - "Tool builder {} did not return a Tool instance".format(tool_name)) + "Tool builder {} did not return a Tool instance".format(tool_name) + ) def list_tools(self) -> List[str]: return list(self.tools.keys()) @@ -29,7 +29,6 @@ tools_registry = ToolsRegistry() def register(tool_name): - def decorator(tool: FuncToolBuilder): tools_registry.register(tool_name, tool) return tool diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index c89ac7a7..80eb6700 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -118,19 +118,14 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): # Most of the time it doesn't matter, but we should figure out why it happens frequently with: # applescript yield {"output": traceback.format_exc()} - yield { - "output": f"Retrying... ({retry_count}/{max_retries})" - } + yield {"output": f"Retrying... ({retry_count}/{max_retries})"} yield {"output": "Restarting process."} self.start_process() retry_count += 1 if retry_count > max_retries: - yield { - "output": - "Maximum retries reached. Could not execute code." - } + yield {"output": "Maximum retries reached. Could not execute code."} return while True: @@ -139,8 +134,7 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): else: time.sleep(0.1) try: - output = self.output_queue.get( - timeout=0.3) # Waits for 0.3 seconds + output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds yield output except queue.Empty: if self.done.is_set(): diff --git a/swarms/utils/decorators.py b/swarms/utils/decorators.py index 2f22528b..8a5a5d56 100644 --- a/swarms/utils/decorators.py +++ b/swarms/utils/decorators.py @@ -6,7 +6,6 @@ import warnings def log_decorator(func): - def wrapper(*args, **kwargs): logging.info(f"Entering {func.__name__}") result = func(*args, **kwargs) @@ -17,7 +16,6 @@ def log_decorator(func): def error_decorator(func): - def wrapper(*args, **kwargs): try: return func(*args, **kwargs) @@ -29,22 +27,18 @@ def error_decorator(func): def timing_decorator(func): - def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() - logging.info( - f"{func.__name__} executed in {end_time - start_time} seconds") + logging.info(f"{func.__name__} executed in {end_time - start_time} seconds") return result return wrapper def retry_decorator(max_retries=5): - def decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): for _ in range(max_retries): @@ -83,20 +77,16 @@ def synchronized_decorator(func): def deprecated_decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.warn(f"{func.__name__} is deprecated", - category=DeprecationWarning) + warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning) return func(*args, **kwargs) return wrapper def validate_inputs_decorator(validator): - def decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): if not validator(*args, **kwargs): diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py index 5c2dfdcd..55a4e5d5 100644 --- a/swarms/utils/futures.py +++ b/swarms/utils/futures.py @@ -5,8 +5,6 @@ T = TypeVar("T") def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: - futures.wait(fs_dict.values(), - timeout=None, - return_when=futures.ALL_COMPLETED) + futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) return {key: future.result() for key, future in fs_dict.items()} diff --git a/swarms/utils/hash.py b/swarms/utils/hash.py index 458fc147..725cc6ba 100644 --- a/swarms/utils/hash.py +++ b/swarms/utils/hash.py @@ -4,7 +4,8 @@ import hashlib def dataframe_to_hash(dataframe: pd.DataFrame) -> str: return hashlib.sha256( - pd.util.hash_pandas_object(dataframe, index=True).values).hexdigest() + pd.util.hash_pandas_object(dataframe, index=True).values + ).hexdigest() def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str: diff --git a/swarms/utils/main.py b/swarms/utils/main.py index 9d5eefdf..63cb0e4a 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -51,16 +51,16 @@ def get_new_image_name(org_img_name, func_name="update"): if len(name_split) == 1: most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, - recent_prev_file_name, - most_org_file_name) + new_file_name = "{}_{}_{}_{}.png".format( + this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + ) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, - recent_prev_file_name, - most_org_file_name) + new_file_name = "{}_{}_{}_{}.png".format( + this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + ) return os.path.join(head, new_file_name) @@ -73,16 +73,16 @@ def get_new_dataframe_name(org_img_name, func_name="update"): if len(name_split) == 1: most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, - recent_prev_file_name, - most_org_file_name) + new_file_name = "{}_{}_{}_{}.csv".format( + this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + ) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, - recent_prev_file_name, - most_org_file_name) + new_file_name = "{}_{}_{}_{}.csv".format( + this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + ) return os.path.join(head, new_file_name) @@ -92,7 +92,6 @@ def get_new_dataframe_name(org_img_name, func_name="update"): class Code: - def __init__(self, value: int): self.value = value @@ -101,7 +100,6 @@ class Code: class Color(Code): - def bg(self) -> "Color": self.value += 10 return self @@ -148,7 +146,6 @@ class Color(Code): class Style(Code): - @staticmethod def reset() -> "Style": return Style(0) @@ -205,8 +202,7 @@ def dim_multiline(message: str) -> str: lines = message.split("\n") if len(lines) <= 1: return lines[0] - return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to( - Color.black().bright()) + return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) # +=============================> ANSI Ending @@ -217,7 +213,6 @@ STATIC_DIR = "static" class AbstractUploader(ABC): - @abstractmethod def upload(self, filepath: str) -> str: pass @@ -233,9 +228,7 @@ class AbstractUploader(ABC): class S3Uploader(AbstractUploader): - - def __init__(self, accessKey: str, secretKey: str, region: str, - bucket: str): + def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str): self.accessKey = accessKey self.secretKey = secretKey self.region = region @@ -270,7 +263,6 @@ class S3Uploader(AbstractUploader): class StaticUploader(AbstractUploader): - def __init__(self, server: str, path: Path, endpoint: str): self.server = server self.path = path @@ -338,19 +330,16 @@ class FileType(Enum): class BaseHandler: - def handle(self, filename: str) -> str: raise NotImplementedError class FileHandler: - def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): self.handlers = handlers self.path = path - def register(self, filetype: FileType, - handler: BaseHandler) -> "FileHandler": + def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": self.handlers[filetype] = handler return self @@ -358,8 +347,8 @@ class FileHandler: filetype = FileType.from_url(url) data = requests.get(url).content local_filename = os.path.join( - "file", - str(uuid.uuid4())[0:8] + filetype.to_extension()) + "file", str(uuid.uuid4())[0:8] + filetype.to_extension() + ) os.makedirs(os.path.dirname(local_filename), exist_ok=True) with open(local_filename, "wb") as f: size = f.write(data) @@ -368,15 +357,17 @@ class FileHandler: def handle(self, url: str) -> str: try: - if url.startswith(os.environ.get("SERVER", - "http://localhost:8000")): + if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): local_filepath = url[ - len(os.environ.get("SERVER", "http://localhost:8000")) + 1:] + len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : + ] local_filename = Path("file") / local_filepath.split("/")[-1] src = self.path / local_filepath - dst = (self.path / - os.environ.get("PLAYGROUND_DIR", "./playground") / - local_filename) + dst = ( + self.path + / os.environ.get("PLAYGROUND_DIR", "./playground") + / local_filename + ) os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(src, dst) else: @@ -386,7 +377,8 @@ class FileHandler: if FileType.from_url(url) == FileType.IMAGE: raise Exception( f"No handler for {FileType.from_url(url)}. " - "Please set USE_GPU to True in env/settings.py") + "Please set USE_GPU to True in env/settings.py" + ) else: raise Exception(f"No handler for {FileType.from_url(url)}") return handler.handle(local_filename) @@ -400,17 +392,17 @@ class FileHandler: class CsvToDataframe(BaseHandler): - def handle(self, filename: str): df = pd.read_csv(filename) description = ( f"Dataframe with {len(df)} rows and {len(df.columns)} columns. " "Columns are: " - f"{', '.join(df.columns)}") + f"{', '.join(df.columns)}" + ) print( f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:" - f" {description}") + f" {description}" + ) - return DATAFRAME_PROMPT.format(filename=filename, - description=description) + return DATAFRAME_PROMPT.format(filename=filename, description=description) diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index 020c9bef..a2f346ea 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -7,6 +7,5 @@ def extract_code_in_backticks_in_string(message: str) -> str: """ pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks - match = re.search(pattern, message, - re.DOTALL) # re.DOTALL to match newline chars + match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars return match.group(1).strip() if match else None diff --git a/swarms/utils/revutils.py b/swarms/utils/revutils.py index 9db1e123..7868ae44 100644 --- a/swarms/utils/revutils.py +++ b/swarms/utils/revutils.py @@ -49,12 +49,16 @@ def get_input( """ Multiline input function. """ - return (session.prompt( - completer=completer, - multiline=True, - auto_suggest=AutoSuggestFromHistory(), - key_bindings=key_bindings, - ) if session else prompt(multiline=True)) + return ( + session.prompt( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=key_bindings, + ) + if session + else prompt(multiline=True) + ) async def get_input_async( @@ -64,11 +68,15 @@ async def get_input_async( """ Multiline input function. """ - return (await session.prompt_async( - completer=completer, - multiline=True, - auto_suggest=AutoSuggestFromHistory(), - ) if session else prompt(multiline=True)) + return ( + await session.prompt_async( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + ) + if session + else prompt(multiline=True) + ) def get_filtered_keys_from_object(obj: object, *keys: str) -> any: @@ -86,7 +94,9 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any: return {key for key in class_keys if key not in keys[1:]} # Check if all passed keys are valid if invalid_keys := set(keys) - class_keys: - raise ValueError(f"Invalid keys: {invalid_keys}",) + raise ValueError( + f"Invalid keys: {invalid_keys}", + ) # Only return specified keys that are in class_keys return {key for key in keys if key in class_keys} @@ -114,8 +124,8 @@ def random_int(min: int, max: int) -> int: if __name__ == "__main__": logging.basicConfig( - format= - "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",) + format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", + ) log = logging.getLogger(__name__) diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index 47cc815f..8f0e5ccf 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -106,22 +106,21 @@ class Serializable(BaseModel, ABC): lc_kwargs.update({key: secret_value}) return { - "lc": - 1, - "type": - "constructor", + "lc": 1, + "type": "constructor", "id": [*self.lc_namespace, self.__class__.__name__], - "kwargs": - lc_kwargs if not secrets else _replace_secrets( - lc_kwargs, secrets), + "kwargs": lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets), } def to_json_not_implemented(self) -> SerializedNotImplemented: return to_json_not_implemented(self) -def _replace_secrets(root: Dict[Any, Any], - secrets_map: Dict[str, str]) -> Dict[Any, Any]: +def _replace_secrets( + root: Dict[Any, Any], secrets_map: Dict[str, str] +) -> Dict[Any, Any]: result = root.copy() for path, secret_id in secrets_map.items(): [*parts, last] = path.split(".") diff --git a/swarms/utils/static.py b/swarms/utils/static.py index 23f13996..3b8a276d 100644 --- a/swarms/utils/static.py +++ b/swarms/utils/static.py @@ -8,7 +8,6 @@ from swarms.utils.main import AbstractUploader class StaticUploader(AbstractUploader): - def __init__(self, server: str, path: Path, endpoint: str): self.server = server self.path = path diff --git a/swarms/workers/worker.py b/swarms/workers/worker.py index bef9682a..9986666a 100644 --- a/swarms/workers/worker.py +++ b/swarms/workers/worker.py @@ -4,7 +4,8 @@ from typing import Dict, Union import faiss from langchain.chains.qa_with_sources.loading import ( - load_qa_with_sources_chain,) + load_qa_with_sources_chain, +) from langchain.docstore import InMemoryDocstore from langchain.embeddings import OpenAIEmbeddings from langchain.tools import ReadFileTool, WriteFileTool @@ -131,7 +132,8 @@ class Worker: ``` """ query_website_tool = WebpageQATool( - qa_chain=load_qa_with_sources_chain(self.llm)) + qa_chain=load_qa_with_sources_chain(self.llm) + ) self.tools = [ WriteFileTool(root_dir=ROOT_DIR), @@ -155,13 +157,15 @@ class Worker: embedding_size = 1536 index = faiss.IndexFlatL2(embedding_size) - self.vectorstore = FAISS(embeddings_model.embed_query, index, - InMemoryDocstore({}), {}) + self.vectorstore = FAISS( + embeddings_model.embed_query, index, InMemoryDocstore({}), {} + ) except Exception as error: raise RuntimeError( "Error setting up memory perhaps try try tuning the embedding size:" - f" {error}") + f" {error}" + ) def setup_agent(self): """ @@ -290,6 +294,8 @@ class Worker: def is_within_proximity(self, other_worker): """Using Euclidean distance for proximity check""" - distance = ((self.coordinates[0] - other_worker.coordinates[0])**2 + - (self.coordinates[1] - other_worker.coordinates[1])**2)**0.5 + distance = ( + (self.coordinates[0] - other_worker.coordinates[0]) ** 2 + + (self.coordinates[1] - other_worker.coordinates[1]) ** 2 + ) ** 0.5 return distance < 10 # threshold for proximity