From 49c7b97c09b04ff5a7bf2a56beea05acbc00cf0e Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 23 Nov 2023 23:27:40 -0800 Subject: [PATCH] code quality fixes: line length = 80 --- playground/agents/mm_agent_example.py | 5 +- playground/agents/revgpt_agent.py | 10 +- .../demos/accountant_team/accountant_team.py | 7 +- playground/demos/ai_research_team/main.py | 3 +- playground/demos/autotemp/autotemp.py | 18 ++- playground/demos/blog_gen/blog_gen.py | 38 ++++-- .../multi_modal_auto_agent.py | 5 +- playground/demos/nutrition/nutrition.py | 32 +++-- playground/demos/positive_med/positive_med.py | 13 +- playground/models/bioclip.py | 4 +- playground/models/idefics.py | 10 +- playground/models/llama_function_caller.py | 4 +- playground/models/vilt.py | 3 +- playground/structs/flow_tools.py | 7 +- playground/swarms/debate.py | 44 ++++--- playground/swarms/multi_agent_debate.py | 6 +- playground/swarms/orchestrate.py | 4 +- playground/swarms/orchestrator.py | 4 +- playground/swarms/swarms_example.py | 5 +- pyproject.toml | 9 +- swarms/agents/omni_modal_agent.py | 7 +- swarms/memory/base.py | 22 ++-- swarms/memory/chroma.py | 41 ++++-- swarms/memory/cosine_similarity.py | 9 +- swarms/memory/db.py | 4 +- swarms/memory/ocean.py | 8 +- swarms/memory/pg.py | 15 ++- swarms/memory/pinecone.py | 17 ++- swarms/memory/schemas.py | 9 +- swarms/memory/utils.py | 2 +- swarms/models/__init__.py | 6 +- swarms/models/anthropic.py | 48 ++++--- swarms/models/bioclip.py | 16 ++- swarms/models/biogpt.py | 2 +- swarms/models/cohere_chat.py | 16 ++- swarms/models/dalle3.py | 40 +++--- swarms/models/distilled_whisperx.py | 21 ++- swarms/models/eleven_labs.py | 8 +- swarms/models/fastvit.py | 12 +- swarms/models/fuyu.py | 12 +- swarms/models/gpt4v.py | 7 +- swarms/models/huggingface.py | 33 +++-- swarms/models/idefics.py | 4 +- swarms/models/jina_embeds.py | 12 +- swarms/models/kosmos2.py | 18 ++- swarms/models/kosmos_two.py | 38 ++++-- swarms/models/llama_function_caller.py | 15 ++- swarms/models/mistral.py | 12 +- swarms/models/mpt.py | 9 +- swarms/models/nougat.py | 16 ++- swarms/models/openai_embeddings.py | 45 +++++-- swarms/models/openai_function_caller.py | 14 +- swarms/models/openai_models.py | 121 ++++++++++++------ swarms/models/palm.py | 26 +++- swarms/models/simple_ada.py | 4 +- swarms/models/speecht5.py | 8 +- swarms/models/ssd_1b.py | 32 +++-- swarms/models/whisperx.py | 4 +- swarms/models/wizard_storytelling.py | 8 +- swarms/models/yarn_mistral.py | 8 +- swarms/prompts/agent_prompt.py | 16 +-- swarms/prompts/agent_prompts.py | 98 +++++++------- swarms/prompts/base.py | 4 +- swarms/prompts/chat_prompt.py | 4 +- swarms/prompts/multi_modal_prompts.py | 4 +- swarms/prompts/python.py | 61 ++++----- swarms/prompts/sales.py | 33 ++--- swarms/prompts/sales_prompts.py | 33 ++--- swarms/structs/autoscaler.py | 23 +++- swarms/structs/flow.py | 30 +++-- swarms/structs/non_linear_workflow.py | 15 ++- swarms/structs/sequential_workflow.py | 25 ++-- swarms/swarms/autobloggen.py | 10 +- swarms/swarms/base.py | 4 +- swarms/swarms/dialogue_simulator.py | 11 +- swarms/swarms/god_mode.py | 24 +++- swarms/swarms/groupchat.py | 14 +- swarms/swarms/multi_agent_collab.py | 24 +++- swarms/swarms/orchestrate.py | 23 +++- swarms/tools/autogpt.py | 24 +++- swarms/tools/mm_models.py | 79 +++++++----- swarms/tools/tool.py | 100 +++++++++++---- swarms/utils/apa.py | 4 +- swarms/utils/code_interpreter.py | 14 +- swarms/utils/decorators.py | 8 +- swarms/utils/futures.py | 4 +- swarms/utils/loggers.py | 28 ++-- swarms/utils/main.py | 16 ++- swarms/utils/parse_code.py | 4 +- swarms/utils/serializable.py | 8 +- tests/agents/omni_modal.py | 4 +- tests/memory/oceandb.py | 4 +- tests/memory/pinecone.py | 4 +- tests/models/LLM.py | 4 +- tests/models/anthropic.py | 20 ++- tests/models/auto_temp.py | 4 +- tests/models/bingchat.py | 4 +- tests/models/bioclip.py | 4 +- tests/models/biogpt.py | 5 +- tests/models/cohere.py | 106 +++++++++++---- tests/models/dalle3.py | 60 ++++++--- tests/models/distill_whisper.py | 35 +++-- tests/models/elevenlab.py | 8 +- tests/models/fuyu.py | 4 +- tests/models/gpt4v.py | 60 ++++++--- tests/models/hf.py | 13 +- tests/models/huggingface.py | 15 ++- tests/models/idefics.py | 25 ++-- tests/models/kosmos.py | 12 +- tests/models/kosmos2.py | 25 +++- tests/models/llama_function_caller.py | 12 +- tests/models/nougat.py | 7 +- tests/models/revgptv1.py | 13 +- tests/models/speech_t5.py | 12 +- tests/models/ssd_1b.py | 16 ++- tests/models/timm_model.py | 24 +++- tests/models/vilt.py | 4 +- tests/models/whisperx.py | 16 ++- tests/models/yi_200k.py | 24 +++- tests/structs/flow.py | 68 +++++++--- tests/swarms/godmode.py | 8 +- tests/swarms/groupchat.py | 12 +- tests/swarms/multi_agent_collab.py | 8 +- tests/swarms/multi_agent_debate.py | 9 +- tests/tools/base.py | 8 +- tests/utils/subprocess_code_interpreter.py | 52 ++++++-- 126 files changed, 1706 insertions(+), 728 deletions(-) diff --git a/playground/agents/mm_agent_example.py b/playground/agents/mm_agent_example.py index 0da0d469..5326af6e 100644 --- a/playground/agents/mm_agent_example.py +++ b/playground/agents/mm_agent_example.py @@ -9,6 +9,9 @@ text = node.run_text("What is your name? Generate a picture of yourself") img = node.run_img("/image1", "What is this image about?") chat = node.chat( - "What is your name? Generate a picture of yourself. What is this image about?", + ( + "What is your name? Generate a picture of yourself. What is this image" + " about?" + ), streaming=True, ) diff --git a/playground/agents/revgpt_agent.py b/playground/agents/revgpt_agent.py index 42d95359..16a720e8 100644 --- a/playground/agents/revgpt_agent.py +++ b/playground/agents/revgpt_agent.py @@ -10,13 +10,19 @@ config = { "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "PUID": os.getenv("REVGPT_PUID"), - "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")], + "unverified_plugin_domains": [ + os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS") + ], } llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config) worker = Worker(ai_name="Optimus Prime", llm=llm) -task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." +task = ( + "What were the winning boston marathon times for the past 5 years (ending" + " in 2022)? Generate a table of the year, name, country of origin, and" + " times." +) response = worker.run(task) print(response) diff --git a/playground/demos/accountant_team/accountant_team.py b/playground/demos/accountant_team/accountant_team.py index 61cc2f7a..d9edc2f6 100644 --- a/playground/demos/accountant_team/accountant_team.py +++ b/playground/demos/accountant_team/accountant_team.py @@ -103,7 +103,8 @@ class AccountantSwarms: # Provide decision making support to the accountant decision_making_support_agent_output = decision_making_support_agent.run( - f"{self.decision_making_support_agent_instructions}: {summary_agent_output}" + f"{self.decision_making_support_agent_instructions}:" + f" {summary_agent_output}" ) return decision_making_support_agent_output @@ -113,5 +114,7 @@ swarm = AccountantSwarms( pdf_path="tesla.pdf", fraud_detection_instructions="Detect fraud in the document", summary_agent_instructions="Generate an actionable summary of the document", - decision_making_support_agent_instructions="Provide decision making support to the business owner:", + decision_making_support_agent_instructions=( + "Provide decision making support to the business owner:" + ), ) diff --git a/playground/demos/ai_research_team/main.py b/playground/demos/ai_research_team/main.py index a297bc0a..77d8dbdc 100644 --- a/playground/demos/ai_research_team/main.py +++ b/playground/demos/ai_research_team/main.py @@ -48,6 +48,7 @@ paper_implementor_agent = Flow( paper = pdf_to_text(PDF_PATH) algorithmic_psuedocode_agent = paper_summarizer_agent.run( - f"Focus on creating the algorithmic pseudocode for the novel method in this paper: {paper}" + "Focus on creating the algorithmic pseudocode for the novel method in this" + f" paper: {paper}" ) pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent) diff --git a/playground/demos/autotemp/autotemp.py b/playground/demos/autotemp/autotemp.py index ab521606..b136bad7 100644 --- a/playground/demos/autotemp/autotemp.py +++ b/playground/demos/autotemp/autotemp.py @@ -9,11 +9,18 @@ class AutoTemp: """ def __init__( - self, api_key, default_temp=0.0, alt_temps=None, auto_select=True, max_workers=6 + self, + api_key, + default_temp=0.0, + alt_temps=None, + auto_select=True, + max_workers=6, ): self.api_key = api_key self.default_temp = default_temp - self.alt_temps = alt_temps if alt_temps else [0.4, 0.6, 0.8, 1.0, 1.2, 1.4] + self.alt_temps = ( + alt_temps if alt_temps else [0.4, 0.6, 0.8, 1.0, 1.2, 1.4] + ) self.auto_select = auto_select self.max_workers = max_workers self.llm = OpenAIChat( @@ -62,12 +69,15 @@ class AutoTemp: if not scores: return "No valid outputs generated.", None - sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) + sorted_scores = sorted( + scores.items(), key=lambda item: item[1], reverse=True + ) best_temp, best_score = sorted_scores[0] best_output = outputs[best_temp] return ( - f"Best AutoTemp Output (Temp {best_temp} | Score: {best_score}):\n{best_output}" + f"Best AutoTemp Output (Temp {best_temp} | Score:" + f" {best_score}):\n{best_output}" if self.auto_select else "\n".join( f"Temp {temp} | Score: {score}:\n{outputs[temp]}" diff --git a/playground/demos/blog_gen/blog_gen.py b/playground/demos/blog_gen/blog_gen.py index 3781d895..84ab240d 100644 --- a/playground/demos/blog_gen/blog_gen.py +++ b/playground/demos/blog_gen/blog_gen.py @@ -7,7 +7,10 @@ from swarms.structs import SequentialWorkflow class BlogGen: def __init__( - self, api_key, blog_topic, temperature_range: str = "0.4,0.6,0.8,1.0,1.2" + self, + api_key, + blog_topic, + temperature_range: str = "0.4,0.6,0.8,1.0,1.2", ): # Add blog_topic as an argument self.openai_chat = OpenAIChat(openai_api_key=api_key, temperature=0.8) self.auto_temp = AutoTemp(api_key) @@ -40,7 +43,10 @@ class BlogGen: topic_output = topic_result.generations[0][0].text print( colored( - f"\nTopic Selection Task Output:\n----------------------------\n{topic_output}\n", + ( + "\nTopic Selection Task" + f" Output:\n----------------------------\n{topic_output}\n" + ), "white", ) ) @@ -58,7 +64,10 @@ class BlogGen: initial_draft_output = auto_temp_output # Assuming AutoTemp.run returns the best output directly print( colored( - f"\nInitial Draft Output:\n----------------------------\n{initial_draft_output}\n", + ( + "\nInitial Draft" + f" Output:\n----------------------------\n{initial_draft_output}\n" + ), "white", ) ) @@ -71,7 +80,10 @@ class BlogGen: review_output = review_result.generations[0][0].text print( colored( - f"\nReview Output:\n----------------------------\n{review_output}\n", + ( + "\nReview" + f" Output:\n----------------------------\n{review_output}\n" + ), "white", ) ) @@ -80,22 +92,28 @@ class BlogGen: distribution_prompt = self.DISTRIBUTION_AGENT_SYSTEM_PROMPT.replace( "{{ARTICLE_TOPIC}}", chosen_topic ) - distribution_result = self.openai_chat.generate([distribution_prompt]) + distribution_result = self.openai_chat.generate( + [distribution_prompt] + ) distribution_output = distribution_result.generations[0][0].text print( colored( - f"\nDistribution Output:\n----------------------------\n{distribution_output}\n", + ( + "\nDistribution" + f" Output:\n----------------------------\n{distribution_output}\n" + ), "white", ) ) # Final compilation of the blog - final_blog_content = ( - f"{initial_draft_output}\n\n{review_output}\n\n{distribution_output}" - ) + final_blog_content = f"{initial_draft_output}\n\n{review_output}\n\n{distribution_output}" print( colored( - f"\nFinal Blog Content:\n----------------------------\n{final_blog_content}\n", + ( + "\nFinal Blog" + f" Content:\n----------------------------\n{final_blog_content}\n" + ), "green", ) ) diff --git a/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py b/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py index b462795f..a2602706 100644 --- a/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py +++ b/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py @@ -4,7 +4,10 @@ from swarms.models import Idefics # Multi Modality Auto Agent llm = Idefics(max_length=2000) -task = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +task = ( + "User: What is in this image?" + " https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +) ## Initialize the workflow flow = Flow( diff --git a/playground/demos/nutrition/nutrition.py b/playground/demos/nutrition/nutrition.py index c263f2cd..ffdafd7c 100644 --- a/playground/demos/nutrition/nutrition.py +++ b/playground/demos/nutrition/nutrition.py @@ -10,9 +10,16 @@ load_dotenv() openai_api_key = os.getenv("OPENAI_API_KEY") # Define prompts for various tasks -MEAL_PLAN_PROMPT = "Based on the following user preferences: dietary restrictions as vegetarian, preferred cuisines as Italian and Indian, a total caloric intake of around 2000 calories per day, and an exclusion of legumes, create a detailed weekly meal plan. Include a variety of meals for breakfast, lunch, dinner, and optional snacks." +MEAL_PLAN_PROMPT = ( + "Based on the following user preferences: dietary restrictions as" + " vegetarian, preferred cuisines as Italian and Indian, a total caloric" + " intake of around 2000 calories per day, and an exclusion of legumes," + " create a detailed weekly meal plan. Include a variety of meals for" + " breakfast, lunch, dinner, and optional snacks." +) IMAGE_ANALYSIS_PROMPT = ( - "Identify the items in this fridge, including their quantities and condition." + "Identify the items in this fridge, including their quantities and" + " condition." ) @@ -45,7 +52,9 @@ def create_vision_agent(image_path): {"type": "text", "text": IMAGE_ANALYSIS_PROMPT}, { "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, }, ], } @@ -53,7 +62,9 @@ def create_vision_agent(image_path): "max_tokens": 300, } response = requests.post( - "https://api.openai.com/v1/chat/completions", headers=headers, json=payload + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, ) return response.json() @@ -65,10 +76,11 @@ def generate_integrated_shopping_list( # Prepare the prompt for the LLM fridge_contents = image_analysis["choices"][0]["message"]["content"] prompt = ( - f"Based on this meal plan: {meal_plan_output}, " - f"and the following items in the fridge: {fridge_contents}, " - f"considering dietary preferences as vegetarian with a preference for Italian and Indian cuisines, " - f"generate a comprehensive shopping list that includes only the items needed." + f"Based on this meal plan: {meal_plan_output}, and the following items" + f" in the fridge: {fridge_contents}, considering dietary preferences as" + " vegetarian with a preference for Italian and Indian cuisines," + " generate a comprehensive shopping list that includes only the items" + " needed." ) # Send the prompt to the LLM and return the response @@ -94,7 +106,9 @@ user_preferences = { } # Generate Meal Plan -meal_plan_output = meal_plan_agent.run(f"Generate a meal plan: {user_preferences}") +meal_plan_output = meal_plan_agent.run( + f"Generate a meal plan: {user_preferences}" +) # Vision Agent - Analyze an Image image_analysis_output = create_vision_agent("full_fridge.jpg") diff --git a/playground/demos/positive_med/positive_med.py b/playground/demos/positive_med/positive_med.py index 6f7a2d3a..ea0c7c4e 100644 --- a/playground/demos/positive_med/positive_med.py +++ b/playground/demos/positive_med/positive_med.py @@ -39,9 +39,9 @@ def get_review_prompt(article): def social_media_prompt(article: str, goal: str = "Clicks and engagement"): - prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace("{{ARTICLE}}", article).replace( - "{{GOAL}}", goal - ) + prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace( + "{{ARTICLE}}", article + ).replace("{{GOAL}}", goal) return prompt @@ -50,7 +50,8 @@ topic_selection_task = ( "Generate 10 topics on gaining mental clarity using ancient practices" ) topics = llm( - f"Your System Instructions: {TOPIC_GENERATOR}, Your current task: {topic_selection_task}" + f"Your System Instructions: {TOPIC_GENERATOR}, Your current task:" + f" {topic_selection_task}" ) dashboard = print( @@ -109,7 +110,9 @@ reviewed_draft = print( # Agent that publishes on social media -distribution_agent = llm(social_media_prompt(draft_blog, goal="Clicks and engagement")) +distribution_agent = llm( + social_media_prompt(draft_blog, goal="Clicks and engagement") +) distribution_agent_out = print( colored( f""" diff --git a/playground/models/bioclip.py b/playground/models/bioclip.py index dcdd309b..11fb9f27 100644 --- a/playground/models/bioclip.py +++ b/playground/models/bioclip.py @@ -1,6 +1,8 @@ from swarms.models.bioclip import BioClip -clip = BioClip("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") +clip = BioClip( + "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" +) labels = [ "adenocarcinoma histopathology", diff --git a/playground/models/idefics.py b/playground/models/idefics.py index 032e0f3b..39d6f4eb 100644 --- a/playground/models/idefics.py +++ b/playground/models/idefics.py @@ -2,11 +2,17 @@ from swarms.models import idefics model = idefics() -user_input = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +user_input = ( + "User: What is in this image?" + " https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +) response = model.chat(user_input) print(response) -user_input = "User: And who is that? https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052" +user_input = ( + "User: And who is that?" + " https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052" +) response = model.chat(user_input) print(response) diff --git a/playground/models/llama_function_caller.py b/playground/models/llama_function_caller.py index 43bca3a5..201009a8 100644 --- a/playground/models/llama_function_caller.py +++ b/playground/models/llama_function_caller.py @@ -28,7 +28,9 @@ llama_caller.add_func( ) # Call the function -result = llama_caller.call_function("get_weather", location="Paris", format="Celsius") +result = llama_caller.call_function( + "get_weather", location="Paris", format="Celsius" +) print(result) # Stream a user prompt diff --git a/playground/models/vilt.py b/playground/models/vilt.py index 127514e0..8e40f59d 100644 --- a/playground/models/vilt.py +++ b/playground/models/vilt.py @@ -3,5 +3,6 @@ from swarms.models.vilt import Vilt model = Vilt() output = model( - "What is this image", "http://images.cocodataset.org/val2017/000000039769.jpg" + "What is this image", + "http://images.cocodataset.org/val2017/000000039769.jpg", ) diff --git a/playground/structs/flow_tools.py b/playground/structs/flow_tools.py index 647f6617..42ec0f72 100644 --- a/playground/structs/flow_tools.py +++ b/playground/structs/flow_tools.py @@ -30,7 +30,9 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = ( + phrase.strip() for line in lines for phrase in line.split(" ") + ) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -58,5 +60,6 @@ flow = Flow( ) out = flow.run( - "Generate a 10,000 word blog on mental clarity and the benefits of meditation." + "Generate a 10,000 word blog on mental clarity and the benefits of" + " meditation." ) diff --git a/playground/swarms/debate.py b/playground/swarms/debate.py index 2c47ed8e..4c97817d 100644 --- a/playground/swarms/debate.py +++ b/playground/swarms/debate.py @@ -36,7 +36,9 @@ class DialogueAgent: message = self.model( [ self.system_message, - HumanMessage(content="\n".join(self.message_history + [self.prefix])), + HumanMessage( + content="\n".join(self.message_history + [self.prefix]) + ), ] ) return message.content @@ -124,19 +126,19 @@ game_description = f"""Here is the topic for the presidential debate: {topic}. The presidential candidates are: {', '.join(character_names)}.""" player_descriptor_system_message = SystemMessage( - content="You can add detail to the description of each presidential candidate." + content=( + "You can add detail to the description of each presidential candidate." + ) ) def generate_character_description(character_name): character_specifier_prompt = [ player_descriptor_system_message, - HumanMessage( - content=f"""{game_description} + HumanMessage(content=f"""{game_description} Please reply with a creative description of the presidential candidate, {character_name}, in {word_limit} words or less, that emphasizes their personalities. Speak directly to {character_name}. - Do not add anything else.""" - ), + Do not add anything else."""), ] character_description = ChatOpenAI(temperature=1.0)( character_specifier_prompt @@ -155,9 +157,7 @@ Your goal is to be as creative as possible and make the voters think you are the def generate_character_system_message(character_name, character_header): - return SystemMessage( - content=( - f"""{character_header} + return SystemMessage(content=f"""{character_header} You will speak in the style of {character_name}, and exaggerate their personality. You will come up with creative ideas related to {topic}. Do not say the same things over and over again. @@ -169,13 +169,12 @@ Speak only from the perspective of {character_name}. Stop speaking the moment you finish speaking from your perspective. Never forget to keep your response to {word_limit} words! Do not add anything else. - """ - ) - ) + """) character_descriptions = [ - generate_character_description(character_name) for character_name in character_names + generate_character_description(character_name) + for character_name in character_names ] character_headers = [ generate_character_header(character_name, character_description) @@ -185,7 +184,9 @@ character_headers = [ ] character_system_messages = [ generate_character_system_message(character_name, character_headers) - for character_name, character_headers in zip(character_names, character_headers) + for character_name, character_headers in zip( + character_names, character_headers + ) ] for ( @@ -207,7 +208,10 @@ for ( class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: - return "Your response should be an integer delimited by angled brackets, like this: ." + return ( + "Your response should be an integer delimited by angled brackets," + " like this: ." + ) bid_parser = BidOutputParser( @@ -248,8 +252,7 @@ for character_name, bidding_template in zip( topic_specifier_prompt = [ SystemMessage(content="You can make a task more specific."), - HumanMessage( - content=f"""{game_description} + HumanMessage(content=f"""{game_description} You are the debate moderator. Please make the debate topic more specific. @@ -257,8 +260,7 @@ topic_specifier_prompt = [ Be creative and imaginative. Please reply with the specified topic in {word_limit} words or less. Speak directly to the presidential candidates: {*character_names,}. - Do not add anything else.""" - ), + Do not add anything else."""), ] specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content @@ -321,7 +323,9 @@ for character_name, character_system_message, bidding_template in zip( max_iters = 10 n = 0 -simulator = DialogueSimulator(agents=characters, selection_function=select_next_speaker) +simulator = DialogueSimulator( + agents=characters, selection_function=select_next_speaker +) simulator.reset() simulator.inject("Debate Moderator", specified_topic) print(f"(Debate Moderator): {specified_topic}") diff --git a/playground/swarms/multi_agent_debate.py b/playground/swarms/multi_agent_debate.py index f0bec797..d5382e56 100644 --- a/playground/swarms/multi_agent_debate.py +++ b/playground/swarms/multi_agent_debate.py @@ -36,7 +36,11 @@ agents = [worker1, worker2, worker3] debate = MultiAgentDebate(agents, select_speaker) # Run task -task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." +task = ( + "What were the winning boston marathon times for the past 5 years (ending" + " in 2022)? Generate a table of the year, name, country of origin, and" + " times." +) results = debate.run(task, max_iters=4) # Print results diff --git a/playground/swarms/orchestrate.py b/playground/swarms/orchestrate.py index e43b75e3..a90a72e8 100644 --- a/playground/swarms/orchestrate.py +++ b/playground/swarms/orchestrate.py @@ -10,4 +10,6 @@ node = Worker( orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) # Agent 7 sends a message to Agent 9 -orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") +orchestrator.chat( + sender_id=7, receiver_id=9, message="Can you help me with this task?" +) diff --git a/playground/swarms/orchestrator.py b/playground/swarms/orchestrator.py index e43b75e3..a90a72e8 100644 --- a/playground/swarms/orchestrator.py +++ b/playground/swarms/orchestrator.py @@ -10,4 +10,6 @@ node = Worker( orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) # Agent 7 sends a message to Agent 9 -orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") +orchestrator.chat( + sender_id=7, receiver_id=9, message="Can you help me with this task?" +) diff --git a/playground/swarms/swarms_example.py b/playground/swarms/swarms_example.py index 6dabe4a1..23b714d9 100644 --- a/playground/swarms/swarms_example.py +++ b/playground/swarms/swarms_example.py @@ -7,7 +7,10 @@ api_key = "" swarm = HierarchicalSwarm(api_key) # Define an objective -objective = "Find 20 potential customers for a HierarchicalSwarm based AI Agent automation infrastructure" +objective = ( + "Find 20 potential customers for a HierarchicalSwarm based AI Agent" + " automation infrastructure" +) # Run HierarchicalSwarm swarm.run(objective) diff --git a/pyproject.toml b/pyproject.toml index 3dbf8570..eea95362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,4 +84,11 @@ recursive = true aggressive = 3 [tool.ruff] -line-length = 80 \ No newline at end of file +line-length = 80 + +[tool.black] +line-length = 80 +target-version = ['py38'] +preview = true + + diff --git a/swarms/agents/omni_modal_agent.py b/swarms/agents/omni_modal_agent.py index 007a2219..6a22c477 100644 --- a/swarms/agents/omni_modal_agent.py +++ b/swarms/agents/omni_modal_agent.py @@ -18,7 +18,12 @@ from swarms.agents.message import Message class Step: def __init__( - self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool + self, + task: str, + id: int, + dep: List[int], + args: Dict[str, str], + tool: BaseTool, ): self.task = task self.id = id diff --git a/swarms/memory/base.py b/swarms/memory/base.py index 7f71c4b9..3ca49617 100644 --- a/swarms/memory/base.py +++ b/swarms/memory/base.py @@ -37,7 +37,7 @@ class BaseVectorStore(ABC): self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> None: execute_futures_dict( { @@ -54,7 +54,7 @@ class BaseVectorStore(ABC): artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: if not meta: meta = {} @@ -67,7 +67,11 @@ class BaseVectorStore(ABC): vector = artifact.generate_embedding(self.embedding_driver) return self.upsert_vector( - vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs + vector, + vector_id=artifact.id, + namespace=namespace, + meta=meta, + **kwargs, ) def upsert_text( @@ -76,14 +80,14 @@ class BaseVectorStore(ABC): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: return self.upsert_vector( self.embedding_driver.embed_string(string), vector_id=vector_id, namespace=namespace, meta=meta if meta else {}, - **kwargs + **kwargs, ) @abstractmethod @@ -93,12 +97,14 @@ class BaseVectorStore(ABC): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: ... @abstractmethod - def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: + def load_entry( + self, vector_id: str, namespace: Optional[str] = None + ) -> Entry: ... @abstractmethod @@ -112,6 +118,6 @@ class BaseVectorStore(ABC): count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, - **kwargs + **kwargs, ) -> list[QueryResult]: ... diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index 67ba4cb2..2f4e473f 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -111,7 +111,9 @@ class Chroma(VectorStore): chroma_db_impl="duckdb+parquet", ) else: - _client_settings = chromadb.config.Settings(is_persistent=True) + _client_settings = chromadb.config.Settings( + is_persistent=True + ) _client_settings.persist_directory = persist_directory else: _client_settings = chromadb.config.Settings() @@ -124,9 +126,11 @@ class Chroma(VectorStore): self._embedding_function = embedding_function self._collection = self._client.get_or_create_collection( name=collection_name, - embedding_function=self._embedding_function.embed_documents - if self._embedding_function is not None - else None, + embedding_function=( + self._embedding_function.embed_documents + if self._embedding_function is not None + else None + ), metadata=collection_metadata, ) self.override_relevance_score_fn = relevance_score_fn @@ -203,7 +207,9 @@ class Chroma(VectorStore): metadatas = [metadatas[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids] embeddings_with_metadatas = ( - [embeddings[idx] for idx in non_empty_ids] if embeddings else None + [embeddings[idx] for idx in non_empty_ids] + if embeddings + else None ) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: @@ -216,7 +222,8 @@ class Chroma(VectorStore): except ValueError as e: if "Expected metadata value to be" in str(e): msg = ( - "Try filtering complex metadata from the document using " + "Try filtering complex metadata from the document" + " using " "langchain.vectorstores.utils.filter_complex_metadata." ) raise ValueError(e.args[0] + "\n\n" + msg) @@ -258,7 +265,9 @@ class Chroma(VectorStore): Returns: List[Document]: List of documents most similar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + docs_and_scores = self.similarity_search_with_score( + query, k, filter=filter + ) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector( @@ -428,7 +437,9 @@ class Chroma(VectorStore): candidates = _results_to_docs(results) - selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + selected_results = [ + r for i, r in enumerate(candidates) if i in mmr_selected + ] return selected_results def max_marginal_relevance_search( @@ -460,7 +471,8 @@ class Chroma(VectorStore): """ if self._embedding_function is None: raise ValueError( - "For MMR search, you must specify an embedding function oncreation." + "For MMR search, you must specify an embedding function" + " oncreation." ) embedding = self._embedding_function.embed_query(query) @@ -543,7 +555,9 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], documents: List[Document]) -> None: + def update_documents( + self, ids: List[str], documents: List[Document] + ) -> None: """Update a document in the collection. Args: @@ -554,7 +568,8 @@ class Chroma(VectorStore): metadata = [document.metadata for document in documents] if self._embedding_function is None: raise ValueError( - "For update, you must specify an embedding function on creation." + "For update, you must specify an embedding function on" + " creation." ) embeddings = self._embedding_function.embed_documents(text) @@ -645,7 +660,9 @@ class Chroma(VectorStore): ids=batch[0], ) else: - chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) + chroma_collection.add_texts( + texts=texts, metadatas=metadatas, ids=ids + ) return chroma_collection @classmethod diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index 99d47368..cdcd1a2b 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -18,8 +18,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: Y = np.array(Y) if X.shape[1] != Y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." + "Number of columns in X and Y must be the same. X has shape" + f" {X.shape} and Y has shape {Y.shape}." ) try: import simsimd as simd @@ -32,8 +32,9 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: return Z except ImportError: logger.info( - "Unable to import simsimd, defaulting to NumPy implementation. If you want " - "to use simsimd please install with `pip install simsimd`." + "Unable to import simsimd, defaulting to NumPy implementation. If" + " you want to use simsimd please install with `pip install" + " simsimd`." ) X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) diff --git a/swarms/memory/db.py b/swarms/memory/db.py index 9f23b59f..4ffec16f 100644 --- a/swarms/memory/db.py +++ b/swarms/memory/db.py @@ -151,7 +151,9 @@ class InMemoryTaskDB(TaskDB): ) -> Artifact: artifact_id = str(uuid.uuid4()) artifact = Artifact( - artifact_id=artifact_id, file_name=file_name, relative_path=relative_path + artifact_id=artifact_id, + file_name=file_name, + relative_path=relative_path, ) task = await self.get_task(task_id) task.artifacts.append(artifact) diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py index da58c81c..fb0873af 100644 --- a/swarms/memory/ocean.py +++ b/swarms/memory/ocean.py @@ -91,7 +91,9 @@ class OceanDB: try: return collection.add(documents=[document], ids=[id]) except Exception as e: - logging.error(f"Failed to append document to the collection. Error {e}") + logging.error( + f"Failed to append document to the collection. Error {e}" + ) raise def add_documents(self, collection, documents: List[str], ids: List[str]): @@ -137,7 +139,9 @@ class OceanDB: the results of the query """ try: - results = collection.query(query_texts=query_texts, n_results=n_results) + results = collection.query( + query_texts=query_texts, n_results=n_results + ) return results except Exception as e: logging.error(f"Failed to query the collection. Error {e}") diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index a421c887..ce591c6e 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -89,11 +89,15 @@ class PgVectorVectorStore(BaseVectorStore): engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) _model: any = field( - default=Factory(lambda self: self.default_vector_model(), takes_self=True) + default=Factory( + lambda self: self.default_vector_model(), takes_self=True + ) ) @connection_string.validator - def validate_connection_string(self, _, connection_string: Optional[str]) -> None: + def validate_connection_string( + self, _, connection_string: Optional[str] + ) -> None: # If an engine is provided, the connection string is not used. if self.engine is not None: return @@ -104,7 +108,8 @@ class PgVectorVectorStore(BaseVectorStore): if not connection_string.startswith("postgresql://"): raise ValueError( - "The connection string must describe a Postgres database connection" + "The connection string must describe a Postgres database" + " connection" ) @engine.validator @@ -148,7 +153,7 @@ class PgVectorVectorStore(BaseVectorStore): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: """Inserts or updates a vector in the collection.""" with Session(self.engine) as session: @@ -208,7 +213,7 @@ class PgVectorVectorStore(BaseVectorStore): namespace: Optional[str] = None, include_vectors: bool = False, distance_metric: str = "cosine_distance", - **kwargs + **kwargs, ) -> list[BaseVectorStore.QueryResult]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace. diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 2374f12a..a7eb7442 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -108,7 +108,7 @@ class PineconeVectorStoreStore(BaseVector): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: """Upsert vector""" vector_id = vector_id if vector_id else str_to_hash(str(vector)) @@ -123,7 +123,9 @@ class PineconeVectorStoreStore(BaseVector): self, vector_id: str, namespace: Optional[str] = None ) -> Optional[BaseVector.Entry]: """Load entry""" - result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() + result = self.index.fetch( + ids=[vector_id], namespace=namespace + ).to_dict() vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -138,7 +140,9 @@ class PineconeVectorStoreStore(BaseVector): else: return None - def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]: + def load_entries( + self, namespace: Optional[str] = None + ) -> list[BaseVector.Entry]: """Load entries""" # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: @@ -169,7 +173,7 @@ class PineconeVectorStoreStore(BaseVector): include_vectors: bool = False, # PineconeVectorStoreStorageDriver-specific params: include_metadata=True, - **kwargs + **kwargs, ) -> list[BaseVector.QueryResult]: """Query vectors""" vector = self.embedding_driver.embed_string(query) @@ -196,6 +200,9 @@ class PineconeVectorStoreStore(BaseVector): def create_index(self, name: str, **kwargs) -> None: """Create index""" - params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs + params = { + "name": name, + "dimension": self.embedding_driver.dimensions, + } | kwargs pinecone.create_index(**params) diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index bbc71bc2..89f1453b 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -50,7 +50,9 @@ class StepInput(BaseModel): class StepOutput(BaseModel): __root__: Any = Field( ..., - description="Output that the task step has produced. Any value is allowed.", + description=( + "Output that the task step has produced. Any value is allowed." + ), example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -112,8 +114,9 @@ class Step(StepRequestBody): None, description="Output of the task step.", example=( - "I am going to use the write_to_file command and write Washington to a file" - " called output.txt List[Document]: """Filter out metadata types that are not supported for a vector store.""" updated_documents = [] diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index f509087c..10bf2fab 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -7,7 +7,11 @@ sys.stderr = log_file from swarms.models.anthropic import Anthropic # noqa: E402 from swarms.models.petals import Petals # noqa: E402 from swarms.models.mistral import Mistral # noqa: E402 -from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402 +from swarms.models.openai_models import ( + OpenAI, + AzureOpenAI, + OpenAIChat, +) # noqa: E402 from swarms.models.zephyr import Zephyr # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index edaae087..1f47e1bf 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -50,7 +50,9 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: ] invalid_groups = [i for i, count in enumerate(counts) if count != 1] if invalid_groups: - invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] + invalid_group_names = [ + ", ".join(arg_groups[i]) for i in invalid_groups + ] raise ValueError( "Exactly one argument in each of the following" " groups must be defined:" @@ -106,7 +108,10 @@ def mock_now(dt_value): # type: ignore def guard_import( - module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None + module_name: str, + *, + pip_name: Optional[str] = None, + package: Optional[str] = None, ) -> Any: """Dynamically imports a module and raises a helpful exception if the module is not installed.""" @@ -180,18 +185,18 @@ def build_extra_kwargs( if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. + warnings.warn(f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) + Please confirm that {field_name} is what you intended.""") extra_kwargs[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra_kwargs.keys() + ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be specified explicitly." + " Instead they were passed in as part of `model_kwargs` parameter." ) return extra_kwargs @@ -250,7 +255,9 @@ class _AnthropicCommon(BaseLanguageModel): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") + get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) ) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( @@ -305,7 +312,9 @@ class _AnthropicCommon(BaseLanguageModel): """Get the identifying parameters.""" return {**{}, **self._default_params} - def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + def _get_anthropic_stop( + self, stop: Optional[List[str]] = None + ) -> List[str]: if not self.HUMAN_PROMPT or not self.AI_PROMPT: raise NameError("Please ensure the anthropic package is loaded") @@ -354,8 +363,8 @@ class Anthropic(LLM, _AnthropicCommon): def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( - "This Anthropic LLM is deprecated. " - "Please use `from langchain.chat_models import ChatAnthropic` instead" + "This Anthropic LLM is deprecated. Please use `from" + " langchain.chat_models import ChatAnthropic` instead" ) return values @@ -372,12 +381,16 @@ class Anthropic(LLM, _AnthropicCommon): return prompt # Already wrapped. # Guard against common errors in specifying wrong number of newlines. - corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) + corrected_prompt, n_subs = re.subn( + r"^\n*Human:", self.HUMAN_PROMPT, prompt + ) if n_subs == 1: return corrected_prompt # As a last resort, wrap the prompt ourselves to emulate instruct-style. - return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + return ( + f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + ) def _call( self, @@ -476,7 +489,10 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} for token in self.client.completions.create( - prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, ): chunk = GenerationChunk(text=token.completion) yield chunk diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py index c2b4bfa5..1c2627a6 100644 --- a/swarms/models/bioclip.py +++ b/swarms/models/bioclip.py @@ -98,7 +98,9 @@ class BioClip: ) = open_clip.create_model_and_transforms(model_path) self.tokenizer = open_clip.get_tokenizer(model_path) self.device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") ) self.model.to(self.device) self.model.eval() @@ -110,13 +112,17 @@ class BioClip: template: str = "this is a photo of ", context_length: int = 256, ): - image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device) + image = torch.stack([self.preprocess_val(Image.open(img_path))]).to( + self.device + ) texts = self.tokenizer( [template + l for l in labels], context_length=context_length ).to(self.device) with torch.no_grad(): - image_features, text_features, logit_scale = self.model(image, texts) + image_features, text_features, logit_scale = self.model( + image, texts + ) logits = ( (logit_scale * image_features @ text_features.t()) .detach() @@ -142,7 +148,9 @@ class BioClip: title = ( metadata["filename"] + "\n" - + "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()]) + + "\n".join( + [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()] + ) ) ax.set_title(title, fontsize=14) plt.tight_layout() diff --git a/swarms/models/biogpt.py b/swarms/models/biogpt.py index 83c31e55..d5e692f2 100644 --- a/swarms/models/biogpt.py +++ b/swarms/models/biogpt.py @@ -154,7 +154,7 @@ class BioGPT: min_length=self.min_length, max_length=self.max_length, num_beams=num_beams, - early_stopping=early_stopping + early_stopping=early_stopping, ) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py index c583b827..508e9073 100644 --- a/swarms/models/cohere_chat.py +++ b/swarms/models/cohere_chat.py @@ -96,7 +96,9 @@ class BaseCohere(Serializable): values, "cohere_api_key", "COHERE_API_KEY" ) client_name = values["user_agent"] - values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + values["client"] = cohere.Client( + cohere_api_key, client_name=client_name + ) values["async_client"] = cohere.AsyncClient( cohere_api_key, client_name=client_name ) @@ -172,17 +174,23 @@ class Cohere(LLM, BaseCohere): """Return type of llm.""" return "cohere" - def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: + def _invocation_params( + self, stop: Optional[List[str]], **kwargs: Any + ) -> dict: params = self._default_params if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) elif self.stop is not None: params["stop_sequences"] = self.stop else: params["stop_sequences"] = stop return {**params, **kwargs} - def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: + def _process_response( + self, response: Any, stop: Optional[List[str]] + ) -> str: text = response.generations[0].text # If stop tokens are provided, Cohere's endpoint returns them. # In order to make this consistent with other endpoints, we strip them. diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 7d9bcf5d..3c130670 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -169,8 +169,8 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -234,8 +234,8 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -248,8 +248,7 @@ class Dalle3: """Print the Dalle3 dashboard""" print( colored( - ( - f"""Dalle3 Dashboard: + f"""Dalle3 Dashboard: -------------------- Model: {self.model} @@ -265,13 +264,14 @@ class Dalle3: -------------------- - """ - ), + """, "green", ) ) - def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): + def process_batch_concurrently( + self, tasks: List[str], max_workers: int = 5 + ): """ Process a batch of tasks concurrently @@ -293,8 +293,12 @@ class Dalle3: ['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png', """ - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_task = {executor.submit(self, task): task for task in tasks} + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = { + executor.submit(self, task): task for task in tasks + } results = [] for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] @@ -307,14 +311,20 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing" + " your api key and or try again" ), "red", ) ) - print(colored(f"Error running Dalle3: {error.http_status}", "red")) - print(colored(f"Error running Dalle3: {error.error}", "red")) + print( + colored( + f"Error running Dalle3: {error.http_status}", "red" + ) + ) + print( + colored(f"Error running Dalle3: {error.error}", "red") + ) raise error def _generate_uuid(self): diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 98b3660a..2b4fb5a5 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -28,7 +28,10 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): retries -= 1 if retries <= 0: raise - print(f"Retry after exception: {e}, Attempts remaining: {retries}") + print( + f"Retry after exception: {e}, Attempts remaining:" + f" {retries}" + ) await asyncio.sleep(delay) return wrapper @@ -62,7 +65,9 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + self.torch_dtype = ( + torch.float16 if torch.cuda.is_available() else torch.float32 + ) self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, @@ -119,7 +124,9 @@ class DistilWhisperModel: try: with torch.no_grad(): # Load the whole audio file, but process and transcribe it in chunks - audio_input = self.processor.audio_file_to_array(audio_file_path) + audio_input = self.processor.audio_file_to_array( + audio_file_path + ) sample_rate = audio_input.sampling_rate len(audio_input.array) / sample_rate chunks = [ @@ -139,7 +146,9 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = processed_inputs.input_values.to(self.device) + processed_inputs = processed_inputs.input_values.to( + self.device + ) # Generate transcription for the chunk logits = self.model.generate(processed_inputs) @@ -157,4 +166,6 @@ class DistilWhisperModel: time.sleep(chunk_duration) except Exception as e: - print(colored(f"An error occurred during transcription: {e}", "red")) + print( + colored(f"An error occurred during transcription: {e}", "red") + ) diff --git a/swarms/models/eleven_labs.py b/swarms/models/eleven_labs.py index 2fece5b6..42f4dae1 100644 --- a/swarms/models/eleven_labs.py +++ b/swarms/models/eleven_labs.py @@ -79,7 +79,9 @@ class ElevenLabsText2SpeechTool(BaseTool): f.write(speech) return f.name except Exception as e: - raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}") + raise RuntimeError( + f"Error while running ElevenLabsText2SpeechTool: {e}" + ) def play(self, speech_file: str) -> None: """Play the text as speech.""" @@ -93,7 +95,9 @@ class ElevenLabsText2SpeechTool(BaseTool): """Stream the text as speech as it is generated. Play the text in your speakers.""" elevenlabs = _import_elevenlabs() - speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True) + speech_stream = elevenlabs.generate( + text=query, model=self.model, stream=True + ) elevenlabs.stream(speech_stream) def save(self, speech_file: str, path: str) -> None: diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index d0478777..c9a0d719 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -10,7 +10,9 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the classes for image classification -with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f: +with open( + os.path.join(os.path.dirname(__file__), "fast_vit_classes.json") +) as f: FASTVIT_IMAGENET_1K_CLASSES = json.load(f) @@ -20,7 +22,9 @@ class ClassificationResult(BaseModel): @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): - assert isinstance(v, int) or isinstance(v, float), "must be integer or float" + assert isinstance(v, int) or isinstance( + v, float + ), "must be integer or float" return v @@ -50,7 +54,9 @@ class FastViT: "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True ).to(DEVICE) data_config = timm.data.resolve_model_data_config(self.model) - self.transforms = timm.data.create_transform(**data_config, is_training=False) + self.transforms = timm.data.create_transform( + **data_config, is_training=False + ) self.model.eval() def __call__( diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index 02ab3a25..ed955260 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -46,7 +46,9 @@ class Fuyu: self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.image_processor = FuyuImageProcessor() self.processor = FuyuProcessor( - image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs + image_processor=self.image_processor, + tokenizer=self.tokenizer, + **kwargs, ) self.model = FuyuForCausalLM.from_pretrained( pretrained_path, @@ -69,8 +71,12 @@ class Fuyu: for k, v in model_inputs.items(): model_inputs[k] = v.to(self.device_map) - output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens) - text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True) + output = self.model.generate( + **model_inputs, max_new_tokens=self.max_new_tokens + ) + text = self.processor.batch_decode( + output[:, -7:], skip_special_tokens=True + ) return print(str(text)) def get_img_from_web(self, img_url: str): diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py index 8411cb14..8f2683e0 100644 --- a/swarms/models/gpt4v.py +++ b/swarms/models/gpt4v.py @@ -190,12 +190,15 @@ class GPT4Vision: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run, task, img) for task, img in tasks_images + executor.submit(self.run, task, img) + for task, img in tasks_images ] results = [future.result() for future in futures] return results - async def run_batch_async(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + async def run_batch_async( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images asynchronously""" loop = asyncio.get_event_loop() futures = [ diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 0f226740..1db435f5 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -133,7 +133,9 @@ class HuggingfaceLLM: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -178,7 +180,11 @@ class HuggingfaceLLM: except Exception as e: # self.logger.error(f"Failed to load the model or the tokenizer: {e}") # raise - print(colored(f"Failed to load the model and or the tokenizer: {e}", "red")) + print( + colored( + f"Failed to load the model and or the tokenizer: {e}", "red" + ) + ) def print_error(self, error: str): """Print error""" @@ -207,12 +213,16 @@ class HuggingfaceLLM: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def concurrent_run(self, tasks: List[str], max_workers: int = 5): """Concurrently generate text for a list of prompts.""" - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: results = list(executor.map(self.run, tasks)) return results @@ -220,7 +230,8 @@ class HuggingfaceLLM: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run, task, img) for task, img in tasks_images + executor.submit(self.run, task, img) + for task, img in tasks_images ] results = [future.result() for future in futures] return results @@ -243,7 +254,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -279,8 +292,8 @@ class HuggingfaceLLM: print( colored( ( - f"HuggingfaceLLM could not generate text because of error: {e}," - " try optimizing your arguments" + "HuggingfaceLLM could not generate text because of" + f" error: {e}, try optimizing your arguments" ), "red", ) @@ -305,7 +318,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to( + self.device + ) # self.log.start() diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 73cb4991..0cfcf1af 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -66,7 +66,9 @@ class Idefics: max_length=100, ): self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model = IdeficsForVisionText2Text.from_pretrained( checkpoint, diff --git a/swarms/models/jina_embeds.py b/swarms/models/jina_embeds.py index a72b8a9e..1d1ac3e6 100644 --- a/swarms/models/jina_embeds.py +++ b/swarms/models/jina_embeds.py @@ -54,7 +54,9 @@ class JinaEmbeddings: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -83,7 +85,9 @@ class JinaEmbeddings: try: self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, trust_remote_code=True + self.model_id, + quantization_config=bnb_config, + trust_remote_code=True, ) self.model # .to(self.device) @@ -112,7 +116,9 @@ class JinaEmbeddings: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, task: str): diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index b0e1a9f6..f81e0fdf 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -70,11 +70,13 @@ class Kosmos2(BaseModel): prompt = "An image of" inputs = self.processor(text=prompt, images=image, return_tensors="pt") - outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64) + outputs = self.model.generate( + **inputs, use_cache=True, max_new_tokens=64 + ) - generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[ - 0 - ] + generated_text = self.processor.batch_decode( + outputs, skip_special_tokens=True + )[0] # The actual processing of generated_text to entities would go here # For the purpose of this example, assume a mock function 'extract_entities' exists: @@ -99,7 +101,9 @@ class Kosmos2(BaseModel): if not entities: return Detections.empty() - class_ids = [0] * len(entities) # Replace with actual class ID extraction logic + class_ids = [0] * len( + entities + ) # Replace with actual class ID extraction logic xyxys = [ ( e[1][0] * image.width, @@ -111,7 +115,9 @@ class Kosmos2(BaseModel): ] confidences = [1.0] * len(entities) # Placeholder confidence - return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences) + return Detections( + xyxy=xyxys, class_id=class_ids, confidence=confidences + ) # Usage: diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index 596886f3..c696ef34 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -145,12 +145,12 @@ class Kosmos: elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() - reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ - :, None, None - ] - reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ - :, None, None - ] + reverse_norm_mean = torch.tensor( + [0.48145466, 0.4578275, 0.40821073] + )[:, None, None] + reverse_norm_std = torch.tensor( + [0.26862954, 0.26130258, 0.27577711] + )[:, None, None] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height @@ -188,7 +188,11 @@ class Kosmos: # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) new_image = cv2.rectangle( - new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line + new_image, + (orig_x1, orig_y1), + (orig_x2, orig_y2), + color, + box_line, ) l_o, r_o = ( @@ -211,7 +215,10 @@ class Kosmos: # add text background (text_width, text_height), _ = cv2.getTextSize( - f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + f" {entity_name}", + cv2.FONT_HERSHEY_COMPLEX, + text_size, + text_line, ) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( x1, @@ -222,7 +229,8 @@ class Kosmos: for prev_bbox in previous_bboxes: while is_overlapping( - (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), + prev_bbox, ): text_bg_y1 += ( text_height + text_offset_original + 2 * text_spaces @@ -230,14 +238,18 @@ class Kosmos: text_bg_y2 += ( text_height + text_offset_original + 2 * text_spaces ) - y1 += text_height + text_offset_original + 2 * text_spaces + y1 += ( + text_height + text_offset_original + 2 * text_spaces + ) if text_bg_y2 >= image_h: text_bg_y1 = max( 0, image_h - ( - text_height + text_offset_original + 2 * text_spaces + text_height + + text_offset_original + + 2 * text_spaces ), ) text_bg_y2 = image_h @@ -270,7 +282,9 @@ class Kosmos: cv2.LINE_AA, ) # previous_locations.append((x1, y1)) - previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) + previous_bboxes.append( + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2) + ) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) if save_path: diff --git a/swarms/models/llama_function_caller.py b/swarms/models/llama_function_caller.py index a991641a..ca5ee5d3 100644 --- a/swarms/models/llama_function_caller.py +++ b/swarms/models/llama_function_caller.py @@ -121,7 +121,11 @@ class LlamaFunctionCaller: ) def add_func( - self, name: str, function: Callable, description: str, arguments: List[Dict] + self, + name: str, + function: Callable, + description: str, + arguments: List[Dict], ): """ Adds a new function to the LlamaFunctionCaller. @@ -172,12 +176,17 @@ class LlamaFunctionCaller: if self.streaming: out = self.model.generate( - **inputs, streamer=streamer, max_new_tokens=self.max_tokens, **kwargs + **inputs, + streamer=streamer, + max_new_tokens=self.max_tokens, + **kwargs, ) return out else: - out = self.model.generate(**inputs, max_length=self.max_tokens, **kwargs) + out = self.model.generate( + **inputs, max_length=self.max_tokens, **kwargs + ) # return self.tokenizer.decode(out[0], skip_special_tokens=True) return out diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 7f48a0d6..056a31bb 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -49,7 +49,9 @@ class Mistral: # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": - raise ValueError("CUDA is not available. Please choose a different device.") + raise ValueError( + "CUDA is not available. Please choose a different device." + ) # Load the model and tokenizer self.model = None @@ -70,7 +72,9 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to( + self.device + ) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -87,7 +91,9 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to( + self.device + ) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index 46d1a357..c304355a 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -29,7 +29,9 @@ class MPT7B: """ - def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100): + def __init__( + self, model_name: str, tokenizer_name: str, max_tokens: int = 100 + ): # Loading model and tokenizer details self.model_name = model_name self.tokenizer_name = tokenizer_name @@ -118,7 +120,10 @@ class MPT7B: """ with torch.autocast("cuda", dtype=torch.bfloat16): return self.pipe( - prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True + prompt, + max_new_tokens=self.max_tokens, + do_sample=True, + use_cache=True, )[0]["generated_text"] async def generate_async(self, prompt: str) -> str: diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index f156981c..82bb95f5 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -41,8 +41,12 @@ class Nougat: self.min_length = min_length self.max_new_tokens = max_new_tokens - self.processor = NougatProcessor.from_pretrained(self.model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path) + self.processor = NougatProcessor.from_pretrained( + self.model_name_or_path + ) + self.model = VisionEncoderDecoderModel.from_pretrained( + self.model_name_or_path + ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) @@ -63,8 +67,12 @@ class Nougat: max_new_tokens=self.max_new_tokens, ) - sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] - sequence = self.processor.post_process_generation(sequence, fix_markdown=False) + sequence = self.processor.batch_decode( + outputs, skip_special_tokens=True + )[0] + sequence = self.processor.post_process_generation( + sequence, fix_markdown=False + ) out = print(sequence) return out diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 81dea550..08919d45 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -43,7 +43,9 @@ def get_pydantic_field_names(cls: Any) -> Set[str]: logger = logging.getLogger(__name__) -def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: +def _create_retry_decorator( + embeddings: OpenAIEmbeddings, +) -> Callable[[Any], Any]: import llm min_seconds = 4 @@ -118,7 +120,9 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) -async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: +async def async_embed_with_retry( + embeddings: OpenAIEmbeddings, **kwargs: Any +) -> Any: """Use tenacity to retry the embedding call.""" @_async_retry_decorator(embeddings) @@ -172,7 +176,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "text-embedding-ada-002" - deployment: str = model # to support Azure OpenAI Service custom deployment names + deployment: str = ( + model # to support Azure OpenAI Service custom deployment names + ) openai_api_version: Optional[str] = None # to support Azure OpenAI Service custom endpoints openai_api_base: Optional[str] = None @@ -229,11 +235,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) extra[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra.keys() + ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be specified" + " explicitly. Instead they were passed in as part of" + " `model_kwargs` parameter." ) values["model_kwargs"] = extra @@ -333,7 +342,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -384,11 +395,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): self, input="", **self._invocation_params, - )[ - "data" - ][0]["embedding"] + )["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average( + _result, axis=0, weights=num_tokens_in_batch[i] + ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -414,7 +425,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -458,7 +471,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) )["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average( + _result, axis=0, weights=num_tokens_in_batch[i] + ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -495,7 +510,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + return await self._aget_len_safe_embeddings( + texts, engine=self.deployment + ) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index bac0f28d..f0c41f2a 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -146,7 +146,8 @@ class OpenAIFunctionCaller: self.messages.append({"role": role, "content": content}) @retry( - wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3) + wait=wait_random_exponential(multiplier=1, max=40), + stop=stop_after_attempt(3), ) def chat_completion_request( self, @@ -194,17 +195,22 @@ class OpenAIFunctionCaller: elif message["role"] == "user": print( colored( - f"user: {message['content']}\n", role_to_color[message["role"]] + f"user: {message['content']}\n", + role_to_color[message["role"]], ) ) - elif message["role"] == "assistant" and message.get("function_call"): + elif message["role"] == "assistant" and message.get( + "function_call" + ): print( colored( f"assistant: {message['function_call']}\n", role_to_color[message["role"]], ) ) - elif message["role"] == "assistant" and not message.get("function_call"): + elif message["role"] == "assistant" and not message.get( + "function_call" + ): print( colored( f"assistant: {message['content']}\n", diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index fcf4a223..0547a264 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -62,19 +62,25 @@ def _stream_response_to_generation_chunk( return GenerationChunk( text=stream_response["choices"][0]["text"], generation_info=dict( - finish_reason=stream_response["choices"][0].get("finish_reason", None), + finish_reason=stream_response["choices"][0].get( + "finish_reason", None + ), logprobs=stream_response["choices"][0].get("logprobs", None), ), ) -def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: +def _update_response( + response: Dict[str, Any], stream_response: Dict[str, Any] +) -> None: """Update response from the stream response.""" response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( "finish_reason", None ) - response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] + response["choices"][0]["logprobs"] = stream_response["choices"][0][ + "logprobs" + ] def _streaming_response_template() -> Dict[str, Any]: @@ -315,9 +321,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) async def _astream( @@ -339,9 +347,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) def _generate( @@ -377,10 +387,14 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts." + ) generation: Optional[GenerationChunk] = None - for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): + for chunk in self._stream( + _prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: @@ -389,12 +403,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: @@ -424,7 +442,9 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts." + ) generation: Optional[GenerationChunk] = None async for chunk in self._astream( @@ -438,12 +458,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: @@ -463,7 +487,9 @@ class BaseOpenAI(BaseLLM): """Get the sub prompts for llm call.""" if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: @@ -541,7 +567,9 @@ class BaseOpenAI(BaseLLM): try: enc = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" enc = tiktoken.get_encoding(model) @@ -602,8 +630,9 @@ class BaseOpenAI(BaseLLM): if context_size is None: raise ValueError( - f"Unknown model: {modelname}. Please provide a valid OpenAI model name." - "Known models are: " + ", ".join(model_token_mapping.keys()) + f"Unknown model: {modelname}. Please provide a valid OpenAI" + " model name.Known models are: " + + ", ".join(model_token_mapping.keys()) ) return context_size @@ -753,7 +782,9 @@ class OpenAIChat(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = { + field.alias for field in cls.__fields__.values() + } extra = values.get("model_kwargs", {}) for field_name in list(values): @@ -820,13 +851,21 @@ class OpenAIChat(BaseLLM): ) -> Tuple: if len(prompts) > 1: raise ValueError( - f"OpenAIChat currently only supports single prompt, got {prompts}" + "OpenAIChat currently only supports single prompt, got" + f" {prompts}" ) - messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] - params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} + messages = self.prefix_messages + [ + {"role": "user", "content": prompts[0]} + ] + params: Dict[str, Any] = { + **{"model": self.model_name}, + **self._default_params, + } if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) params["stop"] = stop if params.get("max_tokens") == -1: # for ChatGPT api, omitting max_tokens is equivalent to having no limit @@ -897,7 +936,11 @@ class OpenAIChat(BaseLLM): } return LLMResult( generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] + [ + Generation( + text=full_response["choices"][0]["message"]["content"] + ) + ] ], llm_output=llm_output, ) @@ -911,7 +954,9 @@ class OpenAIChat(BaseLLM): ) -> LLMResult: if self.streaming: generation: Optional[GenerationChunk] = None - async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + async for chunk in self._astream( + prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: @@ -930,7 +975,11 @@ class OpenAIChat(BaseLLM): } return LLMResult( generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] + [ + Generation( + text=full_response["choices"][0]["message"]["content"] + ) + ] ], llm_output=llm_output, ) diff --git a/swarms/models/palm.py b/swarms/models/palm.py index ec8aafd6..8c9277d7 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -37,10 +37,16 @@ def _create_retry_decorator() -> Callable[[Any], Any]: return retry( reraise=True, stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + wait=wait_exponential( + multiplier=multiplier, min=min_seconds, max=max_seconds + ), retry=( - retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) - | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + retry_if_exception_type( + google.api_core.exceptions.ResourceExhausted + ) + | retry_if_exception_type( + google.api_core.exceptions.ServiceUnavailable + ) | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) ), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -64,7 +70,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str: The PaLM API will sometimes erroneously return a single leading space in all lines > 1. This function strips that space. """ - has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) + has_leading_space = all( + not line or line[0] == " " for line in text.split("\n")[1:] + ) if has_leading_space: return text.replace("\n ", "\n") else: @@ -112,7 +120,10 @@ class GooglePalm(BaseLLM, BaseModel): values["client"] = genai - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + if ( + values["temperature"] is not None + and not 0 <= values["temperature"] <= 1 + ): raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: @@ -121,7 +132,10 @@ class GooglePalm(BaseLLM, BaseModel): if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") - if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: + if ( + values["max_output_tokens"] is not None + and values["max_output_tokens"] <= 0 + ): raise ValueError("max_output_tokens must be greater than zero") return values diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 3662dda2..a4e99fe4 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -16,4 +16,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): text = text.replace("\n", " ") - return client.embeddings.create(input=[text], model=model)["data"][0]["embedding"] + return client.embeddings.create(input=[text], model=model)["data"][0][ + "embedding" + ] diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index e98036ac..143a7514 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -90,7 +90,9 @@ class SpeechT5: self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset( + self.dataset_name, split="validation" + ) def __call__(self, text: str, speaker_id: float = 7306): """Call the model on some text and return the speech.""" @@ -121,7 +123,9 @@ class SpeechT5: def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" self.dataset_name = dataset_name - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset( + self.dataset_name, split="validation" + ) # Feature 1: Get sampling rate def get_sampling_rate(self): diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index caeba3fc..406678ef 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -141,8 +141,8 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing your api key and" - " or try again" + f"Error running SSD1B: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -167,8 +167,7 @@ class SSD1B: """Print the SSD1B dashboard""" print( colored( - ( - f"""SSD1B Dashboard: + f"""SSD1B Dashboard: -------------------- Model: {self.model} @@ -184,13 +183,14 @@ class SSD1B: -------------------- - """ - ), + """, "green", ) ) - def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): + def process_batch_concurrently( + self, tasks: List[str], max_workers: int = 5 + ): """ Process a batch of tasks concurrently @@ -211,8 +211,12 @@ class SSD1B: >>> print(results) """ - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_task = {executor.submit(self, task): task for task in tasks} + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = { + executor.submit(self, task): task for task in tasks + } results = [] for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] @@ -225,13 +229,17 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing your api key and" - " or try again" + f"Error running SSD1B: {error} try optimizing" + " your api key and or try again" ), "red", ) ) - print(colored(f"Error running SSD1B: {error.http_status}", "red")) + print( + colored( + f"Error running SSD1B: {error.http_status}", "red" + ) + ) print(colored(f"Error running SSD1B: {error.error}", "red")) raise error diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py index ac592b35..338971da 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx.py @@ -66,7 +66,9 @@ class WhisperX: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model("large-v2", device, compute_type=compute_type) + model = whisperx.load_model( + "large-v2", device, compute_type=compute_type + ) audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index 49ffb70d..a34f6ec7 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -45,7 +45,9 @@ class WizardLLMStoryTeller: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -101,7 +103,9 @@ class WizardLLMStoryTeller: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, prompt_text: str): diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index ebe107a2..065e3140 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -45,7 +45,9 @@ class YarnMistral128: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -106,7 +108,9 @@ class YarnMistral128: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, prompt_text: str): diff --git a/swarms/prompts/agent_prompt.py b/swarms/prompts/agent_prompt.py index c4897193..b36aea19 100644 --- a/swarms/prompts/agent_prompt.py +++ b/swarms/prompts/agent_prompt.py @@ -15,7 +15,9 @@ class PromptGenerator: "thoughts": { "text": "thought", "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "plan": ( + "- short bulleted\n- list that conveys\n- long-term plan" + ), "criticism": "constructive self-criticism", "speak": "thoughts summary to say to user", }, @@ -66,13 +68,11 @@ class PromptGenerator: """ formatted_response_format = json.dumps(self.response_format, indent=4) prompt_string = ( - f"Constraints:\n{''.join(self.constraints)}\n\n" - f"Commands:\n{''.join(self.commands)}\n\n" - f"Resources:\n{''.join(self.resources)}\n\n" - f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" - "You should only respond in JSON format as described below " - f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads" + f"Constraints:\n{''.join(self.constraints)}\n\nCommands:\n{''.join(self.commands)}\n\nResources:\n{''.join(self.resources)}\n\nPerformance" + f" Evaluation:\n{''.join(self.performance_evaluation)}\n\nYou" + " should only respond in JSON format as described below \nResponse" + f" Format: \n{formatted_response_format} \nEnsure the response can" + " be parsed by Python json.loads" ) return prompt_string diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index 8d145fc0..a8c3fca7 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -5,26 +5,26 @@ def generate_agent_role_prompt(agent): """ prompts = { "Finance Agent": ( - "You are a seasoned finance analyst AI assistant. Your primary goal is to" - " compose comprehensive, astute, impartial, and methodically arranged" - " financial reports based on provided data and trends." + "You are a seasoned finance analyst AI assistant. Your primary goal" + " is to compose comprehensive, astute, impartial, and methodically" + " arranged financial reports based on provided data and trends." ), "Travel Agent": ( - "You are a world-travelled AI tour guide assistant. Your main purpose is to" - " draft engaging, insightful, unbiased, and well-structured travel reports" - " on given locations, including history, attractions, and cultural" - " insights." + "You are a world-travelled AI tour guide assistant. Your main" + " purpose is to draft engaging, insightful, unbiased, and" + " well-structured travel reports on given locations, including" + " history, attractions, and cultural insights." ), "Academic Research Agent": ( - "You are an AI academic research assistant. Your primary responsibility is" - " to create thorough, academically rigorous, unbiased, and systematically" - " organized reports on a given research topic, following the standards of" - " scholarly work." + "You are an AI academic research assistant. Your primary" + " responsibility is to create thorough, academically rigorous," + " unbiased, and systematically organized reports on a given" + " research topic, following the standards of scholarly work." ), "Default Agent": ( - "You are an AI critical thinker research assistant. Your sole purpose is to" - " write well written, critically acclaimed, objective and structured" - " reports on given text." + "You are an AI critical thinker research assistant. Your sole" + " purpose is to write well written, critically acclaimed, objective" + " and structured reports on given text." ), } @@ -39,12 +39,12 @@ def generate_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, answer the following' - f' question or topic: "{question}" in a detailed report -- The report should' - " focus on the answer to the question, should be well structured, informative," - " in depth, with facts and numbers if available, a minimum of 1,200 words and" - " with markdown syntax and apa format. Write all source urls at the end of the" - " report in apa format" + f'"""{research_summary}""" Using the above information, answer the' + f' following question or topic: "{question}" in a detailed report --' + " The report should focus on the answer to the question, should be" + " well structured, informative, in depth, with facts and numbers if" + " available, a minimum of 1,200 words and with markdown syntax and apa" + " format. Write all source urls at the end of the report in apa format" ) @@ -55,9 +55,10 @@ def generate_search_queries_prompt(question): """ return ( - "Write 4 google search queries to search online that form an objective opinion" - f' from the following: "{question}"You must respond with a list of strings in' - ' the following format: ["query 1", "query 2", "query 3", "query 4"]' + "Write 4 google search queries to search online that form an objective" + f' opinion from the following: "{question}"You must respond with a list' + ' of strings in the following format: ["query 1", "query 2", "query' + ' 3", "query 4"]' ) @@ -73,14 +74,15 @@ def generate_resource_report_prompt(question, research_summary): """ return ( f'"""{research_summary}""" Based on the above information, generate a' - " bibliography recommendation report for the following question or topic:" - f' "{question}". The report should provide a detailed analysis of each' - " recommended resource, explaining how each source can contribute to finding" - " answers to the research question. Focus on the relevance, reliability, and" - " significance of each source. Ensure that the report is well-structured," - " informative, in-depth, and follows Markdown syntax. Include relevant facts," - " figures, and numbers whenever available. The report should have a minimum" - " length of 1,200 words." + " bibliography recommendation report for the following question or" + f' topic: "{question}". The report should provide a detailed analysis' + " of each recommended resource, explaining how each source can" + " contribute to finding answers to the research question. Focus on the" + " relevance, reliability, and significance of each source. Ensure that" + " the report is well-structured, informative, in-depth, and follows" + " Markdown syntax. Include relevant facts, figures, and numbers" + " whenever available. The report should have a minimum length of 1,200" + " words." ) @@ -92,13 +94,14 @@ def generate_outline_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate an outline for' - " a research report in Markdown syntax for the following question or topic:" - f' "{question}". The outline should provide a well-structured framework for the' - " research report, including the main sections, subsections, and key points to" - " be covered. The research report should be detailed, informative, in-depth," - " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" - " outline and ensure readability." + f'"""{research_summary}""" Using the above information, generate an' + " outline for a research report in Markdown syntax for the following" + f' question or topic: "{question}". The outline should provide a' + " well-structured framework for the research report, including the" + " main sections, subsections, and key points to be covered. The" + " research report should be detailed, informative, in-depth, and a" + " minimum of 1,200 words. Use appropriate Markdown syntax to format" + " the outline and ensure readability." ) @@ -110,11 +113,12 @@ def generate_concepts_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate a list of 5' - " main concepts to learn for a research report on the following question or" - f' topic: "{question}". The outline should provide a well-structured' - " frameworkYou must respond with a list of strings in the following format:" - ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' + f'"""{research_summary}""" Using the above information, generate a list' + " of 5 main concepts to learn for a research report on the following" + f' question or topic: "{question}". The outline should provide a' + " well-structured frameworkYou must respond with a list of strings in" + ' the following format: ["concepts 1", "concepts 2", "concepts 3",' + ' "concepts 4, concepts 5"]' ) @@ -128,10 +132,10 @@ def generate_lesson_prompt(concept): """ prompt = ( - f"generate a comprehensive lesson about {concept} in Markdown syntax. This" - f" should include the definitionof {concept}, its historical background and" - " development, its applications or uses in differentfields, and notable events" - f" or facts related to {concept}." + f"generate a comprehensive lesson about {concept} in Markdown syntax." + f" This should include the definitionof {concept}, its historical" + " background and development, its applications or uses in" + f" differentfields, and notable events or facts related to {concept}." ) return prompt diff --git a/swarms/prompts/base.py b/swarms/prompts/base.py index 54a0bc3f..369063e6 100644 --- a/swarms/prompts/base.py +++ b/swarms/prompts/base.py @@ -12,7 +12,9 @@ if TYPE_CHECKING: def get_buffer_string( - messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" + messages: Sequence[BaseMessage], + human_prefix: str = "Human", + ai_prefix: str = "AI", ) -> str: """Convert sequence of Messages to strings and concatenate them into one string. diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index d1e08df9..bbdaa9c7 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -105,7 +105,9 @@ class ChatMessage(Message): def get_buffer_string( - messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI" + messages: Sequence[Message], + human_prefix: str = "Human", + ai_prefix: str = "AI", ) -> str: string_messages = [] for m in messages: diff --git a/swarms/prompts/multi_modal_prompts.py b/swarms/prompts/multi_modal_prompts.py index b552b68d..1c0830d6 100644 --- a/swarms/prompts/multi_modal_prompts.py +++ b/swarms/prompts/multi_modal_prompts.py @@ -1,6 +1,6 @@ ERROR_PROMPT = ( - "An error has occurred for the following text: \n{promptedQuery} Please explain" - " this error.\n {e}" + "An error has occurred for the following text: \n{promptedQuery} Please" + " explain this error.\n {e}" ) IMAGE_PROMPT = """ diff --git a/swarms/prompts/python.py b/swarms/prompts/python.py index 9d1f4a1e..46df5cdc 100644 --- a/swarms/prompts/python.py +++ b/swarms/prompts/python.py @@ -1,16 +1,17 @@ PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." PY_REFLEXION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the implementation" - " appropriately. Write your full implementation (restate the function" - " signature).\n\n-----" + " implementation, a series of unit tests, and a hint to change the" + " implementation appropriately. Write your full implementation (restate the" + " function signature).\n\n-----" ) PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( - "You are a Python writing assistant. You will be given a function implementation" - " and a series of unit tests. Your goal is to write a few sentences to explain why" - " your implementation is wrong as indicated by the tests. You will need this as a" - " hint when you try again later. Only provide the few sentence description in your" - " answer, not the implementation.\n\n-----" + "You are a Python writing assistant. You will be given a function" + " implementation and a series of unit tests. Your goal is to write a few" + " sentences to explain why your implementation is wrong as indicated by the" + " tests. You will need this as a hint when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation.\n\n-----" ) USE_PYTHON_CODEBLOCK_INSTRUCTION = ( "Use a Python code block to write your response. For" @@ -18,25 +19,26 @@ USE_PYTHON_CODEBLOCK_INSTRUCTION = ( ) PY_SIMPLE_CHAT_INSTRUCTION = ( - "You are an AI that only responds with python code, NOT ENGLISH. You will be given" - " a function signature and its docstring by the user. Write your full" - " implementation (restate the function signature)." + "You are an AI that only responds with python code, NOT ENGLISH. You will" + " be given a function signature and its docstring by the user. Write your" + " full implementation (restate the function signature)." ) PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( - "You are an AI that only responds with only python code. You will be given a" - " function signature and its docstring by the user. Write your full implementation" - " (restate the function signature)." + "You are an AI that only responds with only python code. You will be given" + " a function signature and its docstring by the user. Write your full" + " implementation (restate the function signature)." ) PY_REFLEXION_CHAT_INSTRUCTION = ( "You are an AI Python assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the implementation" - " appropriately. Write your full implementation (restate the function signature)." + " implementation, a series of unit tests, and a hint to change the" + " implementation appropriately. Write your full implementation (restate the" + " function signature)." ) PY_REFLEXION_CHAT_INSTRUCTION_V2 = ( - "You are an AI Python assistant. You will be given your previous implementation of" - " a function, a series of unit tests results, and your self-reflection on your" - " previous implementation. Write your full implementation (restate the function" - " signature)." + "You are an AI Python assistant. You will be given your previous" + " implementation of a function, a series of unit tests results, and your" + " self-reflection on your previous implementation. Write your full" + " implementation (restate the function signature)." ) PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: [previous impl]: @@ -172,18 +174,19 @@ END EXAMPLES ''' PY_SELF_REFLECTION_CHAT_INSTRUCTION = ( "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit tests. Your goal is to write a few sentences" - " to explain why your implementation is wrong as indicated by the tests. You will" - " need this as a hint when you try again later. Only provide the few sentence" - " description in your answer, not the implementation." + " implementation and a series of unit tests. Your goal is to write a few" + " sentences to explain why your implementation is wrong as indicated by the" + " tests. You will need this as a hint when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation." ) PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit test results. Your goal is to write a few" - " sentences to explain why your implementation is wrong as indicated by the tests." - " You will need this as guidance when you try again later. Only provide the few" - " sentence description in your answer, not the implementation. You will be given a" - " few examples by the user." + " implementation and a series of unit test results. Your goal is to write a" + " few sentences to explain why your implementation is wrong as indicated by" + " the tests. You will need this as guidance when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation. You will be given a few examples by the user." ) PY_SELF_REFLECTION_FEW_SHOT = """Example 1: [function impl]: diff --git a/swarms/prompts/sales.py b/swarms/prompts/sales.py index 4f04f7fc..3a362174 100644 --- a/swarms/prompts/sales.py +++ b/swarms/prompts/sales.py @@ -1,23 +1,26 @@ conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your company." - " Be polite and respectful while keeping the tone of the conversation" - " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." + "Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional. Your greeting should be welcoming. Always" + " clarify in your greeting the reason why you are contacting the" + " prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the right person" - " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they are the" + " right person to talk to regarding your product/service. Ensure that" + " they have the authority to make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can benefit the" - " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." + "Value proposition: Briefly explain how your product/service can" + " benefit the prospect. Focus on the unique selling points and value" + " proposition of your product/service that sets it apart from" + " competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." + "Needs analysis: Ask open-ended questions to uncover the prospect's" + " needs and pain points. Listen carefully to their responses and take" + " notes." ), "5": ( "Solution presentation: Based on the prospect's needs, present your" @@ -29,9 +32,9 @@ conversation_stages = { " testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a demo, a" - " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could be a" + " demo, a trial or a meeting with decision-makers. Ensure to summarize" + " what has been discussed and reiterate the benefits." ), } diff --git a/swarms/prompts/sales_prompts.py b/swarms/prompts/sales_prompts.py index 3f2b9f2b..7c1f50ed 100644 --- a/swarms/prompts/sales_prompts.py +++ b/swarms/prompts/sales_prompts.py @@ -46,24 +46,27 @@ Conversation history: conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your company." - " Be polite and respectful while keeping the tone of the conversation" - " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." + "Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional. Your greeting should be welcoming. Always" + " clarify in your greeting the reason why you are contacting the" + " prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the right person" - " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they are the" + " right person to talk to regarding your product/service. Ensure that" + " they have the authority to make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can benefit the" - " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." + "Value proposition: Briefly explain how your product/service can" + " benefit the prospect. Focus on the unique selling points and value" + " proposition of your product/service that sets it apart from" + " competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." + "Needs analysis: Ask open-ended questions to uncover the prospect's" + " needs and pain points. Listen carefully to their responses and take" + " notes." ), "5": ( "Solution presentation: Based on the prospect's needs, present your" @@ -75,8 +78,8 @@ conversation_stages = { " testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a demo, a" - " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could be a" + " demo, a trial or a meeting with decision-makers. Ensure to summarize" + " what has been discussed and reiterate the benefits." ), } diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index be79a860..97e8a5ae 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -7,7 +7,11 @@ from typing import Callable, Dict, List from termcolor import colored from swarms.structs.flow import Flow -from swarms.utils.decorators import error_decorator, log_decorator, timing_decorator +from swarms.utils.decorators import ( + error_decorator, + log_decorator, + timing_decorator, +) class AutoScaler: @@ -69,7 +73,9 @@ class AutoScaler: try: self.tasks_queue.put(task) except Exception as error: - print(f"Error adding task to queue: {error} try again with a new task") + print( + f"Error adding task to queue: {error} try again with a new task" + ) @log_decorator @error_decorator @@ -108,10 +114,15 @@ class AutoScaler: if pending_tasks / len(self.agents_pool) > self.busy_threshold: self.scale_up() - elif active_agents / len(self.agents_pool) < self.idle_threshold: + elif ( + active_agents / len(self.agents_pool) < self.idle_threshold + ): self.scale_down() except Exception as error: - print(f"Error monitoring and scaling: {error} try again with a new task") + print( + f"Error monitoring and scaling: {error} try again with a new" + " task" + ) @log_decorator @error_decorator @@ -125,7 +136,9 @@ class AutoScaler: while True: task = self.task_queue.get() if task: - available_agent = next((agent for agent in self.agents_pool)) + available_agent = next( + (agent for agent in self.agents_pool) + ) if available_agent: available_agent.run(task) except Exception as error: diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index aa0060b4..166d619e 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -348,7 +348,8 @@ class Flow: return "\n".join(tool_descriptions) except Exception as error: print( - f"Error getting tool description: {error} try adding a description to the tool or removing the tool" + f"Error getting tool description: {error} try adding a" + " description to the tool or removing the tool" ) else: return "No tools available" @@ -479,8 +480,12 @@ class Flow: print(colored("Initializing Autonomous Agent...", "yellow")) # print(colored("Loading modules...", "yellow")) # print(colored("Modules loaded successfully.", "green")) - print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"])) - print(colored("All systems operational. Executing task...", "green")) + print( + colored("Autonomous Agent Activated.", "cyan", attrs=["bold"]) + ) + print( + colored("All systems operational. Executing task...", "green") + ) except Exception as error: print( colored( @@ -525,14 +530,16 @@ class Flow: loop_count = 0 while self.max_loops == "auto" or loop_count < self.max_loops: loop_count += 1 - print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) + print( + colored(f"\nLoop {loop_count} of {self.max_loops}", "blue") + ) print("\n") # Check to see if stopping token is in the output to stop the loop if self.stopping_token: - if self._check_stopping_condition(response) or parse_done_token( + if self._check_stopping_condition( response - ): + ) or parse_done_token(response): break # Adjust temperature, comment if no work @@ -629,7 +636,9 @@ class Flow: print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print("\n") - if self._check_stopping_condition(response) or parse_done_token(response): + if self._check_stopping_condition(response) or parse_done_token( + response + ): break # Adjust temperature, comment if no work @@ -949,7 +958,8 @@ class Flow: if hasattr(self.llm, name): value = getattr(self.llm, name) if isinstance( - value, (str, int, float, bool, list, dict, tuple, type(None)) + value, + (str, int, float, bool, list, dict, tuple, type(None)), ): llm_params[name] = value else: @@ -1010,7 +1020,9 @@ class Flow: print(f"Flow state loaded from {file_path}") - def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1): + def retry_on_failure( + self, function, retries: int = 3, retry_delay: int = 1 + ): """Retry wrapper for LLM calls.""" attempt = 0 while attempt < retries: diff --git a/swarms/structs/non_linear_workflow.py b/swarms/structs/non_linear_workflow.py index 22cef91e..79bc0af7 100644 --- a/swarms/structs/non_linear_workflow.py +++ b/swarms/structs/non_linear_workflow.py @@ -7,7 +7,11 @@ from typing import Callable, List, Dict, Any, Sequence class Task: def __init__( - self, id: str, task: str, flows: Sequence[Flow], dependencies: List[str] = [] + self, + id: str, + task: str, + flows: Sequence[Flow], + dependencies: List[str] = [], ): self.id = id self.task = task @@ -20,7 +24,9 @@ class Task: for flow in self.flows: result = flow.run(self.task, *args) self.results.append(result) - args = [result] # The output of one flow becomes the input to the next + args = [ + result + ] # The output of one flow becomes the input to the next class Workflow: @@ -41,7 +47,10 @@ class Workflow: ): future = self.executor.submit( task.execute, - {dep: self.tasks[dep].results for dep in task.dependencies}, + { + dep: self.tasks[dep].results + for dep in task.dependencies + }, ) futures.append((future, task.id)) diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index 22ae4a21..1d7f411d 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -113,7 +113,9 @@ class SequentialWorkflow: restore_state_filepath: Optional[str] = None dashboard: bool = False - def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None: + def add( + self, task: str, flow: Union[Callable, Flow], *args, **kwargs + ) -> None: """ Add a task to the workflow. @@ -182,7 +184,9 @@ class SequentialWorkflow: raise ValueError(f"Task {task_description} not found in workflow.") def save_workflow_state( - self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs + self, + filepath: Optional[str] = "sequential_workflow_state.json", + **kwargs, ) -> None: """ Saves the workflow state to a json file. @@ -348,8 +352,9 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" + "The 'task' argument is required for the" + " Flow flow execution in" + f" '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") @@ -373,7 +378,9 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json" + ) except Exception as e: print( colored( @@ -404,8 +411,8 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" + "The 'task' argument is required for the Flow" + f" flow execution in '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") @@ -429,4 +436,6 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json" + ) diff --git a/swarms/swarms/autobloggen.py b/swarms/swarms/autobloggen.py index dec2620f..d732606b 100644 --- a/swarms/swarms/autobloggen.py +++ b/swarms/swarms/autobloggen.py @@ -103,7 +103,9 @@ class AutoBlogGenSwarm: review_agent = self.print_beautifully("Review Agent", review_agent) # Agent that publishes on social media - distribution_agent = self.llm(self.social_media_prompt(article=review_agent)) + distribution_agent = self.llm( + self.social_media_prompt(article=review_agent) + ) distribution_agent = self.print_beautifully( "Distribution Agent", distribution_agent ) @@ -115,7 +117,11 @@ class AutoBlogGenSwarm: for i in range(self.iterations): self.step() except Exception as error: - print(colored(f"Error while running AutoBlogGenSwarm {error}", "red")) + print( + colored( + f"Error while running AutoBlogGenSwarm {error}", "red" + ) + ) if attempt == self.retry_attempts - 1: raise diff --git a/swarms/swarms/base.py b/swarms/swarms/base.py index e99c9b38..1ccc819c 100644 --- a/swarms/swarms/base.py +++ b/swarms/swarms/base.py @@ -117,7 +117,9 @@ class AbstractSwarm(ABC): pass @abstractmethod - def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None): + def broadcast( + self, message: str, sender: Optional["AbstractWorker"] = None + ): """Broadcast a message to all workers""" pass diff --git a/swarms/swarms/dialogue_simulator.py b/swarms/swarms/dialogue_simulator.py index ec86c414..2775daf0 100644 --- a/swarms/swarms/dialogue_simulator.py +++ b/swarms/swarms/dialogue_simulator.py @@ -23,7 +23,9 @@ class DialogueSimulator: >>> model.run("test") """ - def __init__(self, agents: List[Callable], max_iters: int = 10, name: str = None): + def __init__( + self, agents: List[Callable], max_iters: int = 10, name: str = None + ): self.agents = agents self.max_iters = max_iters self.name = name @@ -45,7 +47,8 @@ class DialogueSimulator: for receiver in self.agents: message_history = ( - f"Speaker Name: {speaker.name} and message: {speaker_message}" + f"Speaker Name: {speaker.name} and message:" + f" {speaker_message}" ) receiver.run(message_history) @@ -56,7 +59,9 @@ class DialogueSimulator: print(f"Error running dialogue simulator: {error}") def __repr__(self): - return f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})" + return ( + f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})" + ) def save_state(self): """Save the state of the dialogue simulator""" diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index e75d81d2..65377308 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -64,7 +64,8 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -83,7 +84,8 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -115,11 +117,13 @@ class GodMode: print(f"{i + 1}. {task}") print("\nLast Responses:") table = [ - [f"LLM {i+1}", response] for i, response in enumerate(self.last_responses) + [f"LLM {i+1}", response] + for i, response in enumerate(self.last_responses) ] print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -137,7 +141,8 @@ class GodMode: """Asynchronous run the task string""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, lambda llm: llm(task), llm) for llm in self.llms + loop.run_in_executor(None, lambda llm: llm(task), llm) + for llm in self.llms ] for response in await asyncio.gather(*futures): print(response) @@ -145,13 +150,18 @@ class GodMode: def concurrent_run(self, task: str) -> List[str]: """Synchronously run the task on all llms and collect responses""" with ThreadPoolExecutor() as executor: - future_to_llm = {executor.submit(llm, task): llm for llm in self.llms} + future_to_llm = { + executor.submit(llm, task): llm for llm in self.llms + } responses = [] for future in as_completed(future_to_llm): try: responses.append(future.result()) except Exception as error: - print(f"{future_to_llm[future]} generated an exception: {error}") + print( + f"{future_to_llm[future]} generated an exception:" + f" {error}" + ) self.last_responses = responses self.task_history.append(task) return responses diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index 5cff3263..76de7e16 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -47,7 +47,9 @@ class GroupChat: def next_agent(self, agent: Flow) -> Flow: """Return the next agent in the list.""" - return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] + return self.agents[ + (self.agent_names.index(agent.name) + 1) % len(self.agents) + ] def select_speaker_msg(self): """Return the message for selecting the next speaker.""" @@ -78,9 +80,9 @@ class GroupChat: { "role": "system", "content": ( - "Read the above conversation. Then select the next most" - f" suitable role from {self.agent_names} to play. Only" - " return the role." + "Read the above conversation. Then select the next" + f" most suitable role from {self.agent_names} to" + " play. Only return the role." ), } ] @@ -126,7 +128,9 @@ class GroupChatManager: self.selector = selector def __call__(self, task: str): - self.groupchat.messages.append({"role": self.selector.name, "content": task}) + self.groupchat.messages.append( + {"role": self.selector.name, "content": task} + ) for i in range(self.groupchat.max_round): speaker = self.groupchat.select_speaker( last_speaker=self.selector, selector=self.selector diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/swarms/multi_agent_collab.py index 85d9955b..98f32d47 100644 --- a/swarms/swarms/multi_agent_collab.py +++ b/swarms/swarms/multi_agent_collab.py @@ -13,8 +13,8 @@ from swarms.utils.logger import logger class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: return ( - "Your response should be an integrater delimited by angled brackets like" - " this: " + "Your response should be an integrater delimited by angled brackets" + " like this: " ) @@ -194,11 +194,15 @@ class MultiAgentCollaboration: print("\n") n += 1 - def select_next_speaker_roundtable(self, step: int, agents: List[Flow]) -> int: + def select_next_speaker_roundtable( + self, step: int, agents: List[Flow] + ) -> int: """Selects the next speaker.""" return step % len(agents) - def select_next_speaker_director(step: int, agents: List[Flow], director) -> int: + def select_next_speaker_director( + step: int, agents: List[Flow], director + ) -> int: # if the step if even => director # => director selects next speaker if step % 2 == 1: @@ -265,7 +269,10 @@ class MultiAgentCollaboration: def format_results(self, results): """Formats the results of the run method""" formatted_results = "\n".join( - [f"{result['agent']} responded: {result['response']}" for result in results] + [ + f"{result['agent']} responded: {result['response']}" + for result in results + ] ) return formatted_results @@ -291,7 +298,12 @@ class MultiAgentCollaboration: return state def __repr__(self): - return f"MultiAgentCollaboration(agents={self.agents}, selection_function={self.select_next_speaker}, max_iters={self.max_iters}, autosave={self.autosave}, saved_file_path_name={self.saved_file_path_name})" + return ( + f"MultiAgentCollaboration(agents={self.agents}," + f" selection_function={self.select_next_speaker}," + f" max_iters={self.max_iters}, autosave={self.autosave}," + f" saved_file_path_name={self.saved_file_path_name})" + ) def performance(self): """Tracks and reports the performance of each agent""" diff --git a/swarms/swarms/orchestrate.py b/swarms/swarms/orchestrate.py index f522911b..b7a7d0e0 100644 --- a/swarms/swarms/orchestrate.py +++ b/swarms/swarms/orchestrate.py @@ -111,7 +111,9 @@ class Orchestrator: self.chroma_client = chromadb.Client() - self.collection = self.chroma_client.create_collection(name=collection_name) + self.collection = self.chroma_client.create_collection( + name=collection_name + ) self.current_tasks = {} @@ -148,13 +150,14 @@ class Orchestrator: ) logging.info( - f"Task {id(str)} has been processed by agent {id(agent)} with" + f"Task {id(str)} has been processed by agent" + f" {id(agent)} with" ) except Exception as error: logging.error( - f"Failed to process task {id(task)} by agent {id(agent)}. Error:" - f" {error}" + f"Failed to process task {id(task)} by agent {id(agent)}." + f" Error: {error}" ) finally: with self.condition: @@ -175,7 +178,9 @@ class Orchestrator: try: # Query the vector database for documents created by the agents - results = self.collection.query(query_texts=[str(agent_id)], n_results=10) + results = self.collection.query( + query_texts=[str(agent_id)], n_results=10 + ) return results except Exception as e: @@ -212,7 +217,9 @@ class Orchestrator: self.collection.add(documents=[result], ids=[str(id(result))]) except Exception as e: - logging.error(f"Failed to append the agent output to database. Error: {e}") + logging.error( + f"Failed to append the agent output to database. Error: {e}" + ) raise def run(self, objective: str): @@ -226,7 +233,9 @@ class Orchestrator: results = [ self.assign_task(agent_id, task) - for agent_id, task in zip(range(len(self.agents)), self.task_queue) + for agent_id, task in zip( + range(len(self.agents)), self.task_queue + ) ] for result in results: diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index cf5450e6..07062d11 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -6,7 +6,9 @@ from typing import Optional import pandas as pd import torch from langchain.agents import tool -from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent +from langchain.agents.agent_toolkits.pandas.base import ( + create_pandas_dataframe_agent, +) from langchain.chains.qa_with_sources.loading import ( BaseCombineDocumentsChain, ) @@ -38,7 +40,10 @@ def pushd(new_dir): @tool def process_csv( - llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None + llm, + csv_file_path: str, + instructions: str, + output_path: Optional[str] = None, ) -> str: """Process a CSV by with pandas in a limited REPL.\ Only use this after writing data to disk as a csv file.\ @@ -49,7 +54,9 @@ def process_csv( df = pd.read_csv(csv_file_path) except Exception as e: return f"Error: {e}" - agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False) + agent = create_pandas_dataframe_agent( + llm, df, max_iterations=30, verbose=False + ) if output_path is not None: instructions += f" Save output to disk at {output_path}" try: @@ -79,7 +86,9 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = ( + phrase.strip() for line in lines for phrase in line.split(" ") + ) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -110,7 +119,8 @@ def _get_text_splitter(): class WebpageQATool(BaseTool): name = "query_webpage" description = ( - "Browse a webpage and retrieve the information relevant to the question." + "Browse a webpage and retrieve the information relevant to the" + " question." ) text_splitter: RecursiveCharacterTextSplitter = Field( default_factory=_get_text_splitter @@ -176,7 +186,9 @@ def VQAinference(self, inputs): image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype) + inputs = processor(raw_image, question, return_tensors="pt").to( + device, torch_dtype + ) out = model.generate(**inputs) answer = processor.decode(out[0], skip_special_tokens=True) diff --git a/swarms/tools/mm_models.py b/swarms/tools/mm_models.py index 58fe11e5..a218ff50 100644 --- a/swarms/tools/mm_models.py +++ b/swarms/tools/mm_models.py @@ -28,7 +28,9 @@ class MaskFormer: def __init__(self, device): print("Initializing MaskFormer to %s" % device) self.device = device - self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + self.processor = CLIPSegProcessor.from_pretrained( + "CIDAS/clipseg-rd64-refined" + ) self.model = CLIPSegForImageSegmentation.from_pretrained( "CIDAS/clipseg-rd64-refined" ).to(device) @@ -76,23 +78,26 @@ class ImageEditing: @tool( name="Remove Something From The Photo", description=( - "useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. " + "useful when you want to remove and object or something from the" + " photo from its description or location. The input to this tool" + " should be a comma separated string of two, representing the" + " image_path and the object need to be removed. " ), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",") - return self.inference_replace(f"{image_path},{to_be_removed_txt},background") + return self.inference_replace( + f"{image_path},{to_be_removed_txt},background" + ) @tool( name="Replace Something From The Photo", description=( - "useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with " + "useful when you want to replace an object from the object" + " description or location with another object from its description." + " The input to this tool should be a comma separated string of" + " three, representing the image_path, the object to be replaced," + " the object to be replaced with " ), ) def inference_replace(self, inputs): @@ -137,10 +142,10 @@ class InstructPix2Pix: @tool( name="Instruct Image Using Text", description=( - "useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. " + "useful when you want to the style of the image to be like the" + " text. like: make it look like a painting. or make it like a" + " robot. The input to this tool should be a comma separated string" + " of two, representing the image_path and the text. " ), ) def inference(self, inputs): @@ -149,14 +154,17 @@ class InstructPix2Pix: image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) image = self.pipe( - text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 + text, + image=original_image, + num_inference_steps=40, + image_guidance_scale=1.2, ).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) logger.debug( - f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}" + f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct" + f" Text: {text}, Output Image: {updated_image_path}" ) return updated_image_path @@ -173,17 +181,18 @@ class Text2Image: self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( - "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" + "longbody, lowres, bad anatomy, bad hands, missing fingers, extra" + " digit, fewer digits, cropped, worst quality, low quality" ) @tool( name="Generate Image From User Input Text", description=( - "useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. " + "useful when you want to generate an image from a user input text" + " and save it to a file. like: generate an image of an object or" + " something, or generate an image that includes some objects. The" + " input to this tool should be a string, representing the text used" + " to generate image. " ), ) def inference(self, text): @@ -205,7 +214,9 @@ class VisualQuestionAnswering: print("Initializing VisualQuestionAnswering to %s" % device) self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained( + "Salesforce/blip-vqa-base" + ) self.model = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype ).to(self.device) @@ -213,10 +224,11 @@ class VisualQuestionAnswering: @tool( name="Answer Question About The Image", description=( - "useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + "useful when you need an answer for a question based on an image." + " like: what is the background color of the last image, how many" + " cats in this figure, what is in this figure. The input to this" + " tool should be a comma separated string of two, representing the" + " image_path and the question" ), ) def inference(self, inputs): @@ -229,8 +241,8 @@ class VisualQuestionAnswering: answer = self.processor.decode(out[0], skip_special_tokens=True) logger.debug( - f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}" + f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}," + f" Input Question: {question}, Output Answer: {answer}" ) return answer @@ -245,7 +257,8 @@ class ImageCaptioning(BaseHandler): "Salesforce/blip-image-captioning-base" ) self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype + "Salesforce/blip-image-captioning-base", + torch_dtype=self.torch_dtype, ).to(self.device) def handle(self, filename: str): @@ -264,8 +277,8 @@ class ImageCaptioning(BaseHandler): out = self.model.generate(**inputs) description = self.processor.decode(out[0], skip_special_tokens=True) print( - f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" - f" {description}" + f"\nProcessed ImageCaptioning, Input Image: {filename}, Output" + f" Text: {description}" ) return IMAGE_PROMPT.format(filename=filename, description=description) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index a5ad3f75..8ae3b7cd 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -7,7 +7,17 @@ import warnings from abc import abstractmethod from functools import partial from inspect import signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -27,7 +37,11 @@ from pydantic import ( root_validator, validate_arguments, ) -from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable +from langchain.schema.runnable import ( + Runnable, + RunnableConfig, + RunnableSerializable, +) class SchemaAnnotationError(TypeError): @@ -52,7 +66,11 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + return { + k: schema[k] + for k in valid_keys + if k not in ("run_manager", "callbacks") + } class _SchemaConfig: @@ -120,12 +138,11 @@ class ChildTool(BaseTool): ...""" name = cls.__name__ raise SchemaAnnotationError( - f"Tool definition for {name} must include valid type annotations" - " for argument 'args_schema' to behave as expected.\n" - "Expected annotation of 'Type[BaseModel]'" - f" but got '{args_schema_type}'.\n" - "Expected class looks like:\n" - f"{typehint_mandate}" + f"Tool definition for {name} must include valid type" + " annotations for argument 'args_schema' to behave as" + " expected.\nExpected annotation of 'Type[BaseModel]' but" + f" got '{args_schema_type}'.\nExpected class looks" + f" like:\n{typehint_mandate}" ) name: str @@ -147,7 +164,9 @@ class ChildTool(BaseTool): callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + callback_manager: Optional[BaseCallbackManager] = Field( + default=None, exclude=True + ) """Deprecated. Please use callbacks instead.""" tags: Optional[List[str]] = None """Optional list of tags associated with the tool. Defaults to None @@ -244,7 +263,9 @@ class ChildTool(BaseTool): else: if input_args is not None: result = input_args.parse_obj(tool_input) - return {k: v for k, v in result.dict().items() if k in tool_input} + return { + k: v for k, v in result.dict().items() if k in tool_input + } return tool_input @root_validator() @@ -286,7 +307,9 @@ class ChildTool(BaseTool): *args, ) - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs( + self, tool_input: Union[str, Dict] + ) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -353,8 +376,9 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" + "Got unexpected type of `handle_tool_error`. Expected" + " bool, str or callable. Received:" + f" {self.handle_tool_error}" ) run_manager.on_tool_end( str(observation), color="red", name=self.name, **kwargs @@ -409,7 +433,9 @@ class ChildTool(BaseTool): # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) observation = ( - await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + await self._arun( + *tool_args, run_manager=run_manager, **tool_kwargs + ) if new_arg_supported else await self._arun(*tool_args, **tool_kwargs) ) @@ -428,8 +454,9 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" + "Got unexpected type of `handle_tool_error`. Expected" + " bool, str or callable. Received:" + f" {self.handle_tool_error}" ) await run_manager.on_tool_end( str(observation), color="red", name=self.name, **kwargs @@ -484,14 +511,17 @@ class Tool(BaseTool): # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs( + self, tool_input: Union[str, Dict] + ) -> Tuple[Tuple, Dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) # For backwards compatibility. The tool must be run with a single input all_args = list(args) + list(kwargs.values()) if len(all_args) != 1: raise ToolException( - f"Too many arguments to single-input tool {self.name}. Args: {all_args}" + f"Too many arguments to single-input tool {self.name}. Args:" + f" {all_args}" ) return tuple(all_args), {} @@ -503,7 +533,9 @@ class Tool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") + new_argument_supported = signature(self.func).parameters.get( + "callbacks" + ) return ( self.func( *args, @@ -537,12 +569,18 @@ class Tool(BaseTool): ) else: return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), *args + None, + partial(self._run, run_manager=run_manager, **kwargs), + *args, ) # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Optional[Callable], description: str, **kwargs: Any + self, + name: str, + func: Optional[Callable], + description: str, + **kwargs: Any, ) -> None: """Initialize tool.""" super(Tool, self).__init__( @@ -617,7 +655,9 @@ class StructuredTool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") + new_argument_supported = signature(self.func).parameters.get( + "callbacks" + ) return ( self.func( *args, @@ -714,7 +754,9 @@ class StructuredTool(BaseTool): description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", source_function) + _args_schema = create_schema_from_function( + f"{name}Schema", source_function + ) return cls( name=name, func=func, @@ -772,7 +814,9 @@ def tool( async def ainvoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: - return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + return await runnable.ainvoke( + kwargs, {"callbacks": callbacks} + ) def invoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any @@ -821,7 +865,11 @@ def tool( return _make_tool - if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + if ( + len(args) == 2 + and isinstance(args[0], str) + and isinstance(args[1], Runnable) + ): return _make_with_name(args[0])(args[1]) elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index 94c6f158..4adcb5cf 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -144,7 +144,9 @@ class Singleton(abc.ABCMeta, type): def __call__(cls, *args, **kwargs): """Call method for the singleton metaclass.""" if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs + ) return cls._instances[cls] diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index 86059a83..fc2f95f7 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -116,14 +116,20 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): # Most of the time it doesn't matter, but we should figure out why it happens frequently with: # applescript yield {"output": traceback.format_exc()} - yield {"output": f"Retrying... ({retry_count}/{max_retries})"} + yield { + "output": f"Retrying... ({retry_count}/{max_retries})" + } yield {"output": "Restarting process."} self.start_process() retry_count += 1 if retry_count > max_retries: - yield {"output": "Maximum retries reached. Could not execute code."} + yield { + "output": ( + "Maximum retries reached. Could not execute code." + ) + } return while True: @@ -132,7 +138,9 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): else: time.sleep(0.1) try: - output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds + output = self.output_queue.get( + timeout=0.3 + ) # Waits for 0.3 seconds yield output except queue.Empty: if self.done.is_set(): diff --git a/swarms/utils/decorators.py b/swarms/utils/decorators.py index 8a5a5d56..cf4a774c 100644 --- a/swarms/utils/decorators.py +++ b/swarms/utils/decorators.py @@ -31,7 +31,9 @@ def timing_decorator(func): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() - logging.info(f"{func.__name__} executed in {end_time - start_time} seconds") + logging.info( + f"{func.__name__} executed in {end_time - start_time} seconds" + ) return result return wrapper @@ -79,7 +81,9 @@ def synchronized_decorator(func): def deprecated_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning) + warnings.warn( + f"{func.__name__} is deprecated", category=DeprecationWarning + ) return func(*args, **kwargs) return wrapper diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py index 55a4e5d5..a5ffdf51 100644 --- a/swarms/utils/futures.py +++ b/swarms/utils/futures.py @@ -5,6 +5,8 @@ T = TypeVar("T") def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: - futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) + futures.wait( + fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED + ) return {key: future.result() for key, future in fs_dict.items()} diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index da822d1a..d9845543 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -113,8 +113,8 @@ class Logger: ) error_handler.setLevel(logging.ERROR) error_formatter = AutoGptFormatter( - "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s" - " %(message_no_color)s" + "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d" + " %(title)s %(message_no_color)s" ) error_handler.setFormatter(error_formatter) @@ -140,7 +140,12 @@ class Logger: self.chat_plugins = [] def typewriter_log( - self, title="", title_color="", content="", speak_text=False, level=logging.INFO + self, + title="", + title_color="", + content="", + speak_text=False, + level=logging.INFO, ): """ Logs a message to the typewriter. @@ -255,7 +260,9 @@ class Logger: if isinstance(message, list): message = " ".join(message) self.logger.log( - level, message, extra={"title": str(title), "color": str(title_color)} + level, + message, + extra={"title": str(title), "color": str(title_color)}, ) def set_level(self, level): @@ -284,12 +291,15 @@ class Logger: if not additionalText: additionalText = ( "Please ensure you've setup and configured everything" - " correctly. Read https://github.com/Torantulino/Auto-GPT#readme to " - "double check. You can also create a github issue or join the discord" + " correctly. Read" + " https://github.com/Torantulino/Auto-GPT#readme to double" + " check. You can also create a github issue or join the discord" " and ask there!" ) - self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText) + self.typewriter_log( + "DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText + ) def log_json(self, data: Any, file_name: str) -> None: """ @@ -367,7 +377,9 @@ class TypingConsoleHandler(logging.StreamHandler): print(word, end="", flush=True) if i < len(words) - 1: print(" ", end="", flush=True) - typing_speed = random.uniform(min_typing_speed, max_typing_speed) + typing_speed = random.uniform( + min_typing_speed, max_typing_speed + ) time.sleep(typing_speed) # type faster after each word min_typing_speed = min_typing_speed * 0.95 diff --git a/swarms/utils/main.py b/swarms/utils/main.py index a17d4782..73704552 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -201,7 +201,9 @@ def dim_multiline(message: str) -> str: lines = message.split("\n") if len(lines) <= 1: return lines[0] - return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) + return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to( + Color.black().bright() + ) # +=============================> ANSI Ending @@ -227,7 +229,9 @@ class AbstractUploader(ABC): class S3Uploader(AbstractUploader): - def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str): + def __init__( + self, accessKey: str, secretKey: str, region: str, bucket: str + ): self.accessKey = accessKey self.secretKey = secretKey self.region = region @@ -338,7 +342,9 @@ class FileHandler: self.handlers = handlers self.path = path - def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": + def register( + self, filetype: FileType, handler: BaseHandler + ) -> "FileHandler": self.handlers[filetype] = handler return self @@ -356,7 +362,9 @@ class FileHandler: def handle(self, url: str) -> str: try: - if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): + if url.startswith( + os.environ.get("SERVER", "http://localhost:8000") + ): local_filepath = url[ len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : ] diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index a2f346ea..9e3b8cb4 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -7,5 +7,7 @@ def extract_code_in_backticks_in_string(message: str) -> str: """ pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks - match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars + match = re.search( + pattern, message, re.DOTALL + ) # re.DOTALL to match newline chars return match.group(1).strip() if match else None diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index 8f0e5ccf..c7f9bc2c 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -109,9 +109,11 @@ class Serializable(BaseModel, ABC): "lc": 1, "type": "constructor", "id": [*self.lc_namespace, self.__class__.__name__], - "kwargs": lc_kwargs - if not secrets - else _replace_secrets(lc_kwargs, secrets), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), } def to_json_not_implemented(self) -> SerializedNotImplemented: diff --git a/tests/agents/omni_modal.py b/tests/agents/omni_modal.py index d106f66c..41aa050b 100644 --- a/tests/agents/omni_modal.py +++ b/tests/agents/omni_modal.py @@ -35,4 +35,6 @@ def test_omnimodalagent_run(omni_agent): def test_task_executor_initialization(omni_agent): - assert omni_agent.task_executor is not None, "TaskExecutor initialization failed" + assert ( + omni_agent.task_executor is not None + ), "TaskExecutor initialization failed" diff --git a/tests/memory/oceandb.py b/tests/memory/oceandb.py index 3e31afab..c74b7c15 100644 --- a/tests/memory/oceandb.py +++ b/tests/memory/oceandb.py @@ -30,7 +30,9 @@ def test_create_collection(): def test_create_collection_exception(): with patch("oceandb.Client") as MockClient: - MockClient.create_collection.side_effect = Exception("Create collection error") + MockClient.create_collection.side_effect = Exception( + "Create collection error" + ) db = OceanDB(MockClient) with pytest.raises(Exception) as e: db.create_collection("test", "modality") diff --git a/tests/memory/pinecone.py b/tests/memory/pinecone.py index bd037bef..106a6e81 100644 --- a/tests/memory/pinecone.py +++ b/tests/memory/pinecone.py @@ -6,7 +6,9 @@ api_key = os.getenv("PINECONE_API_KEY") or "" def test_init(): - with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + with patch("pinecone.init") as MockInit, patch( + "pinecone.Index" + ) as MockIndex: store = PineconeVectorStore( api_key=api_key, index_name="test_index", environment="test_env" ) diff --git a/tests/models/LLM.py b/tests/models/LLM.py index 20493519..a7ca149f 100644 --- a/tests/models/LLM.py +++ b/tests/models/LLM.py @@ -11,7 +11,9 @@ class TestLLM(unittest.TestCase): @patch.object(ChatOpenAI, "__init__", return_value=None) def setUp(self, mock_hf_init, mock_openai_init): self.llm_openai = LLM(openai_api_key="mock_openai_key") - self.llm_hf = LLM(hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token") + self.llm_hf = LLM( + hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token" + ) self.prompt = "Who won the FIFA World Cup in 1998?" def test_init(self): diff --git a/tests/models/anthropic.py b/tests/models/anthropic.py index e2447614..fecd3585 100644 --- a/tests/models/anthropic.py +++ b/tests/models/anthropic.py @@ -74,7 +74,9 @@ def test_anthropic_default_params(anthropic_instance): } -def test_anthropic_run(mock_anthropic_env, mock_requests_post, anthropic_instance): +def test_anthropic_run( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -98,7 +100,9 @@ def test_anthropic_run(mock_anthropic_env, mock_requests_post, anthropic_instanc ) -def test_anthropic_call(mock_anthropic_env, mock_requests_post, anthropic_instance): +def test_anthropic_call( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -193,18 +197,24 @@ def test_anthropic_convert_prompt(anthropic_instance): def test_anthropic_call_with_stop(anthropic_instance): - response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"]) + response = anthropic_instance( + "Translate to French.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" def test_anthropic_stream_with_stop(anthropic_instance): - generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"]) + generator = anthropic_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_anthropic_async_call_with_stop(anthropic_instance): - response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + response = anthropic_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" diff --git a/tests/models/auto_temp.py b/tests/models/auto_temp.py index bd37e5bb..76cdc7c3 100644 --- a/tests/models/auto_temp.py +++ b/tests/models/auto_temp.py @@ -47,7 +47,9 @@ def test_run_auto_select(auto_temp_agent): def test_run_no_scores(auto_temp_agent): task = "Invalid task." temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" - with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor: + with ThreadPoolExecutor( + max_workers=auto_temp_agent.max_workers + ) as executor: with patch.object( executor, "submit", side_effect=[None, None, None, None, None, None] ): diff --git a/tests/models/bingchat.py b/tests/models/bingchat.py index ce3af99d..8f29f905 100644 --- a/tests/models/bingchat.py +++ b/tests/models/bingchat.py @@ -44,7 +44,9 @@ class TestBingChat(unittest.TestCase): original_image_gen = BingChat.ImageGen BingChat.ImageGen = MockImageGen - img_path = self.chat.create_img("Test prompt", auth_cookie="mock_auth_cookie") + img_path = self.chat.create_img( + "Test prompt", auth_cookie="mock_auth_cookie" + ) self.assertEqual(img_path, "./output/mock_image.png") BingChat.ImageGen = original_image_gen diff --git a/tests/models/bioclip.py b/tests/models/bioclip.py index 50a65570..54ab5bb9 100644 --- a/tests/models/bioclip.py +++ b/tests/models/bioclip.py @@ -127,7 +127,9 @@ def test_clip_multiple_images(clip_instance, sample_image_path): # Test model inference performance -def test_clip_inference_performance(clip_instance, sample_image_path, benchmark): +def test_clip_inference_performance( + clip_instance, sample_image_path, benchmark +): labels = [ "adenocarcinoma histopathology", "brain MRI", diff --git a/tests/models/biogpt.py b/tests/models/biogpt.py index f420292b..e1daa14e 100644 --- a/tests/models/biogpt.py +++ b/tests/models/biogpt.py @@ -46,7 +46,10 @@ def test_cell_biology_response(biogpt_instance): # 40. Test for a question about protein structure def test_protein_structure_response(biogpt_instance): - question = "What's the difference between alpha helix and beta sheet structures in proteins?" + question = ( + "What's the difference between alpha helix and beta sheet structures in" + " proteins?" + ) response = biogpt_instance(question) assert response and isinstance(response, str) diff --git a/tests/models/cohere.py b/tests/models/cohere.py index d1bea935..08a0e39d 100644 --- a/tests/models/cohere.py +++ b/tests/models/cohere.py @@ -49,7 +49,9 @@ def test_cohere_stream_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - generator = cohere_instance.stream("Error handling with invalid API key.") + generator = cohere_instance.stream( + "Error handling with invalid API key." + ) for token in generator: pass @@ -94,13 +96,17 @@ def test_cohere_call_with_stop(cohere_instance): def test_cohere_stream_with_stop(cohere_instance): - generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"]) + generator = cohere_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_cohere_async_call_with_stop(cohere_instance): - response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + response = cohere_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Cohere" @@ -187,14 +193,22 @@ def test_cohere_generate_with_embed_english_v2(cohere_instance): def test_cohere_generate_with_embed_english_light_v2(cohere_instance): cohere_instance.model = "embed-english-light-v2.0" - response = cohere_instance("Generate embeddings with English Light v2.0 model.") - assert response.startswith("Generated embeddings with English Light v2.0 model") + response = cohere_instance( + "Generate embeddings with English Light v2.0 model." + ) + assert response.startswith( + "Generated embeddings with English Light v2.0 model" + ) def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): cohere_instance.model = "embed-multilingual-v2.0" - response = cohere_instance("Generate embeddings with Multilingual v2.0 model.") - assert response.startswith("Generated embeddings with Multilingual v2.0 model") + response = cohere_instance( + "Generate embeddings with Multilingual v2.0 model." + ) + assert response.startswith( + "Generated embeddings with Multilingual v2.0 model" + ) def test_cohere_generate_with_embed_english_v3(cohere_instance): @@ -205,14 +219,22 @@ def test_cohere_generate_with_embed_english_v3(cohere_instance): def test_cohere_generate_with_embed_english_light_v3(cohere_instance): cohere_instance.model = "embed-english-light-v3.0" - response = cohere_instance("Generate embeddings with English Light v3.0 model.") - assert response.startswith("Generated embeddings with English Light v3.0 model") + response = cohere_instance( + "Generate embeddings with English Light v3.0 model." + ) + assert response.startswith( + "Generated embeddings with English Light v3.0 model" + ) def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): cohere_instance.model = "embed-multilingual-v3.0" - response = cohere_instance("Generate embeddings with Multilingual v3.0 model.") - assert response.startswith("Generated embeddings with Multilingual v3.0 model") + response = cohere_instance( + "Generate embeddings with Multilingual v3.0 model." + ) + assert response.startswith( + "Generated embeddings with Multilingual v3.0 model" + ) def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance): @@ -423,7 +445,9 @@ def test_cohere_representation_model_classification(cohere_instance): def test_cohere_representation_model_language_detection(cohere_instance): # Test using the Representation model for language detection cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language("Detect the language of this text.") + language = cohere_instance.detect_language( + "Detect the language of this text." + ) assert isinstance(language, str) @@ -447,7 +471,9 @@ def test_cohere_representation_model_multilingual_embedding(cohere_instance): assert len(embedding) > 0 -def test_cohere_representation_model_multilingual_classification(cohere_instance): +def test_cohere_representation_model_multilingual_classification( + cohere_instance, +): # Test using the Representation model for multilingual text classification cohere_instance.model = "embed-multilingual-v3.0" classification = cohere_instance.classify("Classify multilingual text.") @@ -456,7 +482,9 @@ def test_cohere_representation_model_multilingual_classification(cohere_instance assert "score" in classification -def test_cohere_representation_model_multilingual_language_detection(cohere_instance): +def test_cohere_representation_model_multilingual_language_detection( + cohere_instance, +): # Test using the Representation model for multilingual language detection cohere_instance.model = "embed-multilingual-v3.0" language = cohere_instance.detect_language( @@ -471,12 +499,17 @@ def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( # Test handling max tokens limit exceeded error for multilingual model cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " multilingual model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) -def test_cohere_representation_model_multilingual_light_embedding(cohere_instance): +def test_cohere_representation_model_multilingual_light_embedding( + cohere_instance, +): # Test using the Representation model for multilingual light text embedding cohere_instance.model = "embed-multilingual-light-v3.0" embedding = cohere_instance.embed("Generate multilingual light embeddings.") @@ -484,10 +517,14 @@ def test_cohere_representation_model_multilingual_light_embedding(cohere_instanc assert len(embedding) > 0 -def test_cohere_representation_model_multilingual_light_classification(cohere_instance): +def test_cohere_representation_model_multilingual_light_classification( + cohere_instance, +): # Test using the Representation model for multilingual light text classification cohere_instance.model = "embed-multilingual-light-v3.0" - classification = cohere_instance.classify("Classify multilingual light text.") + classification = cohere_instance.classify( + "Classify multilingual light text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -510,7 +547,10 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede # Test handling max tokens limit exceeded error for multilingual light model cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " multilingual light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -553,19 +593,26 @@ def test_cohere_representation_model_english_classification(cohere_instance): assert "score" in classification -def test_cohere_representation_model_english_language_detection(cohere_instance): +def test_cohere_representation_model_english_language_detection( + cohere_instance, +): # Test using the Representation model for English language detection cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language("Detect the language of English text.") + language = cohere_instance.detect_language( + "Detect the language of English text." + ) assert isinstance(language, str) -def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance): +def test_cohere_representation_model_english_max_tokens_limit_exceeded( + cohere_instance, +): # Test handling max tokens limit exceeded error for English model cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for English model." + "This is a test prompt that will exceed the max tokens limit for" + " English model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -579,7 +626,9 @@ def test_cohere_representation_model_english_light_embedding(cohere_instance): assert len(embedding) > 0 -def test_cohere_representation_model_english_light_classification(cohere_instance): +def test_cohere_representation_model_english_light_classification( + cohere_instance, +): # Test using the Representation model for English light text classification cohere_instance.model = "embed-english-light-v3.0" classification = cohere_instance.classify("Classify English light text.") @@ -588,7 +637,9 @@ def test_cohere_representation_model_english_light_classification(cohere_instanc assert "score" in classification -def test_cohere_representation_model_english_light_language_detection(cohere_instance): +def test_cohere_representation_model_english_light_language_detection( + cohere_instance, +): # Test using the Representation model for English light language detection cohere_instance.model = "embed-english-light-v3.0" language = cohere_instance.detect_language( @@ -603,7 +654,10 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( # Test handling max tokens limit exceeded error for English light model cohere_instance.model = "embed-english-light-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for English light model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " English light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) diff --git a/tests/models/dalle3.py b/tests/models/dalle3.py index a23d077e..9b7cf0e1 100644 --- a/tests/models/dalle3.py +++ b/tests/models/dalle3.py @@ -62,7 +62,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): def test_dalle3_create_variations_success(dalle3, mock_openai_client): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) @@ -84,7 +86,9 @@ def test_dalle3_create_variations_success(dalle3, mock_openai_client): def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError @@ -186,7 +190,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client): assert str(excinfo.value) == expected_error_message -def test_dalle3_create_variations_with_invalid_image_url(dalle3, mock_openai_client): +def test_dalle3_create_variations_with_invalid_image_url( + dalle3, mock_openai_client +): # Arrange img_url = "https://invalid-image-url.com" expected_error_message = "Error running Dalle3: Invalid image URL" @@ -228,7 +234,9 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): # Simulate a retry scenario mock_openai_client.images.generate.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -242,14 +250,18 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) # Simulate a retry scenario mock_openai_client.images.create_variation.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_variation_url)]), ] @@ -280,9 +292,13 @@ def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys): assert expected_error_message in captured.err -def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, capsys): +def test_dalle3_create_variations_exception_logging( + dalle3, mock_openai_client, capsys +): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError @@ -323,7 +339,9 @@ def test_dalle3_call_no_api_key(): def test_dalle3_create_variations_no_api_key(): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) dalle3 = Dalle3(api_key=None) expected_error_message = "Error running Dalle3: API Key is missing" @@ -334,7 +352,9 @@ def test_dalle3_create_variations_no_api_key(): assert str(excinfo.value) == expected_error_message -def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client): +def test_dalle3_call_with_retry_max_retries_exceeded( + dalle3, mock_openai_client +): # Arrange task = "A painting of a dog" @@ -354,7 +374,9 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded( dalle3, mock_openai_client ): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) # Simulate max retries exceeded mock_openai_client.images.create_variation.side_effect = OpenAIError( @@ -377,7 +399,9 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): # Simulate success after a retry mock_openai_client.images.generate.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -389,16 +413,22 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): assert mock_openai_client.images.generate.call_count == 2 -def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client): +def test_dalle3_create_variations_retry_with_success( + dalle3, mock_openai_client +): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) # Simulate success after a retry mock_openai_client.images.create_variation.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_variation_url)]), ] diff --git a/tests/models/distill_whisper.py b/tests/models/distill_whisper.py index d83caf62..6f95a0e3 100644 --- a/tests/models/distill_whisper.py +++ b/tests/models/distill_whisper.py @@ -44,7 +44,9 @@ async def test_async_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) - transcription = await distil_whisper_model.async_transcribe(audio_file_path) + transcription = await distil_whisper_model.async_transcribe( + audio_file_path + ) os.remove(audio_file_path) assert isinstance(transcription, str) @@ -62,7 +64,9 @@ def test_transcribe_audio_data(distil_whisper_model): @pytest.mark.asyncio async def test_async_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) - transcription = await distil_whisper_model.async_transcribe(test_data.tobytes()) + transcription = await distil_whisper_model.async_transcribe( + test_data.tobytes() + ) assert isinstance(transcription, str) assert transcription.strip() != "" @@ -73,7 +77,9 @@ def test_real_time_transcribe(distil_whisper_model, capsys): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) - distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) + distil_whisper_model.real_time_transcribe( + audio_file_path, chunk_duration=1 + ) os.remove(audio_file_path) @@ -82,7 +88,9 @@ def test_real_time_transcribe(distil_whisper_model, capsys): assert "Chunk" in captured.out -def test_real_time_transcribe_audio_file_not_found(distil_whisper_model, capsys): +def test_real_time_transcribe_audio_file_not_found( + distil_whisper_model, capsys +): audio_file_path = "non_existent_audio.wav" distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) @@ -145,7 +153,9 @@ def test_create_audio_file(): test_data = np.random.rand(16000) # Simulated audio data (1 second) sample_rate = 16000 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: - audio_file_path = create_audio_file(test_data, sample_rate, audio_file.name) + audio_file_path = create_audio_file( + test_data, sample_rate, audio_file.name + ) assert os.path.exists(audio_file_path) os.remove(audio_file_path) @@ -215,7 +225,9 @@ async def test_async_transcription_success(whisper_model, audio_file_path): @pytest.mark.asyncio -async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): +async def test_async_transcription_failure( + whisper_model, invalid_audio_file_path +): with pytest.raises(Exception): await whisper_model.async_transcribe(invalid_audio_file_path) @@ -248,17 +260,22 @@ def mocked_model(monkeypatch): model_mock, ) monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock + "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", + processor_mock, ) return model_mock, processor_mock @pytest.mark.asyncio -async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): +async def test_async_transcribe_with_mocked_model( + mocked_model, audio_file_path +): model_mock, processor_mock = mocked_model # Set up what the mock should return when it's called model_mock.return_value.generate.return_value = torch.tensor([[0]]) - processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] + processor_mock.return_value.batch_decode.return_value = [ + "mocked transcription" + ] model_wrapper = DistilWhisperModel() transcription = await model_wrapper.async_transcribe(audio_file_path) assert transcription == "mocked transcription" diff --git a/tests/models/elevenlab.py b/tests/models/elevenlab.py index 7dbcf2ea..986ce937 100644 --- a/tests/models/elevenlab.py +++ b/tests/models/elevenlab.py @@ -61,10 +61,14 @@ def test_run_text_to_speech_with_mock(eleven_labs_tool): def test_run_text_to_speech_error_handling(eleven_labs_tool): with patch("your_module._import_elevenlabs") as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.side_effect = Exception("Test Exception") + mock_elevenlabs_instance.generate.side_effect = Exception( + "Test Exception" + ) with pytest.raises( RuntimeError, - match="Error while running ElevenLabsText2SpeechTool: Test Exception", + match=( + "Error while running ElevenLabsText2SpeechTool: Test Exception" + ), ): eleven_labs_tool.run(SAMPLE_TEXT) diff --git a/tests/models/fuyu.py b/tests/models/fuyu.py index 9a26dbfb..a70cb42a 100644 --- a/tests/models/fuyu.py +++ b/tests/models/fuyu.py @@ -75,7 +75,9 @@ def test_tokenizer_type(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance): - assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor + assert ( + fuyu_instance.processor.image_processor == fuyu_instance.image_processor + ) assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer diff --git a/tests/models/gpt4v.py b/tests/models/gpt4v.py index 23e97d03..8532d313 100644 --- a/tests/models/gpt4v.py +++ b/tests/models/gpt4v.py @@ -4,7 +4,12 @@ from unittest.mock import Mock import pytest from dotenv import load_dotenv -from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout +from requests.exceptions import ( + ConnectionError, + HTTPError, + RequestException, + Timeout, +) from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse @@ -173,7 +178,11 @@ def test_gpt4vision_call_retry_with_success_after_timeout( Timeout("Request timed out"), { "choices": [ - {"message": {"content": {"text": "A description of the image."}}} + { + "message": { + "content": {"text": "A description of the image."} + } + } ], }, ] @@ -200,16 +209,18 @@ def test_gpt4vision_process_img(): assert img_data.startswith("/9j/") # Base64-encoded image data -def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client): +def test_gpt4vision_call_single_task_single_image( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." expected_response = GPT4VisionResponse(answer="A description of the image.") - mock_openai_client.chat.completions.create.return_value.choices[ - 0 - ].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[0].text = ( + expected_response.answer + ) # Act response = gpt4vision(img_url, [task]) @@ -219,16 +230,21 @@ def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client mock_openai_client.chat.completions.create.assert_called_once() -def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_client): +def test_gpt4vision_call_single_task_multiple_images( + gpt4vision, mock_openai_client +): # Arrange - img_urls = ["https://example.com/image1.jpg", "https://example.com/image2.jpg"] + img_urls = [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + ] task = "Describe these images." expected_response = GPT4VisionResponse(answer="Descriptions of the images.") - mock_openai_client.chat.completions.create.return_value.choices[ - 0 - ].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[0].text = ( + expected_response.answer + ) # Act response = gpt4vision(img_urls, [task]) @@ -238,7 +254,9 @@ def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_cli mock_openai_client.chat.completions.create.assert_called_once() -def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): +def test_gpt4vision_call_multiple_tasks_single_image( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" tasks = ["Describe this image.", "What's in this picture?"] @@ -249,7 +267,9 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli ] def create_mock_response(response): - return {"choices": [{"message": {"content": {"text": response.answer}}}]} + return { + "choices": [{"message": {"content": {"text": response.answer}}}] + } mock_openai_client.chat.completions.create.side_effect = [ create_mock_response(response) for response in expected_responses @@ -279,7 +299,11 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli mock_openai_client.chat.completions.create.side_effect = [ { "choices": [ - {"message": {"content": {"text": expected_responses[i].answer}}} + { + "message": { + "content": {"text": expected_responses[i].answer} + } + } ] } for i in range(len(expected_responses)) @@ -295,7 +319,9 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli ) # Should be called only once -def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client): +def test_gpt4vision_call_multiple_tasks_multiple_images( + gpt4vision, mock_openai_client +): # Arrange img_urls = [ "https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", @@ -328,7 +354,9 @@ def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = HTTPError("HTTP Error") + mock_openai_client.chat.completions.create.side_effect = HTTPError( + "HTTP Error" + ) # Act and Assert with pytest.raises(HTTPError): diff --git a/tests/models/hf.py b/tests/models/hf.py index ab3b648d..d3ff9a04 100644 --- a/tests/models/hf.py +++ b/tests/models/hf.py @@ -26,7 +26,10 @@ def mock_bitsandbytesconfig(): @pytest.fixture def hugging_face_llm( - mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig + mock_torch, + mock_autotokenizer, + mock_automodelforcausallm, + mock_bitsandbytesconfig, ): HuggingFaceLLM.torch = mock_torch HuggingFaceLLM.AutoTokenizer = mock_autotokenizer @@ -70,8 +73,12 @@ def test_init_with_quantize( def test_generate_text(hugging_face_llm): prompt_text = "test prompt" expected_output = "test output" - hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0]) # Mock tensor - hugging_face_llm.model.generate.return_value = torch.tensor([0]) # Mock tensor + hugging_face_llm.tokenizer.encode.return_value = torch.tensor( + [0] + ) # Mock tensor + hugging_face_llm.model.generate.return_value = torch.tensor( + [0] + ) # Mock tensor hugging_face_llm.tokenizer.decode.return_value = expected_output output = hugging_face_llm.generate_text(prompt_text) diff --git a/tests/models/huggingface.py b/tests/models/huggingface.py index 9a27054a..62261b9c 100644 --- a/tests/models/huggingface.py +++ b/tests/models/huggingface.py @@ -84,7 +84,9 @@ def test_llm_initialization_params(model_id, max_length): assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) - assert instance.max_length == 500 # Assuming 500 is the default max_length + assert ( + instance.max_length == 500 + ) # Assuming 500 is the default max_length # Test for setting an invalid device @@ -180,7 +182,8 @@ def test_llm_long_input_warning(mock_warning, llm_instance): # Test for run method behavior when model raises an exception @patch( - "swarms.models.huggingface.HuggingfaceLLM._model.generate", side_effect=RuntimeError + "swarms.models.huggingface.HuggingfaceLLM._model.generate", + side_effect=RuntimeError, ) def test_llm_run_model_exception(mock_generate, llm_instance): with pytest.raises(RuntimeError): @@ -191,7 +194,9 @@ def test_llm_run_model_exception(mock_generate, llm_instance): @patch("torch.cuda.is_available", return_value=False) def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): with pytest.raises(EnvironmentError): - llm_instance.set_device("cuda") # Attempt to set CUDA when it's not available + llm_instance.set_device( + "cuda" + ) # Attempt to set CUDA when it's not available # Test for proper cleanup after model use (releasing resources) @@ -225,7 +230,9 @@ def test_llm_multilingual_input(mock_run, llm_instance): mock_run.return_value = "mocked multilingual output" multilingual_input = "Bonjour, ceci est un test multilingue." result = llm_instance.run(multilingual_input) - assert isinstance(result, str) # Simple check to ensure output is string type + assert isinstance( + result, str + ) # Simple check to ensure output is string type # Test caching mechanism to prevent re-running the same inputs diff --git a/tests/models/idefics.py b/tests/models/idefics.py index 610657bd..2ee9f010 100644 --- a/tests/models/idefics.py +++ b/tests/models/idefics.py @@ -1,7 +1,11 @@ import pytest from unittest.mock import patch import torch -from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor +from swarms.models.idefics import ( + Idefics, + IdeficsForVisionText2Text, + AutoProcessor, +) @pytest.fixture @@ -30,7 +34,8 @@ def test_init_default(idefics_instance): ) def test_init_device(device, expected): with patch( - "torch.cuda.is_available", return_value=True if expected == "cuda" else False + "torch.cuda.is_available", + return_value=True if expected == "cuda" else False, ): instance = Idefics(device=device) assert instance.device == expected @@ -39,9 +44,9 @@ def test_init_device(device, expected): # Test `run` method def test_run(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -54,9 +59,9 @@ def test_run(idefics_instance): # Test `__call__` method (using the same logic as run for simplicity) def test_call(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -85,7 +90,9 @@ def test_set_checkpoint(idefics_instance): ) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"): idefics_instance.set_checkpoint(new_checkpoint) - mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16) + mock_from_pretrained.assert_called_with( + new_checkpoint, torch_dtype=torch.bfloat16 + ) # Test `set_device` method diff --git a/tests/models/kosmos.py b/tests/models/kosmos.py index 11d224d1..aaa756a3 100644 --- a/tests/models/kosmos.py +++ b/tests/models/kosmos.py @@ -45,7 +45,9 @@ def test_get_image(mock_image_request): # Test multimodal grounding def test_multimodal_grounding(mock_image_request): kosmos = Kosmos() - kosmos.multimodal_grounding("Find the red apple in the image.", TEST_IMAGE_URL) + kosmos.multimodal_grounding( + "Find the red apple in the image.", TEST_IMAGE_URL + ) # TODO: Validate the result if possible @@ -117,7 +119,9 @@ def test_multimodal_grounding(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension(kosmos): - kosmos.referring_expression_comprehension("Show me the green bottle.", IMG_URL2) + kosmos.referring_expression_comprehension( + "Show me the green bottle.", IMG_URL2 + ) @pytest.mark.usefixtures("mock_request_get") @@ -147,7 +151,9 @@ def test_multimodal_grounding_2(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension_2(kosmos): - kosmos.referring_expression_comprehension("Where is the water bottle?", IMG_URL3) + kosmos.referring_expression_comprehension( + "Where is the water bottle?", IMG_URL3 + ) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/kosmos2.py b/tests/models/kosmos2.py index 2ff01092..1ad824cc 100644 --- a/tests/models/kosmos2.py +++ b/tests/models/kosmos2.py @@ -36,7 +36,10 @@ def test_kosmos2_with_sample_image(kosmos2, sample_image): # Mocked extract_entities function for testing def mock_extract_entities(text): - return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + return [ + ("entity1", (0.1, 0.2, 0.3, 0.4)), + ("entity2", (0.5, 0.6, 0.7, 0.8)), + ] # Mocked process_entities_to_detections function for testing @@ -54,7 +57,9 @@ def test_kosmos2_with_mocked_extraction_and_detection( ): monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) monkeypatch.setattr( - kosmos2, "process_entities_to_detections", mock_process_entities_to_detections + kosmos2, + "process_entities_to_detections", + mock_process_entities_to_detections, ) detections = kosmos2(img=sample_image) @@ -234,7 +239,12 @@ def test_kosmos2_with_entities_containing_special_characters( kosmos2, sample_image, monkeypatch ): def mock_extract_entities(text): - return [("entity1 with special characters (ü, ö, etc.)", (0.1, 0.2, 0.3, 0.4))] + return [ + ( + "entity1 with special characters (ü, ö, etc.)", + (0.1, 0.2, 0.3, 0.4), + ) + ] monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) detections = kosmos2(img=sample_image) @@ -252,7 +262,10 @@ def test_kosmos2_with_image_containing_multiple_objects( kosmos2, sample_image, monkeypatch ): def mock_extract_entities(text): - return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + return [ + ("entity1", (0.1, 0.2, 0.3, 0.4)), + ("entity2", (0.5, 0.6, 0.7, 0.8)), + ] monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) detections = kosmos2(img=sample_image) @@ -266,7 +279,9 @@ def test_kosmos2_with_image_containing_multiple_objects( # Test Kosmos2 with image containing no objects -def test_kosmos2_with_image_containing_no_objects(kosmos2, sample_image, monkeypatch): +def test_kosmos2_with_image_containing_no_objects( + kosmos2, sample_image, monkeypatch +): def mock_extract_entities(text): return [] diff --git a/tests/models/llama_function_caller.py b/tests/models/llama_function_caller.py index c54c264b..c38b2267 100644 --- a/tests/models/llama_function_caller.py +++ b/tests/models/llama_function_caller.py @@ -72,7 +72,9 @@ def test_llama_custom_function_invalid_arguments(llama_caller): # Test streaming with custom runtime def test_llama_custom_runtime(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) user_prompt = "Tell me about the tallest mountain in the world." response = llama_caller(user_prompt) @@ -83,7 +85,9 @@ def test_llama_custom_runtime(): # Test caching functionality def test_llama_cache(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) # Perform a request to populate the cache @@ -99,7 +103,9 @@ def test_llama_cache(): # Test response length within max_tokens limit def test_llama_response_length(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) # Generate a long prompt diff --git a/tests/models/nougat.py b/tests/models/nougat.py index e61a45af..ac972e07 100644 --- a/tests/models/nougat.py +++ b/tests/models/nougat.py @@ -21,7 +21,9 @@ def test_nougat_default_initialization(setup_nougat): def test_nougat_custom_initialization(): - nougat = Nougat(model_name_or_path="custom_path", min_length=10, max_new_tokens=50) + nougat = Nougat( + model_name_or_path="custom_path", min_length=10, max_new_tokens=50 + ) assert nougat.model_name_or_path == "custom_path" assert nougat.min_length == 10 assert nougat.max_new_tokens == 50 @@ -98,7 +100,8 @@ def mock_processor_and_model(): with patch( "transformers.NougatProcessor.from_pretrained", return_value=Mock() ), patch( - "transformers.VisionEncoderDecoderModel.from_pretrained", return_value=Mock() + "transformers.VisionEncoderDecoderModel.from_pretrained", + return_value=Mock(), ): yield diff --git a/tests/models/revgptv1.py b/tests/models/revgptv1.py index 12ceeea0..5908b64e 100644 --- a/tests/models/revgptv1.py +++ b/tests/models/revgptv1.py @@ -19,7 +19,10 @@ class TestRevChatGPT(unittest.TestCase): self.assertLess(self.model.end_time - self.model.start_time, 60) def test_generate_summary(self): - text = "This is a sample text to summarize. It has multiple sentences and details. The summary should be concise." + text = ( + "This is a sample text to summarize. It has multiple sentences and" + " details. The summary should be concise." + ) summary = self.model.generate_summary(text) self.assertLess(len(summary), len(text) / 2) @@ -60,7 +63,9 @@ class TestRevChatGPT(unittest.TestCase): convo_id = "123" title = "New Title" self.model.chatbot.change_title(convo_id, title) - self.assertEqual(self.model.chatbot.get_msg_history(convo_id)["title"], title) + self.assertEqual( + self.model.chatbot.get_msg_history(convo_id)["title"], title + ) def test_delete_conversation(self): convo_id = "123" @@ -76,7 +81,9 @@ class TestRevChatGPT(unittest.TestCase): def test_rollback_conversation(self): original_convo_id = self.model.chatbot.conversation_id self.model.chatbot.rollback_conversation(1) - self.assertNotEqual(original_convo_id, self.model.chatbot.conversation_id) + self.assertNotEqual( + original_convo_id, self.model.chatbot.conversation_id + ) if __name__ == "__main__": diff --git a/tests/models/speech_t5.py b/tests/models/speech_t5.py index 4e5f4cb1..f4d21a30 100644 --- a/tests/models/speech_t5.py +++ b/tests/models/speech_t5.py @@ -17,7 +17,9 @@ def test_speecht5_init(speecht5_model): assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) - assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) def test_speecht5_call(speecht5_model): @@ -59,8 +61,12 @@ def test_speecht5_set_embeddings_dataset(speecht5_model): new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" speecht5_model.set_embeddings_dataset(new_dataset_name) assert speecht5_model.dataset_name == new_dataset_name - assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) - speecht5_model.set_embeddings_dataset(old_dataset_name) # Restore original dataset + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) + speecht5_model.set_embeddings_dataset( + old_dataset_name + ) # Restore original dataset def test_speecht5_get_sampling_rate(speecht5_model): diff --git a/tests/models/ssd_1b.py b/tests/models/ssd_1b.py index 7bd3154c..7a7a897f 100644 --- a/tests/models/ssd_1b.py +++ b/tests/models/ssd_1b.py @@ -19,18 +19,24 @@ def test_ssd1b_call(ssd1b_model): neg_prompt = "ugly, blurry, poor quality" image_url = ssd1b_model(task, neg_prompt) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" # Add more tests for various aspects of the class and methods # Example of a parameterized test for different tasks -@pytest.mark.parametrize("task", ["A painting of a cat", "A painting of a tree"]) +@pytest.mark.parametrize( + "task", ["A painting of a cat", "A painting of a tree"] +) def test_ssd1b_parameterized_task(ssd1b_model, task): image_url = ssd1b_model(task) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" # Example of a test using mocks to isolate units of code @@ -39,7 +45,9 @@ def test_ssd1b_with_mock(ssd1b_model, mocker): task = "A painting of a cat" image_url = ssd1b_model(task) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" def test_ssd1b_call_with_cache(ssd1b_model): diff --git a/tests/models/timm_model.py b/tests/models/timm_model.py index a3e62605..07f68b05 100644 --- a/tests/models/timm_model.py +++ b/tests/models/timm_model.py @@ -66,7 +66,9 @@ def test_call_parameterized(model_name, pretrained, in_chans): def test_get_supported_models_mock(): model_handler = TimmModel() - model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"]) + model_handler._get_supported_models = Mock( + return_value=["resnet18", "resnet50"] + ) supported_models = model_handler._get_supported_models() assert supported_models == ["resnet18", "resnet50"] @@ -80,7 +82,9 @@ def test_create_model_mock(sample_model_info): def test_call_exception(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="invalid_model", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) with pytest.raises(Exception): model_handler.__call__(model_info, input_tensor) @@ -111,7 +115,9 @@ def test_environment_variable(): @pytest.mark.slow def test_marked_slow(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) output_shape = model_handler(model_info, input_tensor) assert isinstance(output_shape, torch.Size) @@ -136,7 +142,9 @@ def test_marked_parameterized(model_name, pretrained, in_chans): def test_exception_testing(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="invalid_model", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) with pytest.raises(Exception): model_handler.__call__(model_info, input_tensor) @@ -144,7 +152,9 @@ def test_exception_testing(): def test_parameterized_testing(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) output_shape = model_handler.__call__(model_info, input_tensor) assert isinstance(output_shape, torch.Size) @@ -153,7 +163,9 @@ def test_parameterized_testing(): def test_use_mocks_and_monkeypatching(): model_handler = TimmModel() model_handler._create_model = Mock(return_value=torch.nn.Module()) - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) model = model_handler._create_model(model_info) assert isinstance(model, torch.nn.Module) diff --git a/tests/models/vilt.py b/tests/models/vilt.py index b376f41b..8dcdce88 100644 --- a/tests/models/vilt.py +++ b/tests/models/vilt.py @@ -33,7 +33,9 @@ def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): # 3. Test Exception Handling for network -@patch.object(requests, "get", side_effect=requests.RequestException("Network error")) +@patch.object( + requests, "get", side_effect=requests.RequestException("Network error") +) def test_vilt_network_exception(vilt_instance): with pytest.raises(requests.RequestException): vilt_instance( diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py index bcbd02e9..5fad3431 100644 --- a/tests/models/whisperx.py +++ b/tests/models/whisperx.py @@ -28,7 +28,9 @@ def test_speech_to_text_install(mock_run): # Mock pytube.YouTube and pytube.Streams for download tests @patch("pytube.YouTube") @patch.object(YouTube, "streams") -def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir): +def test_speech_to_text_download_youtube_video( + mock_streams, mock_youtube, temp_dir +): # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() @@ -116,7 +118,11 @@ def test_speech_to_text_transcribe_whisperx_failure( @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_missing_segments( - mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model + mock_diarization, + mock_align, + mock_align_model, + mock_load_audio, + mock_load_model, ): # Mock whisperx functions to return incomplete output mock_load_model.return_value = mock_load_model @@ -142,7 +148,11 @@ def test_speech_to_text_transcribe_missing_segments( @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_align_failure( - mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model + mock_diarization, + mock_align, + mock_align_model, + mock_load_audio, + mock_load_model, ): # Mock whisperx functions to raise an exception during align mock_load_model.return_value = mock_load_model diff --git a/tests/models/yi_200k.py b/tests/models/yi_200k.py index 72a6d1b2..6b179ca1 100644 --- a/tests/models/yi_200k.py +++ b/tests/models/yi_200k.py @@ -39,21 +39,27 @@ def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): prompt = None # Invalid prompt - with pytest.raises(ValueError, match="Input prompt must be a non-empty string"): + with pytest.raises( + ValueError, match="Input prompt must be a non-empty string" + ): yi34b_model(prompt) def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): prompt = "There's a place where time stands still." max_length = -1 # Invalid max_length - with pytest.raises(ValueError, match="max_length must be a positive integer"): + with pytest.raises( + ValueError, match="max_length must be a positive integer" + ): yi34b_model(prompt, max_length=max_length) def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): prompt = "There's a place where time stands still." temperature = 2.0 # Invalid temperature - with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"): + with pytest.raises( + ValueError, match="temperature must be between 0.01 and 1.0" + ): yi34b_model(prompt, temperature=temperature) @@ -74,7 +80,9 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p): def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): prompt = "There's a place where time stands still." top_k = -1 # Invalid top_k - with pytest.raises(ValueError, match="top_k must be a non-negative integer"): + with pytest.raises( + ValueError, match="top_k must be a non-negative integer" + ): yi34b_model(prompt, top_k=top_k) @@ -86,7 +94,9 @@ def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): @pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) -def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty): +def test_yi34b_generate_text_with_repitition_penalty( + yi34b_model, repitition_penalty +): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) assert isinstance(generated_text, str) @@ -95,7 +105,9 @@ def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_pen def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model): prompt = "There's a place where time stands still." repitition_penalty = 0.0 # Invalid repitition_penalty - with pytest.raises(ValueError, match="repitition_penalty must be a positive float"): + with pytest.raises( + ValueError, match="repitition_penalty must be a positive float" + ): yi34b_model(prompt, repitition_penalty=repitition_penalty) diff --git a/tests/structs/flow.py b/tests/structs/flow.py index edc4b9c7..84034a08 100644 --- a/tests/structs/flow.py +++ b/tests/structs/flow.py @@ -30,7 +30,9 @@ def basic_flow(mocked_llm): @pytest.fixture def flow_with_condition(mocked_llm): - return Flow(llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats) + return Flow( + llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats + ) # Basic Tests @@ -61,7 +63,9 @@ def test_provide_feedback(basic_flow): @patch("time.sleep", return_value=None) # to speed up tests def test_run_without_stopping_condition(mocked_sleep, basic_flow): response = basic_flow.run("Test task") - assert response == "Test task" # since our mocked llm doesn't modify the response + assert ( + response == "Test task" + ) # since our mocked llm doesn't modify the response @patch("time.sleep", return_value=None) # to speed up tests @@ -108,7 +112,9 @@ def test_flow_with_custom_stopping_condition(mocked_llm): def stopping_condition(x): return "terminate" in x.lower() - flow = Flow(llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition) + flow = Flow( + llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition + ) assert flow.stopping_condition("Please terminate now") assert not flow.stopping_condition("Continue the process") @@ -174,7 +180,9 @@ def test_save_different_memory(basic_flow, tmp_path): # Test the stopping condition check def test_check_stopping_condition(flow_with_condition): assert flow_with_condition._check_stopping_condition("Stop this process") - assert not flow_with_condition._check_stopping_condition("Continue the task") + assert not flow_with_condition._check_stopping_condition( + "Continue the task" + ) # Test without providing max loops (default value should be 5) @@ -370,7 +378,8 @@ def test_flow_autosave_path(flow_instance): def test_flow_response_length(flow_instance): # Test checking the length of the response response = flow_instance.run( - "Generate a 10,000 word long blog on mental clarity and the benefits of meditation." + "Generate a 10,000 word long blog on mental clarity and the benefits of" + " meditation." ) assert len(response) > flow_instance.get_response_length_threshold() @@ -550,7 +559,10 @@ def test_flow_rollback(flow_instance): assert flow_instance.get_user_messages() == state1["user_messages"] assert flow_instance.get_response_history() == state1["response_history"] assert flow_instance.get_conversation_log() == state1["conversation_log"] - assert flow_instance.is_dynamic_pacing_enabled() == state1["dynamic_pacing_enabled"] + assert ( + flow_instance.is_dynamic_pacing_enabled() + == state1["dynamic_pacing_enabled"] + ) assert ( flow_instance.get_response_length_threshold() == state1["response_length_threshold"] @@ -565,7 +577,9 @@ def test_flow_contextual_intent(flow_instance): # Test contextual intent handling flow_instance.add_context("location", "New York") flow_instance.add_context("time", "tomorrow") - response = flow_instance.run("What's the weather like in {location} at {time}?") + response = flow_instance.run( + "What's the weather like in {location} at {time}?" + ) assert "New York" in response assert "tomorrow" in response @@ -689,7 +703,9 @@ def test_flow_clear_injected_messages(flow_instance): def test_flow_disable_message_history(flow_instance): # Test disabling message history recording flow_instance.disable_message_history() - response = flow_instance.run("This message should not be recorded in history.") + response = flow_instance.run( + "This message should not be recorded in history." + ) assert "This message should not be recorded in history." in response assert len(flow_instance.get_message_history()) == 0 # History is empty @@ -800,16 +816,24 @@ def test_flow_input_validation(flow_instance): ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context(None, "value") # None key, should raise ValueError + flow_instance.update_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context("key", None) # None value, should raise ValueError + flow_instance.update_context( + "key", None + ) # None value, should raise ValueError def test_flow_conversation_reset(flow_instance): @@ -913,16 +937,24 @@ def test_flow_error_handling(flow_instance): ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context(None, "value") # None key, should raise ValueError + flow_instance.update_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context("key", None) # None value, should raise ValueError + flow_instance.update_context( + "key", None + ) # None value, should raise ValueError def test_flow_context_operations(flow_instance): @@ -1089,7 +1121,9 @@ def test_flow_agent_history_prompt(flow_instance): system_prompt = "This is the system prompt." history = ["User: Hi", "AI: Hello"] - agent_history_prompt = flow_instance.agent_history_prompt(system_prompt, history) + agent_history_prompt = flow_instance.agent_history_prompt( + system_prompt, history + ) assert "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt diff --git a/tests/swarms/godmode.py b/tests/swarms/godmode.py index fa9e0c13..8f528026 100644 --- a/tests/swarms/godmode.py +++ b/tests/swarms/godmode.py @@ -16,7 +16,13 @@ def test_godmode_run(monkeypatch): godmode = GodMode(llms=[LLM] * 5) responses = godmode.run("task1") assert len(responses) == 5 - assert responses == ["response", "response", "response", "response", "response"] + assert responses == [ + "response", + "response", + "response", + "response", + "response", + ] @patch("builtins.print") diff --git a/tests/swarms/groupchat.py b/tests/swarms/groupchat.py index b25e7f91..56979d52 100644 --- a/tests/swarms/groupchat.py +++ b/tests/swarms/groupchat.py @@ -165,7 +165,9 @@ def test_groupchat_select_speaker(): # Simulate selecting the next speaker last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is agent2 assert next_speaker == agent2 @@ -183,7 +185,9 @@ def test_groupchat_underpopulated_group(): # Simulate selecting the next speaker in an underpopulated group last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker in an underpopulated group assert next_speaker == last_speaker @@ -207,7 +211,9 @@ def test_groupchat_max_rounds(): last_speaker = next_speaker # Try one more round, should stay with the last speaker - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker after reaching max rounds assert next_speaker == last_speaker diff --git a/tests/swarms/multi_agent_collab.py b/tests/swarms/multi_agent_collab.py index 3f7a0729..bea2c795 100644 --- a/tests/swarms/multi_agent_collab.py +++ b/tests/swarms/multi_agent_collab.py @@ -115,7 +115,10 @@ def test_repr(collaboration): def test_load(collaboration): - state = {"step": 5, "results": [{"agent": "Agent1", "response": "Response1"}]} + state = { + "step": 5, + "results": [{"agent": "Agent1", "response": "Response1"}], + } with open(collaboration.saved_file_path_name, "w") as file: json.dump(state, file) @@ -140,7 +143,8 @@ def test_save(collaboration, tmp_path): # Example of parameterized test for different selection functions @pytest.mark.parametrize( - "selection_function", [select_next_speaker_director, select_speaker_round_table] + "selection_function", + [select_next_speaker_director, select_speaker_round_table], ) def test_selection_functions(collaboration, selection_function): collaboration.select_next_speaker = selection_function diff --git a/tests/swarms/multi_agent_debate.py b/tests/swarms/multi_agent_debate.py index a2687f54..25e15ae5 100644 --- a/tests/swarms/multi_agent_debate.py +++ b/tests/swarms/multi_agent_debate.py @@ -1,5 +1,9 @@ from unittest.mock import patch -from swarms.swarms.multi_agent_debate import MultiAgentDebate, Worker, select_speaker +from swarms.swarms.multi_agent_debate import ( + MultiAgentDebate, + Worker, + select_speaker, +) def test_multiagentdebate_initialization(): @@ -57,5 +61,6 @@ def test_multiagentdebate_format_results(): formatted_results = multiagentdebate.format_results(results) assert ( formatted_results - == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded: Goodbye, world!" + == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded:" + " Goodbye, world!" ) diff --git a/tests/tools/base.py b/tests/tools/base.py index 4f7e2b4b..007719b2 100644 --- a/tests/tools/base.py +++ b/tests/tools/base.py @@ -75,7 +75,9 @@ def test_tool_ainvoke_with_coroutine(): async def async_function(input_data): return input_data - tool = Tool(name="test_tool", coroutine=async_function, description="Test tool") + tool = Tool( + name="test_tool", coroutine=async_function, description="Test tool" + ) result = tool.ainvoke("input_data") assert result == "input_data" @@ -560,7 +562,9 @@ class TestTool: with pytest.raises(ValueError): tool(no_doc_func) - def test_tool_raises_error_runnable_without_object_schema(self, mock_runnable): + def test_tool_raises_error_runnable_without_object_schema( + self, mock_runnable + ): with pytest.raises(ValueError): tool(mock_runnable) diff --git a/tests/utils/subprocess_code_interpreter.py b/tests/utils/subprocess_code_interpreter.py index ab7c748f..c15c0e16 100644 --- a/tests/utils/subprocess_code_interpreter.py +++ b/tests/utils/subprocess_code_interpreter.py @@ -4,7 +4,10 @@ import time import pytest -from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter +from swarms.utils.code_interpreter import ( + BaseCodeInterpreter, + SubprocessCodeInterpreter, +) @pytest.fixture @@ -53,7 +56,9 @@ def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter): assert any("Hello, World!" in output.get("output", "") for output in result) -def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter): +def test_subprocess_code_interpreter_run_with_error( + subprocess_code_interpreter, +): code = 'print("Hello, World")\nraise ValueError("Error!")' result = list(subprocess_code_interpreter.run(code)) assert any("Error!" in output.get("output", "") for output in result) @@ -62,9 +67,14 @@ def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter) def test_subprocess_code_interpreter_run_with_keyboard_interrupt( subprocess_code_interpreter, ): - code = 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise KeyboardInterrupt' + code = ( + 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise' + " KeyboardInterrupt" + ) result = list(subprocess_code_interpreter.run(code)) - assert any("KeyboardInterrupt" in output.get("output", "") for output in result) + assert any( + "KeyboardInterrupt" in output.get("output", "") for output in result + ) def test_subprocess_code_interpreter_run_max_retries( @@ -78,7 +88,8 @@ def test_subprocess_code_interpreter_run_max_retries( code = 'print("Hello, World!")' result = list(subprocess_code_interpreter.run(code)) assert any( - "Maximum retries reached. Could not execute code." in output.get("output", "") + "Maximum retries reached. Could not execute code." + in output.get("output", "") for output in result ) @@ -112,19 +123,25 @@ def test_subprocess_code_interpreter_run_retry_on_error( # Import statements and fixtures from the previous code block -def test_subprocess_code_interpreter_line_postprocessor(subprocess_code_interpreter): +def test_subprocess_code_interpreter_line_postprocessor( + subprocess_code_interpreter, +): line = "This is a test line" processed_line = subprocess_code_interpreter.line_postprocessor(line) assert processed_line == line # No processing, should remain the same -def test_subprocess_code_interpreter_preprocess_code(subprocess_code_interpreter): +def test_subprocess_code_interpreter_preprocess_code( + subprocess_code_interpreter, +): code = 'print("Hello, World!")' preprocessed_code = subprocess_code_interpreter.preprocess_code(code) assert preprocessed_code == code # No preprocessing, should remain the same -def test_subprocess_code_interpreter_detect_active_line(subprocess_code_interpreter): +def test_subprocess_code_interpreter_detect_active_line( + subprocess_code_interpreter, +): line = "Active line: 5" active_line = subprocess_code_interpreter.detect_active_line(line) assert active_line == 5 @@ -172,7 +189,9 @@ def test_subprocess_code_interpreter_handle_stream_output_stdout( subprocess_code_interpreter, ): line = "This is a test line" - subprocess_code_interpreter.handle_stream_output(threading.current_thread(), False) + subprocess_code_interpreter.handle_stream_output( + threading.current_thread(), False + ) subprocess_code_interpreter.process.stdout.write(line + "\n") subprocess_code_interpreter.process.stdout.flush() time.sleep(0.1) @@ -184,7 +203,9 @@ def test_subprocess_code_interpreter_handle_stream_output_stderr( subprocess_code_interpreter, ): line = "This is an error line" - subprocess_code_interpreter.handle_stream_output(threading.current_thread(), True) + subprocess_code_interpreter.handle_stream_output( + threading.current_thread(), True + ) subprocess_code_interpreter.process.stderr.write(line + "\n") subprocess_code_interpreter.process.stderr.flush() time.sleep(0.1) @@ -207,12 +228,13 @@ def test_subprocess_code_interpreter_run_with_exception( subprocess_code_interpreter, capsys ): code = 'print("Hello, World!")' - subprocess_code_interpreter.start_cmd = ( - "nonexistent_command" # Force an exception during subprocess creation + subprocess_code_interpreter.start_cmd = ( # Force an exception during subprocess creation + "nonexistent_command" ) result = list(subprocess_code_interpreter.run(code)) assert any( - "Maximum retries reached" in output.get("output", "") for output in result + "Maximum retries reached" in output.get("output", "") + for output in result ) @@ -245,4 +267,6 @@ def test_subprocess_code_interpreter_run_with_unicode_characters( ): code = 'print("こんにちは、世界")' # Contains unicode characters result = list(subprocess_code_interpreter.run(code)) - assert any("こんにちは、世界" in output.get("output", "") for output in result) + assert any( + "こんにちは、世界" in output.get("output", "") for output in result + )