From 2e7905db461fe5116023aa34a4b5affdd3a6cbf9 Mon Sep 17 00:00:00 2001 From: Kye Date: Wed, 8 Nov 2023 17:44:31 -0500 Subject: [PATCH] yapf code quality --- quality.sh | 5 +- swarms/agents/__init__.py | 2 - swarms/agents/agent.py | 121 ++- swarms/agents/aot.py | 54 +- swarms/agents/browser_agent.py | 91 +- swarms/agents/hf_agents.py | 115 ++- swarms/agents/meta_prompter.py | 14 +- swarms/agents/multi_modal_visual_agent.py | 950 +++++++++--------- .../neural_architecture_search_worker.py | 1 + swarms/agents/omni_modal_agent.py | 32 +- swarms/agents/profitpilot.py | 105 +- swarms/agents/refiner_agent.py | 4 +- swarms/agents/registry.py | 4 +- swarms/agents/simple_agent.py | 3 +- swarms/artifacts/base.py | 8 +- swarms/artifacts/main.py | 15 +- swarms/chunkers/__init__.py | 1 - swarms/chunkers/base.py | 39 +- swarms/chunkers/omni_chunker.py | 8 +- swarms/loaders/asana.py | 80 +- swarms/loaders/base.py | 128 +-- swarms/memory/base.py | 113 +-- swarms/memory/chroma.py | 90 +- swarms/memory/cosine_similarity.py | 6 +- swarms/memory/db.py | 23 +- swarms/memory/ocean.py | 9 +- swarms/memory/pg.py | 59 +- swarms/memory/pinecone.py | 53 +- swarms/memory/schemas.py | 34 +- swarms/memory/utils.py | 9 +- swarms/models/__init__.py | 2 - swarms/models/anthropic.py | 97 +- swarms/models/bioclip.py | 31 +- swarms/models/biogpt.py | 18 +- swarms/models/dalle3.py | 24 +- swarms/models/distilled_whisperx.py | 32 +- swarms/models/fastvit.py | 25 +- swarms/models/fuyu.py | 18 +- swarms/models/gpt4v.py | 106 +- swarms/models/huggingface.py | 88 +- swarms/models/idefics.py | 54 +- swarms/models/jina_embeds.py | 25 +- swarms/models/kosmos2.py | 48 +- swarms/models/kosmos_two.py | 80 +- swarms/models/llava.py | 5 +- swarms/models/mistral.py | 9 +- swarms/models/mpt.py | 18 +- swarms/models/nougat.py | 12 +- swarms/models/openai_assistant.py | 11 +- swarms/models/openai_embeddings.py | 128 +-- swarms/models/openai_models.py | 302 +++--- swarms/models/openai_tokenizer.py | 36 +- swarms/models/palm.py | 35 +- swarms/models/pegasus.py | 10 +- swarms/models/simple_ada.py | 4 +- swarms/models/speecht5.py | 15 +- swarms/models/timm.py | 5 +- swarms/models/trocr.py | 5 +- swarms/models/vilt.py | 6 +- swarms/models/wizard_storytelling.py | 76 +- swarms/models/yarn_mistral.py | 70 +- swarms/models/zephyr.py | 5 +- swarms/prompts/agent_output_parser.py | 5 +- swarms/prompts/agent_prompt.py | 25 +- swarms/prompts/agent_prompts.py | 30 +- swarms/prompts/base.py | 26 +- swarms/prompts/chat_prompt.py | 13 +- swarms/prompts/debate.py | 5 +- swarms/prompts/multi_modal_prompts.py | 5 +- swarms/prompts/python.py | 24 +- swarms/prompts/sales.py | 32 +- swarms/prompts/sales_prompts.py | 32 +- swarms/prompts/summaries_prompts.py | 4 - swarms/schemas/typings.py | 11 +- swarms/structs/document.py | 13 +- swarms/structs/flow.py | 46 +- swarms/structs/nonlinear_workflow.py | 13 +- swarms/structs/sequential_workflow.py | 81 +- swarms/structs/task.py | 18 +- swarms/structs/workflow.py | 12 +- swarms/swarms/autoscaler.py | 3 +- swarms/swarms/base.py | 4 +- swarms/swarms/battle_royal.py | 10 +- swarms/swarms/god_mode.py | 26 +- swarms/swarms/groupchat.py | 46 +- swarms/swarms/multi_agent_collab.py | 11 +- swarms/swarms/multi_agent_debate.py | 3 +- swarms/swarms/orchestrate.py | 27 +- swarms/swarms/simple_swarm.py | 1 + swarms/tools/autogpt.py | 42 +- swarms/tools/mm_models.py | 141 ++- swarms/tools/stt.py | 17 +- swarms/tools/tool.py | 223 ++-- swarms/tools/tool_registry.py | 5 +- swarms/utils/code_interpreter.py | 12 +- swarms/utils/decorators.py | 14 +- swarms/utils/futures.py | 4 +- swarms/utils/hash.py | 3 +- swarms/utils/main.py | 79 +- swarms/utils/parse_code.py | 3 +- swarms/utils/revutils.py | 38 +- swarms/utils/serializable.py | 17 +- swarms/utils/static.py | 1 + swarms/workers/worker.py | 20 +- 104 files changed, 2256 insertions(+), 2465 deletions(-) diff --git a/quality.sh b/quality.sh index bf167079..032085ca 100644 --- a/quality.sh +++ b/quality.sh @@ -5,7 +5,7 @@ # Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i) # on all Python files (*.py) under the 'swarms' directory. -autopep8 --in-place --aggressive --aggressive --recursive --experimental swarms/ +autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes swarms/ # Run black with default settings, since black does not have an aggressiveness level. # Black will format all Python files it finds in the 'swarms' directory. @@ -15,4 +15,5 @@ black --experimental-string-processing swarms/ # Add any additional flags if needed according to your version of ruff. ruff swarms/ -# If you want to ensure the script stops if any command fails, add 'set -e' at the top. +# YAPF +yapf --recursive --in-place --verbose --style=google --parallel swarms diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py index 355f0ad1..cd3aa221 100644 --- a/swarms/agents/__init__.py +++ b/swarms/agents/__init__.py @@ -8,8 +8,6 @@ from swarms.agents.registry import Registry # from swarms.agents.idea_to_image_agent import Idea2Image from swarms.agents.simple_agent import SimpleAgent - - """Agent Infrastructure, models, memory, utils, tools""" __all__ = [ diff --git a/swarms/agents/agent.py b/swarms/agents/agent.py index 34d6315f..c16dd780 100644 --- a/swarms/agents/agent.py +++ b/swarms/agents/agent.py @@ -8,8 +8,7 @@ from langchain.chains.llm import LLMChain from langchain.chat_models.base import BaseChatModel from langchain.memory import ChatMessageHistory from langchain.prompts.chat import ( - BaseChatPromptTemplate, -) + BaseChatPromptTemplate,) from langchain.schema import ( BaseChatMessageHistory, Document, @@ -34,7 +33,6 @@ from langchain_experimental.autonomous_agents.autogpt.prompt_generator import ( ) from langchain_experimental.pydantic_v1 import BaseModel, ValidationError - # PROMPT FINISH_NAME = "finish" @@ -72,14 +70,12 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] send_token_limit: int = 4196 def construct_full_prompt(self, goals: List[str]) -> str: - prompt_start = ( - "Your decisions must always be made independently " - "without seeking user assistance.\n" - "Play to your strengths as an LLM and pursue simple " - "strategies with no legal complications.\n" - "If you have completed all your tasks, make sure to " - 'use the "finish" command.' - ) + prompt_start = ("Your decisions must always be made independently " + "without seeking user assistance.\n" + "Play to your strengths as an LLM and pursue simple " + "strategies with no legal complications.\n" + "If you have completed all your tasks, make sure to " + 'use the "finish" command.') # Construct full prompt full_prompt = ( f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" @@ -91,25 +87,23 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] return full_prompt def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"])) + base_prompt = SystemMessage( + content=self.construct_full_prompt(kwargs["goals"])) time_prompt = SystemMessage( - content=f"The current time and date is {time.strftime('%c')}" - ) - used_tokens = self.token_counter(base_prompt.content) + self.token_counter( - time_prompt.content - ) + content=f"The current time and date is {time.strftime('%c')}") + used_tokens = self.token_counter( + base_prompt.content) + self.token_counter(time_prompt.content) memory: VectorStoreRetriever = kwargs["memory"] previous_messages = kwargs["messages"] - relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) + relevant_docs = memory.get_relevant_documents( + str(previous_messages[-10:])) relevant_memory = [d.page_content for d in relevant_docs] relevant_memory_tokens = sum( - [self.token_counter(doc) for doc in relevant_memory] - ) + [self.token_counter(doc) for doc in relevant_memory]) while used_tokens + relevant_memory_tokens > 2500: relevant_memory = relevant_memory[:-1] relevant_memory_tokens = sum( - [self.token_counter(doc) for doc in relevant_memory] - ) + [self.token_counter(doc) for doc in relevant_memory]) content_format = ( f"This reminds you of these events from your past:\n{relevant_memory}\n\n" ) @@ -147,13 +141,23 @@ class PromptGenerator: self.performance_evaluation: List[str] = [] self.response_format = { "thoughts": { - "text": "thought", - "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": "constructive self-criticism", - "speak": "thoughts summary to say to user", + "text": + "thought", + "reasoning": + "reasoning", + "plan": + "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": + "constructive self-criticism", + "speak": + "thoughts summary to say to user", + }, + "command": { + "name": "command name", + "args": { + "arg name": "value" + } }, - "command": {"name": "command name", "args": {"arg name": "value"}}, } def add_constraint(self, constraint: str) -> None: @@ -191,7 +195,9 @@ class PromptGenerator: """ self.performance_evaluation.append(evaluation) - def _generate_numbered_list(self, items: list, item_type: str = "list") -> str: + def _generate_numbered_list(self, + items: list, + item_type: str = "list") -> str: """ Generate a numbered list from given items based on the item_type. @@ -209,16 +215,11 @@ class PromptGenerator: for i, item in enumerate(items) ] finish_description = ( - "use this to signal that you have finished all your objectives" - ) - finish_args = ( - '"response": "final response to let ' - 'people know you have finished your objectives"' - ) - finish_string = ( - f"{len(items) + 1}. {FINISH_NAME}: " - f"{finish_description}, args: {finish_args}" - ) + "use this to signal that you have finished all your objectives") + finish_args = ('"response": "final response to let ' + 'people know you have finished your objectives"') + finish_string = (f"{len(items) + 1}. {FINISH_NAME}: " + f"{finish_description}, args: {finish_args}") return "\n".join(command_strings + [finish_string]) else: return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) @@ -239,8 +240,7 @@ class PromptGenerator: f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" "You should only respond in JSON format as described below " f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads" - ) + "\nEnsure the response can be parsed by Python json.loads") return prompt_string @@ -261,13 +261,11 @@ def get_prompt(tools: List[BaseTool]) -> str: prompt_generator.add_constraint( "~16000 word limit for short term memory. " "Your short term memory is short, " - "so immediately save important information to files." - ) + "so immediately save important information to files.") prompt_generator.add_constraint( "If you are unsure how you previously did something " "or want to recall past events, " - "thinking about similar events will help you remember." - ) + "thinking about similar events will help you remember.") prompt_generator.add_constraint("No user assistance") prompt_generator.add_constraint( 'Exclusively use the commands listed in double quotes e.g. "command name"' @@ -279,29 +277,23 @@ def get_prompt(tools: List[BaseTool]) -> str: # Add resources to the PromptGenerator object prompt_generator.add_resource( - "Internet access for searches and information gathering." - ) + "Internet access for searches and information gathering.") prompt_generator.add_resource("Long Term memory management.") prompt_generator.add_resource( - "GPT-3.5 powered Agents for delegation of simple tasks." - ) + "GPT-3.5 powered Agents for delegation of simple tasks.") prompt_generator.add_resource("File output.") # Add performance evaluations to the PromptGenerator object prompt_generator.add_performance_evaluation( "Continuously review and analyze your actions " - "to ensure you are performing to the best of your abilities." - ) + "to ensure you are performing to the best of your abilities.") prompt_generator.add_performance_evaluation( - "Constructively self-criticize your big-picture behavior constantly." - ) + "Constructively self-criticize your big-picture behavior constantly.") prompt_generator.add_performance_evaluation( - "Reflect on past decisions and strategies to refine your approach." - ) + "Reflect on past decisions and strategies to refine your approach.") prompt_generator.add_performance_evaluation( "Every command has a cost, so be smart and efficient. " - "Aim to complete tasks in the least number of steps." - ) + "Aim to complete tasks in the least number of steps.") # Generate the prompt string prompt_string = prompt_generator.generate_prompt_string() @@ -372,10 +364,8 @@ class AutoGPT: ) def run(self, goals: List[str]) -> str: - user_input = ( - "Determine which next command to use, " - "and respond using the format specified above:" - ) + user_input = ("Determine which next command to use, " + "and respond using the format specified above:") # Interaction Loop loop_count = 0 while True: @@ -392,8 +382,10 @@ class AutoGPT: # Print Assistant thoughts print(assistant_reply) - self.chat_history_memory.add_message(HumanMessage(content=user_input)) - self.chat_history_memory.add_message(AIMessage(content=assistant_reply)) + self.chat_history_memory.add_message( + HumanMessage(content=user_input)) + self.chat_history_memory.add_message( + AIMessage(content=assistant_reply)) # Get command name and arguments action = self.output_parser.parse(assistant_reply) @@ -419,8 +411,7 @@ class AutoGPT: result = ( f"Unknown command '{action.name}'. " "Please refer to the 'COMMANDS' list for available " - "commands and only respond in the specified JSON format." - ) + "commands and only respond in the specified JSON format.") memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} " if self.feedback_tool is not None: diff --git a/swarms/agents/aot.py b/swarms/agents/aot.py index b36fb43c..123f5591 100644 --- a/swarms/agents/aot.py +++ b/swarms/agents/aot.py @@ -4,13 +4,13 @@ import time import openai_model -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) class OpenAI: + def __init__( self, api_key, @@ -68,16 +68,13 @@ class OpenAI: temperature=temperature, ) with open("openai.logs", "a") as log_file: - log_file.write( - "\n" + "-----------" + "\n" + "Prompt : " + prompt + "\n" - ) + log_file.write("\n" + "-----------" + "\n" + "Prompt : " + + prompt + "\n") return response except openai_model.error.RateLimitError as e: sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30) - print( - f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" - " OPENAI_RATE_TIMEOUT" - ) + print(f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" + " OPENAI_RATE_TIMEOUT") time.sleep(sleep_duratoin) def openai_choice2text_handler(self, choice): @@ -100,11 +97,16 @@ class OpenAI: else: response = self.run(prompt, 300, 0.5, k) thoughts = [ - self.openai_choice2text_handler(choice) for choice in response.choices + self.openai_choice2text_handler(choice) + for choice in response.choices ] return thoughts - def generate_thoughts(self, state, k, initial_prompt, rejected_solutions=None): + def generate_thoughts(self, + state, + k, + initial_prompt, + rejected_solutions=None): if isinstance(state, str): pass else: @@ -177,7 +179,8 @@ class OpenAI: """ response = self.run(prompt, 10, 1) try: - value_text = self.openai_choice2text_handler(response.choices[0]) + value_text = self.openai_choice2text_handler( + response.choices[0]) # print(f'state: {value_text}') value = float(value_text) print(f"Evaluated Thought Value: {value}") @@ -187,10 +190,12 @@ class OpenAI: return state_values else: - raise ValueError("Invalid evaluation strategy. Choose 'value' or 'vote'.") + raise ValueError( + "Invalid evaluation strategy. Choose 'value' or 'vote'.") class AoTAgent: + def __init__( self, num_thoughts: int = None, @@ -222,7 +227,8 @@ class AoTAgent: return None best_state, _ = max(self.output, key=lambda x: x[1]) - solution = self.model.generate_solution(self.initial_prompt, best_state) + solution = self.model.generate_solution(self.initial_prompt, + best_state) print(f"Solution is {solution}") return solution if solution else best_state except Exception as error: @@ -239,11 +245,8 @@ class AoTAgent: for next_state in thoughts: state_value = self.evaluated_thoughts[next_state] if state_value > self.value_threshold: - child = ( - (state, next_state) - if isinstance(state, str) - else (*state, next_state) - ) + child = ((state, next_state) if isinstance(state, str) else + (*state, next_state)) self.dfs(child, step + 1) # backtracking @@ -253,17 +256,14 @@ class AoTAgent: continue def generate_and_filter_thoughts(self, state): - thoughts = self.model.generate_thoughts( - state, self.num_thoughts, self.initial_prompt - ) + thoughts = self.model.generate_thoughts(state, self.num_thoughts, + self.initial_prompt) self.evaluated_thoughts = self.model.evaluate_states( - thoughts, self.initial_prompt - ) + thoughts, self.initial_prompt) filtered_thoughts = [ - thought - for thought in thoughts + thought for thought in thoughts if self.evaluated_thoughts[thought] >= self.pruning_threshold ] print(f"filtered_thoughts: {filtered_thoughts}") diff --git a/swarms/agents/browser_agent.py b/swarms/agents/browser_agent.py index 1f4ff12e..3a274468 100644 --- a/swarms/agents/browser_agent.py +++ b/swarms/agents/browser_agent.py @@ -38,7 +38,8 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None): if not os.path.exists("agents"): os.makedirs("agents") - if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local": + if os.path.exists( + f"agents/{agent_name}.py") and config.environment != "local": if not _is_blank_agent(agent_name=agent_name): raise Exception(f"Agent with name {agent_name} already exists") driver = get_driver( # noqa: F841 @@ -54,12 +55,10 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None): print( "\033[34mYou have the Python debugger open, you can run commands in it like you" - " would in a normal Python shell.\033[0m" - ) + " would in a normal Python shell.\033[0m") print( "\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and" - " press enter.\033[0m" - ) + " press enter.\033[0m") breakpoint() @@ -79,12 +78,13 @@ def extract_domain_from_url(url: str): class AutotabChromeDriver(uc.Chrome): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def find_element_with_retry( - self, by=By.ID, value: Optional[str] = None - ) -> WebElement: + def find_element_with_retry(self, + by=By.ID, + value: Optional[str] = None) -> WebElement: try: return super().find_element(by, value) except Exception as e: @@ -102,11 +102,8 @@ def open_plugin(driver: AutotabChromeDriver): def open_plugin_and_login(driver: AutotabChromeDriver): if config.autotab_api_key is not None: - backend_url = ( - "http://localhost:8000" - if config.environment == "local" - else "https://api.autotab.com" - ) + backend_url = ("http://localhost:8000" if config.environment == "local" + else "https://api.autotab.com") driver.get(f"{backend_url}/auth/signin-api-key-page") response = requests.post( f"{backend_url}/auth/signin-api-key", @@ -119,8 +116,7 @@ def open_plugin_and_login(driver: AutotabChromeDriver): else: raise Exception( f"Error {response.status_code} from backend while logging you in" - f" with your API key: {response.text}" - ) + f" with your API key: {response.text}") cookie["name"] = cookie["key"] del cookie["key"] driver.add_cookie(cookie) @@ -130,26 +126,21 @@ def open_plugin_and_login(driver: AutotabChromeDriver): else: print("No autotab API key found, heading to autotab.com to sign up") - url = ( - "http://localhost:3000/dashboard" - if config.environment == "local" - else "https://autotab.com/dashboard" - ) + url = ("http://localhost:3000/dashboard" if config.environment + == "local" else "https://autotab.com/dashboard") driver.get(url) time.sleep(0.5) open_plugin(driver) -def get_driver( - autotab_ext_path: Optional[str] = None, record_mode: bool = False -) -> AutotabChromeDriver: +def get_driver(autotab_ext_path: Optional[str] = None, + record_mode: bool = False) -> AutotabChromeDriver: options = webdriver.ChromeOptions() options.add_argument("--no-sandbox") # Necessary for running options.add_argument( "--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" - " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" - ) + " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36") options.add_argument("--enable-webgl") options.add_argument("--enable-3d-apis") options.add_argument("--enable-clipboard-read-write") @@ -238,7 +229,8 @@ class Config(BaseModel): return cls( autotab_api_key=autotab_api_key, credentials=_credentials, - google_credentials=GoogleCredentials(credentials=google_credentials), + google_credentials=GoogleCredentials( + credentials=google_credentials), chrome_binary_location=config.get("chrome_binary_location"), environment=config.get("environment", "prod"), ) @@ -256,9 +248,9 @@ def is_signed_in_to_google(driver): return len([c for c in cookies if c["name"] == "SAPISID"]) != 0 -def google_login( - driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True -): +def google_login(driver, + credentials: Optional[SiteCredentials] = None, + navigate: bool = True): print("Logging in to Google") if navigate: driver.get("https://accounts.google.com/") @@ -290,8 +282,7 @@ def google_login( email_input.send_keys(credentials.email) email_input.send_keys(Keys.ENTER) WebDriverWait(driver, 10).until( - EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']")) - ) + EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))) password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input.send_keys(credentials.password) @@ -314,21 +305,20 @@ def google_login( cookies = driver.get_cookies() cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"] google_cookies = [ - cookie - for cookie in cookies - if cookie["domain"] in [".google.com", "accounts.google.com"] - and cookie["name"] in cookie_names + cookie for cookie in cookies + if cookie["domain"] in [".google.com", "accounts.google.com"] and + cookie["name"] in cookie_names ] with open("google_cookies.json", "w") as f: json.dump(google_cookies, f) # Log back in login_button = driver.find_element( - By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']" - ) + By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']") login_button.click() time.sleep(1) - password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") + password_input = driver.find_element(By.CSS_SELECTOR, + "[type='password']") password_input.send_keys(credentials.password) password_input.send_keys(Keys.ENTER) @@ -343,8 +333,7 @@ def login(driver, url: str): login_url = credentials.login_url if credentials.login_with_google_account: google_credentials = config.google_credentials.credentials[ - credentials.login_with_google_account - ] + credentials.login_with_google_account] _login_with_google(driver, login_url, google_credentials) else: _login(driver, login_url, credentials=credentials) @@ -371,16 +360,15 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials): driver.get(url) WebDriverWait(driver, 10).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + EC.presence_of_element_located((By.TAG_NAME, "body"))) main_window = driver.current_window_handle xpath = ( "//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with" - " Google') or contains(@title, 'Sign in with Google')]" - ) + " Google') or contains(@title, 'Sign in with Google')]") - WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath))) + WebDriverWait(driver, + 10).until(EC.presence_of_element_located((By.XPATH, xpath))) driver.find_element( By.XPATH, xpath, @@ -388,8 +376,8 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials): driver.switch_to.window(driver.window_handles[-1]) driver.find_element( - By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]" - ).click() + By.XPATH, + f"//*[contains(text(), '{google_credentials.email}')]").click() driver.switch_to.window(main_window) @@ -442,8 +430,11 @@ def should_update(): # Parse the XML file root = ET.fromstring(xml_content) - namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces - xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version") + namespaces = { + "ns": "http://www.google.com/update2/response" + } # add namespaces + xml_version = root.find(".//ns:app/ns:updatecheck", + namespaces).get("version") # Load the local JSON file with open("src/extension/autotab/manifest.json", "r") as f: @@ -484,8 +475,6 @@ def play(agent_name: Optional[str] = None): if __name__ == "__main__": play() - - """ diff --git a/swarms/agents/hf_agents.py b/swarms/agents/hf_agents.py index 7614b1aa..e13d3462 100644 --- a/swarms/agents/hf_agents.py +++ b/swarms/agents/hf_agents.py @@ -19,7 +19,6 @@ from transformers.utils import is_offline_mode, is_openai_available, logging # utils logger = logging.get_logger(__name__) - if is_openai_available(): import openai @@ -28,7 +27,6 @@ else: _tools_are_initialized = False - BASE_PYTHON_TOOLS = { "print": print, "range": range, @@ -48,7 +46,6 @@ class PreTool: HUGGINGFACE_DEFAULT_TOOLS = {} - HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ "image-transformation", "text-download", @@ -59,23 +56,24 @@ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ def get_remote_tools(organization="huggingface-tools"): if is_offline_mode(): - logger.info("You are in offline mode, so remote tools are not available.") + logger.info( + "You are in offline mode, so remote tools are not available.") return {} spaces = list_spaces(author=organization) tools = {} for space_info in spaces: repo_id = space_info.id - resolved_config_file = hf_hub_download( - repo_id, TOOL_CONFIG_FILE, repo_type="space" - ) + resolved_config_file = hf_hub_download(repo_id, + TOOL_CONFIG_FILE, + repo_type="space") with open(resolved_config_file, encoding="utf-8") as reader: config = json.load(reader) task = repo_id.split("/")[-1] - tools[config["name"]] = PreTool( - task=task, description=config["description"], repo_id=repo_id - ) + tools[config["name"]] = PreTool(task=task, + description=config["description"], + repo_id=repo_id) return tools @@ -95,8 +93,7 @@ def _setup_default_tools(): tool_class = getattr(tools_module, tool_class_name) description = tool_class.description HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool( - task=task_name, description=description, repo_id=None - ) + task=task_name, description=description, repo_id=None) if not is_offline_mode(): for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: @@ -200,18 +197,19 @@ class Agent: one of the default tools, that default tool will be overridden. """ - def __init__( - self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None - ): + def __init__(self, + chat_prompt_template=None, + run_prompt_template=None, + additional_tools=None): _setup_default_tools() agent_name = self.__class__.__name__ - self.chat_prompt_template = download_prompt( - chat_prompt_template, agent_name, mode="chat" - ) - self.run_prompt_template = download_prompt( - run_prompt_template, agent_name, mode="run" - ) + self.chat_prompt_template = download_prompt(chat_prompt_template, + agent_name, + mode="chat") + self.run_prompt_template = download_prompt(run_prompt_template, + agent_name, + mode="run") self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() self.log = print if additional_tools is not None: @@ -227,17 +225,16 @@ class Agent: } self._toolbox.update(additional_tools) if len(replacements) > 1: - names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()]) + names = "\n".join( + [f"- {n}: {t}" for n, t in replacements.items()]) logger.warning( "The following tools have been replaced by the ones provided in" - f" `additional_tools`:\n{names}." - ) + f" `additional_tools`:\n{names}.") elif len(replacements) == 1: name = list(replacements.keys())[0] logger.warning( f"{name} has been replaced by {replacements[name]} as provided in" - " `additional_tools`." - ) + " `additional_tools`.") self.prepare_for_new_chat() @@ -247,17 +244,20 @@ class Agent: return self._toolbox def format_prompt(self, task, chat_mode=False): - description = "\n".join( - [f"- {name}: {tool.description}" for name, tool in self.toolbox.items()] - ) + description = "\n".join([ + f"- {name}: {tool.description}" + for name, tool in self.toolbox.items() + ]) if chat_mode: if self.chat_history is None: - prompt = self.chat_prompt_template.replace("<>", description) + prompt = self.chat_prompt_template.replace( + "<>", description) else: prompt = self.chat_history prompt += CHAT_MESSAGE_PROMPT.replace("<>", task) else: - prompt = self.run_prompt_template.replace("<>", description) + prompt = self.run_prompt_template.replace("<>", + description) prompt = prompt.replace("<>", task) return prompt @@ -306,14 +306,19 @@ class Agent: if not return_code: self.log("\n\n==Result==") self.cached_tools = resolve_tools( - code, self.toolbox, remote=remote, cached_tools=self.cached_tools - ) + code, + self.toolbox, + remote=remote, + cached_tools=self.cached_tools) self.chat_state.update(kwargs) - return evaluate( - code, self.cached_tools, self.chat_state, chat_mode=True - ) + return evaluate(code, + self.cached_tools, + self.chat_state, + chat_mode=True) else: - tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) + tool_code = get_tool_creation_code(code, + self.toolbox, + remote=remote) return f"{tool_code}\n{code}" def prepare_for_new_chat(self): @@ -355,12 +360,15 @@ class Agent: self.log(f"\n\n==Code generated by the agent==\n{code}") if not return_code: self.log("\n\n==Result==") - self.cached_tools = resolve_tools( - code, self.toolbox, remote=remote, cached_tools=self.cached_tools - ) + self.cached_tools = resolve_tools(code, + self.toolbox, + remote=remote, + cached_tools=self.cached_tools) return evaluate(code, self.cached_tools, state=kwargs.copy()) else: - tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) + tool_code = get_tool_creation_code(code, + self.toolbox, + remote=remote) return f"{tool_code}\n{code}" def generate_one(self, prompt, stop): @@ -420,8 +428,7 @@ class HFAgent(Agent): ): if not is_openai_available(): raise ImportError( - "Using `OpenAiAgent` requires `openai`: `pip install openai`." - ) + "Using `OpenAiAgent` requires `openai`: `pip install openai`.") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY", None) @@ -429,8 +436,7 @@ class HFAgent(Agent): raise ValueError( "You need an openai key to use `OpenAIAgent`. You can get one here: Get" " one here https://openai.com/api/`. If you have one, set it in your" - " env with `os.environ['OPENAI_API_KEY'] = xxx." - ) + " env with `os.environ['OPENAI_API_KEY'] = xxx.") else: openai.api_key = api_key self.model = model @@ -455,7 +461,10 @@ class HFAgent(Agent): def _chat_generate(self, prompt, stop): result = openai.ChatCompletion.create( model=self.model, - messages=[{"role": "user", "content": prompt}], + messages=[{ + "role": "user", + "content": prompt + }], temperature=0, stop=stop, ) @@ -533,8 +542,7 @@ class AzureOpenAI(Agent): ): if not is_openai_available(): raise ImportError( - "Using `OpenAiAgent` requires `openai`: `pip install openai`." - ) + "Using `OpenAiAgent` requires `openai`: `pip install openai`.") self.deployment_id = deployment_id openai.api_type = "azure" @@ -544,8 +552,7 @@ class AzureOpenAI(Agent): raise ValueError( "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have" " one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] =" - " xxx." - ) + " xxx.") else: openai.api_key = api_key if resource_name is None: @@ -554,8 +561,7 @@ class AzureOpenAI(Agent): raise ValueError( "You need a resource_name to use `AzureOpenAIAgent`. If you have one," " set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] =" - " xxx." - ) + " xxx.") else: openai.api_base = f"https://{resource_name}.openai.azure.com" openai.api_version = api_version @@ -585,7 +591,10 @@ class AzureOpenAI(Agent): def _chat_generate(self, prompt, stop): result = openai.ChatCompletion.create( engine=self.deployment_id, - messages=[{"role": "user", "content": prompt}], + messages=[{ + "role": "user", + "content": prompt + }], temperature=0, stop=stop, ) diff --git a/swarms/agents/meta_prompter.py b/swarms/agents/meta_prompter.py index aeee9878..f744e38e 100644 --- a/swarms/agents/meta_prompter.py +++ b/swarms/agents/meta_prompter.py @@ -88,9 +88,8 @@ class MetaPrompterAgent: Assistant: """ - prompt = PromptTemplate( - input_variables=["history", "human_input"], template=template - ) + prompt = PromptTemplate(input_variables=["history", "human_input"], + template=template) self.chain = LLMChain( llm=self.llm(), @@ -102,13 +101,15 @@ class MetaPrompterAgent: def get_chat_history(self, chain_memory): """Get Chat History from the memory""" memory_key = chain_memory.memory_key - chat_history = chain_memory.load_memory_variables(memory_key)[memory_key] + chat_history = chain_memory.load_memory_variables( + memory_key)[memory_key] return chat_history def get_new_instructions(self, meta_output): """Get New Instructions from the meta_output""" delimiter = "Instructions: " - new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :] + new_instructions = meta_output[meta_output.find(delimiter) + + len(delimiter):] return new_instructions def run(self, task: str): @@ -149,8 +150,7 @@ class MetaPrompterAgent: meta_chain = self.initialize_meta_chain() meta_output = meta_chain.predict( - chat_history=self.get_chat_history(chain.memory) - ) + chat_history=self.get_chat_history(chain.memory)) print(f"Feedback: {meta_output}") self.instructions = self.get_new_instructions(meta_output) diff --git a/swarms/agents/multi_modal_visual_agent.py b/swarms/agents/multi_modal_visual_agent.py index 34780594..72b6c50e 100644 --- a/swarms/agents/multi_modal_visual_agent.py +++ b/swarms/agents/multi_modal_visual_agent.py @@ -150,6 +150,7 @@ def seed_everything(seed): def prompts(name, description): + def decorator(func): func.name = name func.description = description @@ -171,9 +172,12 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): kernel = np.multiply(kernel_h, np.transpose(kernel_w)) kernel[steps:-steps, steps:-steps] = 1 - kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, steps - 1] - kernel[:steps, -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)] - kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, steps - 1] + kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, + steps - 1] + kernel[:steps, + -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)] + kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, + steps - 1] kernel[-steps:, -steps:] = kernel[-steps:, -steps:] / kernel[-steps, -steps] kernel = np.expand_dims(kernel, 2) kernel = np.repeat(kernel, 3, 2) @@ -207,12 +211,12 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): kernel[steps:-steps, :steps] = left kernel[steps:-steps, -steps:] = right - pt_gt_img = easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] - gaussian_gt_img = ( - kernel * gt_img_array + (1 - kernel) * pt_gt_img - ) # gt img with blur img + pt_gt_img = easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] + gaussian_gt_img = (kernel * gt_img_array + (1 - kernel) * pt_gt_img + ) # gt img with blur img gaussian_gt_img = gaussian_gt_img.astype(np.int64) - easy_img[pos_h : pos_h + old_size[1], pos_w : pos_w + old_size[0]] = gaussian_gt_img + easy_img[pos_h:pos_h + old_size[1], + pos_w:pos_w + old_size[0]] = gaussian_gt_img gaussian_img = Image.fromarray(easy_img) return gaussian_img @@ -252,6 +256,7 @@ def get_new_image_name(org_img_name, func_name="update"): class InstructPix2Pix: + def __init__(self, device): print(f"Initializing InstructPix2Pix to {device}") self.device = device @@ -260,110 +265,102 @@ class InstructPix2Pix: self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ).to(device) self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) @prompts( name="Instruct Image Using Text", - description=( - "useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. " - ), + description= + ("useful when you want to the style of the image to be like the text. " + "like: make it look like a painting. or make it like a robot. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the text. "), ) def inference(self, inputs): """Change style of image.""" print("===>Starting InstructPix2Pix Inference") image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) - image = self.pipe( - text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 - ).images[0] + image = self.pipe(text, + image=original_image, + num_inference_steps=40, + image_guidance_scale=1.2).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) print( f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}" - ) + f" {text}, Output Image: {updated_image_path}") return updated_image_path class Text2Image: + def __init__(self, device): print(f"Initializing Text2Image to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype - ) + "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image From User Input Text", - description=( - "useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. " - ), + description= + ("useful when you want to generate an image from a user input text and save" + " it to a file. like: generate an image of an object or something, or" + " generate an image that includes some objects. The input to this tool" + " should be a string, representing the text used to generate image. "), ) def inference(self, text): image_filename = os.path.join("image", f"{str(uuid.uuid4())[:8]}.png") prompt = text + ", " + self.a_prompt image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0] image.save(image_filename) - print( - f"\nProcessed Text2Image, Input Text: {text}, Output Image:" - f" {image_filename}" - ) + print(f"\nProcessed Text2Image, Input Text: {text}, Output Image:" + f" {image_filename}") return image_filename class ImageCaptioning: + def __init__(self, device): print(f"Initializing ImageCaptioning to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-image-captioning-base" - ) + "Salesforce/blip-image-captioning-base") self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype - ).to(self.device) + "Salesforce/blip-image-captioning-base", + torch_dtype=self.torch_dtype).to(self.device) @prompts( name="Get Photo Description", - description=( - "useful when you want to know what is inside the photo. receives image_path" - " as input. The input to this tool should be a string, representing the" - " image_path. " - ), + description= + ("useful when you want to know what is inside the photo. receives image_path" + " as input. The input to this tool should be a string, representing the" + " image_path. "), ) def inference(self, image_path): - inputs = self.processor(Image.open(image_path), return_tensors="pt").to( - self.device, self.torch_dtype - ) + inputs = self.processor(Image.open(image_path), + return_tensors="pt").to(self.device, + self.torch_dtype) out = self.model.generate(**inputs) captions = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text:" - f" {captions}" - ) + f" {captions}") return captions class Image2Canny: + def __init__(self, device): print("Initializing Image2Canny") self.low_threshold = 100 @@ -371,12 +368,11 @@ class Image2Canny: @prompts( name="Edge Detection On Image", - description=( - "useful when you want to detect the edge of the image. like: detect the" - " edges of this image, or canny detection on image, or perform edge" - " detection on this image, or detect the canny image of this image. The" - " input to this tool should be a string, representing the image_path" - ), + description= + ("useful when you want to detect the edge of the image. like: detect the" + " edges of this image, or canny detection on image, or perform edge" + " detection on this image, or detect the canny image of this image. The" + " input to this tool should be a string, representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) @@ -387,14 +383,13 @@ class Image2Canny: canny = Image.fromarray(canny) updated_image_path = get_new_image_name(inputs, func_name="edge") canny.save(updated_image_path) - print( - f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text:" + f" {updated_image_path}") return updated_image_path class CannyText2Image: + def __init__(self, device): print(f"Initializing CannyText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -406,36 +401,31 @@ class CannyText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Canny Image", - description=( - "useful when you want to generate a new real image from both the user" - " description and a canny image. like: generate a real image of a object or" - " something from this canny image, or generate a new real image of a object" - " or something from this edge image. The input to this tool should be a" - " comma separated string of two, representing the image_path and the user" - " description. " - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and a canny image. like: generate a real image of a object or" + " something from this canny image, or generate a new real image of a object" + " or something from this edge image. The input to this tool should be a" + " comma separated string of two, representing the image_path and the user" + " description. "), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -448,83 +438,77 @@ class CannyText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="canny2image") + updated_image_path = get_new_image_name(image_path, + func_name="canny2image") image.save(updated_image_path) print( f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text:" - f" {instruct_text}, Output Text: {updated_image_path}" - ) + f" {instruct_text}, Output Text: {updated_image_path}") return updated_image_path class Image2Line: + def __init__(self, device): print("Initializing Image2Line") self.detector = MLSDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Line Detection On Image", - description=( - "useful when you want to detect the straight line of the image. like:" - " detect the straight lines of this image, or straight line detection on" - " image, or perform straight line detection on this image, or detect the" - " straight line image of this image. The input to this tool should be a" - " string, representing the image_path" - ), + description= + ("useful when you want to detect the straight line of the image. like:" + " detect the straight lines of this image, or straight line detection on" + " image, or perform straight line detection on this image, or detect the" + " straight line image of this image. The input to this tool should be a" + " string, representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) mlsd = self.detector(image) updated_image_path = get_new_image_name(inputs, func_name="line-of") mlsd.save(updated_image_path) - print( - f"\nProcessed Image2Line, Input Image: {inputs}, Output Line:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line:" + f" {updated_image_path}") return updated_image_path class LineText2Image: + def __init__(self, device): print(f"Initializing LineText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-mlsd", torch_dtype=self.torch_dtype - ) + "fusing/stable-diffusion-v1-5-controlnet-mlsd", + torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Line Image", - description=( - "useful when you want to generate a new real image from both the user" - " description and a straight line image. like: generate a real image of a" - " object or something from this straight line image, or generate a new real" - " image of a object or something from this straight lines. The input to" - " this tool should be a comma separated string of two, representing the" - " image_path and the user description. " - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and a straight line image. like: generate a real image of a" + " object or something from this straight line image, or generate a new real" + " image of a object or something from this straight lines. The input to" + " this tool should be a comma separated string of two, representing the" + " image_path and the user description. "), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -537,83 +521,78 @@ class LineText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="line2image") + updated_image_path = get_new_image_name(image_path, + func_name="line2image") image.save(updated_image_path) print( f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text:" - f" {instruct_text}, Output Text: {updated_image_path}" - ) + f" {instruct_text}, Output Text: {updated_image_path}") return updated_image_path class Image2Hed: + def __init__(self, device): print("Initializing Image2Hed") self.detector = HEDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Hed Detection On Image", - description=( - "useful when you want to detect the soft hed boundary of the image. like:" - " detect the soft hed boundary of this image, or hed boundary detection on" - " image, or perform hed boundary detection on this image, or detect soft" - " hed boundary image of this image. The input to this tool should be a" - " string, representing the image_path" - ), + description= + ("useful when you want to detect the soft hed boundary of the image. like:" + " detect the soft hed boundary of this image, or hed boundary detection on" + " image, or perform hed boundary detection on this image, or detect soft" + " hed boundary image of this image. The input to this tool should be a" + " string, representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) hed = self.detector(image) - updated_image_path = get_new_image_name(inputs, func_name="hed-boundary") + updated_image_path = get_new_image_name(inputs, + func_name="hed-boundary") hed.save(updated_image_path) - print( - f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed:" + f" {updated_image_path}") return updated_image_path class HedText2Image: + def __init__(self, device): print(f"Initializing HedText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-hed", torch_dtype=self.torch_dtype - ) + "fusing/stable-diffusion-v1-5-controlnet-hed", + torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Soft Hed Boundary Image", - description=( - "useful when you want to generate a new real image from both the user" - " description and a soft hed boundary image. like: generate a real image of" - " a object or something from this soft hed boundary image, or generate a" - " new real image of a object or something from this hed boundary. The input" - " to this tool should be a comma separated string of two, representing the" - " image_path and the user description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and a soft hed boundary image. like: generate a real image of" + " a object or something from this soft hed boundary image, or generate a" + " new real image of a object or something from this hed boundary. The input" + " to this tool should be a comma separated string of two, representing the" + " image_path and the user description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -626,28 +605,27 @@ class HedText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="hed2image") + updated_image_path = get_new_image_name(image_path, + func_name="hed2image") image.save(updated_image_path) - print( - f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text:" + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class Image2Scribble: + def __init__(self, device): print("Initializing Image2Scribble") self.detector = HEDdetector.from_pretrained("lllyasviel/ControlNet") @prompts( name="Sketch Detection On Image", - description=( - "useful when you want to generate a scribble of the image. like: generate a" - " scribble of this image, or generate a sketch from this image, detect the" - " sketch from this image. The input to this tool should be a string," - " representing the image_path" - ), + description= + ("useful when you want to generate a scribble of the image. like: generate a" + " scribble of this image, or generate a sketch from this image, detect the" + " sketch from this image. The input to this tool should be a string," + " representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) @@ -656,12 +634,12 @@ class Image2Scribble: scribble.save(updated_image_path) print( f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble:" - f" {updated_image_path}" - ) + f" {updated_image_path}") return updated_image_path class ScribbleText2Image: + def __init__(self, device): print(f"Initializing ScribbleText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -673,34 +651,29 @@ class ScribbleText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Sketch Image", - description=( - "useful when you want to generate a new real image from both the user" - " description and a scribble image or a sketch image. The input to this" - " tool should be a comma separated string of two, representing the" - " image_path and the user description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and a scribble image or a sketch image. The input to this" + " tool should be a comma separated string of two, representing the" + " image_path and the user description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -713,41 +686,41 @@ class ScribbleText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="scribble2image") + updated_image_path = get_new_image_name(image_path, + func_name="scribble2image") image.save(updated_image_path) print( f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class Image2Pose: + def __init__(self, device): print("Initializing Image2Pose") - self.detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + self.detector = OpenposeDetector.from_pretrained( + "lllyasviel/ControlNet") @prompts( name="Pose Detection On Image", - description=( - "useful when you want to detect the human pose of the image. like: generate" - " human poses of this image, or generate a pose image from this image. The" - " input to this tool should be a string, representing the image_path" - ), + description= + ("useful when you want to detect the human pose of the image. like: generate" + " human poses of this image, or generate a pose image from this image. The" + " input to this tool should be a string, representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) pose = self.detector(image) updated_image_path = get_new_image_name(inputs, func_name="human-pose") pose.save(updated_image_path) - print( - f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose:" + f" {updated_image_path}") return updated_image_path class PoseText2Image: + def __init__(self, device): print(f"Initializing PoseText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -759,13 +732,11 @@ class PoseText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.num_inference_steps = 20 self.seed = -1 @@ -773,23 +744,20 @@ class PoseText2Image: self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality" - ) + " fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Pose Image", - description=( - "useful when you want to generate a new real image from both the user" - " description and a human pose image. like: generate a real image of a" - " human from this human pose image, or generate a new real image of a human" - " from this pose. The input to this tool should be a comma separated string" - " of two, representing the image_path and the user description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and a human pose image. like: generate a real image of a" + " human from this human pose image, or generate a new real image of a human" + " from this pose. The input to this tool should be a comma separated string" + " of two, representing the image_path and the user description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -802,56 +770,52 @@ class PoseText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="pose2image") + updated_image_path = get_new_image_name(image_path, + func_name="pose2image") image.save(updated_image_path) print( f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class SegText2Image: + def __init__(self, device): print(f"Initializing SegText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.controlnet = ControlNetModel.from_pretrained( - "fusing/stable-diffusion-v1-5-controlnet-seg", torch_dtype=self.torch_dtype - ) + "fusing/stable-diffusion-v1-5-controlnet-seg", + torch_dtype=self.torch_dtype) self.pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality" - ) + " fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Segmentations", - description=( - "useful when you want to generate a new real image from both the user" - " description and segmentations. like: generate a real image of a object or" - " something from this segmentation image, or generate a new real image of a" - " object or something from these segmentations. The input to this tool" - " should be a comma separated string of two, representing the image_path" - " and the user description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and segmentations. like: generate a real image of a object or" + " something from this segmentation image, or generate a new real image of a" + " object or something from these segmentations. The input to this tool" + " should be a comma separated string of two, representing the image_path" + " and the user description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -864,28 +828,27 @@ class SegText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="segment2image") + updated_image_path = get_new_image_name(image_path, + func_name="segment2image") image.save(updated_image_path) - print( - f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text:" + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class Image2Depth: + def __init__(self, device): print("Initializing Image2Depth") self.depth_estimator = pipeline("depth-estimation") @prompts( name="Predict Depth On Image", - description=( - "useful when you want to detect depth of the image. like: generate the" - " depth from this image, or detect the depth map on this image, or predict" - " the depth for this image. The input to this tool should be a string," - " representing the image_path" - ), + description= + ("useful when you want to detect depth of the image. like: generate the" + " depth from this image, or detect the depth map on this image, or predict" + " the depth for this image. The input to this tool should be a string," + " representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) @@ -896,14 +859,13 @@ class Image2Depth: depth = Image.fromarray(depth) updated_image_path = get_new_image_name(inputs, func_name="depth") depth.save(updated_image_path) - print( - f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth:" + f" {updated_image_path}") return updated_image_path class DepthText2Image: + def __init__(self, device): print(f"Initializing DepthText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -915,36 +877,31 @@ class DepthText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality" - ) + " fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Depth", - description=( - "useful when you want to generate a new real image from both the user" - " description and depth image. like: generate a real image of a object or" - " something from this depth image, or generate a new real image of a object" - " or something from the depth map. The input to this tool should be a comma" - " separated string of two, representing the image_path and the user" - " description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and depth image. like: generate a real image of a object or" + " something from this depth image, or generate a new real image of a object" + " or something from the depth map. The input to this tool should be a comma" + " separated string of two, representing the image_path and the user" + " description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -957,30 +914,29 @@ class DepthText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="depth2image") + updated_image_path = get_new_image_name(image_path, + func_name="depth2image") image.save(updated_image_path) print( f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class Image2Normal: + def __init__(self, device): print("Initializing Image2Normal") - self.depth_estimator = pipeline( - "depth-estimation", model="Intel/dpt-hybrid-midas" - ) + self.depth_estimator = pipeline("depth-estimation", + model="Intel/dpt-hybrid-midas") self.bg_threhold = 0.4 @prompts( name="Predict Normal Map On Image", - description=( - "useful when you want to detect norm map of the image. like: generate" - " normal map from this image, or predict normal map of this image. The" - " input to this tool should be a string, representing the image_path" - ), + description= + ("useful when you want to detect norm map of the image. like: generate" + " normal map from this image, or predict normal map of this image. The" + " input to this tool should be a string, representing the image_path"), ) def inference(self, inputs): image = Image.open(inputs) @@ -996,20 +952,19 @@ class Image2Normal: y[image_depth < self.bg_threhold] = 0 z = np.ones_like(x) * np.pi * 2.0 image = np.stack([x, y, z], axis=2) - image /= np.sum(image**2.0, axis=2, keepdims=True) ** 0.5 + image /= np.sum(image**2.0, axis=2, keepdims=True)**0.5 image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) image = Image.fromarray(image) image = image.resize(original_size) updated_image_path = get_new_image_name(inputs, func_name="normal-map") image.save(updated_image_path) - print( - f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth:" - f" {updated_image_path}" - ) + print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth:" + f" {updated_image_path}") return updated_image_path class NormalText2Image: + def __init__(self, device): print(f"Initializing NormalText2Image to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 @@ -1021,36 +976,31 @@ class NormalText2Image: "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), torch_dtype=self.torch_dtype, ) self.pipe.scheduler = UniPCMultistepScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) self.pipe.to(device) self.seed = -1 self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit," - " fewer digits, cropped, worst quality, low quality" - ) + " fewer digits, cropped, worst quality, low quality") @prompts( name="Generate Image Condition On Normal Map", - description=( - "useful when you want to generate a new real image from both the user" - " description and normal map. like: generate a real image of a object or" - " something from this normal map, or generate a new real image of a object" - " or something from the normal map. The input to this tool should be a" - " comma separated string of two, representing the image_path and the user" - " description" - ), + description= + ("useful when you want to generate a new real image from both the user" + " description and normal map. like: generate a real image of a object or" + " something from this normal map, or generate a new real image of a object" + " or something from the normal map. The input to this tool should be a" + " comma separated string of two, representing the image_path and the user" + " description"), ) def inference(self, inputs): image_path, instruct_text = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) image = Image.open(image_path) self.seed = random.randint(0, 65535) seed_everything(self.seed) @@ -1063,50 +1013,53 @@ class NormalText2Image: negative_prompt=self.n_prompt, guidance_scale=9.0, ).images[0] - updated_image_path = get_new_image_name(image_path, func_name="normal2image") + updated_image_path = get_new_image_name(image_path, + func_name="normal2image") image.save(updated_image_path) print( f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text:" - f" {instruct_text}, Output Image: {updated_image_path}" - ) + f" {instruct_text}, Output Image: {updated_image_path}") return updated_image_path class VisualQuestionAnswering: + def __init__(self, device): print(f"Initializing VisualQuestionAnswering to {device}") self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained( + "Salesforce/blip-vqa-base") self.model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype - ).to(self.device) + "Salesforce/blip-vqa-base", + torch_dtype=self.torch_dtype).to(self.device) @prompts( name="Answer Question About The Image", - description=( - "useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + description= + ("useful when you need an answer for a question based on an image. like:" + " what is the background color of the last image, how many cats in this" + " figure, what is in this figure. The input to this tool should be a comma" + " separated string of two, representing the image_path and the question" ), ) def inference(self, inputs): - image_path, question = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) + image_path, question = inputs.split(",")[0], ",".join( + inputs.split(",")[1:]) raw_image = Image.open(image_path).convert("RGB") - inputs = self.processor(raw_image, question, return_tensors="pt").to( - self.device, self.torch_dtype - ) + inputs = self.processor(raw_image, question, + return_tensors="pt").to(self.device, + self.torch_dtype) out = self.model.generate(**inputs) answer = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}" - ) + f" Question: {question}, Output Answer: {answer}") return answer class Segmenting: + def __init__(self, device): print(f"Inintializing Segmentation to {device}") self.device = device @@ -1151,7 +1104,8 @@ class Segmenting: h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255 - image = cv2.addWeighted(image, 0.7, mask_image.astype("uint8"), transparency, 0) + image = cv2.addWeighted(image, 0.7, mask_image.astype("uint8"), + transparency, 0) return image @@ -1159,10 +1113,12 @@ class Segmenting: x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch( - plt.Rectangle( - (x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2 - ) - ) + plt.Rectangle((x0, y0), + w, + h, + edgecolor="green", + facecolor=(0, 0, 0, 0), + lw=2)) ax.text(x0, y0, label) def get_mask_with_boxes(self, image_pil, image, boxes_filt): @@ -1175,8 +1131,7 @@ class Segmenting: boxes_filt = boxes_filt.cpu() transformed_boxes = self.sam_predictor.transform.apply_boxes_torch( - boxes_filt, image.shape[:2] - ).to(self.device) + boxes_filt, image.shape[:2]).to(self.device) masks, _, _ = self.sam_predictor.predict_torch( point_coords=None, @@ -1186,7 +1141,8 @@ class Segmenting: ) return masks - def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases): + def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, + pred_phrases): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam_predictor.set_image(image) @@ -1196,11 +1152,13 @@ class Segmenting: # draw output image for mask in masks: - image = self.show_mask( - mask[0].cpu().numpy(), image, random_color=True, transparency=0.3 - ) + image = self.show_mask(mask[0].cpu().numpy(), + image, + random_color=True, + transparency=0.3) - updated_image_path = get_new_image_name(image_path, func_name="segmentation") + updated_image_path = get_new_image_name(image_path, + func_name="segmentation") new_image = Image.fromarray(image) new_image.save(updated_image_path) @@ -1212,9 +1170,8 @@ class Segmenting: with torch.cuda.amp.autocast(): self.sam_predictor.set_image(img) - def show_points( - self, coords: np.ndarray, labels: np.ndarray, image: np.ndarray - ) -> np.ndarray: + def show_points(self, coords: np.ndarray, labels: np.ndarray, + image: np.ndarray) -> np.ndarray: """Visualize points on top of an image. Args: @@ -1228,13 +1185,17 @@ class Segmenting: pos_points = coords[labels == 1] neg_points = coords[labels == 0] for p in pos_points: - image = cv2.circle( - image, p.astype(int), radius=3, color=(0, 255, 0), thickness=-1 - ) + image = cv2.circle(image, + p.astype(int), + radius=3, + color=(0, 255, 0), + thickness=-1) for p in neg_points: - image = cv2.circle( - image, p.astype(int), radius=3, color=(255, 0, 0), thickness=-1 - ) + image = cv2.circle(image, + p.astype(int), + radius=3, + color=(255, 0, 0), + thickness=-1) return image def segment_image_with_click(self, img, is_positive: bool): @@ -1252,13 +1213,17 @@ class Segmenting: multimask_output=False, ) - img = self.show_mask(masks[0], img, random_color=False, transparency=0.3) + img = self.show_mask(masks[0], + img, + random_color=False, + transparency=0.3) img = self.show_points(input_point, input_label, img) return img - def segment_image_with_coordinate(self, img, is_positive: bool, coordinate: tuple): + def segment_image_with_coordinate(self, img, is_positive: bool, + coordinate: tuple): """ Args: img (numpy.ndarray): the given image, shape: H x W x 3. @@ -1289,7 +1254,10 @@ class Segmenting: multimask_output=False, ) - img = self.show_mask(masks[0], img, random_color=False, transparency=0.3) + img = self.show_mask(masks[0], + img, + random_color=False, + transparency=0.3) img = self.show_points(input_point, input_label, img) @@ -1301,13 +1269,12 @@ class Segmenting: @prompts( name="Segment the Image", - description=( - "useful when you want to segment all the part of the image, but not segment" - " a certain object.like: segment all the object in this image, or generate" - " segmentations on this image, or segment the image,or perform segmentation" - " on this image, or segment all the object in this image.The input to this" - " tool should be a string, representing the image_path" - ), + description= + ("useful when you want to segment all the part of the image, but not segment" + " a certain object.like: segment all the object in this image, or generate" + " segmentations on this image, or segment the image,or perform segmentation" + " on this image, or segment all the object in this image.The input to this" + " tool should be a string, representing the image_path"), ) def inference_all(self, image_path): image = cv2.imread(image_path) @@ -1328,19 +1295,26 @@ class Segmenting: img[:, :, i] = color_mask[i] ax.imshow(np.dstack((img, m))) - updated_image_path = get_new_image_name(image_path, func_name="segment-image") + updated_image_path = get_new_image_name(image_path, + func_name="segment-image") plt.axis("off") - plt.savefig(updated_image_path, bbox_inches="tight", dpi=300, pad_inches=0.0) + plt.savefig(updated_image_path, + bbox_inches="tight", + dpi=300, + pad_inches=0.0) return updated_image_path class Text2Box: + def __init__(self, device): print(f"Initializing ObjectDetection to {device}") self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 - self.model_checkpoint_path = os.path.join("checkpoints", "groundingdino") - self.model_config_path = os.path.join("checkpoints", "grounding_config.py") + self.model_checkpoint_path = os.path.join("checkpoints", + "groundingdino") + self.model_config_path = os.path.join("checkpoints", + "grounding_config.py") self.download_parameters() self.box_threshold = 0.3 self.text_threshold = 0.25 @@ -1358,13 +1332,11 @@ class Text2Box: # load image image_pil = Image.open(image_path).convert("RGB") # load image - transform = T.Compose( - [ - T.RandomResize([512], max_size=1333), - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) + transform = T.Compose([ + T.RandomResize([512], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image @@ -1373,9 +1345,8 @@ class Text2Box: args.device = self.device model = build_model(args) checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu") - load_res = model.load_state_dict( - clean_state_dict(checkpoint["model"]), strict=False - ) + load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), + strict=False) print(load_res) _ = model.eval() return model @@ -1406,11 +1377,11 @@ class Text2Box: # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): - pred_phrase = get_phrases_from_posmap( - logit > self.text_threshold, tokenized, tokenlizer - ) + pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, + tokenized, tokenlizer) if with_logits: - pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") + pred_phrases.append(pred_phrase + + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) @@ -1420,7 +1391,8 @@ class Text2Box: H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] - assert len(boxes) == len(labels), "boxes and labels must have same length" + assert len(boxes) == len( + labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) @@ -1458,12 +1430,11 @@ class Text2Box: @prompts( name="Detect the Give Object", - description=( - "useful when you only want to detect or find out given objects in the" - " pictureThe input to this tool should be a comma separated string of two," - " representing the image_path, the text description of the object to be" - " found" - ), + description= + ("useful when you only want to detect or find out given objects in the" + " pictureThe input to this tool should be a comma separated string of two," + " representing the image_path, the text description of the object to be" + " found"), ) def inference(self, inputs): image_path, det_prompt = inputs.split(",") @@ -1481,19 +1452,18 @@ class Text2Box: image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0] - updated_image_path = get_new_image_name( - image_path, func_name="detect-something" - ) + updated_image_path = get_new_image_name(image_path, + func_name="detect-something") updated_image = image_with_box.resize(size) updated_image.save(updated_image_path) print( f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be" - f" Detect {det_prompt}, Output Image: {updated_image_path}" - ) + f" Detect {det_prompt}, Output Image: {updated_image_path}") return updated_image_path class Inpainting: + def __init__(self, device): self.device = device self.revision = "fp16" if "cuda" in self.device else None @@ -1504,13 +1474,16 @@ class Inpainting: revision=self.revision, torch_dtype=self.torch_dtype, safety_checker=StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ), + "CompVis/stable-diffusion-safety-checker"), ).to(device) - def __call__( - self, prompt, image, mask_image, height=512, width=512, num_inference_steps=50 - ): + def __call__(self, + prompt, + image, + mask_image, + height=512, + width=512, + num_inference_steps=50): update_image = self.inpaint( prompt=prompt, image=image.resize((width, height)), @@ -1533,29 +1506,27 @@ class InfinityOutPainting: self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") def get_BLIP_vqa(self, image, question): - inputs = self.ImageVQA.processor(image, question, return_tensors="pt").to( - self.ImageVQA.device, self.ImageVQA.torch_dtype - ) + inputs = self.ImageVQA.processor(image, question, + return_tensors="pt").to( + self.ImageVQA.device, + self.ImageVQA.torch_dtype) out = self.ImageVQA.model.generate(**inputs) - answer = self.ImageVQA.processor.decode(out[0], skip_special_tokens=True) + answer = self.ImageVQA.processor.decode(out[0], + skip_special_tokens=True) print( f"\nProcessed VisualQuestionAnswering, Input Question: {question}, Output" - f" Answer: {answer}" - ) + f" Answer: {answer}") return answer def get_BLIP_caption(self, image): inputs = self.ImageCaption.processor(image, return_tensors="pt").to( - self.ImageCaption.device, self.ImageCaption.torch_dtype - ) + self.ImageCaption.device, self.ImageCaption.torch_dtype) out = self.ImageCaption.model.generate(**inputs) BLIP_caption = self.ImageCaption.processor.decode( - out[0], skip_special_tokens=True - ) + out[0], skip_special_tokens=True) return BLIP_caption def check_prompt(self, prompt): @@ -1569,8 +1540,7 @@ class InfinityOutPainting: def get_imagine_caption(self, image, imagine): BLIP_caption = self.get_BLIP_caption(image) background_color = self.get_BLIP_vqa( - image, "what is the background color of this image" - ) + image, "what is the background color of this image") style = self.get_BLIP_vqa(image, "what is the style of this image") imagine_prompt = ( "let's pretend you are an excellent painter and now there is an incomplete" @@ -1578,54 +1548,47 @@ class InfinityOutPainting: " painting and describe ityou should consider the background color is" f" {background_color}, the style is {style}You should make the painting as" " vivid and realistic as possibleYou can not use words like painting or" - " pictureand you should use no more than 50 words to describe it" - ) + " pictureand you should use no more than 50 words to describe it") caption = self.llm(imagine_prompt) if imagine else BLIP_caption caption = self.check_prompt(caption) - print( - f"BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}" - ) if imagine else print(f"Prompt: {caption}") + print(f"BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}" + ) if imagine else print(f"Prompt: {caption}") return caption def resize_image(self, image, max_size=1000000, multiple=8): aspect_ratio = image.size[0] / image.size[1] new_width = int(math.sqrt(max_size * aspect_ratio)) new_height = int(new_width / aspect_ratio) - new_width, new_height = new_width - (new_width % multiple), new_height - ( - new_height % multiple - ) + new_width, new_height = new_width - ( + new_width % multiple), new_height - (new_height % multiple) return image.resize((new_width, new_height)) def dowhile(self, original_img, tosize, expand_ratio, imagine, usr_prompt): old_img = original_img while old_img.size != tosize: - prompt = ( - self.check_prompt(usr_prompt) - if usr_prompt - else self.get_imagine_caption(old_img, imagine) - ) + prompt = (self.check_prompt(usr_prompt) if usr_prompt else + self.get_imagine_caption(old_img, imagine)) crop_w = 15 if old_img.size[0] != tosize[0] else 0 crop_h = 15 if old_img.size[1] != tosize[1] else 0 old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h)) temp_canvas_size = ( expand_ratio * old_img.width - if expand_ratio * old_img.width < tosize[0] - else tosize[0], + if expand_ratio * old_img.width < tosize[0] else tosize[0], expand_ratio * old_img.height - if expand_ratio * old_img.height < tosize[1] - else tosize[1], + if expand_ratio * old_img.height < tosize[1] else tosize[1], ) - temp_canvas, temp_mask = Image.new( - "RGB", temp_canvas_size, color="white" - ), Image.new("L", temp_canvas_size, color="white") + temp_canvas, temp_mask = Image.new("RGB", + temp_canvas_size, + color="white"), Image.new( + "L", + temp_canvas_size, + color="white") x, y = (temp_canvas.width - old_img.width) // 2, ( - temp_canvas.height - old_img.height - ) // 2 + temp_canvas.height - old_img.height) // 2 temp_canvas.paste(old_img, (x, y)) temp_mask.paste(0, (x, y, x + old_img.width, y + old_img.height)) resized_temp_canvas, resized_temp_mask = self.resize_image( - temp_canvas - ), self.resize_image(temp_mask) + temp_canvas), self.resize_image(temp_mask) image = self.inpaint( prompt=prompt, image=resized_temp_canvas, @@ -1640,11 +1603,11 @@ class InfinityOutPainting: @prompts( name="Extend An Image", - description=( - "useful when you need to extend an image into a larger image.like: extend" - " the image into a resolution of 2048x1024, extend the image into" - " 2048x1024. The input to this tool should be a comma separated string of" - " two, representing the image_path and the resolution of widthxheight" + description= + ("useful when you need to extend an image into a larger image.like: extend" + " the image into a resolution of 2048x1024, extend the image into" + " 2048x1024. The input to this tool should be a comma separated string of" + " two, representing the image_path and the resolution of widthxheight" ), ) def inference(self, inputs): @@ -1654,12 +1617,12 @@ class InfinityOutPainting: image = Image.open(image_path) image = ImageOps.crop(image, (10, 10, 10, 10)) out_painted_image = self.dowhile(image, tosize, 4, True, False) - updated_image_path = get_new_image_name(image_path, func_name="outpainting") + updated_image_path = get_new_image_name(image_path, + func_name="outpainting") out_painted_image.save(updated_image_path) print( f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input" - f" Resolution: {resolution}, Output Image: {updated_image_path}" - ) + f" Resolution: {resolution}, Output Image: {updated_image_path}") return updated_image_path @@ -1678,22 +1641,20 @@ class ObjectSegmenting: " pictureaccording to the given textlike: segment the cat,or can you" " segment an obeject for meThe input to this tool should be a comma" " separated string of two, representing the image_path, the text" - " description of the object to be found" - ), + " description of the object to be found"), ) def inference(self, inputs): image_path, det_prompt = inputs.split(",") print(f"image_path={image_path}, text_prompt={det_prompt}") image_pil, image = self.grounding.load_image(image_path) - boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt) + boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( + image, det_prompt) updated_image_path = self.sam.segment_image_with_boxes( - image_pil, image_path, boxes_filt, pred_phrases - ) + image_pil, image_path, boxes_filt, pred_phrases) print( f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be" - f" Segment {det_prompt}, Output Image: {updated_image_path}" - ) + f" Segment {det_prompt}, Output Image: {updated_image_path}") return updated_image_path def merge_masks(self, masks): @@ -1724,8 +1685,7 @@ class ObjectSegmenting: image_pil, image = self.grounding.load_image(image_path) boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( - image, text_prompt - ) + image, text_prompt) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam.sam_predictor.set_image(image) @@ -1738,9 +1698,10 @@ class ObjectSegmenting: # draw output image for mask in masks: - image = self.sam.show_mask( - mask[0].cpu().numpy(), image, random_color=True, transparency=0.3 - ) + image = self.sam.show_mask(mask[0].cpu().numpy(), + image, + random_color=True, + transparency=0.3) Image.fromarray(merged_mask) @@ -1750,9 +1711,8 @@ class ObjectSegmenting: class ImageEditing: template_model = True - def __init__( - self, Text2Box: Text2Box, Segmenting: Segmenting, Inpainting: Inpainting - ): + def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting, + Inpainting: Inpainting): print("Initializing ImageEditing") self.sam = Segmenting self.grounding = Text2Box @@ -1765,8 +1725,7 @@ class ImageEditing: mask_array = np.zeros_like(mask, dtype=bool) for idx in true_indices: padded_slice = tuple( - slice(max(0, i - padding), i + padding + 1) for i in idx - ) + slice(max(0, i - padding), i + padding + 1) for i in idx) mask_array[padded_slice] = True new_mask = (mask_array * 255).astype(np.uint8) # new_mask @@ -1774,38 +1733,34 @@ class ImageEditing: @prompts( name="Remove Something From The Photo", - description=( - "useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. " - ), + description= + ("useful when you want to remove and object or something from the photo " + "from its description or location. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the object need to be removed. "), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",")[0], ",".join( - inputs.split(",")[1:] - ) + inputs.split(",")[1:]) return self.inference_replace_sam( - f"{image_path},{to_be_removed_txt},background" - ) + f"{image_path},{to_be_removed_txt},background") @prompts( name="Replace Something From The Photo", - description=( - "useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with " - ), + description= + ("useful when you want to replace an object from the object description or" + " location with another object from its description. The input to this tool" + " should be a comma separated string of three, representing the image_path," + " the object to be replaced, the object to be replaced with "), ) def inference_replace_sam(self, inputs): image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") - print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") + print( + f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}") image_pil, image = self.grounding.load_image(image_path) boxes_filt, pred_phrases = self.grounding.get_grounding_boxes( - image, to_be_replaced_txt - ) + image, to_be_replaced_txt) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) self.sam.sam_predictor.set_image(image) @@ -1817,19 +1772,16 @@ class ImageEditing: mask = self.pad_edge(mask, padding=20) # numpy mask_image = Image.fromarray(mask) - updated_image = self.inpaint( - prompt=replace_with_txt, image=image_pil, mask_image=mask_image - ) - updated_image_path = get_new_image_name( - image_path, func_name="replace-something" - ) + updated_image = self.inpaint(prompt=replace_with_txt, + image=image_pil, + mask_image=mask_image) + updated_image_path = get_new_image_name(image_path, + func_name="replace-something") updated_image = updated_image.resize(image_pil.size) updated_image.save(updated_image_path) - print( - f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" - f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" - f" {updated_image_path}" - ) + print(f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" + f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" + f" {updated_image_path}") return updated_image_path @@ -1851,10 +1803,9 @@ class BackgroundRemoving: @prompts( name="Remove the background", - description=( - "useful when you want to extract the object or remove the background," - "the input should be a string image_path" - ), + description= + ("useful when you want to extract the object or remove the background," + "the input should be a string image_path"), ) def inference(self, image_path): """ @@ -1868,9 +1819,8 @@ class BackgroundRemoving: mask = Image.fromarray(mask) image.putalpha(mask) - updated_image_path = get_new_image_name( - image_path, func_name="detect-something" - ) + updated_image_path = get_new_image_name(image_path, + func_name="detect-something") image.save(updated_image_path) return updated_image_path @@ -1893,6 +1843,7 @@ class BackgroundRemoving: class MultiModalVisualAgent: + def __init__( self, load_dict, @@ -1905,8 +1856,7 @@ class MultiModalVisualAgent: if "ImageCaptioning" not in load_dict: raise ValueError( "You have to load ImageCaptioning as a basic function for" - " MultiModalVisualAgent" - ) + " MultiModalVisualAgent") self.models = {} @@ -1916,17 +1866,18 @@ class MultiModalVisualAgent: for class_name, module in globals().items(): if getattr(module, "template_model", False): template_required_names = { - k - for k in inspect.signature(module.__init__).parameters.keys() - if k != "self" + k for k in inspect.signature( + module.__init__).parameters.keys() if k != "self" } - loaded_names = set([type(e).__name__ for e in self.models.values()]) + loaded_names = set( + [type(e).__name__ for e in self.models.values()]) if template_required_names.issubset(loaded_names): - self.models[class_name] = globals()[class_name]( - **{name: self.models[name] for name in template_required_names} - ) + self.models[class_name] = globals()[class_name](**{ + name: self.models[name] + for name in template_required_names + }) print(f"All the Available Functions: {self.models}") @@ -1936,13 +1887,13 @@ class MultiModalVisualAgent: if e.startswith("inference"): func = getattr(instance, e) self.tools.append( - Tool(name=func.name, description=func.description, func=func) - ) + Tool(name=func.name, + description=func.description, + func=func)) self.llm = OpenAI(temperature=0) - self.memory = ConversationBufferMemory( - memory_key="chat_history", output_key="output" - ) + self.memory = ConversationBufferMemory(memory_key="chat_history", + output_key="output") def init_agent(self, lang): self.memory.clear() @@ -1980,8 +1931,7 @@ class MultiModalVisualAgent: def run_text(self, text): self.agent.memory.buffer = cut_dialogue_history( - self.agent.memory.buffer, keep_last_n_words=500 - ) + self.agent.memory.buffer, keep_last_n_words=500) res = self.agent({"input": text.strip()}) res["output"] = res["output"].replace("\\", "/") @@ -1991,10 +1941,8 @@ class MultiModalVisualAgent: res["output"], ) - print( - f"\nProcessed run_text, Input text: {text}\n" - f"Current Memory: {self.agent.memory.buffer}" - ) + print(f"\nProcessed run_text, Input text: {text}\n" + f"Current Memory: {self.agent.memory.buffer}") return response @@ -2016,12 +1964,10 @@ class MultiModalVisualAgent: description = self.models["ImageCaptioning"].inference(image_filename) if lang == "Chinese": - Human_prompt = ( - f"\nHuman: 提供一张名为 {image_filename}的图片。它的描述是:" - f" {description}。 这些信息帮助你理解这个图像," - "但是你应该使用工具来完成下面的任务,而不是直接从我的描述中想象。" - ' 如果你明白了, 说 "收到". \n' - ) + Human_prompt = (f"\nHuman: 提供一张名为 {image_filename}的图片。它的描述是:" + f" {description}。 这些信息帮助你理解这个图像," + "但是你应该使用工具来完成下面的任务,而不是直接从我的描述中想象。" + ' 如果你明白了, 说 "收到". \n') AI_prompt = "收到。 " else: Human_prompt = ( @@ -2029,18 +1975,14 @@ class MultiModalVisualAgent: f" {description}. This information helps you to understand this image," " but you should use tools to finish following tasks, rather than" " directly imagine from my description. If you understand, say" - ' "Received". \n' - ) + ' "Received". \n') AI_prompt = "Received. " - self.agent.memory.buffer = ( - self.agent.memory.buffer + Human_prompt + "AI: " + AI_prompt - ) + self.agent.memory.buffer = (self.agent.memory.buffer + Human_prompt + + "AI: " + AI_prompt) - print( - f"\nProcessed run_image, Input image: {image_filename}\n" - f"Current Memory: {self.agent.memory.buffer}" - ) + print(f"\nProcessed run_image, Input image: {image_filename}\n" + f"Current Memory: {self.agent.memory.buffer}") return AI_prompt @@ -2087,7 +2029,10 @@ class MultiModalAgent: """ - def __init__(self, load_dict, temperature: int = 0.1, language: str = "english"): + def __init__(self, + load_dict, + temperature: int = 0.1, + language: str = "english"): self.load_dict = load_dict self.temperature = temperature self.langigage = language @@ -2123,7 +2068,10 @@ class MultiModalAgent: except Exception as error: return f"Error processing image: {str(error)}" - def chat(self, msg: str = None, language: str = "english", streaming: bool = False): + def chat(self, + msg: str = None, + language: str = "english", + streaming: bool = False): """ Run chat with the multi-modal agent diff --git a/swarms/agents/neural_architecture_search_worker.py b/swarms/agents/neural_architecture_search_worker.py index fd253b95..3bfd8323 100644 --- a/swarms/agents/neural_architecture_search_worker.py +++ b/swarms/agents/neural_architecture_search_worker.py @@ -2,6 +2,7 @@ class Replicator: + def __init__( self, model_name, diff --git a/swarms/agents/omni_modal_agent.py b/swarms/agents/omni_modal_agent.py index 007a2219..b6fdfbdc 100644 --- a/swarms/agents/omni_modal_agent.py +++ b/swarms/agents/omni_modal_agent.py @@ -3,23 +3,20 @@ from typing import Dict, List from langchain.base_language import BaseLanguageModel from langchain.tools.base import BaseTool from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import ( - load_response_generator, -) + load_response_generator,) from langchain_experimental.autonomous_agents.hugginggpt.task_executor import ( - TaskExecutor, -) + TaskExecutor,) from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( - load_chat_planner, -) + load_chat_planner,) from transformers import load_tool from swarms.agents.message import Message class Step: - def __init__( - self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool - ): + + def __init__(self, task: str, id: int, dep: List[int], args: Dict[str, str], + tool: BaseTool): self.task = task self.id = id self.dep = dep @@ -28,6 +25,7 @@ class Step: class Plan: + def __init__(self, steps: List[Step]): self.steps = steps @@ -73,8 +71,7 @@ class OmniModalAgent: print("Loading tools...") self.tools = [ - load_tool(tool_name) - for tool_name in [ + load_tool(tool_name) for tool_name in [ "document-question-answering", "image-captioning", "image-question-answering", @@ -99,18 +96,15 @@ class OmniModalAgent: def run(self, input: str) -> str: """Run the OmniAgent""" - plan = self.chat_planner.plan( - inputs={ - "input": input, - "hf_tools": self.tools, - } - ) + plan = self.chat_planner.plan(inputs={ + "input": input, + "hf_tools": self.tools, + }) self.task_executor = TaskExecutor(plan) self.task_executor.run() response = self.response_generator.generate( - {"task_execution": self.task_executor} - ) + {"task_execution": self.task_executor}) return response diff --git a/swarms/agents/profitpilot.py b/swarms/agents/profitpilot.py index 6858dc72..a4ff13a5 100644 --- a/swarms/agents/profitpilot.py +++ b/swarms/agents/profitpilot.py @@ -145,13 +145,12 @@ def setup_knowledge_base(product_catalog: str = None): llm = OpenAI(temperature=0) embeddings = OpenAIEmbeddings() - docsearch = Chroma.from_texts( - texts, embeddings, collection_name="product-knowledge-base" - ) + docsearch = Chroma.from_texts(texts, + embeddings, + collection_name="product-knowledge-base") knowledge_base = RetrievalQA.from_chain_type( - llm=llm, chain_type="stuff", retriever=docsearch.as_retriever() - ) + llm=llm, chain_type="stuff", retriever=docsearch.as_retriever()) return knowledge_base @@ -163,8 +162,8 @@ def get_tools(product_catalog): Tool( name="ProductSearch", func=knowledge_base.run, - description=( - "useful for when you need to answer questions about product information" + description= + ("useful for when you need to answer questions about product information" ), ), # omnimodal agent @@ -194,8 +193,7 @@ class CustomPromptTemplateForTools(StringPromptTemplate): tools = self.tools_getter(kwargs["input"]) # Create a tools variable from the list of tools provided kwargs["tools"] = "\n".join( - [f"{tool.name}: {tool.description}" for tool in tools] - ) + [f"{tool.name}: {tool.description}" for tool in tools]) # Create a list of tool names for the tools provided kwargs["tool_names"] = ", ".join([tool.name for tool in tools]) return self.template.format(**kwargs) @@ -218,8 +216,7 @@ class SalesConvoOutputParser(AgentOutputParser): print("-------") if f"{self.ai_prefix}:" in text: return AgentFinish( - {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text - ) + {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text) regex = r"Action: (.*?)[\n]*Action Input: (.*)" match = re.search(regex, text) if not match: @@ -228,15 +225,15 @@ class SalesConvoOutputParser(AgentOutputParser): { "output": ( "I apologize, I was unable to find the answer to your question." - " Is there anything else I can help with?" - ) + " Is there anything else I can help with?") }, text, ) # raise OutputParserException(f"Could not parse LLM output: `{text}`") action = match.group(1) action_input = match.group(2) - return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) + return AgentAction(action.strip(), + action_input.strip(" ").strip('"'), text) @property def _type(self) -> str: @@ -264,13 +261,11 @@ class ProfitPilot(Chain, BaseModel): "2": ( "Qualification: Qualify the prospect by confirming if they are the right" " person to talk to regarding your product/service. Ensure that they have" - " the authority to make purchasing decisions." - ), + " the authority to make purchasing decisions."), "3": ( "Value proposition: Briefly explain how your product/service can benefit" " the prospect. Focus on the unique selling points and value proposition of" - " your product/service that sets it apart from competitors." - ), + " your product/service that sets it apart from competitors."), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs" " and pain points. Listen carefully to their responses and take notes." @@ -282,13 +277,11 @@ class ProfitPilot(Chain, BaseModel): "6": ( "Objection handling: Address any objections that the prospect may have" " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims." - ), + " testimonials to support your claims."), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has" - " been discussed and reiterate the benefits." - ), + " been discussed and reiterate the benefits."), } salesperson_name: str = "Ted Lasso" @@ -298,19 +291,16 @@ class ProfitPilot(Chain, BaseModel): "Sleep Haven is a premium mattress company that provides customers with the" " most comfortable and supportive sleeping experience possible. We offer a" " range of high-quality mattresses, pillows, and bedding accessories that are" - " designed to meet the unique needs of our customers." - ) + " designed to meet the unique needs of our customers.") company_values: str = ( "Our mission at Sleep Haven is to help people achieve a better night's sleep by" " providing them with the best possible sleep solutions. We believe that" " quality sleep is essential to overall health and well-being, and we are" " committed to helping our customers achieve optimal sleep by offering" - " exceptional products and customer service." - ) + " exceptional products and customer service.") conversation_purpose: str = ( "find out whether they are looking to achieve better sleep via buying a premier" - " mattress." - ) + " mattress.") conversation_type: str = "call" def retrieve_conversation_stage(self, key): @@ -336,8 +326,7 @@ class ProfitPilot(Chain, BaseModel): ) self.current_conversation_stage = self.retrieve_conversation_stage( - conversation_stage_id - ) + conversation_stage_id) print(f"Conversation Stage: {self.current_conversation_stage}") @@ -391,13 +380,15 @@ class ProfitPilot(Chain, BaseModel): return {} @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = False, **kwargs): # noqa: F821 + def from_llm(cls, + llm: BaseLLM, + verbose: bool = False, + **kwargs): # noqa: F821 """Initialize the SalesGPT Controller.""" stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) sales_conversation_utterance_chain = SalesConversationChain.from_llm( - llm, verbose=verbose - ) + llm, verbose=verbose) if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False: sales_agent_executor = None @@ -430,7 +421,8 @@ class ProfitPilot(Chain, BaseModel): # WARNING: this output parser is NOT reliable yet # It makes assumptions about output from LLM which can break and throw an error - output_parser = SalesConvoOutputParser(ai_prefix=kwargs["salesperson_name"]) + output_parser = SalesConvoOutputParser( + ai_prefix=kwargs["salesperson_name"]) sales_agent_with_tools = LLMSingleActionAgent( llm_chain=llm_chain, @@ -441,12 +433,12 @@ class ProfitPilot(Chain, BaseModel): ) sales_agent_executor = AgentExecutor.from_agent_and_tools( - agent=sales_agent_with_tools, tools=tools, verbose=verbose - ) + agent=sales_agent_with_tools, tools=tools, verbose=verbose) return cls( stage_analyzer_chain=stage_analyzer_chain, - sales_conversation_utterance_chain=sales_conversation_utterance_chain, + sales_conversation_utterance_chain= + sales_conversation_utterance_chain, sales_agent_executor=sales_agent_executor, verbose=verbose, **kwargs, @@ -458,32 +450,27 @@ config = dict( salesperson_name="Ted Lasso", salesperson_role="Business Development Representative", company_name="Sleep Haven", - company_business=( - "Sleep Haven is a premium mattress company that provides customers with the" - " most comfortable and supportive sleeping experience possible. We offer a" - " range of high-quality mattresses, pillows, and bedding accessories that are" - " designed to meet the unique needs of our customers." - ), - company_values=( - "Our mission at Sleep Haven is to help people achieve a better night's sleep by" - " providing them with the best possible sleep solutions. We believe that" - " quality sleep is essential to overall health and well-being, and we are" - " committed to helping our customers achieve optimal sleep by offering" - " exceptional products and customer service." - ), - conversation_purpose=( - "find out whether they are looking to achieve better sleep via buying a premier" - " mattress." - ), + company_business= + ("Sleep Haven is a premium mattress company that provides customers with the" + " most comfortable and supportive sleeping experience possible. We offer a" + " range of high-quality mattresses, pillows, and bedding accessories that are" + " designed to meet the unique needs of our customers."), + company_values= + ("Our mission at Sleep Haven is to help people achieve a better night's sleep by" + " providing them with the best possible sleep solutions. We believe that" + " quality sleep is essential to overall health and well-being, and we are" + " committed to helping our customers achieve optimal sleep by offering" + " exceptional products and customer service."), + conversation_purpose= + ("find out whether they are looking to achieve better sleep via buying a premier" + " mattress."), conversation_history=[], conversation_type="call", conversation_stage=conversation_stages.get( "1", - ( - "Introduction: Start the conversation by introducing yourself and your" - " company. Be polite and respectful while keeping the tone of the" - " conversation professional." - ), + ("Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional."), ), use_tools=True, product_catalog="sample_product_catalog.txt", diff --git a/swarms/agents/refiner_agent.py b/swarms/agents/refiner_agent.py index 2a1383e9..509484e3 100644 --- a/swarms/agents/refiner_agent.py +++ b/swarms/agents/refiner_agent.py @@ -1,9 +1,11 @@ class PromptRefiner: + def __init__(self, system_prompt: str, llm): super().__init__() self.system_prompt = system_prompt self.llm = llm def run(self, task: str): - refine = self.llm(f"System Prompt: {self.system_prompt} Current task: {task}") + refine = self.llm( + f"System Prompt: {self.system_prompt} Current task: {task}") return refine diff --git a/swarms/agents/registry.py b/swarms/agents/registry.py index aa1f1375..5cf2c0d5 100644 --- a/swarms/agents/registry.py +++ b/swarms/agents/registry.py @@ -10,6 +10,7 @@ class Registry(BaseModel): entries: Dict = {} def register(self, key: str): + def decorator(class_builder): self.entries[key] = class_builder return class_builder @@ -20,8 +21,7 @@ class Registry(BaseModel): if type not in self.entries: raise ValueError( f"{type} is not registered. Please register with the" - f' .register("{type}") method provided in {self.name} registry' - ) + f' .register("{type}") method provided in {self.name} registry') return self.entries[type](**kwargs) def get_all_entries(self): diff --git a/swarms/agents/simple_agent.py b/swarms/agents/simple_agent.py index 88327095..847cbc67 100644 --- a/swarms/agents/simple_agent.py +++ b/swarms/agents/simple_agent.py @@ -29,7 +29,8 @@ class SimpleAgent: def run(self, task: str) -> str: """Run method""" - metrics = print(colored(f"Agent {self.name} is running task: {task}", "red")) + metrics = print( + colored(f"Agent {self.name} is running task: {task}", "red")) print(metrics) response = self.flow.run(task) diff --git a/swarms/artifacts/base.py b/swarms/artifacts/base.py index dac7a523..1357a86b 100644 --- a/swarms/artifacts/base.py +++ b/swarms/artifacts/base.py @@ -10,9 +10,8 @@ from marshmallow.exceptions import RegistryError @define class BaseArtifact(ABC): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - name: str = field( - default=Factory(lambda self: self.id, takes_self=True), kw_only=True - ) + name: str = field(default=Factory(lambda self: self.id, takes_self=True), + kw_only=True) value: any = field() type: str = field( default=Factory(lambda self: self.__class__.__name__, takes_self=True), @@ -54,7 +53,8 @@ class BaseArtifact(ABC): class_registry.register("ListArtifact", ListArtifactSchema) try: - return class_registry.get_class(artifact_dict["type"])().load(artifact_dict) + return class_registry.get_class( + artifact_dict["type"])().load(artifact_dict) except RegistryError: raise ValueError("Unsupported artifact type") diff --git a/swarms/artifacts/main.py b/swarms/artifacts/main.py index 4b240b22..8845ada3 100644 --- a/swarms/artifacts/main.py +++ b/swarms/artifacts/main.py @@ -15,8 +15,7 @@ class Artifact(BaseModel): artifact_id: StrictStr = Field(..., description="ID of the artifact") file_name: StrictStr = Field(..., description="Filename of the artifact") relative_path: Optional[StrictStr] = Field( - None, description="Relative path of the artifact" - ) + None, description="Relative path of the artifact") __properties = ["artifact_id", "file_name", "relative_path"] class Config: @@ -49,12 +48,10 @@ class Artifact(BaseModel): if not isinstance(obj, dict): return Artifact.parse_obj(obj) - _obj = Artifact.parse_obj( - { - "artifact_id": obj.get("artifact_id"), - "file_name": obj.get("file_name"), - "relative_path": obj.get("relative_path"), - } - ) + _obj = Artifact.parse_obj({ + "artifact_id": obj.get("artifact_id"), + "file_name": obj.get("file_name"), + "relative_path": obj.get("relative_path"), + }) return _obj diff --git a/swarms/chunkers/__init__.py b/swarms/chunkers/__init__.py index 5e09586b..159e8d5b 100644 --- a/swarms/chunkers/__init__.py +++ b/swarms/chunkers/__init__.py @@ -3,7 +3,6 @@ # from swarms.chunkers.text import TextChunker # from swarms.chunkers.pdf import PdfChunker - # __all__ = [ # "BaseChunker", # "ChunkSeparator", diff --git a/swarms/chunkers/base.py b/swarms/chunkers/base.py index 0fabdcef..d243bd0d 100644 --- a/swarms/chunkers/base.py +++ b/swarms/chunkers/base.py @@ -48,15 +48,13 @@ class BaseChunker(ABC): kw_only=True, ) tokenizer: OpenAITokenizer = field( - default=Factory( - lambda: OpenAITokenizer( - model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL - ) - ), + default=Factory(lambda: OpenAITokenizer( + model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True, ) max_tokens: int = field( - default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True), + default=Factory(lambda self: self.tokenizer.max_tokens, + takes_self=True), kw_only=True, ) @@ -66,8 +64,9 @@ class BaseChunker(ABC): return [TextArtifact(c) for c in self._chunk_recursively(text)] def _chunk_recursively( - self, chunk: str, current_separator: Optional[ChunkSeparator] = None - ) -> list[str]: + self, + chunk: str, + current_separator: Optional[ChunkSeparator] = None) -> list[str]: token_count = self.tokenizer.count_tokens(chunk) if token_count <= self.max_tokens: @@ -79,7 +78,8 @@ class BaseChunker(ABC): half_token_count = token_count // 2 if current_separator: - separators = self.separators[self.separators.index(current_separator) :] + separators = self.separators[self.separators. + index(current_separator):] else: separators = self.separators @@ -102,26 +102,19 @@ class BaseChunker(ABC): if separator.is_prefix: first_subchunk = separator.value + separator.value.join( - subchanks[: balance_index + 1] - ) + subchanks[:balance_index + 1]) second_subchunk = separator.value + separator.value.join( - subchanks[balance_index + 1 :] - ) + subchanks[balance_index + 1:]) else: - first_subchunk = ( - separator.value.join(subchanks[: balance_index + 1]) - + separator.value - ) + first_subchunk = (separator.value.join( + subchanks[:balance_index + 1]) + separator.value) second_subchunk = separator.value.join( - subchanks[balance_index + 1 :] - ) + subchanks[balance_index + 1:]) first_subchunk_rec = self._chunk_recursively( - first_subchunk.strip(), separator - ) + first_subchunk.strip(), separator) second_subchunk_rec = self._chunk_recursively( - second_subchunk.strip(), separator - ) + second_subchunk.strip(), separator) if first_subchunk_rec and second_subchunk_rec: return first_subchunk_rec + second_subchunk_rec diff --git a/swarms/chunkers/omni_chunker.py b/swarms/chunkers/omni_chunker.py index 70a11380..c4870e2b 100644 --- a/swarms/chunkers/omni_chunker.py +++ b/swarms/chunkers/omni_chunker.py @@ -76,8 +76,7 @@ class OmniChunker: colored( f"Could not decode file with extension {file_extension}: {e}", "yellow", - ) - ) + )) return "" def chunk_content(self, content: str) -> List[str]: @@ -91,7 +90,7 @@ class OmniChunker: List[str]: The list of chunks. """ return [ - content[i : i + self.chunk_size] + content[i:i + self.chunk_size] for i in range(0, len(content), self.chunk_size) ] @@ -113,5 +112,4 @@ class OmniChunker: {self.metrics()} """, "cyan", - ) - ) + )) diff --git a/swarms/loaders/asana.py b/swarms/loaders/asana.py index dd14cff4..022b685b 100644 --- a/swarms/loaders/asana.py +++ b/swarms/loaders/asana.py @@ -18,9 +18,9 @@ class AsanaReader(BaseReader): self.client = asana.Client.access_token(asana_token) - def load_data( - self, workspace_id: Optional[str] = None, project_id: Optional[str] = None - ) -> List[Document]: + def load_data(self, + workspace_id: Optional[str] = None, + project_id: Optional[str] = None) -> List[Document]: """Load data from the workspace. Args: @@ -31,18 +31,20 @@ class AsanaReader(BaseReader): """ if workspace_id is None and project_id is None: - raise ValueError("Either workspace_id or project_id must be provided") + raise ValueError( + "Either workspace_id or project_id must be provided") if workspace_id is not None and project_id is not None: raise ValueError( - "Only one of workspace_id or project_id should be provided" - ) + "Only one of workspace_id or project_id should be provided") results = [] if workspace_id is not None: - workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"] - projects = self.client.projects.find_all({"workspace": workspace_id}) + workspace_name = self.client.workspaces.find_by_id( + workspace_id)["name"] + projects = self.client.projects.find_all( + {"workspace": workspace_id}) # Case: Only project_id is provided else: # since we've handled the other cases, this means project_id is not None @@ -50,54 +52,58 @@ class AsanaReader(BaseReader): workspace_name = projects[0]["workspace"]["name"] for project in projects: - tasks = self.client.tasks.find_all( - { - "project": project["gid"], - "opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", - } - ) + tasks = self.client.tasks.find_all({ + "project": + project["gid"], + "opt_fields": + "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", + }) for task in tasks: - stories = self.client.tasks.stories(task["gid"], opt_fields="type,text") - comments = "\n".join( - [ - story["text"] - for story in stories - if story.get("type") == "comment" and "text" in story - ] - ) + stories = self.client.tasks.stories(task["gid"], + opt_fields="type,text") + comments = "\n".join([ + story["text"] + for story in stories + if story.get("type") == "comment" and "text" in story + ]) task_metadata = { - "task_id": task.get("gid", ""), - "name": task.get("name", ""), + "task_id": + task.get("gid", ""), + "name": + task.get("name", ""), "assignee": (task.get("assignee") or {}).get("name", ""), - "completed_on": task.get("completed_at", ""), - "completed_by": (task.get("completed_by") or {}).get("name", ""), - "project_name": project.get("name", ""), + "completed_on": + task.get("completed_at", ""), + "completed_by": (task.get("completed_by") or + {}).get("name", ""), + "project_name": + project.get("name", ""), "custom_fields": [ i["display_value"] for i in task.get("custom_fields") if task.get("custom_fields") is not None ], - "workspace_name": workspace_name, - "url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}", + "workspace_name": + workspace_name, + "url": + f"https://app.asana.com/0/{project['gid']}/{task['gid']}", } if task.get("followers") is not None: task_metadata["followers"] = [ - i.get("name") for i in task.get("followers") if "name" in i + i.get("name") + for i in task.get("followers") + if "name" in i ] else: task_metadata["followers"] = [] results.append( Document( - text=task.get("name", "") - + " " - + task.get("notes", "") - + " " - + comments, + text=task.get("name", "") + " " + + task.get("notes", "") + " " + comments, extra_info=task_metadata, - ) - ) + )) return results diff --git a/swarms/loaders/base.py b/swarms/loaders/base.py index a59a93e2..2d5c7cdb 100644 --- a/swarms/loaders/base.py +++ b/swarms/loaders/base.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: from haystack.schema import Document as HaystackDocument from semantic_kernel.memory.memory_record import MemoryRecord - #### DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" DEFAULT_METADATA_TMPL = "{key}: {value}" @@ -48,7 +47,8 @@ class BaseComponent(BaseModel): # TODO: return type here not supported by current mypy version @classmethod - def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + def from_dict(cls, data: Dict[str, Any], + **kwargs: Any) -> Self: # type: ignore if isinstance(kwargs, dict): data.update(kwargs) @@ -119,13 +119,10 @@ class BaseNode(BaseComponent): class Config: allow_population_by_field_name = True - id_: str = Field( - default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." - ) + id_: str = Field(default_factory=lambda: str(uuid.uuid4()), + description="Unique ID of the node.") embedding: Optional[List[float]] = Field( - default=None, description="Embedding of the node." - ) - + default=None, description="Embedding of the node.") """" metadata fields - injected as part of the text shown to LLMs as context @@ -140,7 +137,8 @@ class BaseNode(BaseComponent): ) excluded_embed_metadata_keys: List[str] = Field( default_factory=list, - description="Metadata keys that are excluded from text for the embed model.", + description= + "Metadata keys that are excluded from text for the embed model.", ) excluded_llm_metadata_keys: List[str] = Field( default_factory=list, @@ -158,7 +156,8 @@ class BaseNode(BaseComponent): """Get Object type.""" @abstractmethod - def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + def get_content(self, + metadata_mode: MetadataMode = MetadataMode.ALL) -> str: """Get object content.""" @abstractmethod @@ -189,7 +188,8 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.SOURCE] if isinstance(relation, list): - raise ValueError("Source object must be a single RelatedNodeInfo object") + raise ValueError( + "Source object must be a single RelatedNodeInfo object") return relation @property @@ -200,7 +200,8 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.PREVIOUS] if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Previous object must be a single RelatedNodeInfo object") + raise ValueError( + "Previous object must be a single RelatedNodeInfo object") return relation @property @@ -211,7 +212,8 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.NEXT] if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Next object must be a single RelatedNodeInfo object") + raise ValueError( + "Next object must be a single RelatedNodeInfo object") return relation @property @@ -222,7 +224,8 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.PARENT] if not isinstance(relation, RelatedNodeInfo): - raise ValueError("Parent object must be a single RelatedNodeInfo object") + raise ValueError( + "Parent object must be a single RelatedNodeInfo object") return relation @property @@ -233,7 +236,8 @@ class BaseNode(BaseComponent): relation = self.relationships[NodeRelationship.CHILD] if not isinstance(relation, list): - raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") + raise ValueError( + "Child objects must be a list of RelatedNodeInfo objects.") return relation @property @@ -250,12 +254,10 @@ class BaseNode(BaseComponent): return self.metadata def __str__(self) -> str: - source_text_truncated = truncate_text( - self.get_content().strip(), TRUNCATE_LENGTH - ) - source_text_wrapped = textwrap.fill( - f"Text: {source_text_truncated}\n", width=WRAP_WIDTH - ) + source_text_truncated = truncate_text(self.get_content().strip(), + TRUNCATE_LENGTH) + source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", + width=WRAP_WIDTH) return f"Node ID: {self.node_id}\n{source_text_wrapped}" def get_embedding(self) -> List[float]: @@ -281,28 +283,23 @@ class BaseNode(BaseComponent): class TextNode(BaseNode): text: str = Field(default="", description="Text content of the node.") start_char_idx: Optional[int] = Field( - default=None, description="Start char index of the node." - ) + default=None, description="Start char index of the node.") end_char_idx: Optional[int] = Field( - default=None, description="End char index of the node." - ) + default=None, description="End char index of the node.") text_template: str = Field( default=DEFAULT_TEXT_NODE_TMPL, - description=( - "Template for how text is formatted, with {content} and " - "{metadata_str} placeholders." - ), + description=("Template for how text is formatted, with {content} and " + "{metadata_str} placeholders."), ) metadata_template: str = Field( default=DEFAULT_METADATA_TMPL, - description=( - "Template for how metadata is formatted, with {key} and " - "{value} placeholders." - ), + description=("Template for how metadata is formatted, with {key} and " + "{value} placeholders."), ) metadata_seperator: str = Field( default="\n", - description="Separator between metadata fields when converting to string.", + description= + "Separator between metadata fields when converting to string.", ) @classmethod @@ -316,8 +313,7 @@ class TextNode(BaseNode): metadata = values.get("metadata", {}) doc_identity = str(text) + str(metadata) values["hash"] = str( - sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() - ) + sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) return values @classmethod @@ -325,15 +321,15 @@ class TextNode(BaseNode): """Get Object type.""" return ObjectType.TEXT - def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + def get_content(self, + metadata_mode: MetadataMode = MetadataMode.NONE) -> str: """Get object content.""" metadata_str = self.get_metadata_str(mode=metadata_mode).strip() if not metadata_str: return self.text - return self.text_template.format( - content=self.text, metadata_str=metadata_str - ).strip() + return self.text_template.format(content=self.text, + metadata_str=metadata_str).strip() def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: """Metadata info string.""" @@ -350,13 +346,11 @@ class TextNode(BaseNode): if key in usable_metadata_keys: usable_metadata_keys.remove(key) - return self.metadata_seperator.join( - [ - self.metadata_template.format(key=key, value=str(value)) - for key, value in self.metadata.items() - if key in usable_metadata_keys - ] - ) + return self.metadata_seperator.join([ + self.metadata_template.format(key=key, value=str(value)) + for key, value in self.metadata.items() + if key in usable_metadata_keys + ]) def set_content(self, value: str) -> None: """Set the content of the node.""" @@ -480,7 +474,8 @@ class NodeWithScore(BaseComponent): else: raise ValueError("Node must be a TextNode to get text.") - def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + def get_content(self, + metadata_mode: MetadataMode = MetadataMode.NONE) -> str: return self.node.get_content(metadata_mode=metadata_mode) def get_embedding(self) -> List[float]: @@ -517,12 +512,10 @@ class Document(TextNode): return self.id_ def __str__(self) -> str: - source_text_truncated = truncate_text( - self.get_content().strip(), TRUNCATE_LENGTH - ) - source_text_wrapped = textwrap.fill( - f"Text: {source_text_truncated}\n", width=WRAP_WIDTH - ) + source_text_truncated = truncate_text(self.get_content().strip(), + TRUNCATE_LENGTH) + source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", + width=WRAP_WIDTH) return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" def get_doc_id(self) -> str: @@ -538,22 +531,27 @@ class Document(TextNode): """Convert struct to Haystack document format.""" from haystack.schema import Document as HaystackDocument - return HaystackDocument( - content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ - ) + return HaystackDocument(content=self.text, + meta=self.metadata, + embedding=self.embedding, + id=self.id_) @classmethod def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": """Convert struct from Haystack document format.""" - return cls( - text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id - ) + return cls(text=doc.content, + metadata=doc.meta, + embedding=doc.embedding, + id_=doc.id) def to_embedchain_format(self) -> Dict[str, Any]: """Convert struct to EmbedChain document format.""" return { "doc_id": self.id_, - "data": {"content": self.text, "meta_data": self.metadata}, + "data": { + "content": self.text, + "meta_data": self.metadata + }, } @classmethod @@ -583,7 +581,8 @@ class Document(TextNode): return cls( text=doc._text, metadata={"additional_metadata": doc._additional_metadata}, - embedding=doc._embedding.tolist() if doc._embedding is not None else None, + embedding=doc._embedding.tolist() + if doc._embedding is not None else None, id_=doc._id, ) @@ -591,7 +590,10 @@ class Document(TextNode): def example(cls) -> "Document": return Document( text=SAMPLE_TEXT, - metadata={"filename": "README.md", "category": "codebase"}, + metadata={ + "filename": "README.md", + "category": "codebase" + }, ) @classmethod diff --git a/swarms/memory/base.py b/swarms/memory/base.py index 7f71c4b9..7c08af6f 100644 --- a/swarms/memory/base.py +++ b/swarms/memory/base.py @@ -30,32 +30,25 @@ class BaseVectorStore(ABC): embedding_driver: Any futures_executor: futures.Executor = field( - default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True - ) - - def upsert_text_artifacts( - self, - artifacts: dict[str, list[TextArtifact]], - meta: Optional[dict] = None, - **kwargs - ) -> None: - execute_futures_dict( - { - namespace: self.futures_executor.submit( - self.upsert_text_artifact, a, namespace, meta, **kwargs - ) - for namespace, artifact_list in artifacts.items() - for a in artifact_list - } - ) - - def upsert_text_artifact( - self, - artifact: TextArtifact, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs - ) -> str: + default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + + def upsert_text_artifacts(self, + artifacts: dict[str, list[TextArtifact]], + meta: Optional[dict] = None, + **kwargs) -> None: + execute_futures_dict({ + namespace: + self.futures_executor.submit(self.upsert_text_artifact, a, + namespace, meta, **kwargs) + for namespace, artifact_list in artifacts.items() + for a in artifact_list + }) + + def upsert_text_artifact(self, + artifact: TextArtifact, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs) -> str: if not meta: meta = {} @@ -66,39 +59,37 @@ class BaseVectorStore(ABC): else: vector = artifact.generate_embedding(self.embedding_driver) - return self.upsert_vector( - vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs - ) - - def upsert_text( - self, - string: str, - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs - ) -> str: - return self.upsert_vector( - self.embedding_driver.embed_string(string), - vector_id=vector_id, - namespace=namespace, - meta=meta if meta else {}, - **kwargs - ) + return self.upsert_vector(vector, + vector_id=artifact.id, + namespace=namespace, + meta=meta, + **kwargs) + + def upsert_text(self, + string: str, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs) -> str: + return self.upsert_vector(self.embedding_driver.embed_string(string), + vector_id=vector_id, + namespace=namespace, + meta=meta if meta else {}, + **kwargs) @abstractmethod - def upsert_vector( - self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs - ) -> str: + def upsert_vector(self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs) -> str: ... @abstractmethod - def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: + def load_entry(self, + vector_id: str, + namespace: Optional[str] = None) -> Entry: ... @abstractmethod @@ -106,12 +97,10 @@ class BaseVectorStore(ABC): ... @abstractmethod - def query( - self, - query: str, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - **kwargs - ) -> list[QueryResult]: + def query(self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs) -> list[QueryResult]: ... diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index 67ba4cb2..080245fb 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -80,10 +80,8 @@ class Chroma(VectorStore): import chromadb import chromadb.config except ImportError: - raise ImportError( - "Could not import chromadb python package. " - "Please install it with `pip install chromadb`." - ) + raise ImportError("Could not import chromadb python package. " + "Please install it with `pip install chromadb`.") if client is not None: self._client_settings = client_settings @@ -94,8 +92,7 @@ class Chroma(VectorStore): # If client_settings is provided with persist_directory specified, # then it is "in-memory and persisting to disk" mode. client_settings.persist_directory = ( - persist_directory or client_settings.persist_directory - ) + persist_directory or client_settings.persist_directory) if client_settings.persist_directory is not None: # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") @@ -108,25 +105,23 @@ class Chroma(VectorStore): major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: _client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - ) + chroma_db_impl="duckdb+parquet",) else: - _client_settings = chromadb.config.Settings(is_persistent=True) + _client_settings = chromadb.config.Settings( + is_persistent=True) _client_settings.persist_directory = persist_directory else: _client_settings = chromadb.config.Settings() self._client_settings = _client_settings self._client = chromadb.Client(_client_settings) - self._persist_directory = ( - _client_settings.persist_directory or persist_directory - ) + self._persist_directory = (_client_settings.persist_directory or + persist_directory) self._embedding_function = embedding_function self._collection = self._client.get_or_create_collection( name=collection_name, embedding_function=self._embedding_function.embed_documents - if self._embedding_function is not None - else None, + if self._embedding_function is not None else None, metadata=collection_metadata, ) self.override_relevance_score_fn = relevance_score_fn @@ -149,10 +144,8 @@ class Chroma(VectorStore): try: import chromadb # noqa: F401 except ImportError: - raise ValueError( - "Could not import chromadb python package. " - "Please install it with `pip install chromadb`." - ) + raise ValueError("Could not import chromadb python package. " + "Please install it with `pip install chromadb`.") return self._collection.query( query_texts=query_texts, query_embeddings=query_embeddings, @@ -202,9 +195,9 @@ class Chroma(VectorStore): if non_empty_ids: metadatas = [metadatas[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids] - embeddings_with_metadatas = ( - [embeddings[idx] for idx in non_empty_ids] if embeddings else None - ) + embeddings_with_metadatas = ([ + embeddings[idx] for idx in non_empty_ids + ] if embeddings else None) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: self._collection.upsert( @@ -225,8 +218,7 @@ class Chroma(VectorStore): if empty_ids: texts_without_metadatas = [texts[j] for j in empty_ids] embeddings_without_metadatas = ( - [embeddings[j] for j in empty_ids] if embeddings else None - ) + [embeddings[j] for j in empty_ids] if embeddings else None) ids_without_metadatas = [ids[j] for j in empty_ids] self._collection.upsert( embeddings=embeddings_without_metadatas, @@ -258,7 +250,9 @@ class Chroma(VectorStore): Returns: List[Document]: List of documents most similar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + docs_and_scores = self.similarity_search_with_score(query, + k, + filter=filter) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector( @@ -381,8 +375,7 @@ class Chroma(VectorStore): raise ValueError( "No supported normalization function" f" for distance metric of type: {distance}." - "Consider providing relevance_score_fn to Chroma constructor." - ) + "Consider providing relevance_score_fn to Chroma constructor.") def max_marginal_relevance_search_by_vector( self, @@ -428,7 +421,9 @@ class Chroma(VectorStore): candidates = _results_to_docs(results) - selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + selected_results = [ + r for i, r in enumerate(candidates) if i in mmr_selected + ] return selected_results def max_marginal_relevance_search( @@ -523,10 +518,8 @@ class Chroma(VectorStore): It will also be called automatically when the object is destroyed. """ if self._persist_directory is None: - raise ValueError( - "You must specify a persist_directory on" - "creation to persist the collection." - ) + raise ValueError("You must specify a persist_directory on" + "creation to persist the collection.") import chromadb # Maintain backwards compatibility with chromadb < 0.4.0 @@ -543,7 +536,8 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], documents: List[Document]) -> None: + def update_documents(self, ids: List[str], + documents: List[Document]) -> None: """Update a document in the collection. Args: @@ -558,17 +552,16 @@ class Chroma(VectorStore): ) embeddings = self._embedding_function.embed_documents(text) - if hasattr( - self._collection._client, "max_batch_size" - ): # for Chroma 0.4.10 and above + if hasattr(self._collection._client, + "max_batch_size"): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( - api=self._collection._client, - ids=ids, - metadatas=metadata, - documents=text, - embeddings=embeddings, + api=self._collection._client, + ids=ids, + metadatas=metadata, + documents=text, + embeddings=embeddings, ): self._collection.update( ids=batch[0], @@ -628,16 +621,15 @@ class Chroma(VectorStore): ) if ids is None: ids = [str(uuid.uuid1()) for _ in texts] - if hasattr( - chroma_collection._client, "max_batch_size" - ): # for Chroma 0.4.10 and above + if hasattr(chroma_collection._client, + "max_batch_size"): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( - api=chroma_collection._client, - ids=ids, - metadatas=metadatas, - documents=texts, + api=chroma_collection._client, + ids=ids, + metadatas=metadatas, + documents=texts, ): chroma_collection.add_texts( texts=batch[3] if batch[3] else [], @@ -645,7 +637,9 @@ class Chroma(VectorStore): ids=batch[0], ) else: - chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) + chroma_collection.add_texts(texts=texts, + metadatas=metadatas, + ids=ids) return chroma_collection @classmethod diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index 99d47368..9b183834 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -19,8 +19,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: if X.shape[1] != Y.shape[1]: raise ValueError( f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." - ) + f"and Y has shape {Y.shape}.") try: import simsimd as simd @@ -33,8 +32,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: except ImportError: logger.info( "Unable to import simsimd, defaulting to NumPy implementation. If you want " - "to use simsimd please install with `pip install simsimd`." - ) + "to use simsimd please install with `pip install simsimd`.") X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) # Ignore divide by zero errors run time warnings as those are handled below. diff --git a/swarms/memory/db.py b/swarms/memory/db.py index 9f23b59f..8e6bad12 100644 --- a/swarms/memory/db.py +++ b/swarms/memory/db.py @@ -27,6 +27,7 @@ class NotFoundException(Exception): class TaskDB(ABC): + async def create_task( self, input: Optional[str], @@ -67,9 +68,9 @@ class TaskDB(ABC): async def list_tasks(self) -> List[Task]: raise NotImplementedError - async def list_steps( - self, task_id: str, status: Optional[Status] = None - ) -> List[Step]: + async def list_steps(self, + task_id: str, + status: Optional[Status] = None) -> List[Step]: raise NotImplementedError @@ -136,8 +137,8 @@ class InMemoryTaskDB(TaskDB): async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: task = await self.get_task(task_id) artifact = next( - filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None - ) + filter(lambda a: a.artifact_id == artifact_id, task.artifacts), + None) if not artifact: raise NotFoundException("Artifact", artifact_id) return artifact @@ -150,9 +151,9 @@ class InMemoryTaskDB(TaskDB): step_id: Optional[str] = None, ) -> Artifact: artifact_id = str(uuid.uuid4()) - artifact = Artifact( - artifact_id=artifact_id, file_name=file_name, relative_path=relative_path - ) + artifact = Artifact(artifact_id=artifact_id, + file_name=file_name, + relative_path=relative_path) task = await self.get_task(task_id) task.artifacts.append(artifact) @@ -165,9 +166,9 @@ class InMemoryTaskDB(TaskDB): async def list_tasks(self) -> List[Task]: return [task for task in self._tasks.values()] - async def list_steps( - self, task_id: str, status: Optional[Status] = None - ) -> List[Step]: + async def list_steps(self, + task_id: str, + status: Optional[Status] = None) -> List[Step]: task = await self.get_task(task_id) steps = task.steps if status: diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py index da58c81c..339c3596 100644 --- a/swarms/memory/ocean.py +++ b/swarms/memory/ocean.py @@ -63,8 +63,7 @@ class OceanDB: try: embedding_function = MultiModalEmbeddingFunction(modality=modality) collection = self.client.create_collection( - collection_name, embedding_function=embedding_function - ) + collection_name, embedding_function=embedding_function) return collection except Exception as e: logging.error(f"Failed to create collection. Error {e}") @@ -91,7 +90,8 @@ class OceanDB: try: return collection.add(documents=[document], ids=[id]) except Exception as e: - logging.error(f"Failed to append document to the collection. Error {e}") + logging.error( + f"Failed to append document to the collection. Error {e}") raise def add_documents(self, collection, documents: List[str], ids: List[str]): @@ -137,7 +137,8 @@ class OceanDB: the results of the query """ try: - results = collection.query(query_texts=query_texts, n_results=n_results) + results = collection.query(query_texts=query_texts, + n_results=n_results) return results except Exception as e: logging.error(f"Failed to query the collection. Error {e}") diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index bd768459..09534cac 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -88,12 +88,12 @@ class PgVectorVectorStore(BaseVectorStore): create_engine_params: dict = field(factory=dict, kw_only=True) engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) - _model: any = field( - default=Factory(lambda self: self.default_vector_model(), takes_self=True) - ) + _model: any = field(default=Factory( + lambda self: self.default_vector_model(), takes_self=True)) @connection_string.validator - def validate_connection_string(self, _, connection_string: Optional[str]) -> None: + def validate_connection_string(self, _, + connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. if self.engine is not None: return @@ -122,9 +122,8 @@ class PgVectorVectorStore(BaseVectorStore): If not, a connection string is used to create a new database connection here. """ if self.engine is None: - self.engine = create_engine( - self.connection_string, **self.create_engine_params - ) + self.engine = create_engine(self.connection_string, + **self.create_engine_params) def setup( self, @@ -142,14 +141,12 @@ class PgVectorVectorStore(BaseVectorStore): if create_schema: self._model.metadata.create_all(self.engine) - def upsert_vector( - self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs - ) -> str: + def upsert_vector(self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs) -> str: """Inserts or updates a vector in the collection.""" with Session(self.engine) as session: obj = self._model( @@ -164,9 +161,9 @@ class PgVectorVectorStore(BaseVectorStore): return str(obj.id) - def load_entry( - self, vector_id: str, namespace: Optional[str] = None - ) -> BaseVectorStore.Entry: + def load_entry(self, + vector_id: str, + namespace: Optional[str] = None) -> BaseVectorStore.Entry: """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" with Session(self.engine) as session: result = session.get(self._model, vector_id) @@ -179,8 +176,8 @@ class PgVectorVectorStore(BaseVectorStore): ) def load_entries( - self, namespace: Optional[str] = None - ) -> list[BaseVectorStore.Entry]: + self, + namespace: Optional[str] = None) -> list[BaseVectorStore.Entry]: """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace. """ @@ -197,19 +194,16 @@ class PgVectorVectorStore(BaseVectorStore): vector=result.vector, namespace=result.namespace, meta=result.meta, - ) - for result in results + ) for result in results ] - def query( - self, - query: str, - count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, - namespace: Optional[str] = None, - include_vectors: bool = False, - distance_metric: str = "cosine_distance", - **kwargs - ) -> list[BaseVectorStore.QueryResult]: + def query(self, + query: str, + count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, + namespace: Optional[str] = None, + include_vectors: bool = False, + distance_metric: str = "cosine_distance", + **kwargs) -> list[BaseVectorStore.QueryResult]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace. """ @@ -245,8 +239,7 @@ class PgVectorVectorStore(BaseVectorStore): score=result[1], meta=result[0].meta, namespace=result[0].namespace, - ) - for result in results + ) for result in results ] def default_vector_model(self) -> any: diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 2374f12a..0269aa38 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -102,14 +102,12 @@ class PineconeVectorStoreStore(BaseVector): self.index = pinecone.Index(self.index_name) - def upsert_vector( - self, - vector: list[float], - vector_id: Optional[str] = None, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - **kwargs - ) -> str: + def upsert_vector(self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs) -> str: """Upsert vector""" vector_id = vector_id if vector_id else str_to_hash(str(vector)) @@ -120,10 +118,12 @@ class PineconeVectorStoreStore(BaseVector): return vector_id def load_entry( - self, vector_id: str, namespace: Optional[str] = None - ) -> Optional[BaseVector.Entry]: + self, + vector_id: str, + namespace: Optional[str] = None) -> Optional[BaseVector.Entry]: """Load entry""" - result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() + result = self.index.fetch(ids=[vector_id], + namespace=namespace).to_dict() vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -138,7 +138,8 @@ class PineconeVectorStoreStore(BaseVector): else: return None - def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]: + def load_entries(self, + namespace: Optional[str] = None) -> list[BaseVector.Entry]: """Load entries""" # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: @@ -157,20 +158,18 @@ class PineconeVectorStoreStore(BaseVector): vector=r["values"], meta=r["metadata"], namespace=results["namespace"], - ) - for r in results["matches"] + ) for r in results["matches"] ] def query( - self, - query: str, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - # PineconeVectorStoreStorageDriver-specific params: - include_metadata=True, - **kwargs - ) -> list[BaseVector.QueryResult]: + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + # PineconeVectorStoreStorageDriver-specific params: + include_metadata=True, + **kwargs) -> list[BaseVector.QueryResult]: """Query vectors""" vector = self.embedding_driver.embed_string(query) @@ -190,12 +189,14 @@ class PineconeVectorStoreStore(BaseVector): score=r["score"], meta=r["metadata"], namespace=results["namespace"], - ) - for r in results["matches"] + ) for r in results["matches"] ] def create_index(self, name: str, **kwargs) -> None: """Create index""" - params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs + params = { + "name": name, + "dimension": self.embedding_driver.dimensions + } | kwargs pinecone.create_index(**params) diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index bbc71bc2..ce54208d 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -20,9 +20,9 @@ class Artifact(BaseModel): description="Id of the artifact", example="b225e278-8b4c-4f99-a696-8facf19f0e56", ) - file_name: str = Field( - ..., description="Filename of the artifact", example="main.py" - ) + file_name: str = Field(..., + description="Filename of the artifact", + example="main.py") relative_path: Optional[str] = Field( None, description="Relative path of the artifact in the agent's workspace", @@ -50,7 +50,8 @@ class StepInput(BaseModel): class StepOutput(BaseModel): __root__: Any = Field( ..., - description="Output that the task step has produced. Any value is allowed.", + description= + "Output that the task step has produced. Any value is allowed.", example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -81,9 +82,9 @@ class Task(TaskRequestBody): class StepRequestBody(BaseModel): - input: Optional[str] = Field( - None, description="Input prompt for the step.", example="Washington" - ) + input: Optional[str] = Field(None, + description="Input prompt for the step.", + example="Washington") additional_input: Optional[StepInput] = None @@ -104,22 +105,19 @@ class Step(StepRequestBody): description="The ID of the task step.", example="6bb1801a-fd80-45e8-899a-4dd723cc602e", ) - name: Optional[str] = Field( - None, description="The name of the task step.", example="Write to file" - ) + name: Optional[str] = Field(None, + description="The name of the task step.", + example="Write to file") status: Status = Field(..., description="The status of the task step.") output: Optional[str] = Field( None, description="Output of the task step.", - example=( - "I am going to use the write_to_file command and write Washington to a file" - " called output.txt best_score: best_score = equation_score idx_to_add = i @@ -57,8 +56,8 @@ def maximal_marginal_relevance( def filter_complex_metadata( documents: List[Document], *, - allowed_types: Tuple[Type, ...] = (str, bool, int, float) -) -> List[Document]: + allowed_types: Tuple[Type, + ...] = (str, bool, int, float)) -> List[Document]: """Filter out metadata types that are not supported for a vector store.""" updated_documents = [] for document in documents: diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 1f9ae052..6f6ea8ba 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -9,7 +9,6 @@ from swarms.models.huggingface import HuggingfaceLLM from swarms.models.wizard_storytelling import WizardLLMStoryTeller from swarms.models.mpt import MPT7B - # MultiModal Models from swarms.models.idefics import Idefics from swarms.models.kosmos_two import Kosmos @@ -27,7 +26,6 @@ import sys log_file = open("errors.txt", "w") sys.stderr = log_file - __all__ = [ "Anthropic", "Petals", diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index 30ec22ce..634fa030 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -41,21 +41,24 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: """Validate specified keyword args are mutually exclusive.""" def decorator(func: Callable) -> Callable: + @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: """Validate exactly one arg in each group is not None.""" counts = [ - sum(1 for arg in arg_group if kwargs.get(arg) is not None) + sum(1 + for arg in arg_group + if kwargs.get(arg) is not None) for arg_group in arg_groups ] invalid_groups = [i for i, count in enumerate(counts) if count != 1] if invalid_groups: - invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] - raise ValueError( - "Exactly one argument in each of the following" - " groups must be defined:" - f" {', '.join(invalid_group_names)}" - ) + invalid_group_names = [ + ", ".join(arg_groups[i]) for i in invalid_groups + ] + raise ValueError("Exactly one argument in each of the following" + " groups must be defined:" + f" {', '.join(invalid_group_names)}") return func(*args, **kwargs) return wrapper @@ -105,9 +108,10 @@ def mock_now(dt_value): # type: ignore datetime.datetime = real_datetime -def guard_import( - module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None -) -> Any: +def guard_import(module_name: str, + *, + pip_name: Optional[str] = None, + package: Optional[str] = None) -> Any: """Dynamically imports a module and raises a helpful exception if the module is not installed.""" try: @@ -115,8 +119,7 @@ def guard_import( except ImportError: raise ImportError( f"Could not import {module_name} python package. " - f"Please install it with `pip install {pip_name or module_name}`." - ) + f"Please install it with `pip install {pip_name or module_name}`.") return module @@ -132,23 +135,19 @@ def check_package_version( if lt_version is not None and imported_version >= parse(lt_version): raise ValueError( f"Expected {package} version to be < {lt_version}. Received " - f"{imported_version}." - ) + f"{imported_version}.") if lte_version is not None and imported_version > parse(lte_version): raise ValueError( f"Expected {package} version to be <= {lte_version}. Received " - f"{imported_version}." - ) + f"{imported_version}.") if gt_version is not None and imported_version <= parse(gt_version): raise ValueError( f"Expected {package} version to be > {gt_version}. Received " - f"{imported_version}." - ) + f"{imported_version}.") if gte_version is not None and imported_version < parse(gte_version): raise ValueError( f"Expected {package} version to be >= {gte_version}. Received " - f"{imported_version}." - ) + f"{imported_version}.") def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: @@ -180,19 +179,17 @@ def build_extra_kwargs( if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. + warnings.warn(f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) + Please confirm that {field_name} is what you intended.""") extra_kwargs[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra_kwargs.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter." - ) + "Instead they were passed in as part of `model_kwargs` parameter.") return extra_kwargs @@ -241,17 +238,16 @@ class _AnthropicCommon(BaseLanguageModel): def build_extra(cls, values: Dict) -> Dict: extra = values.get("model_kwargs", {}) all_required_field_names = get_pydantic_field_names(cls) - values["model_kwargs"] = build_extra_kwargs( - extra, values, all_required_field_names - ) + values["model_kwargs"] = build_extra_kwargs(extra, values, + all_required_field_names) return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") - ) + get_from_dict_or_env(values, "anthropic_api_key", + "ANTHROPIC_API_KEY")) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( values, @@ -281,8 +277,7 @@ class _AnthropicCommon(BaseLanguageModel): except ImportError: raise ImportError( "Could not import anthropic python package. " - "Please it install it with `pip install anthropic`." - ) + "Please it install it with `pip install anthropic`.") return values @property @@ -305,7 +300,8 @@ class _AnthropicCommon(BaseLanguageModel): """Get the identifying parameters.""" return {**{}, **self._default_params} - def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + def _get_anthropic_stop(self, + stop: Optional[List[str]] = None) -> List[str]: if not self.HUMAN_PROMPT or not self.AI_PROMPT: raise NameError("Please ensure the anthropic package is loaded") @@ -372,7 +368,8 @@ class Anthropic(LLM, _AnthropicCommon): return prompt # Already wrapped. # Guard against common errors in specifying wrong number of newlines. - corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) + corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, + prompt) if n_subs == 1: return corrected_prompt @@ -405,9 +402,10 @@ class Anthropic(LLM, _AnthropicCommon): """ if self.streaming: completion = "" - for chunk in self._stream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs - ): + for chunk in self._stream(prompt=prompt, + stop=stop, + run_manager=run_manager, + **kwargs): completion += chunk.text return completion @@ -433,9 +431,10 @@ class Anthropic(LLM, _AnthropicCommon): """Call out to Anthropic's completion endpoint asynchronously.""" if self.streaming: completion = "" - async for chunk in self._astream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs - ): + async for chunk in self._astream(prompt=prompt, + stop=stop, + run_manager=run_manager, + **kwargs): completion += chunk.text return completion @@ -476,8 +475,10 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} for token in self.client.completions.create( - prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params - ): + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params): chunk = GenerationChunk(text=token.completion) yield chunk if run_manager: @@ -509,10 +510,10 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} async for token in await self.async_client.completions.create( - prompt=self._wrap_prompt(prompt), - stop_sequences=stop, - stream=True, - **params, + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, ): chunk = GenerationChunk(text=token.completion) yield chunk diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py index c2b4bfa5..d7052ef3 100644 --- a/swarms/models/bioclip.py +++ b/swarms/models/bioclip.py @@ -97,9 +97,8 @@ class BioClip: self.preprocess_val, ) = open_clip.create_model_and_transforms(model_path) self.tokenizer = open_clip.get_tokenizer(model_path) - self.device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) + self.device = (torch.device("cuda") + if torch.cuda.is_available() else torch.device("cpu")) self.model.to(self.device) self.model.eval() @@ -110,18 +109,17 @@ class BioClip: template: str = "this is a photo of ", context_length: int = 256, ): - image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device) - texts = self.tokenizer( - [template + l for l in labels], context_length=context_length - ).to(self.device) + image = torch.stack([self.preprocess_val(Image.open(img_path)) + ]).to(self.device) + texts = self.tokenizer([template + l for l in labels], + context_length=context_length).to(self.device) with torch.no_grad(): - image_features, text_features, logit_scale = self.model(image, texts) - logits = ( - (logit_scale * image_features @ text_features.t()) - .detach() - .softmax(dim=-1) - ) + image_features, text_features, logit_scale = self.model( + image, texts) + logits = ((logit_scale * + image_features @ text_features.t()).detach().softmax( + dim=-1)) sorted_indices = torch.argsort(logits, dim=-1, descending=True) logits = logits.cpu().numpy() sorted_indices = sorted_indices.cpu().numpy() @@ -139,11 +137,8 @@ class BioClip: fig, ax = plt.subplots(figsize=(5, 5)) ax.imshow(img) ax.axis("off") - title = ( - metadata["filename"] - + "\n" - + "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()]) - ) + title = (metadata["filename"] + "\n" + "\n".join( + [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()])) ax.set_title(title, fontsize=14) plt.tight_layout() plt.show() diff --git a/swarms/models/biogpt.py b/swarms/models/biogpt.py index 83c31e55..ebec10b9 100644 --- a/swarms/models/biogpt.py +++ b/swarms/models/biogpt.py @@ -102,9 +102,9 @@ class BioGPT: list[dict]: A list of generated texts. """ set_seed(42) - generator = pipeline( - "text-generation", model=self.model, tokenizer=self.tokenizer - ) + generator = pipeline("text-generation", + model=self.model, + tokenizer=self.tokenizer) out = generator( text, max_length=self.max_length, @@ -149,13 +149,11 @@ class BioGPT: inputs = self.tokenizer(sentence, return_tensors="pt") set_seed(42) with torch.no_grad(): - beam_output = self.model.generate( - **inputs, - min_length=self.min_length, - max_length=self.max_length, - num_beams=num_beams, - early_stopping=early_stopping - ) + beam_output = self.model.generate(**inputs, + min_length=self.min_length, + max_length=self.max_length, + num_beams=num_beams, + early_stopping=early_stopping) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) # Feature 1: Set a new tokenizer and model diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index c24f262d..788bae62 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -124,13 +124,10 @@ class Dalle3: # Handling exceptions and printing the errors details print( colored( - ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" - ), + (f"Error running Dalle3: {error} try optimizing your api key and" + " or try again"), "red", - ) - ) + )) raise error def create_variations(self, img: str): @@ -157,22 +154,19 @@ class Dalle3: """ try: - response = self.client.images.create_variation( - img=open(img, "rb"), n=self.n, size=self.size - ) + response = self.client.images.create_variation(img=open(img, "rb"), + n=self.n, + size=self.size) img = response.data[0].url return img except (Exception, openai.OpenAIError) as error: print( colored( - ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" - ), + (f"Error running Dalle3: {error} try optimizing your api key and" + " or try again"), "red", - ) - ) + )) print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red")) raise error diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 0a60aaac..8fc5b99a 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -18,6 +18,7 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): """ def decorator(func): + @wraps(func) async def wrapper(*args, **kwargs): retries = max_retries @@ -28,7 +29,9 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): retries -= 1 if retries <= 0: raise - print(f"Retry after exception: {e}, Attempts remaining: {retries}") + print( + f"Retry after exception: {e}, Attempts remaining: {retries}" + ) await asyncio.sleep(delay) return wrapper @@ -62,7 +65,8 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + self.torch_dtype = torch.float16 if torch.cuda.is_available( + ) else torch.float32 self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, @@ -119,14 +123,14 @@ class DistilWhisperModel: try: with torch.no_grad(): # Load the whole audio file, but process and transcribe it in chunks - audio_input = self.processor.audio_file_to_array(audio_file_path) + audio_input = self.processor.audio_file_to_array( + audio_file_path) sample_rate = audio_input.sampling_rate total_duration = len(audio_input.array) / sample_rate chunks = [ - audio_input.array[i : i + sample_rate * chunk_duration] - for i in range( - 0, len(audio_input.array), sample_rate * chunk_duration - ) + audio_input.array[i:i + sample_rate * chunk_duration] + for i in range(0, len(audio_input.array), sample_rate * + chunk_duration) ] print(colored("Starting real-time transcription...", "green")) @@ -139,22 +143,22 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = processed_inputs.input_values.to(self.device) + processed_inputs = processed_inputs.input_values.to( + self.device) # Generate transcription for the chunk logits = self.model.generate(processed_inputs) transcription = self.processor.batch_decode( - logits, skip_special_tokens=True - )[0] + logits, skip_special_tokens=True)[0] # Print the chunk's transcription print( - colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") - + transcription - ) + colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + + transcription) # Wait for the chunk's duration to simulate real-time processing time.sleep(chunk_duration) except Exception as e: - print(colored(f"An error occurred during transcription: {e}", "red")) + print(colored(f"An error occurred during transcription: {e}", + "red")) diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index a2d6bc0a..370569fb 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -11,7 +11,8 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the classes for image classification -with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f: +with open(os.path.join(os.path.dirname(__file__), + "fast_vit_classes.json")) as f: FASTVIT_IMAGENET_1K_CLASSES = json.load(f) @@ -21,7 +22,8 @@ class ClassificationResult(BaseModel): @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): - assert isinstance(v, int) or isinstance(v, float), "must be integer or float" + assert isinstance(v, int) or isinstance( + v, float), "must be integer or float" return v @@ -47,16 +49,16 @@ class FastViT: """ def __init__(self): - self.model = timm.create_model( - "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True - ).to(DEVICE) + self.model = timm.create_model("hf_hub:timm/fastvit_s12.apple_in1k", + pretrained=True).to(DEVICE) data_config = timm.data.resolve_model_data_config(self.model) - self.transforms = timm.data.create_transform(**data_config, is_training=False) + self.transforms = timm.data.create_transform(**data_config, + is_training=False) self.model.eval() - def __call__( - self, img: str, confidence_threshold: float = 0.5 - ) -> ClassificationResult: + def __call__(self, + img: str, + confidence_threshold: float = 0.5) -> ClassificationResult: """classifies the input image and returns the top k classes and their probabilities""" img = Image.open(img).convert("RGB") img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE) @@ -65,9 +67,8 @@ class FastViT: probabilities = torch.nn.functional.softmax(output, dim=1) # Get top k classes and their probabilities - top_probs, top_classes = torch.topk( - probabilities, k=FASTVIT_IMAGENET_1K_CLASSES - ) + top_probs, top_classes = torch.topk(probabilities, + k=FASTVIT_IMAGENET_1K_CLASSES) # Filter by confidence threshold mask = top_probs > confidence_threshold diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index d2d3ebe7..d7148d0e 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -46,9 +46,9 @@ class Fuyu: self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.image_processor = FuyuImageProcessor() - self.processor = FuyuProcessor( - image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs - ) + self.processor = FuyuProcessor(image_processor=self.image_processor, + tokenizer=self.tokenizer, + **kwargs) self.model = FuyuForCausalLM.from_pretrained( pretrained_path, device_map=device_map, @@ -63,15 +63,17 @@ class Fuyu: def __call__(self, text: str, img: str): """Call the model with text and img paths""" image_pil = Image.open(img) - model_inputs = self.processor( - text=text, images=[image_pil], device=self.device_map - ) + model_inputs = self.processor(text=text, + images=[image_pil], + device=self.device_map) for k, v in model_inputs.items(): model_inputs[k] = v.to(self.device_map) - output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens) - text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True) + output = self.model.generate(**model_inputs, + max_new_tokens=self.max_new_tokens) + text = self.processor.batch_decode(output[:, -7:], + skip_special_tokens=True) return print(str(text)) def get_img_from_web(self, img_url: str): diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py index 3fa87443..87393fab 100644 --- a/swarms/models/gpt4v.py +++ b/swarms/models/gpt4v.py @@ -130,19 +130,23 @@ class GPT4Vision: } # Image content - image_content = [ - {"type": "imavge_url", "image_url": img} - if img.startswith("http") - else {"type": "image", "data": img} - for img in img - ] - - messages = [ - { - "role": "user", - "content": image_content + [{"type": "text", "text": q} for q in tasks], - } - ] + image_content = [{ + "type": "imavge_url", + "image_url": img + } if img.startswith("http") else { + "type": "image", + "data": img + } for img in img] + + messages = [{ + "role": + "user", + "content": + image_content + [{ + "type": "text", + "text": q + } for q in tasks], + }] payload = { "model": "gpt-4-vision-preview", @@ -160,7 +164,8 @@ class GPT4Vision: timeout=self.timeout_seconds, ) response.raise_for_status() - answer = response.json()["choices"][0]["message"]["content"]["text"] + answer = response.json( + )["choices"][0]["message"]["content"]["text"] return GPT4VisionResponse(answer=answer) except requests.exceptions.HTTPError as error: self.logger.error( @@ -179,8 +184,7 @@ class GPT4Vision: except Exception as error: self.logger.error( f"Unexpected Error: {error} try optimizing your api key and try" - " again" - ) + " again") raise error from None raise TimeoutError("API Request timed out after multiple retries") @@ -212,18 +216,20 @@ class GPT4Vision: try: response = self.client.chat.completions.create( model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": f"{task}"}, - { - "type": "image_url", - "image_url": f"{img}", - }, - ], - } - ], + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": f"{task}" + }, + { + "type": "image_url", + "image_url": f"{img}", + }, + ], + }], max_tokens=self.max_tokens, ) @@ -232,13 +238,10 @@ class GPT4Vision: except Exception as error: print( colored( - ( - f"Error when calling GPT4Vision, Error: {error} Try optimizing" - " your key, and try again" - ), + (f"Error when calling GPT4Vision, Error: {error} Try optimizing" + " your key, and try again"), "red", - ) - ) + )) async def arun(self, task: str, img: str) -> str: """ @@ -267,18 +270,20 @@ class GPT4Vision: try: response = await self.client.chat.completions.create( model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": f"{task}"}, - { - "type": "image_url", - "image_url": f"{img}", - }, - ], - } - ], + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": f"{task}" + }, + { + "type": "image_url", + "image_url": f"{img}", + }, + ], + }], max_tokens=self.max_tokens, ) out = response.choices[0].text @@ -286,10 +291,7 @@ class GPT4Vision: except Exception as error: print( colored( - ( - f"Error when calling GPT4Vision, Error: {error} Try optimizing" - " your key, and try again" - ), + (f"Error when calling GPT4Vision, Error: {error} Try optimizing" + " your key, and try again"), "red", - ) - ) + )) diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 9279fea4..a84cc960 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -47,9 +47,8 @@ class HuggingfaceLLM: **kwargs, ): self.logger = logging.getLogger(__name__) - self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") - ) + self.device = (device if device else + ("cuda" if torch.cuda.is_available() else "cpu")) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -58,9 +57,8 @@ class HuggingfaceLLM: self.model, self.tokenizer = None, None if self.distributed: - assert ( - torch.cuda.device_count() > 1 - ), "You need more than 1 gpu for distributed processing" + assert (torch.cuda.device_count() > + 1), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -75,17 +73,17 @@ class HuggingfaceLLM: try: self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id, *args, **kwargs - ) + self.model_id, *args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, *args, **kwargs - ) + self.model_id, quantization_config=bnb_config, *args, **kwargs) self.model # .to(self.device) except Exception as e: # self.logger.error(f"Failed to load the model or the tokenizer: {e}") # raise - print(colored(f"Failed to load the model and or the tokenizer: {e}", "red")) + print( + colored(f"Failed to load the model and or the tokenizer: {e}", + "red")) def print_error(self, error: str): """Print error""" @@ -97,20 +95,18 @@ class HuggingfaceLLM: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = ( - BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config - else None - ) + bnb_config = (BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config else None) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config - ).to(self.device) + self.model_id, + quantization_config=bnb_config).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}") raise def run(self, task: str): @@ -131,7 +127,8 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, + return_tensors="pt").to(self.device) # self.log.start() @@ -140,39 +137,36 @@ class HuggingfaceLLM: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: print( colored( - ( - f"HuggingfaceLLM could not generate text because of error: {e}," - " try optimizing your arguments" - ), + (f"HuggingfaceLLM could not generate text because of error: {e}," + " try optimizing your arguments"), "red", - ) - ) + )) raise async def run_async(self, task: str, *args, **kwargs) -> str: @@ -216,7 +210,8 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, + return_tensors="pt").to(self.device) # self.log.start() @@ -225,26 +220,26 @@ class HuggingfaceLLM: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs @@ -305,8 +300,7 @@ class HuggingfaceLLM: """, "red", - ) - ) + )) print(dashboard) diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 73cb4991..41b8823d 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -65,9 +65,8 @@ class Idefics: torch_dtype=torch.bfloat16, max_length=100, ): - self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") - ) + self.device = (device if device else + ("cuda" if torch.cuda.is_available() else "cpu")) self.model = IdeficsForVisionText2Text.from_pretrained( checkpoint, torch_dtype=torch_dtype, @@ -96,21 +95,17 @@ class Idefics: list A list of generated text strings. """ - inputs = ( - self.processor( - prompts, add_end_of_utterance_token=False, return_tensors="pt" - ).to(self.device) - if batched_mode - else self.processor(prompts[0], return_tensors="pt").to(self.device) - ) + inputs = (self.processor( + prompts, add_end_of_utterance_token=False, return_tensors="pt").to( + self.device) if batched_mode else self.processor( + prompts[0], return_tensors="pt").to(self.device)) exit_condition = self.processor.tokenizer( - "", add_special_tokens=False - ).input_ids + "", add_special_tokens=False).input_ids bad_words_ids = self.processor.tokenizer( - ["", "", "", add_special_tokens=False - ).input_ids + "", add_special_tokens=False).input_ids bad_words_ids = self.processor.tokenizer( - ["", "", " 1 - ), "You need more than 1 gpu for distributed processing" + assert (torch.cuda.device_count() > + 1), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -83,8 +81,9 @@ class JinaEmbeddings: try: self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, trust_remote_code=True - ) + self.model_id, + quantization_config=bnb_config, + trust_remote_code=True) self.model # .to(self.device) except Exception as e: @@ -97,11 +96,8 @@ class JinaEmbeddings: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = ( - BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config - else None - ) + bnb_config = (BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config else None) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, @@ -112,7 +108,8 @@ class JinaEmbeddings: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}") raise def run(self, task: str): diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index 12d5638a..9a1eafba 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -14,11 +14,8 @@ class Detections(BaseModel): @root_validator def check_length(cls, values): - assert ( - len(values.get("xyxy")) - == len(values.get("class_id")) - == len(values.get("confidence")) - ), "All fields must have the same length." + assert (len(values.get("xyxy")) == len(values.get("class_id")) == len( + values.get("confidence"))), "All fields must have the same length." return values @validator("xyxy", "class_id", "confidence", pre=True, each_item=True) @@ -39,11 +36,9 @@ class Kosmos2(BaseModel): @classmethod def initialize(cls): model = AutoModelForVision2Seq.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True - ) + "ydshieh/kosmos-2-patch14-224", trust_remote_code=True) processor = AutoProcessor.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True - ) + "ydshieh/kosmos-2-patch14-224", trust_remote_code=True) return cls(model=model, processor=processor) def __call__(self, img: str) -> Detections: @@ -51,11 +46,12 @@ class Kosmos2(BaseModel): prompt = "An image of" inputs = self.processor(text=prompt, images=image, return_tensors="pt") - outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64) + outputs = self.model.generate(**inputs, + use_cache=True, + max_new_tokens=64) - generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[ - 0 - ] + generated_text = self.processor.batch_decode( + outputs, skip_special_tokens=True)[0] # The actual processing of generated_text to entities would go here # For the purpose of this example, assume a mock function 'extract_entities' exists: @@ -66,8 +62,8 @@ class Kosmos2(BaseModel): return detections def extract_entities( - self, text: str - ) -> List[Tuple[str, Tuple[float, float, float, float]]]: + self, + text: str) -> List[Tuple[str, Tuple[float, float, float, float]]]: # Placeholder function for entity extraction # This should be replaced with the actual method of extracting entities return [] @@ -80,19 +76,19 @@ class Kosmos2(BaseModel): if not entities: return Detections.empty() - class_ids = [0] * len(entities) # Replace with actual class ID extraction logic - xyxys = [ - ( - e[1][0] * image.width, - e[1][1] * image.height, - e[1][2] * image.width, - e[1][3] * image.height, - ) - for e in entities - ] + class_ids = [0] * len( + entities) # Replace with actual class ID extraction logic + xyxys = [( + e[1][0] * image.width, + e[1][1] * image.height, + e[1][2] * image.width, + e[1][3] * image.height, + ) for e in entities] confidences = [1.0] * len(entities) # Placeholder confidence - return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences) + return Detections(xyxy=xyxys, + class_id=class_ids, + confidence=confidences) # Usage: diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index 596886f3..402ad73d 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -46,11 +46,9 @@ class Kosmos: model_name="ydshieh/kosmos-2-patch14-224", ): self.model = AutoModelForVision2Seq.from_pretrained( - model_name, trust_remote_code=True - ) - self.processor = AutoProcessor.from_pretrained( - model_name, trust_remote_code=True - ) + model_name, trust_remote_code=True) + self.processor = AutoProcessor.from_pretrained(model_name, + trust_remote_code=True) def get_image(self, url): """Image""" @@ -73,8 +71,7 @@ class Kosmos: skip_special_tokens=True, )[0] processed_text, entities = self.processor.post_process_generation( - generated_texts - ) + generated_texts) def __call__(self, prompt, image): """Run call""" @@ -93,8 +90,7 @@ class Kosmos: skip_special_tokens=True, )[0] processed_text, entities = self.processor.post_process_generation( - generated_texts - ) + generated_texts) # tasks def multimodal_grounding(self, phrase, image_url): @@ -145,12 +141,10 @@ class Kosmos: elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() - reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ - :, None, None - ] - reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ - :, None, None - ] + reverse_norm_mean = torch.tensor( + [0.48145466, 0.4578275, 0.40821073])[:, None, None] + reverse_norm_std = torch.tensor( + [0.26862954, 0.26130258, 0.27577711])[:, None, None] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height @@ -169,9 +163,9 @@ class Kosmos: # thickness of text text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) box_line = 3 - (c_width, text_height), _ = cv2.getTextSize( - "F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line - ) + (c_width, text_height), _ = cv2.getTextSize("F", + cv2.FONT_HERSHEY_COMPLEX, + text_size, text_line) base_height = int(text_height * 0.675) text_offset_original = text_height - base_height text_spaces = 3 @@ -187,9 +181,8 @@ class Kosmos: # draw bbox # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) - new_image = cv2.rectangle( - new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line - ) + new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), + (orig_x2, orig_y2), color, box_line) l_o, r_o = ( box_line // 2 + box_line % 2, @@ -200,19 +193,15 @@ class Kosmos: y1 = orig_y1 - l_o if y1 < text_height + text_offset_original + 2 * text_spaces: - y1 = ( - orig_y1 - + r_o - + text_height - + text_offset_original - + 2 * text_spaces - ) + y1 = (orig_y1 + r_o + text_height + text_offset_original + + 2 * text_spaces) x1 = orig_x1 + r_o # add text background - (text_width, text_height), _ = cv2.getTextSize( - f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line - ) + (text_width, + text_height), _ = cv2.getTextSize(f" {entity_name}", + cv2.FONT_HERSHEY_COMPLEX, + text_size, text_line) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( x1, y1 - (text_height + text_offset_original + 2 * text_spaces), @@ -222,23 +211,19 @@ class Kosmos: for prev_bbox in previous_bboxes: while is_overlapping( - (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox - ): - text_bg_y1 += ( - text_height + text_offset_original + 2 * text_spaces - ) - text_bg_y2 += ( - text_height + text_offset_original + 2 * text_spaces - ) + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), + prev_bbox): + text_bg_y1 += (text_height + text_offset_original + + 2 * text_spaces) + text_bg_y2 += (text_height + text_offset_original + + 2 * text_spaces) y1 += text_height + text_offset_original + 2 * text_spaces if text_bg_y2 >= image_h: text_bg_y1 = max( 0, - image_h - - ( - text_height + text_offset_original + 2 * text_spaces - ), + image_h - (text_height + text_offset_original + + 2 * text_spaces), ) text_bg_y2 = image_h y1 = image_h @@ -255,9 +240,9 @@ class Kosmos: # white bg_color = [255, 255, 255] new_image[i, j] = ( - alpha * new_image[i, j] - + (1 - alpha) * np.array(bg_color) - ).astype(np.uint8) + alpha * new_image[i, j] + + (1 - alpha) * np.array(bg_color)).astype( + np.uint8) cv2.putText( new_image, @@ -270,7 +255,8 @@ class Kosmos: cv2.LINE_AA, ) # previous_locations.append((x1, y1)) - previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) + previous_bboxes.append( + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) if save_path: diff --git a/swarms/models/llava.py b/swarms/models/llava.py index 6f8019bc..7f49ad4a 100644 --- a/swarms/models/llava.py +++ b/swarms/models/llava.py @@ -48,9 +48,8 @@ class MultiModalLlava: revision=revision, ).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, use_fast=True - ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, + use_fast=True) self.pipe = pipeline( "text-generation", model=self.model, diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 7f48a0d6..f14d9e39 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -49,7 +49,8 @@ class Mistral: # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": - raise ValueError("CUDA is not available. Please choose a different device.") + raise ValueError( + "CUDA is not available. Please choose a different device.") # Load the model and tokenizer self.model = None @@ -70,7 +71,8 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], + return_tensors="pt").to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -87,7 +89,8 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], + return_tensors="pt").to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index 035e2b54..9fb6c90b 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -26,7 +26,10 @@ class MPT7B: """ - def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100): + def __init__(self, + model_name: str, + tokenizer_name: str, + max_tokens: int = 100): # Loading model and tokenizer details self.model_name = model_name self.tokenizer_name = tokenizer_name @@ -37,11 +40,9 @@ class MPT7B: self.logger = logging.getLogger(__name__) config = AutoModelForCausalLM.from_pretrained( - model_name, trust_remote_code=True - ).config + model_name, trust_remote_code=True).config self.model = AutoModelForCausalLM.from_pretrained( - model_name, config=config, trust_remote_code=True - ) + model_name, config=config, trust_remote_code=True) # Initializing a text-generation pipeline self.pipe = pipeline( @@ -114,9 +115,10 @@ class MPT7B: """ with torch.autocast("cuda", dtype=torch.bfloat16): - return self.pipe( - prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True - )[0]["generated_text"] + return self.pipe(prompt, + max_new_tokens=self.max_tokens, + do_sample=True, + use_cache=True)[0]["generated_text"] async def generate_async(self, prompt: str) -> str: """Generate Async""" diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index 34465c73..a362f94f 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -41,8 +41,10 @@ class Nougat: self.min_length = min_length self.max_new_tokens = max_new_tokens - self.processor = NougatProcessor.from_pretrained(self.model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path) + self.processor = NougatProcessor.from_pretrained( + self.model_name_or_path) + self.model = VisionEncoderDecoderModel.from_pretrained( + self.model_name_or_path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) @@ -63,8 +65,10 @@ class Nougat: max_new_tokens=self.max_new_tokens, ) - sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] - sequence = self.processor.post_process_generation(sequence, fix_markdown=False) + sequence = self.processor.batch_decode(outputs, + skip_special_tokens=True)[0] + sequence = self.processor.post_process_generation(sequence, + fix_markdown=False) out = print(repr(sequence)) return out diff --git a/swarms/models/openai_assistant.py b/swarms/models/openai_assistant.py index 6d0c518f..37b41191 100644 --- a/swarms/models/openai_assistant.py +++ b/swarms/models/openai_assistant.py @@ -55,9 +55,9 @@ class OpenAIAssistant: return thread def add_message_to_thread(self, thread_id: str, message: str): - message = self.client.beta.threads.add_message( - thread_id=thread_id, role=self.user, content=message - ) + message = self.client.beta.threads.add_message(thread_id=thread_id, + role=self.user, + content=message) return message def run(self, task: str): @@ -67,8 +67,7 @@ class OpenAIAssistant: instructions=self.instructions, ) - out = self.client.beta.threads.runs.retrieve( - thread_id=run.thread_id, run_id=run.id - ) + out = self.client.beta.threads.runs.retrieve(thread_id=run.thread_id, + run_id=run.id) return out diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 81dea550..8eeb009d 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -28,9 +28,10 @@ from tenacity import ( from swarms.models.embeddings_base import Embeddings -def get_from_dict_or_env( - values: dict, key: str, env_key: str, default: Any = None -) -> Any: +def get_from_dict_or_env(values: dict, + key: str, + env_key: str, + default: Any = None) -> Any: import os return values.get(key) or os.getenv(env_key) or default @@ -43,7 +44,8 @@ def get_pydantic_field_names(cls: Any) -> Set[str]: logger = logging.getLogger(__name__) -def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: +def _create_retry_decorator( + embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: import llm min_seconds = 4 @@ -54,13 +56,11 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any reraise=True, stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(llm.error.Timeout) - | retry_if_exception_type(llm.error.APIError) - | retry_if_exception_type(llm.error.APIConnectionError) - | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type(llm.error.ServiceUnavailableError) - ), + retry=(retry_if_exception_type(llm.error.Timeout) | + retry_if_exception_type(llm.error.APIError) | + retry_if_exception_type(llm.error.APIConnectionError) | + retry_if_exception_type(llm.error.RateLimitError) | + retry_if_exception_type(llm.error.ServiceUnavailableError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -76,17 +76,16 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any: reraise=True, stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(llm.error.Timeout) - | retry_if_exception_type(llm.error.APIError) - | retry_if_exception_type(llm.error.APIConnectionError) - | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type(llm.error.ServiceUnavailableError) - ), + retry=(retry_if_exception_type(llm.error.Timeout) | + retry_if_exception_type(llm.error.APIError) | + retry_if_exception_type(llm.error.APIConnectionError) | + retry_if_exception_type(llm.error.RateLimitError) | + retry_if_exception_type(llm.error.ServiceUnavailableError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) def wrap(func: Callable) -> Callable: + async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: async for _ in async_retrying: return await func(*args, **kwargs) @@ -118,7 +117,8 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) -async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: +async def async_embed_with_retry(embeddings: OpenAIEmbeddings, + **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" @_async_retry_decorator(embeddings) @@ -225,11 +225,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): warnings.warn( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) + Please confirm that {field_name} is what you intended.""") extra[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " @@ -242,9 +242,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["openai_api_key"] = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) + values["openai_api_key"] = get_from_dict_or_env(values, + "openai_api_key", + "OPENAI_API_KEY") values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", @@ -284,10 +284,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): values["client"] = llm.Embedding except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) + raise ImportError("Could not import openai python package. " + "Please install it with `pip install openai`.") return values @property @@ -315,8 +313,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return openai_args def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None - ) -> List[List[float]]: + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken @@ -324,8 +325,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " - "Please install it with `pip install tiktoken`." - ) + "Please install it with `pip install tiktoken`.") tokens = [] indices = [] @@ -333,7 +333,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -347,7 +348,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j : j + self.embedding_ctx_length]) + tokens.append(token[j:j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -366,7 +367,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in _iter: response = embed_with_retry( self, - input=tokens[i : i + _chunk_size], + input=tokens[i:i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -384,11 +385,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): self, input="", **self._invocation_params, - )[ - "data" - ][0]["embedding"] + )["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average(_result, + axis=0, + weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -396,8 +397,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb async def _aget_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None - ) -> List[List[float]]: + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None) -> List[List[float]]: embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken @@ -405,8 +409,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " - "Please install it with `pip install tiktoken`." - ) + "Please install it with `pip install tiktoken`.") tokens = [] indices = [] @@ -414,7 +417,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -428,7 +432,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j : j + self.embedding_ctx_length]) + tokens.append(token[j:j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -436,7 +440,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(0, len(tokens), _chunk_size): response = await async_embed_with_retry( self, - input=tokens[i : i + _chunk_size], + input=tokens[i:i + _chunk_size], **self._invocation_params, ) batched_embeddings.extend(r["embedding"] for r in response["data"]) @@ -450,22 +454,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i in range(len(texts)): _result = results[i] if len(_result) == 0: - average = ( - await async_embed_with_retry( - self, - input="", - **self._invocation_params, - ) - )["data"][0]["embedding"] + average = (await async_embed_with_retry( + self, + input="", + **self._invocation_params, + ))["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average(_result, + axis=0, + weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings - def embed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 - ) -> List[List[float]]: + def embed_documents(self, + texts: List[str], + chunk_size: Optional[int] = 0) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. Args: @@ -481,8 +485,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return self._get_len_safe_embeddings(texts, engine=self.deployment) async def aembed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 - ) -> List[List[float]]: + self, + texts: List[str], + chunk_size: Optional[int] = 0) -> List[List[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. Args: @@ -495,7 +500,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + return await self._aget_len_safe_embeddings(texts, + engine=self.deployment) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 0c803755..e1a327b5 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -33,9 +33,8 @@ from langchain.utils.utils import build_extra_kwargs logger = logging.getLogger(__name__) -def update_token_usage( - keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] -) -> None: +def update_token_usage(keys: Set[str], response: Dict[str, Any], + token_usage: Dict[str, Any]) -> None: """Update token usage.""" _keys_to_use = keys.intersection(response["usage"]) for _key in _keys_to_use: @@ -46,44 +45,42 @@ def update_token_usage( def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], -) -> GenerationChunk: + stream_response: Dict[str, Any],) -> GenerationChunk: """Convert a stream response to a generation chunk.""" return GenerationChunk( text=stream_response["choices"][0]["text"], generation_info=dict( - finish_reason=stream_response["choices"][0].get("finish_reason", None), + finish_reason=stream_response["choices"][0].get( + "finish_reason", None), logprobs=stream_response["choices"][0].get("logprobs", None), ), ) -def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: +def _update_response(response: Dict[str, Any], + stream_response: Dict[str, Any]) -> None: """Update response from the stream response.""" response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( - "finish_reason", None - ) - response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] + "finish_reason", None) + response["choices"][0]["logprobs"] = stream_response["choices"][0][ + "logprobs"] def _streaming_response_template() -> Dict[str, Any]: return { - "choices": [ - { - "text": "", - "finish_reason": None, - "logprobs": None, - } - ] + "choices": [{ + "text": "", + "finish_reason": None, + "logprobs": None, + }] } def _create_retry_decorator( llm: Union[BaseOpenAI, OpenAIChat], - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, + run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun]] = None, ) -> Callable[[Any], Any]: import openai @@ -94,9 +91,9 @@ def _create_retry_decorator( openai.error.RateLimitError, openai.error.ServiceUnavailableError, ] - return create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) + return create_base_retry_decorator(error_types=errors, + max_retries=llm.max_retries, + run_manager=run_manager) def completion_with_retry( @@ -206,7 +203,8 @@ class BaseOpenAI(BaseLLM): API but with different models. In those cases, in order to avoid erroring when tiktoken is called, you can specify a model name to use here.""" - def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore + def __new__(cls, + **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore """Initialize the OpenAI object.""" data.get("model_name", "") return super().__new__(cls) @@ -221,17 +219,16 @@ class BaseOpenAI(BaseLLM): """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) - values["model_kwargs"] = build_extra_kwargs( - extra, values, all_required_field_names - ) + values["model_kwargs"] = build_extra_kwargs(extra, values, + all_required_field_names) return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["openai_api_key"] = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) + values["openai_api_key"] = get_from_dict_or_env(values, + "openai_api_key", + "OPENAI_API_KEY") values["openai_api_base"] = get_from_dict_or_env( values, "openai_api_base", @@ -255,10 +252,8 @@ class BaseOpenAI(BaseLLM): values["client"] = openai.Completion except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) + raise ImportError("Could not import openai python package. " + "Please install it with `pip install openai`.") if values["streaming"] and values["n"] > 1: raise ValueError("Cannot stream results when n > 1.") if values["streaming"] and values["best_of"] > 1: @@ -295,9 +290,10 @@ class BaseOpenAI(BaseLLM): ) -> Iterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} self.get_sub_prompts(params, [prompt], stop) # this mutates params - for stream_resp in completion_with_retry( - self, prompt=prompt, run_manager=run_manager, **params - ): + for stream_resp in completion_with_retry(self, + prompt=prompt, + run_manager=run_manager, + **params): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk if run_manager: @@ -306,8 +302,7 @@ class BaseOpenAI(BaseLLM): chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + if chunk.generation_info else None, ) async def _astream( @@ -320,8 +315,7 @@ class BaseOpenAI(BaseLLM): params = {**self._invocation_params, **kwargs, "stream": True} self.get_sub_prompts(params, [prompt], stop) # this mutate params async for stream_resp in await acompletion_with_retry( - self, prompt=prompt, run_manager=run_manager, **params - ): + self, prompt=prompt, run_manager=run_manager, **params): chunk = _stream_response_to_generation_chunk(stream_resp) yield chunk if run_manager: @@ -330,8 +324,7 @@ class BaseOpenAI(BaseLLM): chunk=chunk, verbose=self.verbose, logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + if chunk.generation_info else None, ) def _generate( @@ -367,30 +360,32 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts.") generation: Optional[GenerationChunk] = None - for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): + for chunk in self._stream(_prompts[0], stop, run_manager, + **kwargs): if generation is None: generation = chunk else: generation += chunk assert generation is not None - choices.append( - { - "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, - } - ) + choices.append({ + "text": + generation.text, + "finish_reason": + generation.generation_info.get("finish_reason") + if generation.generation_info else None, + "logprobs": + generation.generation_info.get("logprobs") + if generation.generation_info else None, + }) else: - response = completion_with_retry( - self, prompt=_prompts, run_manager=run_manager, **params - ) + response = completion_with_retry(self, + prompt=_prompts, + run_manager=run_manager, + **params) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -414,32 +409,32 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts.") generation: Optional[GenerationChunk] = None - async for chunk in self._astream( - _prompts[0], stop, run_manager, **kwargs - ): + async for chunk in self._astream(_prompts[0], stop, run_manager, + **kwargs): if generation is None: generation = chunk else: generation += chunk assert generation is not None - choices.append( - { - "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, - } - ) + choices.append({ + "text": + generation.text, + "finish_reason": + generation.generation_info.get("finish_reason") + if generation.generation_info else None, + "logprobs": + generation.generation_info.get("logprobs") + if generation.generation_info else None, + }) else: - response = await acompletion_with_retry( - self, prompt=_prompts, run_manager=run_manager, **params - ) + response = await acompletion_with_retry(self, + prompt=_prompts, + run_manager=run_manager, + **params) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) @@ -453,39 +448,35 @@ class BaseOpenAI(BaseLLM): """Get the sub prompts for llm call.""" if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params.") params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: raise ValueError( - "max_tokens set to -1 not supported for multiple inputs." - ) + "max_tokens set to -1 not supported for multiple inputs.") params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) sub_prompts = [ - prompts[i : i + self.batch_size] + prompts[i:i + self.batch_size] for i in range(0, len(prompts), self.batch_size) ] return sub_prompts - def create_llm_result( - self, choices: Any, prompts: List[str], token_usage: Dict[str, int] - ) -> LLMResult: + def create_llm_result(self, choices: Any, prompts: List[str], + token_usage: Dict[str, int]) -> LLMResult: """Create the LLMResult from the choices and prompts.""" generations = [] for i, _ in enumerate(prompts): - sub_choices = choices[i * self.n : (i + 1) * self.n] - generations.append( - [ - Generation( - text=choice["text"], - generation_info=dict( - finish_reason=choice.get("finish_reason"), - logprobs=choice.get("logprobs"), - ), - ) - for choice in sub_choices - ] - ) + sub_choices = choices[i * self.n:(i + 1) * self.n] + generations.append([ + Generation( + text=choice["text"], + generation_info=dict( + finish_reason=choice.get("finish_reason"), + logprobs=choice.get("logprobs"), + ), + ) for choice in sub_choices + ]) llm_output = {"token_usage": token_usage, "model_name": self.model_name} return LLMResult(generations=generations, llm_output=llm_output) @@ -500,7 +491,10 @@ class BaseOpenAI(BaseLLM): if self.openai_proxy: import openai - openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 + openai.proxy = { + "http": self.openai_proxy, + "https": self.openai_proxy + } # type: ignore[assignment] # noqa: E501 return {**openai_creds, **self._default_params} @property @@ -524,14 +518,14 @@ class BaseOpenAI(BaseLLM): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`." - ) + "Please install it with `pip install tiktoken`.") model_name = self.tiktoken_model_name or self.model_name try: enc = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" enc = tiktoken.get_encoding(model) @@ -593,9 +587,7 @@ class BaseOpenAI(BaseLLM): if context_size is None: raise ValueError( f"Unknown model: {modelname}. Please provide a valid OpenAI model name." - "Known models are: " - + ", ".join(model_token_mapping.keys()) - ) + "Known models are: " + ", ".join(model_token_mapping.keys())) return context_size @@ -673,14 +665,15 @@ class AzureOpenAI(BaseOpenAI): "OPENAI_API_VERSION", ) values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", "azure" - ) + values, "openai_api_type", "OPENAI_API_TYPE", "azure") return values @property def _identifying_params(self) -> Mapping[str, Any]: return { - **{"deployment_name": self.deployment_name}, + **{ + "deployment_name": self.deployment_name + }, **super()._identifying_params, } @@ -745,7 +738,9 @@ class OpenAIChat(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = { + field.alias for field in cls.__fields__.values() + } extra = values.get("model_kwargs", {}) for field_name in list(values): @@ -759,9 +754,8 @@ class OpenAIChat(BaseLLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) + openai_api_key = get_from_dict_or_env(values, "openai_api_key", + "OPENAI_API_KEY") openai_api_base = get_from_dict_or_env( values, "openai_api_base", @@ -774,9 +768,10 @@ class OpenAIChat(BaseLLM): "OPENAI_PROXY", default="", ) - openai_organization = get_from_dict_or_env( - values, "openai_organization", "OPENAI_ORGANIZATION", default="" - ) + openai_organization = get_from_dict_or_env(values, + "openai_organization", + "OPENAI_ORGANIZATION", + default="") try: import openai @@ -786,20 +781,20 @@ class OpenAIChat(BaseLLM): if openai_organization: openai.organization = openai_organization if openai_proxy: - openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 + openai.proxy = { + "http": openai_proxy, + "https": openai_proxy + } # type: ignore[assignment] # noqa: E501 except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) + raise ImportError("Could not import openai python package. " + "Please install it with `pip install openai`.") try: values["client"] = openai.ChatCompletion except AttributeError: raise ValueError( "`openai` has no `ChatCompletion` attribute, this is likely " "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) + "with `pip install --upgrade openai`.") return values @property @@ -807,18 +802,27 @@ class OpenAIChat(BaseLLM): """Get the default parameters for calling OpenAI API.""" return self.model_kwargs - def _get_chat_params( - self, prompts: List[str], stop: Optional[List[str]] = None - ) -> Tuple: + def _get_chat_params(self, + prompts: List[str], + stop: Optional[List[str]] = None) -> Tuple: if len(prompts) > 1: raise ValueError( f"OpenAIChat currently only supports single prompt, got {prompts}" ) - messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] - params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} + messages = self.prefix_messages + [{ + "role": "user", + "content": prompts[0] + }] + params: Dict[str, Any] = { + **{ + "model": self.model_name + }, + **self._default_params + } if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params.") params["stop"] = stop if params.get("max_tokens") == -1: # for ChatGPT api, omitting max_tokens is equivalent to having no limit @@ -834,9 +838,10 @@ class OpenAIChat(BaseLLM): ) -> Iterator[GenerationChunk]: messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} - for stream_resp in completion_with_retry( - self, messages=messages, run_manager=run_manager, **params - ): + for stream_resp in completion_with_retry(self, + messages=messages, + run_manager=run_manager, + **params): token = stream_resp["choices"][0]["delta"].get("content", "") chunk = GenerationChunk(text=token) yield chunk @@ -853,8 +858,7 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params([prompt], stop) params = {**params, **kwargs, "stream": True} async for stream_resp in await acompletion_with_retry( - self, messages=messages, run_manager=run_manager, **params - ): + self, messages=messages, run_manager=run_manager, **params): token = stream_resp["choices"][0]["delta"].get("content", "") chunk = GenerationChunk(text=token) yield chunk @@ -880,17 +884,19 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = completion_with_retry( - self, messages=messages, run_manager=run_manager, **params - ) + full_response = completion_with_retry(self, + messages=messages, + run_manager=run_manager, + **params) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, } return LLMResult( - generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] - ], + generations=[[ + Generation( + text=full_response["choices"][0]["message"]["content"]) + ]], llm_output=llm_output, ) @@ -903,7 +909,8 @@ class OpenAIChat(BaseLLM): ) -> LLMResult: if self.streaming: generation: Optional[GenerationChunk] = None - async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + async for chunk in self._astream(prompts[0], stop, run_manager, + **kwargs): if generation is None: generation = chunk else: @@ -913,17 +920,19 @@ class OpenAIChat(BaseLLM): messages, params = self._get_chat_params(prompts, stop) params = {**params, **kwargs} - full_response = await acompletion_with_retry( - self, messages=messages, run_manager=run_manager, **params - ) + full_response = await acompletion_with_retry(self, + messages=messages, + run_manager=run_manager, + **params) llm_output = { "token_usage": full_response["usage"], "model_name": self.model_name, } return LLMResult( - generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] - ], + generations=[[ + Generation( + text=full_response["choices"][0]["message"]["content"]) + ]], llm_output=llm_output, ) @@ -948,8 +957,7 @@ class OpenAIChat(BaseLLM): raise ImportError( "Could not import tiktoken python package. " "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`." - ) + "Please install it with `pip install tiktoken`.") enc = tiktoken.encoding_for_model(self.model_name) return enc.encode( diff --git a/swarms/models/openai_tokenizer.py b/swarms/models/openai_tokenizer.py index 9ff1fa08..26ec9221 100644 --- a/swarms/models/openai_tokenizer.py +++ b/swarms/models/openai_tokenizer.py @@ -71,16 +71,15 @@ class OpenAITokenizer(BaseTokenizer): @property def max_tokens(self) -> int: - tokens = next( - v - for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() - if self.model.startswith(k) - ) + tokens = next(v for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() + if self.model.startswith(k)) offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset - def count_tokens(self, text: str | list, model: Optional[str] = None) -> int: + def count_tokens(self, + text: str | list, + model: Optional[str] = None) -> int: """ Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -96,12 +95,12 @@ class OpenAITokenizer(BaseTokenizer): encoding = tiktoken.get_encoding("cl100k_base") if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", }: tokens_per_message = 3 tokens_per_name = 1 @@ -113,21 +112,18 @@ class OpenAITokenizer(BaseTokenizer): elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: logging.info( "gpt-3.5-turbo may update over time. Returning num tokens assuming" - " gpt-3.5-turbo-0613." - ) + " gpt-3.5-turbo-0613.") return self.count_tokens(text, model="gpt-3.5-turbo-0613") elif "gpt-4" in model: logging.info( "gpt-4 may update over time. Returning num tokens assuming" - " gpt-4-0613." - ) + " gpt-4-0613.") return self.count_tokens(text, model="gpt-4-0613") else: raise NotImplementedError( f"""token_count() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for - information on how messages are converted to tokens.""" - ) + information on how messages are converted to tokens.""") num_tokens = 0 @@ -144,5 +140,5 @@ class OpenAITokenizer(BaseTokenizer): return num_tokens else: return len( - self.encoding.encode(text, allowed_special=set(self.stop_sequences)) - ) + self.encoding.encode(text, + allowed_special=set(self.stop_sequences))) diff --git a/swarms/models/palm.py b/swarms/models/palm.py index ec8aafd6..c551c288 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -26,8 +26,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]: except ImportError: raise ImportError( "Could not import google-api-core python package. " - "Please install it with `pip install google-api-core`." - ) + "Please install it with `pip install google-api-core`.") multiplier = 2 min_seconds = 1 @@ -37,12 +36,15 @@ def _create_retry_decorator() -> Callable[[Any], Any]: return retry( reraise=True, stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) - | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) - | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) - ), + wait=wait_exponential(multiplier=multiplier, + min=min_seconds, + max=max_seconds), + retry=(retry_if_exception_type( + google.api_core.exceptions.ResourceExhausted) | + retry_if_exception_type( + google.api_core.exceptions.ServiceUnavailable) | + retry_if_exception_type( + google.api_core.exceptions.GoogleAPIError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -64,7 +66,8 @@ def _strip_erroneous_leading_spaces(text: str) -> str: The PaLM API will sometimes erroneously return a single leading space in all lines > 1. This function strips that space. """ - has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) + has_leading_space = all( + not line or line[0] == " " for line in text.split("\n")[1:]) if has_leading_space: return text.replace("\n ", "\n") else: @@ -97,9 +100,8 @@ class GooglePalm(BaseLLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) + google_api_key = get_from_dict_or_env(values, "google_api_key", + "GOOGLE_API_KEY") try: import google.generativeai as genai @@ -107,12 +109,12 @@ class GooglePalm(BaseLLM, BaseModel): except ImportError: raise ImportError( "Could not import google-generativeai python package. " - "Please install it with `pip install google-generativeai`." - ) + "Please install it with `pip install google-generativeai`.") values["client"] = genai - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + if values["temperature"] is not None and not 0 <= values[ + "temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: @@ -121,7 +123,8 @@ class GooglePalm(BaseLLM, BaseModel): if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") - if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: + if values["max_output_tokens"] is not None and values[ + "max_output_tokens"] <= 0: raise ValueError("max_output_tokens must be greater than zero") return values diff --git a/swarms/models/pegasus.py b/swarms/models/pegasus.py index e388d40c..c2571f72 100644 --- a/swarms/models/pegasus.py +++ b/swarms/models/pegasus.py @@ -33,9 +33,10 @@ class PegasusEmbedding: """ - def __init__( - self, modality: str, multi_process: bool = False, n_processes: int = 4 - ): + def __init__(self, + modality: str, + multi_process: bool = False, + n_processes: int = 4): self.modality = modality self.multi_process = multi_process self.n_processes = n_processes @@ -43,8 +44,7 @@ class PegasusEmbedding: self.pegasus = Pegasus(modality, multi_process, n_processes) except Exception as e: logging.error( - f"Failed to initialize Pegasus with modality: {modality}: {e}" - ) + f"Failed to initialize Pegasus with modality: {modality}: {e}") raise def embed(self, data: Union[str, list[str]]): diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 7eb923b4..fbb7c066 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -21,6 +21,4 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): return openai.Embedding.create( input=[text], model=model, - )["data"][ - 0 - ]["embedding"] + )["data"][0]["embedding"] diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index e98036ac..d1b476b9 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -90,17 +90,17 @@ class SpeechT5: self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset(self.dataset_name, + split="validation") def __call__(self, text: str, speaker_id: float = 7306): """Call the model on some text and return the speech.""" speaker_embedding = torch.tensor( - self.embeddings_dataset[speaker_id]["xvector"] - ).unsqueeze(0) + self.embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0) inputs = self.processor(text=text, return_tensors="pt") - speech = self.model.generate_speech( - inputs["input_ids"], speaker_embedding, vocoder=self.vocoder - ) + speech = self.model.generate_speech(inputs["input_ids"], + speaker_embedding, + vocoder=self.vocoder) return speech def save_speech(self, speech, filename="speech.wav"): @@ -121,7 +121,8 @@ class SpeechT5: def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" self.dataset_name = dataset_name - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset(self.dataset_name, + split="validation") # Feature 1: Get sampling rate def get_sampling_rate(self): diff --git a/swarms/models/timm.py b/swarms/models/timm.py index 5d9b965a..5b17c76c 100644 --- a/swarms/models/timm.py +++ b/swarms/models/timm.py @@ -50,9 +50,8 @@ class TimmModel: in_chans=model_info.in_chans, ) - def __call__( - self, model_info: TimmModelInfo, input_tensor: torch.Tensor - ) -> torch.Size: + def __call__(self, model_info: TimmModelInfo, + input_tensor: torch.Tensor) -> torch.Size: """ Create and run a model specified by `model_info` on `input_tensor`. diff --git a/swarms/models/trocr.py b/swarms/models/trocr.py index f4a4156d..1b9e72e7 100644 --- a/swarms/models/trocr.py +++ b/swarms/models/trocr.py @@ -10,9 +10,8 @@ import requests class TrOCR: - def __init__( - self, - ): + + def __init__(self,): pass def __call__(self): diff --git a/swarms/models/vilt.py b/swarms/models/vilt.py index f95d265c..4725a317 100644 --- a/swarms/models/vilt.py +++ b/swarms/models/vilt.py @@ -23,11 +23,9 @@ class Vilt: def __init__(self): self.processor = ViltProcessor.from_pretrained( - "dandelin/vilt-b32-finetuned-vqa" - ) + "dandelin/vilt-b32-finetuned-vqa") self.model = ViltForQuestionAnswering.from_pretrained( - "dandelin/vilt-b32-finetuned-vqa" - ) + "dandelin/vilt-b32-finetuned-vqa") def __call__(self, text: str, image_url: str): """ diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index 49ffb70d..929fe10e 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -33,7 +33,8 @@ class WizardLLMStoryTeller: def __init__( self, - model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF", + model_id: + str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF", device: str = None, max_length: int = 500, quantize: bool = False, @@ -44,9 +45,8 @@ class WizardLLMStoryTeller: decoding=False, ): self.logger = logging.getLogger(__name__) - self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") - ) + self.device = (device if device else + ("cuda" if torch.cuda.is_available() else "cpu")) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -56,9 +56,8 @@ class WizardLLMStoryTeller: # self.log = Logging() if self.distributed: - assert ( - torch.cuda.device_count() > 1 - ), "You need more than 1 gpu for distributed processing" + assert (torch.cuda.device_count() > + 1), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -74,8 +73,7 @@ class WizardLLMStoryTeller: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config - ) + self.model_id, quantization_config=bnb_config) self.model # .to(self.device) except Exception as e: @@ -88,20 +86,18 @@ class WizardLLMStoryTeller: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = ( - BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config - else None - ) + bnb_config = (BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config else None) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config - ).to(self.device) + self.model_id, + quantization_config=bnb_config).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}") raise def run(self, prompt_text: str): @@ -120,9 +116,8 @@ class WizardLLMStoryTeller: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode(prompt_text, + return_tensors="pt").to(self.device) # self.log.start() @@ -131,26 +126,26 @@ class WizardLLMStoryTeller: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) @@ -174,9 +169,8 @@ class WizardLLMStoryTeller: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode(prompt_text, + return_tensors="pt").to(self.device) # self.log.start() @@ -185,26 +179,26 @@ class WizardLLMStoryTeller: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index ebe107a2..e3120e20 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -44,9 +44,8 @@ class YarnMistral128: decoding=False, ): self.logger = logging.getLogger(__name__) - self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") - ) + self.device = (device if device else + ("cuda" if torch.cuda.is_available() else "cpu")) self.model_id = model_id self.max_length = max_length self.verbose = verbose @@ -56,9 +55,8 @@ class YarnMistral128: # self.log = Logging() if self.distributed: - assert ( - torch.cuda.device_count() > 1 - ), "You need more than 1 gpu for distributed processing" + assert (torch.cuda.device_count() > + 1), "You need more than 1 gpu for distributed processing" bnb_config = None if quantize: @@ -93,20 +91,18 @@ class YarnMistral128: try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - bnb_config = ( - BitsAndBytesConfig(**self.quantization_config) - if self.quantization_config - else None - ) + bnb_config = (BitsAndBytesConfig(**self.quantization_config) + if self.quantization_config else None) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config - ).to(self.device) + self.model_id, + quantization_config=bnb_config).to(self.device) if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}") raise def run(self, prompt_text: str): @@ -125,9 +121,8 @@ class YarnMistral128: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode(prompt_text, + return_tensors="pt").to(self.device) # self.log.start() @@ -136,26 +131,26 @@ class YarnMistral128: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) @@ -202,9 +197,8 @@ class YarnMistral128: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode(prompt_text, + return_tensors="pt").to(self.device) # self.log.start() @@ -213,26 +207,26 @@ class YarnMistral128: for _ in range(max_length): output_sequence = [] - outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=len(inputs) + + 1, + do_sample=True) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) # print token in real-time print( - self.tokenizer.decode( - [output_tokens], skip_special_tokens=True - ), + self.tokenizer.decode([output_tokens], + skip_special_tokens=True), end="", flush=True, ) inputs = outputs else: with torch.no_grad(): - outputs = self.model.generate( - inputs, max_length=max_length, do_sample=True - ) + outputs = self.model.generate(inputs, + max_length=max_length, + do_sample=True) del inputs diff --git a/swarms/models/zephyr.py b/swarms/models/zephyr.py index f75945ea..0ed23f19 100644 --- a/swarms/models/zephyr.py +++ b/swarms/models/zephyr.py @@ -28,7 +28,8 @@ class Zephyr: model_name: str = "HuggingFaceH4/zephyr-7b-alpha", tokenize: bool = False, add_generation_prompt: bool = True, - system_prompt: str = "You are a friendly chatbot who always responds in the style of a pirate", + system_prompt: + str = "You are a friendly chatbot who always responds in the style of a pirate", max_new_tokens: int = 300, temperature: float = 0.5, top_k: float = 50, @@ -70,7 +71,7 @@ class Zephyr: ) outputs = self.pipe(prompt) # max_new_token=self.max_new_tokens) print(outputs[0]["generated_text"]) - + def chat(self, message: str): """ Adds a user message to the conversation and generates a chatbot response. diff --git a/swarms/prompts/agent_output_parser.py b/swarms/prompts/agent_output_parser.py index 27f8ac24..e00db22d 100644 --- a/swarms/prompts/agent_output_parser.py +++ b/swarms/prompts/agent_output_parser.py @@ -24,9 +24,8 @@ class AgentOutputParser(BaseAgentOutputParser): @staticmethod def _preprocess_json_input(input_str: str) -> str: - corrected_str = re.sub( - r'(? dict: diff --git a/swarms/prompts/agent_prompt.py b/swarms/prompts/agent_prompt.py index c4897193..aa84ebf8 100644 --- a/swarms/prompts/agent_prompt.py +++ b/swarms/prompts/agent_prompt.py @@ -13,13 +13,23 @@ class PromptGenerator: self.performance_evaluation: List[str] = [] self.response_format = { "thoughts": { - "text": "thought", - "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", - "criticism": "constructive self-criticism", - "speak": "thoughts summary to say to user", + "text": + "thought", + "reasoning": + "reasoning", + "plan": + "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": + "constructive self-criticism", + "speak": + "thoughts summary to say to user", + }, + "command": { + "name": "command name", + "args": { + "arg name": "value" + } }, - "command": {"name": "command name", "args": {"arg name": "value"}}, } def add_constraint(self, constraint: str) -> None: @@ -72,7 +82,6 @@ class PromptGenerator: f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" "You should only respond in JSON format as described below " f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads" - ) + "\nEnsure the response can be parsed by Python json.loads") return prompt_string diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index 8d145fc0..3de5bcb2 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -7,25 +7,21 @@ def generate_agent_role_prompt(agent): "Finance Agent": ( "You are a seasoned finance analyst AI assistant. Your primary goal is to" " compose comprehensive, astute, impartial, and methodically arranged" - " financial reports based on provided data and trends." - ), + " financial reports based on provided data and trends."), "Travel Agent": ( "You are a world-travelled AI tour guide assistant. Your main purpose is to" " draft engaging, insightful, unbiased, and well-structured travel reports" " on given locations, including history, attractions, and cultural" - " insights." - ), + " insights."), "Academic Research Agent": ( "You are an AI academic research assistant. Your primary responsibility is" " to create thorough, academically rigorous, unbiased, and systematically" " organized reports on a given research topic, following the standards of" - " scholarly work." - ), + " scholarly work."), "Default Agent": ( "You are an AI critical thinker research assistant. Your sole purpose is to" " write well written, critically acclaimed, objective and structured" - " reports on given text." - ), + " reports on given text."), } return prompts.get(agent, "No such agent") @@ -44,8 +40,7 @@ def generate_report_prompt(question, research_summary): " focus on the answer to the question, should be well structured, informative," " in depth, with facts and numbers if available, a minimum of 1,200 words and" " with markdown syntax and apa format. Write all source urls at the end of the" - " report in apa format" - ) + " report in apa format") def generate_search_queries_prompt(question): @@ -57,8 +52,7 @@ def generate_search_queries_prompt(question): return ( "Write 4 google search queries to search online that form an objective opinion" f' from the following: "{question}"You must respond with a list of strings in' - ' the following format: ["query 1", "query 2", "query 3", "query 4"]' - ) + ' the following format: ["query 1", "query 2", "query 3", "query 4"]') def generate_resource_report_prompt(question, research_summary): @@ -80,8 +74,7 @@ def generate_resource_report_prompt(question, research_summary): " significance of each source. Ensure that the report is well-structured," " informative, in-depth, and follows Markdown syntax. Include relevant facts," " figures, and numbers whenever available. The report should have a minimum" - " length of 1,200 words." - ) + " length of 1,200 words.") def generate_outline_report_prompt(question, research_summary): @@ -98,8 +91,7 @@ def generate_outline_report_prompt(question, research_summary): " research report, including the main sections, subsections, and key points to" " be covered. The research report should be detailed, informative, in-depth," " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" - " outline and ensure readability." - ) + " outline and ensure readability.") def generate_concepts_prompt(question, research_summary): @@ -114,8 +106,7 @@ def generate_concepts_prompt(question, research_summary): " main concepts to learn for a research report on the following question or" f' topic: "{question}". The outline should provide a well-structured' " frameworkYou must respond with a list of strings in the following format:" - ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' - ) + ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]') def generate_lesson_prompt(concept): @@ -131,8 +122,7 @@ def generate_lesson_prompt(concept): f"generate a comprehensive lesson about {concept} in Markdown syntax. This" f" should include the definitionof {concept}, its historical background and" " development, its applications or uses in differentfields, and notable events" - f" or facts related to {concept}." - ) + f" or facts related to {concept}.") return prompt diff --git a/swarms/prompts/base.py b/swarms/prompts/base.py index 54a0bc3f..8bb77236 100644 --- a/swarms/prompts/base.py +++ b/swarms/prompts/base.py @@ -11,9 +11,9 @@ if TYPE_CHECKING: from langchain.prompts.chat import ChatPromptTemplate -def get_buffer_string( - messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: +def get_buffer_string(messages: Sequence[BaseMessage], + human_prefix: str = "Human", + ai_prefix: str = "AI") -> str: """Convert sequence of Messages to strings and concatenate them into one string. Args: @@ -88,9 +88,9 @@ class BaseMessage(Serializable): class BaseMessageChunk(BaseMessage): - def _merge_kwargs_dict( - self, left: Dict[str, Any], right: Dict[str, Any] - ) -> Dict[str, Any]: + + def _merge_kwargs_dict(self, left: Dict[str, Any], + right: Dict[str, Any]) -> Dict[str, Any]: """Merge additional_kwargs from another BaseMessageChunk into this one.""" merged = left.copy() for k, v in right.items(): @@ -99,8 +99,7 @@ class BaseMessageChunk(BaseMessage): elif not isinstance(merged[k], type(v)): raise ValueError( f'additional_kwargs["{k}"] already exists in this message,' - " but with a different type." - ) + " but with a different type.") elif isinstance(merged[k], str): merged[k] += v elif isinstance(merged[k], dict): @@ -119,15 +118,12 @@ class BaseMessageChunk(BaseMessage): return self.__class__( content=self.content + other.content, additional_kwargs=self._merge_kwargs_dict( - self.additional_kwargs, other.additional_kwargs - ), + self.additional_kwargs, other.additional_kwargs), ) else: - raise TypeError( - 'unsupported operand type(s) for +: "' - f"{self.__class__.__name__}" - f'" and "{other.__class__.__name__}"' - ) + raise TypeError('unsupported operand type(s) for +: "' + f"{self.__class__.__name__}" + f'" and "{other.__class__.__name__}"') class HumanMessage(BaseMessage): diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index b0330e24..5f48488f 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -66,9 +66,10 @@ class SystemMessage(Message): of input messages. """ - def __init__( - self, content: str, role: str = "System", additional_kwargs: Dict = None - ): + def __init__(self, + content: str, + role: str = "System", + additional_kwargs: Dict = None): super().__init__(content, role, additional_kwargs) def get_type(self) -> str: @@ -106,9 +107,9 @@ class ChatMessage(Message): return "chat" -def get_buffer_string( - messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: +def get_buffer_string(messages: Sequence[Message], + human_prefix: str = "Human", + ai_prefix: str = "AI") -> str: string_messages = [] for m in messages: message = f"{m.role}: {m.content}" diff --git a/swarms/prompts/debate.py b/swarms/prompts/debate.py index a11c7af4..5a6be762 100644 --- a/swarms/prompts/debate.py +++ b/swarms/prompts/debate.py @@ -38,7 +38,6 @@ def debate_monitor(game_description, word_limit, character_names): return prompt -def generate_character_header( - game_description, topic, character_name, character_description -): +def generate_character_header(game_description, topic, character_name, + character_description): pass diff --git a/swarms/prompts/multi_modal_prompts.py b/swarms/prompts/multi_modal_prompts.py index f558c3c4..dc2bccd5 100644 --- a/swarms/prompts/multi_modal_prompts.py +++ b/swarms/prompts/multi_modal_prompts.py @@ -1,7 +1,6 @@ ERROR_PROMPT = ( "An error has occurred for the following text: \n{promptedQuery} Please explain" - " this error.\n {e}" -) + " this error.\n {e}") IMAGE_PROMPT = """ provide a figure named {filename}. The description is: {description}. @@ -12,7 +11,6 @@ USER INPUT ============ """ - AUDIO_PROMPT = """ provide a audio named {filename}. The description is: {description}. @@ -41,7 +39,6 @@ USER INPUT ============ """ - EVAL_PREFIX = """{bot_name} can execute any user's request. {bot_name} has permission to handle one instance and can handle the environment in it at will. diff --git a/swarms/prompts/python.py b/swarms/prompts/python.py index 9d1f4a1e..cd34e9bd 100644 --- a/swarms/prompts/python.py +++ b/swarms/prompts/python.py @@ -3,30 +3,25 @@ PY_REFLEXION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given your past function" " implementation, a series of unit tests, and a hint to change the implementation" " appropriately. Write your full implementation (restate the function" - " signature).\n\n-----" -) + " signature).\n\n-----") PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given a function implementation" " and a series of unit tests. Your goal is to write a few sentences to explain why" " your implementation is wrong as indicated by the tests. You will need this as a" " hint when you try again later. Only provide the few sentence description in your" - " answer, not the implementation.\n\n-----" -) + " answer, not the implementation.\n\n-----") USE_PYTHON_CODEBLOCK_INSTRUCTION = ( "Use a Python code block to write your response. For" - " example:\n```python\nprint('Hello world!')\n```" -) + " example:\n```python\nprint('Hello world!')\n```") PY_SIMPLE_CHAT_INSTRUCTION = ( "You are an AI that only responds with python code, NOT ENGLISH. You will be given" " a function signature and its docstring by the user. Write your full" - " implementation (restate the function signature)." -) + " implementation (restate the function signature).") PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( "You are an AI that only responds with only python code. You will be given a" " function signature and its docstring by the user. Write your full implementation" - " (restate the function signature)." -) + " (restate the function signature).") PY_REFLEXION_CHAT_INSTRUCTION = ( "You are an AI Python assistant. You will be given your past function" " implementation, a series of unit tests, and a hint to change the implementation" @@ -36,8 +31,7 @@ PY_REFLEXION_CHAT_INSTRUCTION_V2 = ( "You are an AI Python assistant. You will be given your previous implementation of" " a function, a series of unit tests results, and your self-reflection on your" " previous implementation. Write your full implementation (restate the function" - " signature)." -) + " signature).") PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: [previous impl]: ```python @@ -175,16 +169,14 @@ PY_SELF_REFLECTION_CHAT_INSTRUCTION = ( " implementation and a series of unit tests. Your goal is to write a few sentences" " to explain why your implementation is wrong as indicated by the tests. You will" " need this as a hint when you try again later. Only provide the few sentence" - " description in your answer, not the implementation." -) + " description in your answer, not the implementation.") PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( "You are a Python programming assistant. You will be given a function" " implementation and a series of unit test results. Your goal is to write a few" " sentences to explain why your implementation is wrong as indicated by the tests." " You will need this as guidance when you try again later. Only provide the few" " sentence description in your answer, not the implementation. You will be given a" - " few examples by the user." -) + " few examples by the user.") PY_SELF_REFLECTION_FEW_SHOT = """Example 1: [function impl]: ```python diff --git a/swarms/prompts/sales.py b/swarms/prompts/sales.py index 6c945332..6660e084 100644 --- a/swarms/prompts/sales.py +++ b/swarms/prompts/sales.py @@ -3,39 +3,31 @@ conversation_stages = { "Introduction: Start the conversation by introducing yourself and your company." " Be polite and respectful while keeping the tone of the conversation" " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." - ), + " greeting the reason why you are contacting the prospect."), "2": ( "Qualification: Qualify the prospect by confirming if they are the right person" " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." - ), + " authority to make purchasing decisions."), "3": ( "Value proposition: Briefly explain how your product/service can benefit the" " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." - ), + " product/service that sets it apart from competitors."), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." - ), - "5": ( - "Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." - ), - "6": ( - "Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims." - ), + " pain points. Listen carefully to their responses and take notes."), + "5": ("Solution presentation: Based on the prospect's needs, present your" + " product/service as the solution that can address their pain points." + ), + "6": + ("Objection handling: Address any objections that the prospect may have" + " regarding your product/service. Be prepared to provide evidence or" + " testimonials to support your claims."), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." - ), + " discussed and reiterate the benefits."), } - SALES_AGENT_TOOLS_PROMPT = """ Never forget your name is {salesperson_name}. You work as a {salesperson_role}. You work at company named {company_name}. {company_name}'s business is the following: {company_business}. diff --git a/swarms/prompts/sales_prompts.py b/swarms/prompts/sales_prompts.py index ec4ef168..ce5303b3 100644 --- a/swarms/prompts/sales_prompts.py +++ b/swarms/prompts/sales_prompts.py @@ -20,7 +20,6 @@ The answer needs to be one number only, no words. If there is no conversation history, output 1. Do not answer anything else nor add anything to you answer.""" - SALES = """Never forget your name is {salesperson_name}. You work as a {salesperson_role}. You work at company named {company_name}. {company_name}'s business is the following: {company_business} Company values are the following. {company_values} @@ -50,34 +49,27 @@ conversation_stages = { "Introduction: Start the conversation by introducing yourself and your company." " Be polite and respectful while keeping the tone of the conversation" " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." - ), + " greeting the reason why you are contacting the prospect."), "2": ( "Qualification: Qualify the prospect by confirming if they are the right person" " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." - ), + " authority to make purchasing decisions."), "3": ( "Value proposition: Briefly explain how your product/service can benefit the" " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." - ), + " product/service that sets it apart from competitors."), "4": ( "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." - ), - "5": ( - "Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." - ), - "6": ( - "Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims." - ), + " pain points. Listen carefully to their responses and take notes."), + "5": ("Solution presentation: Based on the prospect's needs, present your" + " product/service as the solution that can address their pain points." + ), + "6": + ("Objection handling: Address any objections that the prospect may have" + " regarding your product/service. Be prepared to provide evidence or" + " testimonials to support your claims."), "7": ( "Close: Ask for the sale by proposing a next step. This could be a demo, a" " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." - ), + " discussed and reiterate the benefits."), } diff --git a/swarms/prompts/summaries_prompts.py b/swarms/prompts/summaries_prompts.py index 01c4c502..646d1ba0 100644 --- a/swarms/prompts/summaries_prompts.py +++ b/swarms/prompts/summaries_prompts.py @@ -10,7 +10,6 @@ summary. Pick a suitable emoji for every bullet point. Your response should be i a YouTube video, use the following text: {{CONTENT}}. """ - SUMMARIZE_PROMPT_2 = """ Provide a very short summary, no more than three sentences, for the following article: @@ -25,7 +24,6 @@ Summary: """ - SUMMARIZE_PROMPT_3 = """ Provide a TL;DR for the following article: @@ -39,7 +37,6 @@ Instead of computing on the individual qubits themselves, we will then compute o TL;DR: """ - SUMMARIZE_PROMPT_4 = """ Provide a very short summary in four bullet points for the following article: @@ -54,7 +51,6 @@ Bulletpoints: """ - SUMMARIZE_PROMPT_5 = """ Please generate a summary of the following conversation and at the end summarize the to-do's for the support Agent: diff --git a/swarms/schemas/typings.py b/swarms/schemas/typings.py index d281a870..f59b16f7 100644 --- a/swarms/schemas/typings.py +++ b/swarms/schemas/typings.py @@ -7,7 +7,6 @@ import platform from enum import Enum from typing import Union - python_version = list(platform.python_version_tuple()) SUPPORT_ADD_NOTES = int(python_version[0]) >= 3 and int(python_version[1]) >= 11 @@ -19,13 +18,11 @@ class ChatbotError(Exception): def __init__(self, *args: object) -> None: if SUPPORT_ADD_NOTES: + super().add_note(( + "Please check that the input is correct, or you can resolve this" + " issue by filing an issue"),) super().add_note( - ( - "Please check that the input is correct, or you can resolve this" - " issue by filing an issue" - ), - ) - super().add_note("Project URL: https://github.com/acheong08/ChatGPT") + "Project URL: https://github.com/acheong08/ChatGPT") super().__init__(*args) diff --git a/swarms/structs/document.py b/swarms/structs/document.py index b87d3d91..505df6ae 100644 --- a/swarms/structs/document.py +++ b/swarms/structs/document.py @@ -63,9 +63,8 @@ class BaseDocumentTransformer(ABC): """ # noqa: E501 @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], + **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -75,9 +74,8 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], + **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: @@ -87,5 +85,4 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents - ) + None, partial(self.transform_documents, **kwargs), documents) diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index 7be03036..a7a19258 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -19,14 +19,12 @@ from termcolor import colored import inspect import random - # Prompts DYNAMIC_STOP_PROMPT = """ When you have finished the task from the Human, output a special token: This will enable you to leave the autonomous loop. """ - # Constants FLOW_SYSTEM_PROMPT = f""" You are an autonomous agent granted autonomy from a Flow structure. @@ -40,7 +38,6 @@ to aid in these complex tasks. Your responses should be coherent, contextually r """ - # Utility functions @@ -184,8 +181,7 @@ class Flow: value = self.llm.__dict__.get(name, "Unknown") params_str_list.append( - f" {name.capitalize().replace('_', ' ')}: {value}" - ) + f" {name.capitalize().replace('_', ' ')}: {value}") return "\n".join(params_str_list) @@ -193,7 +189,7 @@ class Flow: """ Take the history and truncate it to fit into the model context length """ - truncated_history = self.memory[-1][-self.context_length :] + truncated_history = self.memory[-1][-self.context_length:] self.memory[-1] = truncated_history def add_task_to_memory(self, task: str): @@ -243,8 +239,7 @@ class Flow: ---------------------------------------- """, "green", - ) - ) + )) # print(dashboard) @@ -254,18 +249,17 @@ class Flow: print(colored("Initializing Autonomous Agent...", "yellow")) # print(colored("Loading modules...", "yellow")) # print(colored("Modules loaded successfully.", "green")) - print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"])) - print(colored("All systems operational. Executing task...", "green")) + print(colored("Autonomous Agent Activated.", "cyan", + attrs=["bold"])) + print(colored("All systems operational. Executing task...", + "green")) except Exception as error: print( colored( - ( - "Error activating autonomous agent. Try optimizing your" - " parameters..." - ), + ("Error activating autonomous agent. Try optimizing your" + " parameters..."), "red", - ) - ) + )) print(error) def run(self, task: str, **kwargs): @@ -307,7 +301,8 @@ class Flow: for i in range(self.max_loops): print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue")) print("\n") - if self._check_stopping_condition(response) or parse_done_token(response): + if self._check_stopping_condition(response) or parse_done_token( + response): break # Adjust temperature, comment if no work @@ -351,7 +346,6 @@ class Flow: async def arun(self, task: str, **kwargs): """Async run""" pass - """ Run the autonomous agent loop @@ -387,7 +381,8 @@ class Flow: for i in range(self.max_loops): print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue")) print("\n") - if self._check_stopping_condition(response) or parse_done_token(response): + if self._check_stopping_condition(response) or parse_done_token( + response): break # Adjust temperature, comment if no work @@ -565,7 +560,9 @@ class Flow: import boto3 s3 = boto3.client("s3") - s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory)) + s3.put_object(Bucket=bucket_name, + Key=object_name, + Body=json.dumps(self.memory)) print(f"Backed up memory to S3: {bucket_name}/{object_name}") def analyze_feedback(self): @@ -684,8 +681,8 @@ class Flow: if hasattr(self.llm, name): value = getattr(self.llm, name) if isinstance( - value, (str, int, float, bool, list, dict, tuple, type(None)) - ): + value, + (str, int, float, bool, list, dict, tuple, type(None))): llm_params[name] = value else: llm_params[name] = str( @@ -745,7 +742,10 @@ class Flow: print(f"Flow state loaded from {file_path}") - def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1): + def retry_on_failure(self, + function, + retries: int = 3, + retry_delay: int = 1): """Retry wrapper for LLM calls.""" attempt = 0 while attempt < retries: diff --git a/swarms/structs/nonlinear_workflow.py b/swarms/structs/nonlinear_workflow.py index 2357f614..140c0d7b 100644 --- a/swarms/structs/nonlinear_workflow.py +++ b/swarms/structs/nonlinear_workflow.py @@ -8,9 +8,10 @@ class Task: Task is a unit of work that can be executed by an agent """ - def __init__( - self, id: str, parents: List["Task"] = None, children: List["Task"] = None - ): + def __init__(self, + id: str, + parents: List["Task"] = None, + children: List["Task"] = None): self.id = id self.parents = parents self.children = children @@ -79,7 +80,8 @@ class NonLinearWorkflow: for task in ordered_tasks: if task.can_execute: - future = self.executor.submit(self.agents.run, task.task_string) + future = self.executor.submit(self.agents.run, + task.task_string) futures_list[future] = task for future in as_completed(futures_list): @@ -95,7 +97,8 @@ class NonLinearWorkflow: def to_graph(self) -> Dict[str, set[str]]: """Convert the workflow to a graph""" graph = { - task.id: set(child.id for child in task.children) for task in self.tasks + task.id: set(child.id for child in task.children) + for task in self.tasks } return graph diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index 8c7d9760..8dd5abbd 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -61,13 +61,12 @@ class Task: if isinstance(self.flow, Flow): # Add a prompt to notify the Flow of the sequential workflow if "prompt" in self.kwargs: - self.kwargs["prompt"] += ( - f"\n\nPrevious output: {self.result}" if self.result else "" - ) + self.kwargs["prompt"] += (f"\n\nPrevious output: {self.result}" + if self.result else "") else: self.kwargs["prompt"] = f"Main task: {self.description}" + ( - f"\n\nPrevious output: {self.result}" if self.result else "" - ) + f"\n\nPrevious output: {self.result}" + if self.result else "") self.result = self.flow.run(*self.args, **self.kwargs) else: self.result = self.flow(*self.args, **self.kwargs) @@ -111,7 +110,8 @@ class SequentialWorkflow: restore_state_filepath: Optional[str] = None dashboard: bool = False - def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None: + def add(self, task: str, flow: Union[Callable, Flow], *args, + **kwargs) -> None: """ Add a task to the workflow. @@ -127,8 +127,7 @@ class SequentialWorkflow: # Append the task to the tasks list self.tasks.append( - Task(description=task, flow=flow, args=list(args), kwargs=kwargs) - ) + Task(description=task, flow=flow, args=list(args), kwargs=kwargs)) def reset_workflow(self) -> None: """Resets the workflow by clearing the results of each task.""" @@ -180,8 +179,9 @@ class SequentialWorkflow: raise ValueError(f"Task {task_description} not found in workflow.") def save_workflow_state( - self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs - ) -> None: + self, + filepath: Optional[str] = "sequential_workflow_state.json", + **kwargs) -> None: """ Saves the workflow state to a json file. @@ -202,16 +202,13 @@ class SequentialWorkflow: with open(filepath, "w") as f: # Saving the state as a json for simplicuty state = { - "tasks": [ - { - "description": task.description, - "args": task.args, - "kwargs": task.kwargs, - "result": task.result, - "history": task.history, - } - for task in self.tasks - ], + "tasks": [{ + "description": task.description, + "args": task.args, + "kwargs": task.kwargs, + "result": task.result, + "history": task.history, + } for task in self.tasks], "max_loops": self.max_loops, } json.dump(state, f, indent=4) @@ -223,8 +220,7 @@ class SequentialWorkflow: Sequential Workflow Initializing...""", "green", attrs=["bold", "underline"], - ) - ) + )) def workflow_dashboard(self, **kwargs) -> None: """ @@ -263,8 +259,7 @@ class SequentialWorkflow: """, "cyan", attrs=["bold", "underline"], - ) - ) + )) def workflow_shutdown(self, **kwargs) -> None: print( @@ -273,8 +268,7 @@ class SequentialWorkflow: Sequential Workflow Shutdown...""", "red", attrs=["bold", "underline"], - ) - ) + )) def add_objective_to_workflow(self, task: str, **kwargs) -> None: print( @@ -283,8 +277,7 @@ class SequentialWorkflow: Adding Objective to Workflow...""", "green", attrs=["bold", "underline"], - ) - ) + )) task = Task( description=task, @@ -349,13 +342,12 @@ class SequentialWorkflow: if "task" not in task.kwargs: raise ValueError( "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" - ) + f" execution in '{task.description}'") # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") - task.result = task.flow.run( - flow_task_arg, *task.args, **task.kwargs - ) + task.result = task.flow.run(flow_task_arg, + *task.args, + **task.kwargs) else: # If it's not a Flow instance, call the flow directly task.result = task.flow(*task.args, **task.kwargs) @@ -373,19 +365,17 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json") except Exception as e: print( colored( - ( - f"Error initializing the Sequential workflow: {e} try" - " optimizing your inputs like the flow class and task" - " description" - ), + (f"Error initializing the Sequential workflow: {e} try" + " optimizing your inputs like the flow class and task" + " description"), "red", attrs=["bold", "underline"], - ) - ) + )) async def arun(self) -> None: """ @@ -405,13 +395,11 @@ class SequentialWorkflow: if "task" not in task.kwargs: raise ValueError( "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" - ) + f" execution in '{task.description}'") # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") task.result = await task.flow.arun( - flow_task_arg, *task.args, **task.kwargs - ) + flow_task_arg, *task.args, **task.kwargs) else: # If it's not a Flow instance, call the flow directly task.result = await task.flow(*task.args, **task.kwargs) @@ -429,4 +417,5 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json") diff --git a/swarms/structs/task.py b/swarms/structs/task.py index 80f95d4d..6824bf0e 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -13,6 +13,7 @@ from swarms.artifacts.error_artifact import ErrorArtifact class BaseTask(ABC): + class State(Enum): PENDING = 1 EXECUTING = 2 @@ -33,11 +34,15 @@ class BaseTask(ABC): @property def parents(self) -> List[BaseTask]: - return [self.structure.find_task(parent_id) for parent_id in self.parent_ids] + return [ + self.structure.find_task(parent_id) for parent_id in self.parent_ids + ] @property def children(self) -> List[BaseTask]: - return [self.structure.find_task(child_id) for child_id in self.child_ids] + return [ + self.structure.find_task(child_id) for child_id in self.child_ids + ] def __rshift__(self, child: BaseTask) -> BaseTask: return self.add_child(child) @@ -118,8 +123,7 @@ class BaseTask(ABC): def can_execute(self) -> bool: return self.state == self.State.PENDING and all( - parent.is_finished() for parent in self.parents - ) + parent.is_finished() for parent in self.parents) def reset(self) -> BaseTask: self.state = self.State.PENDING @@ -132,10 +136,10 @@ class BaseTask(ABC): class Task(BaseModel): - input: Optional[StrictStr] = Field(None, description="Input prompt for the task") + input: Optional[StrictStr] = Field(None, + description="Input prompt for the task") additional_input: Optional[Any] = Field( - None, description="Input parameters for the task. Any value is allowed" - ) + None, description="Input parameters for the task. Any value is allowed") task_id: StrictStr = Field(..., description="ID of the task") class Config: diff --git a/swarms/structs/workflow.py b/swarms/structs/workflow.py index 762ee6cc..e4a841ed 100644 --- a/swarms/structs/workflow.py +++ b/swarms/structs/workflow.py @@ -65,11 +65,13 @@ class Workflow: def context(self, task: Task) -> Dict[str, Any]: """Context in tasks""" return { - "parent_output": task.parents[0].output - if task.parents and task.parents[0].output - else None, - "parent": task.parents[0] if task.parents else None, - "child": task.children[0] if task.children else None, + "parent_output": + task.parents[0].output + if task.parents and task.parents[0].output else None, + "parent": + task.parents[0] if task.parents else None, + "child": + task.children[0] if task.children else None, } def __run_from_task(self, task: Optional[Task]) -> None: diff --git a/swarms/swarms/autoscaler.py b/swarms/swarms/autoscaler.py index 5f6bedde..d0aaa598 100644 --- a/swarms/swarms/autoscaler.py +++ b/swarms/swarms/autoscaler.py @@ -87,7 +87,8 @@ class AutoScaler: while True: sleep(60) # check minute pending_tasks = self.task_queue.qsize() - active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()]) + active_agents = sum( + [1 for agent in self.agents_pool if agent.is_busy()]) if pending_tasks / len(self.agents_pool) > self.busy_threshold: self.scale_up() diff --git a/swarms/swarms/base.py b/swarms/swarms/base.py index e99c9b38..6d8e0163 100644 --- a/swarms/swarms/base.py +++ b/swarms/swarms/base.py @@ -117,7 +117,9 @@ class AbstractSwarm(ABC): pass @abstractmethod - def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None): + def broadcast(self, + message: str, + sender: Optional["AbstractWorker"] = None): """Broadcast a message to all workers""" pass diff --git a/swarms/swarms/battle_royal.py b/swarms/swarms/battle_royal.py index 2a02186e..7b5c2a99 100644 --- a/swarms/swarms/battle_royal.py +++ b/swarms/swarms/battle_royal.py @@ -77,19 +77,15 @@ class BattleRoyalSwarm: # Check for clashes and handle them for i, worker1 in enumerate(self.workers): for j, worker2 in enumerate(self.workers): - if ( - i != j - and worker1.is_within_proximity(worker2) - and set(worker1.teams) != set(worker2.teams) - ): + if (i != j and worker1.is_within_proximity(worker2) and + set(worker1.teams) != set(worker2.teams)): winner, loser = self.clash(worker1, worker2, question) print(f"Worker {winner.id} won over Worker {loser.id}") def communicate(self, sender: Worker, reciever: Worker, message: str): """Communicate a message from one worker to another.""" if sender.is_within_proximity(reciever) or any( - team in sender.teams for team in reciever.teams - ): + team in sender.teams for team in reciever.teams): pass def clash(self, worker1: Worker, worker2: Worker, question: str): diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index fe842f0a..7f302318 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -49,9 +49,8 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" - ) - ) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan")) def run_all(self, task): """Run the task on all LLMs""" @@ -74,18 +73,15 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" - ) - ) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan")) # New Features def save_responses_to_file(self, filename): """Save responses to file""" with open(filename, "w") as file: - table = [ - [f"LLM {i+1}", response] - for i, response in enumerate(self.last_responses) - ] + table = [[f"LLM {i+1}", response] + for i, response in enumerate(self.last_responses)] file.write(tabulate(table, headers=["LLM", "Response"])) @classmethod @@ -105,11 +101,9 @@ class GodMode: for i, task in enumerate(self.task_history): print(f"{i + 1}. {task}") print("\nLast Responses:") - table = [ - [f"LLM {i+1}", response] for i, response in enumerate(self.last_responses) - ] + table = [[f"LLM {i+1}", response] + for i, response in enumerate(self.last_responses)] print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" - ) - ) + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan")) diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index dd3e36a2..842ebac9 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Dict, List from swarms.structs.flow import Flow - logger = logging.getLogger(__name__) @@ -34,7 +33,8 @@ class GroupChat: def next_agent(self, agent: Flow) -> Flow: """Return the next agent in the list.""" - return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] + return self.agents[(self.agent_names.index(agent.name) + 1) % + len(self.agents)] def select_speaker_msg(self): """Return the message for selecting the next speaker.""" @@ -55,24 +55,17 @@ class GroupChat: if n_agents < 3: logger.warning( f"GroupChat is underpopulated with {n_agents} agents. Direct" - " communication would be more efficient." - ) + " communication would be more efficient.") name = selector.generate_reply( - self.format_history( - self.messages - + [ - { - "role": "system", - "content": ( - "Read the above conversation. Then select the next most" - f" suitable role from {self.agent_names} to play. Only" - " return the role." - ), - } - ] - ) - ) + self.format_history(self.messages + [{ + "role": + "system", + "content": + ("Read the above conversation. Then select the next most" + f" suitable role from {self.agent_names} to play. Only" + " return the role."), + }])) try: return self.agent_by_name(name["content"]) except ValueError: @@ -80,8 +73,7 @@ class GroupChat: def _participant_roles(self): return "\n".join( - [f"{agent.name}: {agent.system_message}" for agent in self.agents] - ) + [f"{agent.name}: {agent.system_message}" for agent in self.agents]) def format_history(self, messages: List[Dict]) -> str: formatted_messages = [] @@ -92,19 +84,21 @@ class GroupChat: class GroupChatManager: + def __init__(self, groupchat: GroupChat, selector: Flow): self.groupchat = groupchat self.selector = selector def __call__(self, task: str): - self.groupchat.messages.append({"role": self.selector.name, "content": task}) + self.groupchat.messages.append({ + "role": self.selector.name, + "content": task + }) for i in range(self.groupchat.max_round): - speaker = self.groupchat.select_speaker( - last_speaker=self.selector, selector=self.selector - ) + speaker = self.groupchat.select_speaker(last_speaker=self.selector, + selector=self.selector) reply = speaker.generate_reply( - self.groupchat.format_history(self.groupchat.messages) - ) + self.groupchat.format_history(self.groupchat.messages)) self.groupchat.messages.append(reply) print(reply) if i == self.groupchat.max_round - 1: diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/swarms/multi_agent_collab.py index 9a5f27bc..a3b79d7f 100644 --- a/swarms/swarms/multi_agent_collab.py +++ b/swarms/swarms/multi_agent_collab.py @@ -5,16 +5,16 @@ from langchain.output_parsers import RegexParser # utils class BidOutputParser(RegexParser): + def get_format_instructions(self) -> str: return ( "Your response should be an integrater delimited by angled brackets like" - " this: " - ) + " this: ") -bid_parser = BidOutputParser( - regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid" -) +bid_parser = BidOutputParser(regex=r"<(\d+)>", + output_keys=["bid"], + default_output_key="bid") def select_next_speaker(step: int, agents, director) -> int: @@ -29,6 +29,7 @@ def select_next_speaker(step: int, agents, director) -> int: # main class MultiAgentCollaboration: + def __init__( self, agents, diff --git a/swarms/swarms/multi_agent_debate.py b/swarms/swarms/multi_agent_debate.py index 4bba3619..1c7ebdf9 100644 --- a/swarms/swarms/multi_agent_debate.py +++ b/swarms/swarms/multi_agent_debate.py @@ -46,7 +46,6 @@ class MultiAgentDebate: def format_results(self, results): formatted_results = "\n".join( - [f"Agent responded: {result['response']}" for result in results] - ) + [f"Agent responded: {result['response']}" for result in results]) return formatted_results diff --git a/swarms/swarms/orchestrate.py b/swarms/swarms/orchestrate.py index f522911b..d47771ab 100644 --- a/swarms/swarms/orchestrate.py +++ b/swarms/swarms/orchestrate.py @@ -111,7 +111,8 @@ class Orchestrator: self.chroma_client = chromadb.Client() - self.collection = self.chroma_client.create_collection(name=collection_name) + self.collection = self.chroma_client.create_collection( + name=collection_name) self.current_tasks = {} @@ -137,9 +138,8 @@ class Orchestrator: result = self.worker.run(task["content"]) # using the embed method to get the vector representation of the result - vector_representation = self.embed( - result, self.api_key, self.model_name - ) + vector_representation = self.embed(result, self.api_key, + self.model_name) self.collection.add( embeddings=[vector_representation], @@ -154,8 +154,7 @@ class Orchestrator: except Exception as error: logging.error( f"Failed to process task {id(task)} by agent {id(agent)}. Error:" - f" {error}" - ) + f" {error}") finally: with self.condition: self.agents.put(agent) @@ -163,8 +162,7 @@ class Orchestrator: def embed(self, input, api_key, model_name): openai = embedding_functions.OpenAIEmbeddingFunction( - api_key=api_key, model_name=model_name - ) + api_key=api_key, model_name=model_name) embedding = openai(input) return embedding @@ -175,13 +173,13 @@ class Orchestrator: try: # Query the vector database for documents created by the agents - results = self.collection.query(query_texts=[str(agent_id)], n_results=10) + results = self.collection.query(query_texts=[str(agent_id)], + n_results=10) return results except Exception as e: logging.error( - f"Failed to retrieve results from agent {agent_id}. Error {e}" - ) + f"Failed to retrieve results from agent {agent_id}. Error {e}") raise # @abstractmethod @@ -212,7 +210,8 @@ class Orchestrator: self.collection.add(documents=[result], ids=[str(id(result))]) except Exception as e: - logging.error(f"Failed to append the agent output to database. Error: {e}") + logging.error( + f"Failed to append the agent output to database. Error: {e}") raise def run(self, objective: str): @@ -225,8 +224,8 @@ class Orchestrator: self.task_queue.append(objective) results = [ - self.assign_task(agent_id, task) - for agent_id, task in zip(range(len(self.agents)), self.task_queue) + self.assign_task(agent_id, task) for agent_id, task in zip( + range(len(self.agents)), self.task_queue) ] for result in results: diff --git a/swarms/swarms/simple_swarm.py b/swarms/swarms/simple_swarm.py index 7e806215..a382c0d7 100644 --- a/swarms/swarms/simple_swarm.py +++ b/swarms/swarms/simple_swarm.py @@ -2,6 +2,7 @@ from queue import Queue, PriorityQueue class SimpleSwarm: + def __init__( self, llm, diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index cf5450e6..270504aa 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -8,8 +8,7 @@ import torch from langchain.agents import tool from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent from langchain.chains.qa_with_sources.loading import ( - BaseCombineDocumentsChain, -) + BaseCombineDocumentsChain,) from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.tools import BaseTool @@ -37,9 +36,10 @@ def pushd(new_dir): @tool -def process_csv( - llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None -) -> str: +def process_csv(llm, + csv_file_path: str, + instructions: str, + output_path: Optional[str] = None) -> str: """Process a CSV by with pandas in a limited REPL.\ Only use this after writing data to disk as a csv file.\ Any figures must be saved to disk to be viewed by the human.\ @@ -49,7 +49,10 @@ def process_csv( df = pd.read_csv(csv_file_path) except Exception as e: return f"Error: {e}" - agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False) + agent = create_pandas_dataframe_agent(llm, + df, + max_iterations=30, + verbose=False) if output_path is not None: instructions += f" Save output to disk at {output_path}" try: @@ -79,7 +82,8 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = ( + phrase.strip() for line in lines for phrase in line.split(" ")) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -113,8 +117,7 @@ class WebpageQATool(BaseTool): "Browse a webpage and retrieve the information relevant to the question." ) text_splitter: RecursiveCharacterTextSplitter = Field( - default_factory=_get_text_splitter - ) + default_factory=_get_text_splitter) qa_chain: BaseCombineDocumentsChain def _run(self, url: str, question: str) -> str: @@ -125,9 +128,12 @@ class WebpageQATool(BaseTool): results = [] # TODO: Handle this with a MapReduceChain for i in range(0, len(web_docs), 4): - input_docs = web_docs[i : i + 4] + input_docs = web_docs[i:i + 4] window_result = self.qa_chain( - {"input_documents": input_docs, "question": question}, + { + "input_documents": input_docs, + "question": question + }, return_only_outputs=True, ) results.append(f"Response from window {i} - {window_result}") @@ -135,7 +141,10 @@ class WebpageQATool(BaseTool): Document(page_content="\n".join(results), metadata={"source": url}) ] return self.qa_chain( - {"input_documents": results_docs, "question": question}, + { + "input_documents": results_docs, + "question": question + }, return_only_outputs=True, ) @@ -171,18 +180,17 @@ def VQAinference(self, inputs): torch_dtype = torch.float16 if "cuda" in device else torch.float32 processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", torch_dtype=torch_dtype - ).to(device) + "Salesforce/blip-vqa-base", torch_dtype=torch_dtype).to(device) image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype) + inputs = processor(raw_image, question, + return_tensors="pt").to(device, torch_dtype) out = model.generate(**inputs) answer = processor.decode(out[0], skip_special_tokens=True) logger.debug( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}" - ) + f" Question: {question}, Output Answer: {answer}") return answer diff --git a/swarms/tools/mm_models.py b/swarms/tools/mm_models.py index 58fe11e5..fd115bd6 100644 --- a/swarms/tools/mm_models.py +++ b/swarms/tools/mm_models.py @@ -25,13 +25,14 @@ from swarms.utils.main import BaseHandler, get_new_image_name class MaskFormer: + def __init__(self, device): print("Initializing MaskFormer to %s" % device) self.device = device - self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + self.processor = CLIPSegProcessor.from_pretrained( + "CIDAS/clipseg-rd64-refined") self.model = CLIPSegForImageSegmentation.from_pretrained( - "CIDAS/clipseg-rd64-refined" - ).to(device) + "CIDAS/clipseg-rd64-refined").to(device) def inference(self, image_path, text): threshold = 0.5 @@ -39,9 +40,10 @@ class MaskFormer: padding = 20 original_image = Image.open(image_path) image = original_image.resize((512, 512)) - inputs = self.processor( - text=text, images=image, padding="max_length", return_tensors="pt" - ).to(self.device) + inputs = self.processor(text=text, + images=image, + padding="max_length", + return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs) mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold @@ -52,8 +54,7 @@ class MaskFormer: mask_array = np.zeros_like(mask, dtype=bool) for idx in true_indices: padded_slice = tuple( - slice(max(0, i - padding), i + padding + 1) for i in idx - ) + slice(max(0, i - padding), i + padding + 1) for i in idx) mask_array[padded_slice] = True visual_mask = (mask_array * 255).astype(np.uint8) image_mask = Image.fromarray(visual_mask) @@ -61,6 +62,7 @@ class MaskFormer: class ImageEditing: + def __init__(self, device): print("Initializing ImageEditing to %s" % device) self.device = device @@ -75,25 +77,24 @@ class ImageEditing: @tool( name="Remove Something From The Photo", - description=( - "useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. " - ), + description= + ("useful when you want to remove and object or something from the photo " + "from its description or location. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the object need to be removed. "), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",") - return self.inference_replace(f"{image_path},{to_be_removed_txt},background") + return self.inference_replace( + f"{image_path},{to_be_removed_txt},background") @tool( name="Replace Something From The Photo", - description=( - "useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with " - ), + description= + ("useful when you want to replace an object from the object description or" + " location with another object from its description. The input to this tool" + " should be a comma separated string of three, representing the image_path," + " the object to be replaced, the object to be replaced with "), ) def inference_replace(self, inputs): image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") @@ -105,22 +106,21 @@ class ImageEditing: image=original_image.resize((512, 512)), mask_image=mask_image.resize((512, 512)), ).images[0] - updated_image_path = get_new_image_name( - image_path, func_name="replace-something" - ) + updated_image_path = get_new_image_name(image_path, + func_name="replace-something") updated_image = updated_image.resize(original_size) updated_image.save(updated_image_path) logger.debug( f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" - f" {updated_image_path}" - ) + f" {updated_image_path}") return updated_image_path class InstructPix2Pix: + def __init__(self, device): print("Initializing InstructPix2Pix to %s" % device) self.device = device @@ -131,60 +131,56 @@ class InstructPix2Pix: torch_dtype=self.torch_dtype, ).to(device) self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( - self.pipe.scheduler.config - ) + self.pipe.scheduler.config) @tool( name="Instruct Image Using Text", - description=( - "useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. " - ), + description= + ("useful when you want to the style of the image to be like the text. " + "like: make it look like a painting. or make it like a robot. " + "The input to this tool should be a comma separated string of two, " + "representing the image_path and the text. "), ) def inference(self, inputs): """Change style of image.""" logger.debug("===> Starting InstructPix2Pix Inference") image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) - image = self.pipe( - text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 - ).images[0] + image = self.pipe(text, + image=original_image, + num_inference_steps=40, + image_guidance_scale=1.2).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) logger.debug( f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}" - ) + f" {text}, Output Image: {updated_image_path}") return updated_image_path class Text2Image: + def __init__(self, device): print("Initializing Text2Image to %s" % device) self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype - ) + "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" - ) + "fewer digits, cropped, worst quality, low quality") @tool( name="Generate Image From User Input Text", - description=( - "useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. " - ), + description= + ("useful when you want to generate an image from a user input text and save" + " it to a file. like: generate an image of an object or something, or" + " generate an image that includes some objects. The input to this tool" + " should be a string, representing the text used to generate image. "), ) def inference(self, text): image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png") @@ -194,59 +190,59 @@ class Text2Image: logger.debug( f"\nProcessed Text2Image, Input Text: {text}, Output Image:" - f" {image_filename}" - ) + f" {image_filename}") return image_filename class VisualQuestionAnswering: + def __init__(self, device): print("Initializing VisualQuestionAnswering to %s" % device) self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained( + "Salesforce/blip-vqa-base") self.model = BlipForQuestionAnswering.from_pretrained( - "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype - ).to(self.device) + "Salesforce/blip-vqa-base", + torch_dtype=self.torch_dtype).to(self.device) @tool( name="Answer Question About The Image", - description=( - "useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + description= + ("useful when you need an answer for a question based on an image. like:" + " what is the background color of the last image, how many cats in this" + " figure, what is in this figure. The input to this tool should be a comma" + " separated string of two, representing the image_path and the question" ), ) def inference(self, inputs): image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = self.processor(raw_image, question, return_tensors="pt").to( - self.device, self.torch_dtype - ) + inputs = self.processor(raw_image, question, + return_tensors="pt").to(self.device, + self.torch_dtype) out = self.model.generate(**inputs) answer = self.processor.decode(out[0], skip_special_tokens=True) logger.debug( f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}" - ) + f" Question: {question}, Output Answer: {answer}") return answer class ImageCaptioning(BaseHandler): + def __init__(self, device): print("Initializing ImageCaptioning to %s" % device) self.device = device self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.processor = BlipProcessor.from_pretrained( - "Salesforce/blip-image-captioning-base" - ) + "Salesforce/blip-image-captioning-base") self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype - ).to(self.device) + "Salesforce/blip-image-captioning-base", + torch_dtype=self.torch_dtype).to(self.device) def handle(self, filename: str): img = Image.open(filename) @@ -258,14 +254,13 @@ class ImageCaptioning(BaseHandler): img.save(filename, "PNG") print(f"Resize image form {width}x{height} to {width_new}x{height_new}") - inputs = self.processor(Image.open(filename), return_tensors="pt").to( - self.device, self.torch_dtype - ) + inputs = self.processor(Image.open(filename), + return_tensors="pt").to(self.device, + self.torch_dtype) out = self.model.generate(**inputs) description = self.processor.decode(out[0], skip_special_tokens=True) print( f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" - f" {description}" - ) + f" {description}") return IMAGE_PROMPT.format(filename=filename, description=description) diff --git a/swarms/tools/stt.py b/swarms/tools/stt.py index cfe3e656..da9d7f27 100644 --- a/swarms/tools/stt.py +++ b/swarms/tools/stt.py @@ -9,6 +9,7 @@ from pytube import YouTube class SpeechToText: + def __init__( self, video_url, @@ -61,14 +62,15 @@ class SpeechToText: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model("large-v2", device, compute_type=compute_type) + model = whisperx.load_model("large-v2", + device, + compute_type=compute_type) audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) # 2. Align Whisper output 🔍 model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=device - ) + language_code=result["language"], device=device) result = whisperx.align( result["segments"], model_a, @@ -80,8 +82,7 @@ class SpeechToText: # 3. Assign speaker labels 🏷️ diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=device - ) + use_auth_token=self.hf_api_key, device=device) diarize_model(audio_file) try: @@ -98,8 +99,7 @@ class SpeechToText: # 2. Align Whisper output 🔍 model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=self.device - ) + language_code=result["language"], device=self.device) result = whisperx.align( result["segments"], @@ -112,8 +112,7 @@ class SpeechToText: # 3. Assign speaker labels 🏷️ diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=self.device - ) + use_auth_token=self.hf_api_key, device=self.device) diarize_model(audio_file) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index 1b1072a5..29b0f5de 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -34,9 +34,8 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" -def _create_subset_model( - name: str, model: BaseModel, field_names: list -) -> Type[BaseModel]: +def _create_subset_model(name: str, model: BaseModel, + field_names: list) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} for field_name in field_names: @@ -52,7 +51,11 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + return { + k: schema[k] + for k in valid_keys + if k not in ("run_manager", "callbacks") + } class _SchemaConfig: @@ -82,9 +85,8 @@ def create_schema_from_function( del inferred_model.__fields__["callbacks"] # Pydantic adds placeholder virtual fields we need to strip valid_properties = _get_filtered_args(inferred_model, func) - return _create_subset_model( - f"{model_name}Schema", inferred_model, list(valid_properties) - ) + return _create_subset_model(f"{model_name}Schema", inferred_model, + list(valid_properties)) class ToolException(Exception): @@ -125,8 +127,7 @@ class ChildTool(BaseTool): "Expected annotation of 'Type[BaseModel]'" f" but got '{args_schema_type}'.\n" "Expected class looks like:\n" - f"{typehint_mandate}" - ) + f"{typehint_mandate}") name: str """The unique name of the tool that clearly communicates its purpose.""" @@ -147,7 +148,8 @@ class ChildTool(BaseTool): callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + callback_manager: Optional[BaseCallbackManager] = Field(default=None, + exclude=True) """Deprecated. Please use callbacks instead.""" tags: Optional[List[str]] = None """Optional list of tags associated with the tool. Defaults to None @@ -162,9 +164,8 @@ class ChildTool(BaseTool): You can use these to eg identify a specific instance of a tool with its use case. """ - handle_tool_error: Optional[ - Union[bool, str, Callable[[ToolException], str]] - ] = False + handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], + str]]] = False """Handle the content of the ToolException thrown.""" class Config(Serializable.Config): @@ -244,7 +245,9 @@ class ChildTool(BaseTool): else: if input_args is not None: result = input_args.parse_obj(tool_input) - return {k: v for k, v in result.dict().items() if k in tool_input} + return { + k: v for k, v in result.dict().items() if k in tool_input + } return tool_input @root_validator() @@ -286,7 +289,8 @@ class ChildTool(BaseTool): *args, ) - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, + tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -325,7 +329,10 @@ class ChildTool(BaseTool): # TODO: maybe also pass through run_manager is _run supports kwargs new_arg_supported = signature(self._run).parameters.get("run_manager") run_manager = callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, + { + "name": self.name, + "description": self.description + }, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, @@ -335,9 +342,7 @@ class ChildTool(BaseTool): tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) observation = ( self._run(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported - else self._run(*tool_args, **tool_kwargs) - ) + if new_arg_supported else self._run(*tool_args, **tool_kwargs)) except ToolException as e: if not self.handle_tool_error: run_manager.on_tool_error(e) @@ -354,19 +359,20 @@ class ChildTool(BaseTool): else: raise ValueError( "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs - ) + f"or callable. Received: {self.handle_tool_error}") + run_manager.on_tool_end(str(observation), + color="red", + name=self.name, + **kwargs) return observation except (Exception, KeyboardInterrupt) as e: run_manager.on_tool_error(e) raise e else: - run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) + run_manager.on_tool_end(str(observation), + color=color, + name=self.name, + **kwargs) return observation async def arun( @@ -399,7 +405,10 @@ class ChildTool(BaseTool): ) new_arg_supported = signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, + { + "name": self.name, + "description": self.description + }, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, @@ -408,11 +417,10 @@ class ChildTool(BaseTool): try: # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = ( - await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) - if new_arg_supported - else await self._arun(*tool_args, **tool_kwargs) - ) + observation = (await self._arun(*tool_args, + run_manager=run_manager, + **tool_kwargs) if new_arg_supported + else await self._arun(*tool_args, **tool_kwargs)) except ToolException as e: if not self.handle_tool_error: await run_manager.on_tool_error(e) @@ -429,19 +437,20 @@ class ChildTool(BaseTool): else: raise ValueError( "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - await run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs - ) + f"or callable. Received: {self.handle_tool_error}") + await run_manager.on_tool_end(str(observation), + color="red", + name=self.name, + **kwargs) return observation except (Exception, KeyboardInterrupt) as e: await run_manager.on_tool_error(e) raise e else: - await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs - ) + await run_manager.on_tool_end(str(observation), + color=color, + name=self.name, + **kwargs) return observation def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: @@ -459,7 +468,6 @@ class Tool(BaseTool): """The asynchronous version of the function.""" # --- Runnable --- - async def ainvoke( self, input: Union[str, Dict], @@ -469,8 +477,7 @@ class Tool(BaseTool): if not self.coroutine: # If the tool does not implement async, fall back to default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) + None, partial(self.invoke, input, config, **kwargs)) return await super().ainvoke(input, config, **kwargs) @@ -485,7 +492,8 @@ class Tool(BaseTool): # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, + tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) # For backwards compatibility. The tool must be run with a single input @@ -504,16 +512,13 @@ class Tool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) + new_argument_supported = signature( + self.func).parameters.get("callbacks") + return (self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) if new_argument_supported else self.func(*args, **kwargs)) raise NotImplementedError("Tool does not support sync") async def _arun( @@ -524,31 +529,27 @@ class Tool(BaseTool): ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) + new_argument_supported = signature( + self.coroutine).parameters.get("callbacks") + return (await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) if new_argument_supported else await self.coroutine( + *args, **kwargs)) else: return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), *args - ) + None, partial(self._run, run_manager=run_manager, **kwargs), + *args) # TODO: this is for backwards compatibility, remove in future - def __init__( - self, name: str, func: Optional[Callable], description: str, **kwargs: Any - ) -> None: + def __init__(self, name: str, func: Optional[Callable], description: str, + **kwargs: Any) -> None: """Initialize tool.""" - super(Tool, self).__init__( - name=name, func=func, description=description, **kwargs - ) + super(Tool, self).__init__(name=name, + func=func, + description=description, + **kwargs) @classmethod def from_function( @@ -558,9 +559,8 @@ class Tool(BaseTool): description: str, return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, - coroutine: Optional[ - Callable[..., Awaitable[Any]] - ] = None, # This is last for compatibility, but should be after func + coroutine: Optional[Callable[..., Awaitable[ + Any]]] = None, # This is last for compatibility, but should be after func **kwargs: Any, ) -> Tool: """Initialize tool from a function.""" @@ -589,7 +589,6 @@ class StructuredTool(BaseTool): """The asynchronous version of the function.""" # --- Runnable --- - async def ainvoke( self, input: Union[str, Dict], @@ -599,8 +598,7 @@ class StructuredTool(BaseTool): if not self.coroutine: # If the tool does not implement async, fall back to default implementation return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) + None, partial(self.invoke, input, config, **kwargs)) return await super().ainvoke(input, config, **kwargs) @@ -619,16 +617,13 @@ class StructuredTool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) + new_argument_supported = signature( + self.func).parameters.get("callbacks") + return (self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) if new_argument_supported else self.func(*args, **kwargs)) raise NotImplementedError("Tool does not support sync") async def _arun( @@ -639,18 +634,14 @@ class StructuredTool(BaseTool): ) -> str: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) + new_argument_supported = signature( + self.coroutine).parameters.get("callbacks") + return (await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) if new_argument_supported else await self.coroutine( + *args, **kwargs)) return await asyncio.get_running_loop().run_in_executor( None, partial(self._run, run_manager=run_manager, **kwargs), @@ -707,8 +698,7 @@ class StructuredTool(BaseTool): description = description or source_function.__doc__ if description is None: raise ValueError( - "Function must have a docstring if description not provided." - ) + "Function must have a docstring if description not provided.") # Description example: # search_api(query: str) - Searches the API for the query. @@ -716,7 +706,8 @@ class StructuredTool(BaseTool): description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", source_function) + _args_schema = create_schema_from_function(f"{name}Schema", + source_function) return cls( name=name, func=func, @@ -764,6 +755,7 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: + def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: if isinstance(dec_func, Runnable): runnable = dec_func @@ -771,14 +763,13 @@ def tool( if runnable.input_schema.schema().get("type") != "object": raise ValueError("Runnable must have an object schema.") - async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: - return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + async def ainvoke_wrapper(callbacks: Optional[Callbacks] = None, + **kwargs: Any) -> Any: + return await runnable.ainvoke(kwargs, + {"callbacks": callbacks}) - def invoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any - ) -> Any: + def invoke_wrapper(callbacks: Optional[Callbacks] = None, + **kwargs: Any) -> Any: return runnable.invoke(kwargs, {"callbacks": callbacks}) coroutine = ainvoke_wrapper @@ -811,8 +802,7 @@ def tool( if func.__doc__ is None: raise ValueError( "Function must have a docstring if " - "description not provided and infer_schema is False." - ) + "description not provided and infer_schema is False.") return Tool( name=tool_name, func=func, @@ -823,7 +813,8 @@ def tool( return _make_tool - if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + if len(args) == 2 and isinstance(args[0], str) and isinstance( + args[1], Runnable): return _make_with_name(args[0])(args[1]) elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name diff --git a/swarms/tools/tool_registry.py b/swarms/tools/tool_registry.py index 5aa544e9..3354646a 100644 --- a/swarms/tools/tool_registry.py +++ b/swarms/tools/tool_registry.py @@ -6,6 +6,7 @@ FuncToolBuilder = Callable[[], ToolBuilder] class ToolsRegistry: + def __init__(self) -> None: self.tools: Dict[str, FuncToolBuilder] = {} @@ -18,8 +19,7 @@ class ToolsRegistry: if isinstance(ret, tool): return ret raise ValueError( - "Tool builder {} did not return a Tool instance".format(tool_name) - ) + "Tool builder {} did not return a Tool instance".format(tool_name)) def list_tools(self) -> List[str]: return list(self.tools.keys()) @@ -29,6 +29,7 @@ tools_registry = ToolsRegistry() def register(tool_name): + def decorator(tool: FuncToolBuilder): tools_registry.register(tool_name, tool) return tool diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index 80eb6700..c89ac7a7 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -118,14 +118,19 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): # Most of the time it doesn't matter, but we should figure out why it happens frequently with: # applescript yield {"output": traceback.format_exc()} - yield {"output": f"Retrying... ({retry_count}/{max_retries})"} + yield { + "output": f"Retrying... ({retry_count}/{max_retries})" + } yield {"output": "Restarting process."} self.start_process() retry_count += 1 if retry_count > max_retries: - yield {"output": "Maximum retries reached. Could not execute code."} + yield { + "output": + "Maximum retries reached. Could not execute code." + } return while True: @@ -134,7 +139,8 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): else: time.sleep(0.1) try: - output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds + output = self.output_queue.get( + timeout=0.3) # Waits for 0.3 seconds yield output except queue.Empty: if self.done.is_set(): diff --git a/swarms/utils/decorators.py b/swarms/utils/decorators.py index 8a5a5d56..2f22528b 100644 --- a/swarms/utils/decorators.py +++ b/swarms/utils/decorators.py @@ -6,6 +6,7 @@ import warnings def log_decorator(func): + def wrapper(*args, **kwargs): logging.info(f"Entering {func.__name__}") result = func(*args, **kwargs) @@ -16,6 +17,7 @@ def log_decorator(func): def error_decorator(func): + def wrapper(*args, **kwargs): try: return func(*args, **kwargs) @@ -27,18 +29,22 @@ def error_decorator(func): def timing_decorator(func): + def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() - logging.info(f"{func.__name__} executed in {end_time - start_time} seconds") + logging.info( + f"{func.__name__} executed in {end_time - start_time} seconds") return result return wrapper def retry_decorator(max_retries=5): + def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): for _ in range(max_retries): @@ -77,16 +83,20 @@ def synchronized_decorator(func): def deprecated_decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning) + warnings.warn(f"{func.__name__} is deprecated", + category=DeprecationWarning) return func(*args, **kwargs) return wrapper def validate_inputs_decorator(validator): + def decorator(func): + @functools.wraps(func) def wrapper(*args, **kwargs): if not validator(*args, **kwargs): diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py index 55a4e5d5..5c2dfdcd 100644 --- a/swarms/utils/futures.py +++ b/swarms/utils/futures.py @@ -5,6 +5,8 @@ T = TypeVar("T") def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: - futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) + futures.wait(fs_dict.values(), + timeout=None, + return_when=futures.ALL_COMPLETED) return {key: future.result() for key, future in fs_dict.items()} diff --git a/swarms/utils/hash.py b/swarms/utils/hash.py index 725cc6ba..458fc147 100644 --- a/swarms/utils/hash.py +++ b/swarms/utils/hash.py @@ -4,8 +4,7 @@ import hashlib def dataframe_to_hash(dataframe: pd.DataFrame) -> str: return hashlib.sha256( - pd.util.hash_pandas_object(dataframe, index=True).values - ).hexdigest() + pd.util.hash_pandas_object(dataframe, index=True).values).hexdigest() def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str: diff --git a/swarms/utils/main.py b/swarms/utils/main.py index 9c1342aa..9d5eefdf 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -51,16 +51,16 @@ def get_new_image_name(org_img_name, func_name="update"): if len(name_split) == 1: most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name - ) + new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, + recent_prev_file_name, + most_org_file_name) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name - ) + new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, + recent_prev_file_name, + most_org_file_name) return os.path.join(head, new_file_name) @@ -73,26 +73,26 @@ def get_new_dataframe_name(org_img_name, func_name="update"): if len(name_split) == 1: most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name - ) + new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, + recent_prev_file_name, + most_org_file_name) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name - ) + new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, + recent_prev_file_name, + most_org_file_name) return os.path.join(head, new_file_name) # =======================> utils end - # =======================> ANSI BEGINNING class Code: + def __init__(self, value: int): self.value = value @@ -101,6 +101,7 @@ class Code: class Color(Code): + def bg(self) -> "Color": self.value += 10 return self @@ -147,6 +148,7 @@ class Color(Code): class Style(Code): + @staticmethod def reset() -> "Style": return Style(0) @@ -203,19 +205,19 @@ def dim_multiline(message: str) -> str: lines = message.split("\n") if len(lines) <= 1: return lines[0] - return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) + return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to( + Color.black().bright()) # +=============================> ANSI Ending - # ================================> upload base - STATIC_DIR = "static" class AbstractUploader(ABC): + @abstractmethod def upload(self, filepath: str) -> str: pass @@ -227,12 +229,13 @@ class AbstractUploader(ABC): # ================================> upload end - # ========================= upload s3 class S3Uploader(AbstractUploader): - def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str): + + def __init__(self, accessKey: str, secretKey: str, region: str, + bucket: str): self.accessKey = accessKey self.secretKey = secretKey self.region = region @@ -263,11 +266,11 @@ class S3Uploader(AbstractUploader): # ========================= upload s3 - # ========================> upload/static class StaticUploader(AbstractUploader): + def __init__(self, server: str, path: Path, endpoint: str): self.server = server self.path = path @@ -292,7 +295,6 @@ class StaticUploader(AbstractUploader): # ========================> handlers/base - # from env import settings @@ -336,16 +338,19 @@ class FileType(Enum): class BaseHandler: + def handle(self, filename: str) -> str: raise NotImplementedError class FileHandler: + def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): self.handlers = handlers self.path = path - def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": + def register(self, filetype: FileType, + handler: BaseHandler) -> "FileHandler": self.handlers[filetype] = handler return self @@ -353,8 +358,8 @@ class FileHandler: filetype = FileType.from_url(url) data = requests.get(url).content local_filename = os.path.join( - "file", str(uuid.uuid4())[0:8] + filetype.to_extension() - ) + "file", + str(uuid.uuid4())[0:8] + filetype.to_extension()) os.makedirs(os.path.dirname(local_filename), exist_ok=True) with open(local_filename, "wb") as f: size = f.write(data) @@ -363,17 +368,15 @@ class FileHandler: def handle(self, url: str) -> str: try: - if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): + if url.startswith(os.environ.get("SERVER", + "http://localhost:8000")): local_filepath = url[ - len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : - ] + len(os.environ.get("SERVER", "http://localhost:8000")) + 1:] local_filename = Path("file") / local_filepath.split("/")[-1] src = self.path / local_filepath - dst = ( - self.path - / os.environ.get("PLAYGROUND_DIR", "./playground") - / local_filename - ) + dst = (self.path / + os.environ.get("PLAYGROUND_DIR", "./playground") / + local_filename) os.makedirs(os.path.dirname(dst), exist_ok=True) shutil.copy(src, dst) else: @@ -383,8 +386,7 @@ class FileHandler: if FileType.from_url(url) == FileType.IMAGE: raise Exception( f"No handler for {FileType.from_url(url)}. " - "Please set USE_GPU to True in env/settings.py" - ) + "Please set USE_GPU to True in env/settings.py") else: raise Exception(f"No handler for {FileType.from_url(url)}") return handler.handle(local_filename) @@ -394,22 +396,21 @@ class FileHandler: # => base end - # ===========================> class CsvToDataframe(BaseHandler): + def handle(self, filename: str): df = pd.read_csv(filename) description = ( f"Dataframe with {len(df)} rows and {len(df.columns)} columns. " "Columns are: " - f"{', '.join(df.columns)}" - ) + f"{', '.join(df.columns)}") print( f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:" - f" {description}" - ) + f" {description}") - return DATAFRAME_PROMPT.format(filename=filename, description=description) + return DATAFRAME_PROMPT.format(filename=filename, + description=description) diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index a2f346ea..020c9bef 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -7,5 +7,6 @@ def extract_code_in_backticks_in_string(message: str) -> str: """ pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks - match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars + match = re.search(pattern, message, + re.DOTALL) # re.DOTALL to match newline chars return match.group(1).strip() if match else None diff --git a/swarms/utils/revutils.py b/swarms/utils/revutils.py index 7868ae44..9db1e123 100644 --- a/swarms/utils/revutils.py +++ b/swarms/utils/revutils.py @@ -49,16 +49,12 @@ def get_input( """ Multiline input function. """ - return ( - session.prompt( - completer=completer, - multiline=True, - auto_suggest=AutoSuggestFromHistory(), - key_bindings=key_bindings, - ) - if session - else prompt(multiline=True) - ) + return (session.prompt( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=key_bindings, + ) if session else prompt(multiline=True)) async def get_input_async( @@ -68,15 +64,11 @@ async def get_input_async( """ Multiline input function. """ - return ( - await session.prompt_async( - completer=completer, - multiline=True, - auto_suggest=AutoSuggestFromHistory(), - ) - if session - else prompt(multiline=True) - ) + return (await session.prompt_async( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + ) if session else prompt(multiline=True)) def get_filtered_keys_from_object(obj: object, *keys: str) -> any: @@ -94,9 +86,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any: return {key for key in class_keys if key not in keys[1:]} # Check if all passed keys are valid if invalid_keys := set(keys) - class_keys: - raise ValueError( - f"Invalid keys: {invalid_keys}", - ) + raise ValueError(f"Invalid keys: {invalid_keys}",) # Only return specified keys that are in class_keys return {key for key in keys if key in class_keys} @@ -124,8 +114,8 @@ def random_int(min: int, max: int) -> int: if __name__ == "__main__": logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", - ) + format= + "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",) log = logging.getLogger(__name__) diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index 8f0e5ccf..47cc815f 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -106,21 +106,22 @@ class Serializable(BaseModel, ABC): lc_kwargs.update({key: secret_value}) return { - "lc": 1, - "type": "constructor", + "lc": + 1, + "type": + "constructor", "id": [*self.lc_namespace, self.__class__.__name__], - "kwargs": lc_kwargs - if not secrets - else _replace_secrets(lc_kwargs, secrets), + "kwargs": + lc_kwargs if not secrets else _replace_secrets( + lc_kwargs, secrets), } def to_json_not_implemented(self) -> SerializedNotImplemented: return to_json_not_implemented(self) -def _replace_secrets( - root: Dict[Any, Any], secrets_map: Dict[str, str] -) -> Dict[Any, Any]: +def _replace_secrets(root: Dict[Any, Any], + secrets_map: Dict[str, str]) -> Dict[Any, Any]: result = root.copy() for path, secret_id in secrets_map.items(): [*parts, last] = path.split(".") diff --git a/swarms/utils/static.py b/swarms/utils/static.py index 3b8a276d..23f13996 100644 --- a/swarms/utils/static.py +++ b/swarms/utils/static.py @@ -8,6 +8,7 @@ from swarms.utils.main import AbstractUploader class StaticUploader(AbstractUploader): + def __init__(self, server: str, path: Path, endpoint: str): self.server = server self.path = path diff --git a/swarms/workers/worker.py b/swarms/workers/worker.py index 9986666a..bef9682a 100644 --- a/swarms/workers/worker.py +++ b/swarms/workers/worker.py @@ -4,8 +4,7 @@ from typing import Dict, Union import faiss from langchain.chains.qa_with_sources.loading import ( - load_qa_with_sources_chain, -) + load_qa_with_sources_chain,) from langchain.docstore import InMemoryDocstore from langchain.embeddings import OpenAIEmbeddings from langchain.tools import ReadFileTool, WriteFileTool @@ -132,8 +131,7 @@ class Worker: ``` """ query_website_tool = WebpageQATool( - qa_chain=load_qa_with_sources_chain(self.llm) - ) + qa_chain=load_qa_with_sources_chain(self.llm)) self.tools = [ WriteFileTool(root_dir=ROOT_DIR), @@ -157,15 +155,13 @@ class Worker: embedding_size = 1536 index = faiss.IndexFlatL2(embedding_size) - self.vectorstore = FAISS( - embeddings_model.embed_query, index, InMemoryDocstore({}), {} - ) + self.vectorstore = FAISS(embeddings_model.embed_query, index, + InMemoryDocstore({}), {}) except Exception as error: raise RuntimeError( "Error setting up memory perhaps try try tuning the embedding size:" - f" {error}" - ) + f" {error}") def setup_agent(self): """ @@ -294,8 +290,6 @@ class Worker: def is_within_proximity(self, other_worker): """Using Euclidean distance for proximity check""" - distance = ( - (self.coordinates[0] - other_worker.coordinates[0]) ** 2 - + (self.coordinates[1] - other_worker.coordinates[1]) ** 2 - ) ** 0.5 + distance = ((self.coordinates[0] - other_worker.coordinates[0])**2 + + (self.coordinates[1] - other_worker.coordinates[1])**2)**0.5 return distance < 10 # threshold for proximity