diff --git a/playground/agents/mm_agent_example.py b/playground/agents/mm_agent_example.py index 5326af6e..6cedcb29 100644 --- a/playground/agents/mm_agent_example.py +++ b/playground/agents/mm_agent_example.py @@ -4,14 +4,16 @@ load_dict = {"ImageCaptioning": "cuda"} node = MultiModalAgent(load_dict) -text = node.run_text("What is your name? Generate a picture of yourself") +text = node.run_text( + "What is your name? Generate a picture of yourself" +) img = node.run_img("/image1", "What is this image about?") chat = node.chat( ( - "What is your name? Generate a picture of yourself. What is this image" - " about?" + "What is your name? Generate a picture of yourself. What is" + " this image about?" ), streaming=True, ) diff --git a/playground/demos/accountant_team/account_team2.py b/playground/demos/accountant_team/account_team2.py index 1e4de29e..1b9d3659 100644 --- a/playground/demos/accountant_team/account_team2.py +++ b/playground/demos/accountant_team/account_team2.py @@ -53,7 +53,8 @@ decision_making_support_agent = Agent( pdf_path = "bankstatement.pdf" fraud_detection_instructions = "Detect fraud in the document" summary_agent_instructions = ( - "Generate an actionable summary of the document with action steps to take" + "Generate an actionable summary of the document with action steps" + " to take" ) decision_making_support_agent_instructions = ( "Provide decision making support to the business owner:" @@ -77,5 +78,6 @@ summary_agent_output = summary_generator_agent.run( # Provide decision making support to the accountant decision_making_support_agent_output = decision_making_support_agent.run( - f"{decision_making_support_agent_instructions}: {summary_agent_output}" + f"{decision_making_support_agent_instructions}:" + f" {summary_agent_output}" ) diff --git a/playground/demos/accountant_team/accountant_team.py b/playground/demos/accountant_team/accountant_team.py index 0c71156d..fccd628b 100644 --- a/playground/demos/accountant_team/accountant_team.py +++ b/playground/demos/accountant_team/accountant_team.py @@ -81,7 +81,9 @@ class AccountantSwarms: super().__init__() self.pdf_path = pdf_path self.list_pdfs = list_pdfs - self.fraud_detection_instructions = fraud_detection_instructions + self.fraud_detection_instructions = ( + fraud_detection_instructions + ) self.summary_agent_instructions = summary_agent_instructions self.decision_making_support_agent_instructions = ( decision_making_support_agent_instructions @@ -98,7 +100,8 @@ class AccountantSwarms: # Generate an actionable summary of the document summary_agent_output = summary_generator_agent.run( - f"{self.summary_agent_instructions}: {fraud_detection_agent_output}" + f"{self.summary_agent_instructions}:" + f" {fraud_detection_agent_output}" ) # Provide decision making support to the accountant @@ -113,7 +116,9 @@ class AccountantSwarms: swarm = AccountantSwarms( pdf_path="tesla.pdf", fraud_detection_instructions="Detect fraud in the document", - summary_agent_instructions="Generate an actionable summary of the document", + summary_agent_instructions=( + "Generate an actionable summary of the document" + ), decision_making_support_agent_instructions=( "Provide decision making support to the business owner:" ), diff --git a/playground/demos/ad_gen/ad_gen.py b/playground/demos/ad_gen/ad_gen.py index b9a555ab..3d16eb25 100644 --- a/playground/demos/ad_gen/ad_gen.py +++ b/playground/demos/ad_gen/ad_gen.py @@ -30,15 +30,34 @@ class ProductAdConceptGenerator: "in a luxurious setting", "in a playful and colorful background", "in an ice cave setting", "in a serene and calm landscape" ] +<<<<<<< HEAD +======= + self.contexts = [ + "high realism product ad (extremely creative)" + ] +>>>>>>> 831147e ([CODE QUALITY]) def generate_concept(self): theme = random.choice(self.themes) context = random.choice(self.contexts) +<<<<<<< HEAD return f"An ad for {self.product_name} that embodies a {theme} theme {context}" # User input product_name = input("Enter a product name for ad creation (e.g., 'PS5', 'AirPods', 'Kirkland Vodka'): ") social_media_platform = input("Enter a social media platform (e.g., 'Facebook', 'Twitter', 'Instagram'): ") +======= + return ( + f"{theme} inside a {style} {self.product_name}, {context}" + ) + + +# User input +product_name = input( + "Enter a product name for ad creation (e.g., 'PS5', 'AirPods'," + " 'Kirkland Vodka'): " +) +>>>>>>> 831147e ([CODE QUALITY]) # Generate creative concept concept_generator = ProductAdConceptGenerator(product_name) @@ -53,6 +72,16 @@ ad_copy_prompt = f"Write a compelling {social_media_platform} ad copy for a prod ad_copy = ad_copy_agent.run(task=ad_copy_prompt) # Output the results +<<<<<<< HEAD print("Creative Concept:", creative_concept) print("Image Path:", image_paths[0] if image_paths else "No image generated") print("Ad Copy:", ad_copy) +======= +print("Creative Concept:", concept_result) +print("Design Ideas:", design_result) +print("Ad Copy:", copywriting_result) +print( + "Image Path:", + image_paths[0] if image_paths else "No image generated", +) +>>>>>>> 831147e ([CODE QUALITY]) diff --git a/playground/demos/ai_research_team/main.py b/playground/demos/ai_research_team/main.py index c3f0ee24..bda9e0de 100644 --- a/playground/demos/ai_research_team/main.py +++ b/playground/demos/ai_research_team/main.py @@ -48,7 +48,9 @@ paper_implementor_agent = Agent( paper = pdf_to_text(PDF_PATH) algorithmic_psuedocode_agent = paper_summarizer_agent.run( - "Focus on creating the algorithmic pseudocode for the novel method in this" - f" paper: {paper}" + "Focus on creating the algorithmic pseudocode for the novel" + f" method in this paper: {paper}" +) +pytorch_code = paper_implementor_agent.run( + algorithmic_psuedocode_agent ) -pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent) diff --git a/playground/demos/assembly/assembly.py b/playground/demos/assembly/assembly.py index 203739b1..b82e075c 100644 --- a/playground/demos/assembly/assembly.py +++ b/playground/demos/assembly/assembly.py @@ -7,10 +7,10 @@ from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( llm = GPT4VisionAPI() task = ( - "Analyze this image of an assembly line and identify any issues such as" - " misaligned parts, defects, or deviations from the standard assembly" - " process. IF there is anything unsafe in the image, explain why it is" - " unsafe and how it could be improved." + "Analyze this image of an assembly line and identify any issues" + " such as misaligned parts, defects, or deviations from the" + " standard assembly process. IF there is anything unsafe in the" + " image, explain why it is unsafe and how it could be improved." ) img = "assembly_line.jpg" diff --git a/playground/demos/autobloggen.py b/playground/demos/autobloggen.py index 8f9e2ec2..09b02674 100644 --- a/playground/demos/autobloggen.py +++ b/playground/demos/autobloggen.py @@ -9,7 +9,8 @@ from swarms.prompts.autobloggen import ( # Prompts topic_selection_task = ( - "Generate 10 topics on gaining mental clarity using ancient practices" + "Generate 10 topics on gaining mental clarity using ancient" + " practices" ) @@ -54,7 +55,9 @@ class AutoBlogGenSwarm: ): self.llm = llm() self.topic_selection_task = topic_selection_task - self.topic_selection_agent_prompt = topic_selection_agent_prompt + self.topic_selection_agent_prompt = ( + topic_selection_agent_prompt + ) self.objective = objective self.iterations = iterations self.max_retries = max_retries @@ -90,7 +93,9 @@ class AutoBlogGenSwarm: def step(self): """Steps through the task""" - topic_selection_agent = self.llm(self.topic_selection_agent_prompt) + topic_selection_agent = self.llm( + self.topic_selection_agent_prompt + ) topic_selection_agent = self.print_beautifully( "Topic Selection Agent", topic_selection_agent ) @@ -100,7 +105,9 @@ class AutoBlogGenSwarm: # Agent that reviews the draft review_agent = self.llm(self.get_review_prompt(draft_blog)) - review_agent = self.print_beautifully("Review Agent", review_agent) + review_agent = self.print_beautifully( + "Review Agent", review_agent + ) # Agent that publishes on social media distribution_agent = self.llm( @@ -119,7 +126,11 @@ class AutoBlogGenSwarm: except Exception as error: print( colored( - f"Error while running AutoBlogGenSwarm {error}", "red" + ( + "Error while running AutoBlogGenSwarm" + f" {error}" + ), + "red", ) ) if attempt == self.retry_attempts - 1: diff --git a/playground/demos/autotemp/autotemp.py b/playground/demos/autotemp/autotemp.py index b136bad7..baf8f091 100644 --- a/playground/demos/autotemp/autotemp.py +++ b/playground/demos/autotemp/autotemp.py @@ -47,7 +47,11 @@ class AutoTemp: """ score_text = self.llm(eval_prompt, temperature=0.5) score_match = re.search(r"\b\d+(\.\d)?\b", score_text) - return round(float(score_match.group()), 1) if score_match else 0.0 + return ( + round(float(score_match.group()), 1) + if score_match + else 0.0 + ) def run(self, prompt, temperature_string): print("Starting generation process...") diff --git a/playground/demos/autotemp/blog_gen.py b/playground/demos/autotemp/blog_gen.py index 85079f70..e11a1521 100644 --- a/playground/demos/autotemp/blog_gen.py +++ b/playground/demos/autotemp/blog_gen.py @@ -12,7 +12,9 @@ class BlogGen: blog_topic, temperature_range: str = "0.4,0.6,0.8,1.0,1.2", ): # Add blog_topic as an argument - self.openai_chat = OpenAIChat(openai_api_key=api_key, temperature=0.8) + self.openai_chat = OpenAIChat( + openai_api_key=api_key, temperature=0.8 + ) self.auto_temp = AutoTemp(api_key) self.temperature_range = temperature_range self.workflow = SequentialWorkflow(max_loops=5) @@ -52,11 +54,15 @@ class BlogGen: ) chosen_topic = topic_output.split("\n")[0] - print(colored("Selected topic: " + chosen_topic, "yellow")) + print( + colored("Selected topic: " + chosen_topic, "yellow") + ) # Initial draft generation with AutoTemp - initial_draft_prompt = self.DRAFT_WRITER_SYSTEM_PROMPT.replace( - "{{CHOSEN_TOPIC}}", chosen_topic + initial_draft_prompt = ( + self.DRAFT_WRITER_SYSTEM_PROMPT.replace( + "{{CHOSEN_TOPIC}}", chosen_topic + ) ) auto_temp_output = self.auto_temp.run( initial_draft_prompt, self.temperature_range @@ -89,13 +95,17 @@ class BlogGen: ) # Distribution preparation using OpenAIChat - distribution_prompt = self.DISTRIBUTION_AGENT_SYSTEM_PROMPT.replace( - "{{ARTICLE_TOPIC}}", chosen_topic + distribution_prompt = ( + self.DISTRIBUTION_AGENT_SYSTEM_PROMPT.replace( + "{{ARTICLE_TOPIC}}", chosen_topic + ) ) distribution_result = self.openai_chat.generate( [distribution_prompt] ) - distribution_output = distribution_result.generations[0][0].text + distribution_output = distribution_result.generations[0][ + 0 + ].text print( colored( ( diff --git a/playground/demos/autotemp/blog_gen_example.py b/playground/demos/autotemp/blog_gen_example.py index 2c2f1e24..e7109b5a 100644 --- a/playground/demos/autotemp/blog_gen_example.py +++ b/playground/demos/autotemp/blog_gen_example.py @@ -5,7 +5,9 @@ from blog_gen import BlogGen def main(): api_key = os.getenv("OPENAI_API_KEY") if not api_key: - raise ValueError("OPENAI_API_KEY environment variable not set.") + raise ValueError( + "OPENAI_API_KEY environment variable not set." + ) blog_topic = input("Enter the topic for the blog generation: ") diff --git a/playground/demos/developer_swarm/main.py b/playground/demos/developer_swarm/main.py index 54170985..18c0a346 100644 --- a/playground/demos/developer_swarm/main.py +++ b/playground/demos/developer_swarm/main.py @@ -37,12 +37,18 @@ llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000) # Documentation agent documentation_agent = Agent( - llm=llm, sop=DOCUMENTATION_SOP, max_loops=1, + llm=llm, + sop=DOCUMENTATION_SOP, + max_loops=1, ) # Tests agent -tests_agent = Agent(llm=llm, sop=TEST_SOP, max_loops=2,) +tests_agent = Agent( + llm=llm, + sop=TEST_SOP, + max_loops=2, +) # Run the documentation agent @@ -52,6 +58,6 @@ documentation = documentation_agent.run( # Run the tests agent tests = tests_agent.run( - f"Write tests for the following code:{TASK} here is the documentation:" - f" {documentation}" + f"Write tests for the following code:{TASK} here is the" + f" documentation: {documentation}" ) diff --git a/playground/demos/nutrition/nutrition.py b/playground/demos/nutrition/nutrition.py index 51703cfc..aca079ba 100644 --- a/playground/demos/nutrition/nutrition.py +++ b/playground/demos/nutrition/nutrition.py @@ -12,14 +12,15 @@ openai_api_key = os.getenv("OPENAI_API_KEY") # Define prompts for various tasks MEAL_PLAN_PROMPT = ( "Based on the following user preferences: dietary restrictions as" - " vegetarian, preferred cuisines as Italian and Indian, a total caloric" - " intake of around 2000 calories per day, and an exclusion of legumes," - " create a detailed weekly meal plan. Include a variety of meals for" - " breakfast, lunch, dinner, and optional snacks." + " vegetarian, preferred cuisines as Italian and Indian, a total" + " caloric intake of around 2000 calories per day, and an" + " exclusion of legumes, create a detailed weekly meal plan." + " Include a variety of meals for breakfast, lunch, dinner, and" + " optional snacks." ) IMAGE_ANALYSIS_PROMPT = ( - "Identify the items in this fridge, including their quantities and" - " condition." + "Identify the items in this fridge, including their quantities" + " and condition." ) @@ -74,12 +75,15 @@ def generate_integrated_shopping_list( meal_plan_output, image_analysis, user_preferences ): # Prepare the prompt for the LLM - fridge_contents = image_analysis["choices"][0]["message"]["content"] + fridge_contents = image_analysis["choices"][0]["message"][ + "content" + ] prompt = ( - f"Based on this meal plan: {meal_plan_output}, and the following items" - f" in the fridge: {fridge_contents}, considering dietary preferences as" - " vegetarian with a preference for Italian and Indian cuisines," - " generate a comprehensive shopping list that includes only the items" + f"Based on this meal plan: {meal_plan_output}, and the" + f" following items in the fridge: {fridge_contents}," + " considering dietary preferences as vegetarian with a" + " preference for Italian and Indian cuisines, generate a" + " comprehensive shopping list that includes only the items" " needed." ) @@ -124,6 +128,10 @@ print("Integrated Shopping List:", integrated_shopping_list) with open("nutrition_output.txt", "w") as file: file.write("Meal Plan:\n" + meal_plan_output + "\n\n") - file.write("Integrated Shopping List:\n" + integrated_shopping_list + "\n") + file.write( + "Integrated Shopping List:\n" + + integrated_shopping_list + + "\n" + ) print("Outputs have been saved to nutrition_output.txt") diff --git a/playground/demos/positive_med/positive_med.py b/playground/demos/positive_med/positive_med.py index 3c9658cf..b92b9586 100644 --- a/playground/demos/positive_med/positive_med.py +++ b/playground/demos/positive_med/positive_med.py @@ -41,7 +41,9 @@ def get_review_prompt(article): return prompt -def social_media_prompt(article: str, goal: str = "Clicks and engagement"): +def social_media_prompt( + article: str, goal: str = "Clicks and engagement" +): prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace( "{{ARTICLE}}", article ).replace("{{GOAL}}", goal) @@ -50,11 +52,12 @@ def social_media_prompt(article: str, goal: str = "Clicks and engagement"): # Agent that generates topics topic_selection_task = ( - "Generate 10 topics on gaining mental clarity using ancient practices" + "Generate 10 topics on gaining mental clarity using ancient" + " practices" ) topics = llm( - f"Your System Instructions: {TOPIC_GENERATOR_SYSTEM_PROMPT}, Your current" - f" task: {topic_selection_task}" + f"Your System Instructions: {TOPIC_GENERATOR_SYSTEM_PROMPT}, Your" + f" current task: {topic_selection_task}" ) dashboard = print( diff --git a/playground/demos/swarm_of_mma_manufacturing/main.py b/playground/demos/swarm_of_mma_manufacturing/main.py index e868f5a5..802647b5 100644 --- a/playground/demos/swarm_of_mma_manufacturing/main.py +++ b/playground/demos/swarm_of_mma_manufacturing/main.py @@ -24,8 +24,12 @@ api_key = os.getenv("OPENAI_API_KEY") llm = GPT4VisionAPI(openai_api_key=api_key) -assembly_line = "playground/demos/swarm_of_mma_manufacturing/assembly_line.jpg" -red_robots = "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg" +assembly_line = ( + "playground/demos/swarm_of_mma_manufacturing/assembly_line.jpg" +) +red_robots = ( + "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg" +) robots = "playground/demos/swarm_of_mma_manufacturing/robots.jpg" tesla_assembly_line = ( "playground/demos/swarm_of_mma_manufacturing/tesla_assembly.jpg" @@ -35,29 +39,31 @@ tesla_assembly_line = ( # Define detailed prompts for each agent tasks = { "health_safety": ( - "Analyze the factory's working environment for health safety. Focus on" - " cleanliness, ventilation, spacing between workstations, and personal" - " protective equipment availability." + "Analyze the factory's working environment for health safety." + " Focus on cleanliness, ventilation, spacing between" + " workstations, and personal protective equipment" + " availability." ), "productivity": ( - "Review the factory's workflow efficiency, machine utilization, and" - " employee engagement. Identify operational delays or bottlenecks." + "Review the factory's workflow efficiency, machine" + " utilization, and employee engagement. Identify operational" + " delays or bottlenecks." ), "safety": ( - "Analyze the factory's safety measures, including fire exits, safety" - " signage, and emergency response equipment." + "Analyze the factory's safety measures, including fire exits," + " safety signage, and emergency response equipment." ), "security": ( - "Evaluate the factory's security systems, entry/exit controls, and" - " potential vulnerabilities." + "Evaluate the factory's security systems, entry/exit" + " controls, and potential vulnerabilities." ), "sustainability": ( - "Inspect the factory's sustainability practices, including waste" - " management, energy usage, and eco-friendly processes." + "Inspect the factory's sustainability practices, including" + " waste management, energy usage, and eco-friendly processes." ), "efficiency": ( - "Assess the manufacturing process's efficiency, considering the layout," - " logistics, and automation level." + "Assess the manufacturing process's efficiency, considering" + " the layout, logistics, and automation level." ), } @@ -73,7 +79,10 @@ efficiency_prompt = tasks["efficiency"] # Health security agent health_security_agent = Agent( - llm=llm, sop_list=health_safety_prompt, max_loops=2, multi_modal=True + llm=llm, + sop_list=health_safety_prompt, + max_loops=2, + multi_modal=True, ) # Quality control agent @@ -98,10 +107,14 @@ health_check = health_security_agent.run( ) # Add the third task to the productivity_check_agent -productivity_check = productivity_check_agent.run(health_check, assembly_line) +productivity_check = productivity_check_agent.run( + health_check, assembly_line +) # Add the fourth task to the security_check_agent -security_check = security_check_agent.add(productivity_check, red_robots) +security_check = security_check_agent.add( + productivity_check, red_robots +) # Add the fifth task to the efficiency_check_agent efficiency_check = efficiency_check_agent.run( diff --git a/playground/memory/qdrant/usage.py b/playground/memory/qdrant/usage.py index e2739b33..2b7c4a8e 100644 --- a/playground/memory/qdrant/usage.py +++ b/playground/memory/qdrant/usage.py @@ -2,7 +2,8 @@ from langchain.document_loaders import CSVLoader from swarms.memory import qdrant loader = CSVLoader( - file_path="../document_parsing/aipg/aipg.csv", encoding="utf-8-sig" + file_path="../document_parsing/aipg/aipg.csv", + encoding="utf-8-sig", ) docs = loader.load() diff --git a/playground/models/bingchat.py b/playground/models/bingchat.py index bf06ecc6..2af8472c 100644 --- a/playground/models/bingchat.py +++ b/playground/models/bingchat.py @@ -24,7 +24,9 @@ llm = OpenAIChat( ) # Initialize the Worker with the custom tool -worker = Worker(llm=llm, ai_name="EdgeGPT Worker", external_tools=[edgegpt]) +worker = Worker( + llm=llm, ai_name="EdgeGPT Worker", external_tools=[edgegpt] +) # Use the worker to process a task task = "Hello, my name is ChatGPT" diff --git a/playground/models/bioclip.py b/playground/models/bioclip.py index 11fb9f27..307cf798 100644 --- a/playground/models/bioclip.py +++ b/playground/models/bioclip.py @@ -17,5 +17,8 @@ labels = [ ] result = clip("swarms.jpeg", labels) -metadata = {"filename": "images/.jpg".split("/")[-1], "top_probs": result} +metadata = { + "filename": "images/.jpg".split("/")[-1], + "top_probs": result, +} clip.plot_image_with_metadata("swarms.jpeg", metadata) diff --git a/playground/models/distilled_whiserpx.py b/playground/models/distilled_whiserpx.py index 71e1d5ef..0742a1bc 100644 --- a/playground/models/distilled_whiserpx.py +++ b/playground/models/distilled_whiserpx.py @@ -7,4 +7,6 @@ model_wrapper = DistilWhisperModel() transcription = model_wrapper("path/to/audio.mp3") # For async usage -transcription = asyncio.run(model_wrapper.async_transcribe("path/to/audio.mp3")) +transcription = asyncio.run( + model_wrapper.async_transcribe("path/to/audio.mp3") +) diff --git a/playground/models/mpt.py b/playground/models/mpt.py index bdba8754..8ffa30db 100644 --- a/playground/models/mpt.py +++ b/playground/models/mpt.py @@ -1,7 +1,9 @@ from swarms.models.mpt import MPT mpt_instance = MPT( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, ) mpt_instance.generate("Once upon a time in a land far, far away...") diff --git a/playground/structs/flow_tools.py b/playground/structs/flow_tools.py index b51d18ea..5ec51f59 100644 --- a/playground/structs/flow_tools.py +++ b/playground/structs/flow_tools.py @@ -31,7 +31,9 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) chunks = ( - phrase.strip() for line in lines for phrase in line.split(" ") + phrase.strip() + for line in lines + for phrase in line.split(" ") ) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: @@ -60,6 +62,6 @@ agent = Agent( ) out = agent.run( - "Generate a 10,000 word blog on mental clarity and the benefits of" - " meditation." + "Generate a 10,000 word blog on mental clarity and the benefits" + " of meditation." ) diff --git a/playground/structs/sequential_workflow.py b/playground/structs/sequential_workflow.py index a5ee9edb..fa7ca16a 100644 --- a/playground/structs/sequential_workflow.py +++ b/playground/structs/sequential_workflow.py @@ -18,7 +18,9 @@ flow2 = Agent(llm=llm, max_loops=1, dashboard=False) workflow = SequentialWorkflow(max_loops=1) # Add tasks to the workflow -workflow.add("Generate a 10,000 word blog on health and wellness.", flow1) +workflow.add( + "Generate a 10,000 word blog on health and wellness.", flow1 +) # Suppose the next task takes the output of the first task as input workflow.add("Summarize the generated blog", flow2) diff --git a/playground/swarms/chat.py b/playground/swarms/chat.py index b0ebc39a..08783068 100644 --- a/playground/swarms/chat.py +++ b/playground/swarms/chat.py @@ -1,7 +1,11 @@ from swarms import Orchestrator, Worker # Instantiate the Orchestrator with 10 agents -orchestrator = Orchestrator(Worker, agent_list=[Worker] * 10, task_queue=[]) +orchestrator = Orchestrator( + Worker, agent_list=[Worker] * 10, task_queue=[] +) # Agent 1 sends a message to Agent 2 -orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!") +orchestrator.chat( + sender_id=1, receiver_id=2, message="Hello, Agent 2!" +) diff --git a/playground/swarms/debate.py b/playground/swarms/debate.py index 4c97817d..5108d527 100644 --- a/playground/swarms/debate.py +++ b/playground/swarms/debate.py @@ -37,7 +37,9 @@ class DialogueAgent: [ self.system_message, HumanMessage( - content="\n".join(self.message_history + [self.prefix]) + content="\n".join( + self.message_history + [self.prefix] + ) ), ] ) @@ -76,7 +78,9 @@ class DialogueSimulator: def step(self) -> tuple[str, str]: # 1. choose the next speaker - speaker_idx = self.select_next_speaker(self._step, self.agents) + speaker_idx = self.select_next_speaker( + self._step, self.agents + ) speaker = self.agents[speaker_idx] # 2. next speaker sends message @@ -114,7 +118,9 @@ class BiddingDialogueAgent(DialogueAgent): message_history="\n".join(self.message_history), recent_message=self.message_history[-1], ) - bid_string = self.model([SystemMessage(content=prompt)]).content + bid_string = self.model( + [SystemMessage(content=prompt)] + ).content return bid_string @@ -127,7 +133,8 @@ The presidential candidates are: {', '.join(character_names)}.""" player_descriptor_system_message = SystemMessage( content=( - "You can add detail to the description of each presidential candidate." + "You can add detail to the description of each presidential" + " candidate." ) ) @@ -156,7 +163,9 @@ Your goal is to be as creative as possible and make the voters think you are the """ -def generate_character_system_message(character_name, character_header): +def generate_character_system_message( + character_name, character_header +): return SystemMessage(content=f"""{character_header} You will speak in the style of {character_name}, and exaggerate their personality. You will come up with creative ideas related to {topic}. @@ -183,7 +192,9 @@ character_headers = [ ) ] character_system_messages = [ - generate_character_system_message(character_name, character_headers) + generate_character_system_message( + character_name, character_headers + ) for character_name, character_headers in zip( character_names, character_headers ) @@ -209,8 +220,8 @@ for ( class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: return ( - "Your response should be an integer delimited by angled brackets," - " like this: ." + "Your response should be an integer delimited by angled" + " brackets, like this: ." ) @@ -262,7 +273,9 @@ topic_specifier_prompt = [ Speak directly to the presidential candidates: {*character_names,}. Do not add anything else."""), ] -specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content +specified_topic = ChatOpenAI(temperature=1.0)( + topic_specifier_prompt +).content print(f"Original topic:\n{topic}\n") print(f"Detailed topic:\n{specified_topic}\n") @@ -273,7 +286,8 @@ print(f"Detailed topic:\n{specified_topic}\n") wait=tenacity.wait_none(), # No waiting time between retries retry=tenacity.retry_if_exception_type(ValueError), before_sleep=lambda retry_state: print( - f"ValueError occurred: {retry_state.outcome.exception()}, retrying..." + f"ValueError occurred: {retry_state.outcome.exception()}," + " retrying..." ), retry_error_callback=lambda retry_state: 0, ) # Default value when all retries are exhausted @@ -286,7 +300,9 @@ def ask_for_bid(agent) -> str: return bid -def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: +def select_next_speaker( + step: int, agents: List[DialogueAgent] +) -> int: bids = [] for agent in agents: bid = ask_for_bid(agent) @@ -309,7 +325,9 @@ def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: characters = [] for character_name, character_system_message, bidding_template in zip( - character_names, character_system_messages, character_bidding_templates + character_names, + character_system_messages, + character_bidding_templates, ): characters.append( BiddingDialogueAgent( diff --git a/playground/swarms/dialogue_simulator.py b/playground/swarms/dialogue_simulator.py index 76f31f65..ee9241b6 100644 --- a/playground/swarms/dialogue_simulator.py +++ b/playground/swarms/dialogue_simulator.py @@ -2,7 +2,9 @@ from swarms.swarms import DialogueSimulator from swarms.workers.worker import Worker from swarms.models import OpenAIChat -llm = OpenAIChat(model_name="gpt-4", openai_api_key="api-key", temperature=0.5) +llm = OpenAIChat( + model_name="gpt-4", openai_api_key="api-key", temperature=0.5 +) worker1 = Worker( llm=llm, diff --git a/playground/swarms/groupchat.py b/playground/swarms/groupchat.py index f47bc18b..f53257c7 100644 --- a/playground/swarms/groupchat.py +++ b/playground/swarms/groupchat.py @@ -45,5 +45,7 @@ manager = Agent( agents = [flow1, flow2, flow3] group_chat = GroupChat(agents=agents, messages=[], max_round=10) -chat_manager = GroupChatManager(groupchat=group_chat, selector=manager) +chat_manager = GroupChatManager( + groupchat=group_chat, selector=manager +) chat_history = chat_manager("Write me a riddle") diff --git a/playground/swarms/multi_agent_debate.py b/playground/swarms/multi_agent_debate.py index d5382e56..6124a21c 100644 --- a/playground/swarms/multi_agent_debate.py +++ b/playground/swarms/multi_agent_debate.py @@ -1,4 +1,7 @@ -from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker +from swarms.swarms.multi_agent_debate import ( + MultiAgentDebate, + select_speaker, +) from swarms.workers.worker import Worker from swarms.models import OpenAIChat @@ -37,9 +40,9 @@ debate = MultiAgentDebate(agents, select_speaker) # Run task task = ( - "What were the winning boston marathon times for the past 5 years (ending" - " in 2022)? Generate a table of the year, name, country of origin, and" - " times." + "What were the winning boston marathon times for the past 5 years" + " (ending in 2022)? Generate a table of the year, name, country" + " of origin, and times." ) results = debate.run(task, max_iters=4) diff --git a/playground/swarms/orchestrate.py b/playground/swarms/orchestrate.py index a90a72e8..b0e17588 100644 --- a/playground/swarms/orchestrate.py +++ b/playground/swarms/orchestrate.py @@ -7,9 +7,13 @@ node = Worker( # Instantiate the Orchestrator with 10 agents -orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) +orchestrator = Orchestrator( + node, agent_list=[node] * 10, task_queue=[] +) # Agent 7 sends a message to Agent 9 orchestrator.chat( - sender_id=7, receiver_id=9, message="Can you help me with this task?" + sender_id=7, + receiver_id=9, + message="Can you help me with this task?", ) diff --git a/playground/swarms/orchestrator.py b/playground/swarms/orchestrator.py index a90a72e8..b0e17588 100644 --- a/playground/swarms/orchestrator.py +++ b/playground/swarms/orchestrator.py @@ -7,9 +7,13 @@ node = Worker( # Instantiate the Orchestrator with 10 agents -orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) +orchestrator = Orchestrator( + node, agent_list=[node] * 10, task_queue=[] +) # Agent 7 sends a message to Agent 9 orchestrator.chat( - sender_id=7, receiver_id=9, message="Can you help me with this task?" + sender_id=7, + receiver_id=9, + message="Can you help me with this task?", ) diff --git a/playground/swarms/swarms_example.py b/playground/swarms/swarms_example.py index 23b714d9..9f015807 100644 --- a/playground/swarms/swarms_example.py +++ b/playground/swarms/swarms_example.py @@ -8,8 +8,8 @@ swarm = HierarchicalSwarm(api_key) # Define an objective objective = ( - "Find 20 potential customers for a HierarchicalSwarm based AI Agent" - " automation infrastructure" + "Find 20 potential customers for a HierarchicalSwarm based AI" + " Agent automation infrastructure" ) # Run HierarchicalSwarm diff --git a/pyproject.toml b/pyproject.toml index 351442f9..78e7e1a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,17 +77,17 @@ mypy-protobuf = "^3.0.0" [tool.autopep8] -max_line_length = 80 +max_line_length = 70 ignore = "E501,W6" # or ["E501", "W6"] in-place = true recursive = true aggressive = 3 [tool.ruff] -line-length = 80 +line-length = 70 [tool.black] -line-length = 80 +line-length = 70 target-version = ['py38'] preview = true diff --git a/sequential_workflow_example.py b/sequential_workflow_example.py index 9d52b3c5..beefbfb1 100644 --- a/sequential_workflow_example.py +++ b/sequential_workflow_example.py @@ -33,12 +33,16 @@ agent3 = Agent(llm=biochat, max_loops=1, dashboard=False) workflow = SequentialWorkflow(max_loops=1) # Add tasks to the workflow -workflow.add("Generate a 10,000 word blog on health and wellness.", agent1) +workflow.add( + "Generate a 10,000 word blog on health and wellness.", agent1 +) # Suppose the next task takes the output of the first task as input workflow.add("Summarize the generated blog", agent2) -workflow.add("Create a references sheet of materials for the curriculm", agent3) +workflow.add( + "Create a references sheet of materials for the curriculm", agent3 +) # Run the workflow workflow.run() diff --git a/swarms/memory/base.py b/swarms/memory/base.py index f28da852..c0e8e4b6 100644 --- a/swarms/memory/base.py +++ b/swarms/memory/base.py @@ -28,7 +28,8 @@ class BaseVectorStore(ABC): embedding_driver: Any futures_executor: futures.Executor = field( - default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True + default=Factory(lambda: futures.ThreadPoolExecutor()), + kw_only=True, ) def upsert_text_artifacts( @@ -40,7 +41,11 @@ class BaseVectorStore(ABC): execute_futures_dict( { namespace: self.futures_executor.submit( - self.upsert_text_artifact, a, namespace, meta, **kwargs + self.upsert_text_artifact, + a, + namespace, + meta, + **kwargs, ) for namespace, artifact_list in artifacts.items() for a in artifact_list @@ -62,7 +67,9 @@ class BaseVectorStore(ABC): if artifact.embedding: vector = artifact.embedding else: - vector = artifact.generate_embedding(self.embedding_driver) + vector = artifact.generate_embedding( + self.embedding_driver + ) return self.upsert_vector( vector, @@ -106,7 +113,9 @@ class BaseVectorStore(ABC): ... @abstractmethod - def load_entries(self, namespace: Optional[str] = None) -> list[Entry]: + def load_entries( + self, namespace: Optional[str] = None + ) -> list[Entry]: ... @abstractmethod diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index 2f4e473f..79b92964 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -35,11 +35,18 @@ def _results_to_docs(results: Any) -> List[Document]: return [doc for doc, _ in _results_to_docs_and_scores(results)] -def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: +def _results_to_docs_and_scores( + results: Any, +) -> List[Tuple[Document, float]]: return [ # TODO: Chroma can do batch querying, # we shouldn't hard code to the 1st result - (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + ( + Document( + page_content=result[0], metadata=result[1] or {} + ), + result[2], + ) for result in zip( results["documents"][0], results["metadatas"][0], @@ -94,13 +101,16 @@ class Chroma(VectorStore): # If client_settings is provided with persist_directory specified, # then it is "in-memory and persisting to disk" mode. client_settings.persist_directory = ( - persist_directory or client_settings.persist_directory + persist_directory + or client_settings.persist_directory ) if client_settings.persist_directory is not None: # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: - client_settings.chroma_db_impl = "duckdb+parquet" + client_settings.chroma_db_impl = ( + "duckdb+parquet" + ) _client_settings = client_settings elif persist_directory: @@ -120,7 +130,8 @@ class Chroma(VectorStore): self._client_settings = _client_settings self._client = chromadb.Client(_client_settings) self._persist_directory = ( - _client_settings.persist_directory or persist_directory + _client_settings.persist_directory + or persist_directory ) self._embedding_function = embedding_function @@ -189,7 +200,9 @@ class Chroma(VectorStore): embeddings = None texts = list(texts) if self._embedding_function is not None: - embeddings = self._embedding_function.embed_documents(texts) + embeddings = self._embedding_function.embed_documents( + texts + ) if metadatas: # fill metadatas with empty dicts if somebody # did not specify metadata for all texts @@ -205,13 +218,17 @@ class Chroma(VectorStore): empty_ids.append(idx) if non_empty_ids: metadatas = [metadatas[idx] for idx in non_empty_ids] - texts_with_metadatas = [texts[idx] for idx in non_empty_ids] + texts_with_metadatas = [ + texts[idx] for idx in non_empty_ids + ] embeddings_with_metadatas = ( [embeddings[idx] for idx in non_empty_ids] if embeddings else None ) - ids_with_metadata = [ids[idx] for idx in non_empty_ids] + ids_with_metadata = [ + ids[idx] for idx in non_empty_ids + ] try: self._collection.upsert( metadatas=metadatas, @@ -222,7 +239,8 @@ class Chroma(VectorStore): except ValueError as e: if "Expected metadata value to be" in str(e): msg = ( - "Try filtering complex metadata from the document" + "Try filtering complex metadata from the" + " document" " using " "langchain.vectorstores.utils.filter_complex_metadata." ) @@ -230,9 +248,13 @@ class Chroma(VectorStore): else: raise e if empty_ids: - texts_without_metadatas = [texts[j] for j in empty_ids] + texts_without_metadatas = [ + texts[j] for j in empty_ids + ] embeddings_without_metadatas = ( - [embeddings[j] for j in empty_ids] if embeddings else None + [embeddings[j] for j in empty_ids] + if embeddings + else None ) ids_without_metadatas = [ids[j] for j in empty_ids] self._collection.upsert( @@ -351,7 +373,9 @@ class Chroma(VectorStore): where_document=where_document, ) else: - query_embedding = self._embedding_function.embed_query(query) + query_embedding = self._embedding_function.embed_query( + query + ) results = self.__query_collection( query_embeddings=[query_embedding], n_results=k, @@ -388,9 +412,9 @@ class Chroma(VectorStore): return self._max_inner_product_relevance_score_fn else: raise ValueError( - "No supported normalization function" - f" for distance metric of type: {distance}." - "Consider providing relevance_score_fn to Chroma constructor." + "No supported normalization function for distance" + f" metric of type: {distance}.Consider providing" + " relevance_score_fn to Chroma constructor." ) def max_marginal_relevance_search_by_vector( @@ -426,7 +450,12 @@ class Chroma(VectorStore): n_results=fetch_k, where=filter, where_document=where_document, - include=["metadatas", "documents", "distances", "embeddings"], + include=[ + "metadatas", + "documents", + "distances", + "embeddings", + ], ) mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), @@ -471,8 +500,8 @@ class Chroma(VectorStore): """ if self._embedding_function is None: raise ValueError( - "For MMR search, you must specify an embedding function" - " oncreation." + "For MMR search, you must specify an embedding" + " function oncreation." ) embedding = self._embedding_function.embed_query(query) @@ -546,7 +575,9 @@ class Chroma(VectorStore): if int(major) == 0 and int(minor) < 4: self._client.persist() - def update_document(self, document_id: str, document: Document) -> None: + def update_document( + self, document_id: str, document: Document + ) -> None: """Update a document in the collection. Args: @@ -568,8 +599,8 @@ class Chroma(VectorStore): metadata = [document.metadata for document in documents] if self._embedding_function is None: raise ValueError( - "For update, you must specify an embedding function on" - " creation." + "For update, you must specify an embedding function" + " on creation." ) embeddings = self._embedding_function.embed_documents(text) @@ -711,7 +742,9 @@ class Chroma(VectorStore): **kwargs, ) - def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: + def delete( + self, ids: Optional[List[str]] = None, **kwargs: Any + ) -> None: """Delete by vector IDs. Args: diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index cdcd1a2b..6e7b1df3 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -18,8 +18,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: Y = np.array(Y) if X.shape[1] != Y.shape[1]: raise ValueError( - "Number of columns in X and Y must be the same. X has shape" - f" {X.shape} and Y has shape {Y.shape}." + "Number of columns in X and Y must be the same. X has" + f" shape {X.shape} and Y has shape {Y.shape}." ) try: import simsimd as simd @@ -32,9 +32,9 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: return Z except ImportError: logger.info( - "Unable to import simsimd, defaulting to NumPy implementation. If" - " you want to use simsimd please install with `pip install" - " simsimd`." + "Unable to import simsimd, defaulting to NumPy" + " implementation. If you want to use simsimd please" + " install with `pip install simsimd`." ) X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) @@ -68,9 +68,15 @@ def cosine_similarity_top_k( score_array = cosine_similarity(X, Y) score_threshold = score_threshold or -1.0 score_array[score_array < score_threshold] = 0 - top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) - top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] - top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] + top_k = min( + top_k or len(score_array), np.count_nonzero(score_array) + ) + top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[ + -top_k: + ] + top_k_idxs = top_k_idxs[ + np.argsort(score_array.ravel()[top_k_idxs]) + ][::-1] ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) scores = score_array.ravel()[top_k_idxs].tolist() return list(zip(*ret_idxs)), scores # type: ignore diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index ce591c6e..334ccf70 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -84,7 +84,9 @@ class PgVectorVectorStore(BaseVectorStore): """ - connection_string: Optional[str] = field(default=None, kw_only=True) + connection_string: Optional[str] = field( + default=None, kw_only=True + ) create_engine_params: dict = field(factory=dict, kw_only=True) engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) @@ -104,12 +106,14 @@ class PgVectorVectorStore(BaseVectorStore): # If an engine is not provided, a connection string is required. if connection_string is None: - raise ValueError("An engine or connection string is required") + raise ValueError( + "An engine or connection string is required" + ) if not connection_string.startswith("postgresql://"): raise ValueError( - "The connection string must describe a Postgres database" - " connection" + "The connection string must describe a Postgres" + " database connection" ) @engine.validator @@ -120,7 +124,9 @@ class PgVectorVectorStore(BaseVectorStore): # If a connection string is not provided, an engine is required. if engine is None: - raise ValueError("An engine or connection string is required") + raise ValueError( + "An engine or connection string is required" + ) def __attrs_post_init__(self) -> None: """If a an engine is provided, it will be used to connect to the database. @@ -139,10 +145,14 @@ class PgVectorVectorStore(BaseVectorStore): ) -> None: """Provides a mechanism to initialize the database schema and extensions.""" if install_uuid_extension: - self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + self.engine.execute( + 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";' + ) if install_vector_extension: - self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";') + self.engine.execute( + 'CREATE EXTENSION IF NOT EXISTS "vector";' + ) if create_schema: self._model.metadata.create_all(self.engine) @@ -246,7 +256,9 @@ class PgVectorVectorStore(BaseVectorStore): return [ BaseVectorStore.QueryResult( id=str(result[0].id), - vector=result[0].vector if include_vectors else None, + vector=( + result[0].vector if include_vectors else None + ), score=result[1], meta=result[0].meta, namespace=result[0].namespace, diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 9065d661..308273d9 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -111,7 +111,9 @@ class PineconeVectorStoreStore(BaseVectorStore): **kwargs, ) -> str: """Upsert vector""" - vector_id = vector_id if vector_id else str_to_hash(str(vector)) + vector_id = ( + vector_id if vector_id else str_to_hash(str(vector)) + ) params = {"namespace": namespace} | kwargs @@ -179,7 +181,11 @@ class PineconeVectorStoreStore(BaseVectorStore): vector = self.embedding_driver.embed_string(query) params = { - "top_k": count if count else BaseVectorStore.DEFAULT_QUERY_COUNT, + "top_k": ( + count + if count + else BaseVectorStore.DEFAULT_QUERY_COUNT + ), "namespace": namespace, "include_values": include_vectors, "include_metadata": include_metadata, diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index acac82f5..76a5785b 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -2,7 +2,11 @@ from typing import List from sentence_transformers import SentenceTransformer from httpx import RequestError from qdrant_client import QdrantClient -from qdrant_client.http.models import Distance, VectorParams, PointStruct +from qdrant_client.http.models import ( + Distance, + VectorParams, + PointStruct, +) class Qdrant: @@ -33,7 +37,9 @@ class Qdrant: https: bool = True, ): try: - self.client = QdrantClient(url=host, port=port, api_key=api_key) + self.client = QdrantClient( + url=host, port=port, api_key=api_key + ) self.collection_name = collection_name self._load_embedding_model(model_name) self._setup_collection() @@ -56,7 +62,10 @@ class Qdrant: try: exists = self.client.get_collection(self.collection_name) if exists: - print(f"Collection '{self.collection_name}' already exists.") + print( + f"Collection '{self.collection_name}' already" + " exists." + ) except Exception as e: self.client.create_collection( collection_name=self.collection_name, @@ -93,7 +102,8 @@ class Qdrant: ) else: print( - f"Document at index {i} is missing 'page_content' key" + f"Document at index {i} is missing" + " 'page_content' key" ) except Exception as e: print(f"Error processing document at index {i}: {e}") @@ -121,7 +131,9 @@ class Qdrant: SearchResult or None: Returns the search results if successful, otherwise None. """ try: - query_vector = self.model.encode(query, normalize_embeddings=True) + query_vector = self.model.encode( + query, normalize_embeddings=True + ) search_result = self.client.search( collection_name=self.collection_name, query_vector=query_vector, diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index 89f1453b..9147a909 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -9,7 +9,9 @@ from pydantic import BaseModel, Field class TaskInput(BaseModel): __root__: Any = Field( ..., - description="The input parameters for the task. Any value is allowed.", + description=( + "The input parameters for the task. Any value is allowed." + ), example='{\n"debug": false,\n"mode": "benchmarks"\n}', ) @@ -25,7 +27,9 @@ class Artifact(BaseModel): ) relative_path: Optional[str] = Field( None, - description="Relative path of the artifact in the agent's workspace", + description=( + "Relative path of the artifact in the agent's workspace" + ), example="python/code/", ) @@ -34,7 +38,9 @@ class ArtifactUpload(BaseModel): file: bytes = Field(..., description="File to upload") relative_path: Optional[str] = Field( None, - description="Relative path of the artifact in the agent's workspace", + description=( + "Relative path of the artifact in the agent's workspace" + ), example="python/code/", ) @@ -42,7 +48,10 @@ class ArtifactUpload(BaseModel): class StepInput(BaseModel): __root__: Any = Field( ..., - description="Input parameters for the task step. Any value is allowed.", + description=( + "Input parameters for the task step. Any value is" + " allowed." + ), example='{\n"file_to_refactor": "models.py"\n}', ) @@ -51,7 +60,8 @@ class StepOutput(BaseModel): __root__: Any = Field( ..., description=( - "Output that the task step has produced. Any value is allowed." + "Output that the task step has produced. Any value is" + " allowed." ), example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -61,7 +71,9 @@ class TaskRequestBody(BaseModel): input: Optional[str] = Field( None, description="Input prompt for the task.", - example="Write the words you receive to the file 'output.txt'.", + example=( + "Write the words you receive to the file 'output.txt'." + ), ) additional_input: Optional[TaskInput] = None @@ -84,7 +96,9 @@ class Task(TaskRequestBody): class StepRequestBody(BaseModel): input: Optional[str] = Field( - None, description="Input prompt for the step.", example="Washington" + None, + description="Input prompt for the step.", + example="Washington", ) additional_input: Optional[StepInput] = None @@ -107,22 +121,28 @@ class Step(StepRequestBody): example="6bb1801a-fd80-45e8-899a-4dd723cc602e", ) name: Optional[str] = Field( - None, description="The name of the task step.", example="Write to file" + None, + description="The name of the task step.", + example="Write to file", + ) + status: Status = Field( + ..., description="The status of the task step." ) - status: Status = Field(..., description="The status of the task step.") output: Optional[str] = Field( None, description="Output of the task step.", example=( - "I am going to use the write_to_file command and write Washington" - " to a file called output.txt best_score: best_score = equation_score idx_to_add = i idxs.append(idx_to_add) - selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) + selected = np.append( + selected, [embedding_list[idx_to_add]], axis=0 + ) return idxs diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 33870b31..e6089b35 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -10,14 +10,18 @@ from swarms.models.openai_models import ( from swarms.models.zephyr import Zephyr # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 -from swarms.models.wizard_storytelling import WizardLLMStoryTeller # noqa: E402 +from swarms.models.wizard_storytelling import ( + WizardLLMStoryTeller, +) # noqa: E402 from swarms.models.mpt import MPT7B # noqa: E402 # MultiModal Models from swarms.models.idefics import Idefics # noqa: E402 from swarms.models.vilt import Vilt # noqa: E402 from swarms.models.nougat import Nougat # noqa: E402 -from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402 +from swarms.models.layoutlm_document_qa import ( + LayoutLMDocumentQA, +) # noqa: E402 from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402 # from swarms.models.gpt4v import GPT4Vision diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index 1f47e1bf..adffe49d 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -45,10 +45,16 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: """Validate exactly one arg in each group is not None.""" counts = [ - sum(1 for arg in arg_group if kwargs.get(arg) is not None) + sum( + 1 + for arg in arg_group + if kwargs.get(arg) is not None + ) for arg_group in arg_groups ] - invalid_groups = [i for i, count in enumerate(counts) if count != 1] + invalid_groups = [ + i for i, count in enumerate(counts) if count != 1 + ] if invalid_groups: invalid_group_names = [ ", ".join(arg_groups[i]) for i in invalid_groups @@ -119,8 +125,9 @@ def guard_import( module = importlib.import_module(module_name, package) except ImportError: raise ImportError( - f"Could not import {module_name} python package. " - f"Please install it with `pip install {pip_name or module_name}`." + f"Could not import {module_name} python package. Please" + " install it with `pip install" + f" {pip_name or module_name}`." ) return module @@ -134,25 +141,33 @@ def check_package_version( ) -> None: """Check the version of a package.""" imported_version = parse(version(package)) - if lt_version is not None and imported_version >= parse(lt_version): + if lt_version is not None and imported_version >= parse( + lt_version + ): raise ValueError( - f"Expected {package} version to be < {lt_version}. Received " - f"{imported_version}." + f"Expected {package} version to be < {lt_version}." + f" Received {imported_version}." ) - if lte_version is not None and imported_version > parse(lte_version): + if lte_version is not None and imported_version > parse( + lte_version + ): raise ValueError( - f"Expected {package} version to be <= {lte_version}. Received " - f"{imported_version}." + f"Expected {package} version to be <= {lte_version}." + f" Received {imported_version}." ) - if gt_version is not None and imported_version <= parse(gt_version): + if gt_version is not None and imported_version <= parse( + gt_version + ): raise ValueError( - f"Expected {package} version to be > {gt_version}. Received " - f"{imported_version}." + f"Expected {package} version to be > {gt_version}." + f" Received {imported_version}." ) - if gte_version is not None and imported_version < parse(gte_version): + if gte_version is not None and imported_version < parse( + gte_version + ): raise ValueError( - f"Expected {package} version to be >= {gte_version}. Received " - f"{imported_version}." + f"Expected {package} version to be >= {gte_version}." + f" Received {imported_version}." ) @@ -185,9 +200,11 @@ def build_extra_kwargs( if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - warnings.warn(f"""WARNING! {field_name} is not default parameter. + warnings.warn( + f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""") + Please confirm that {field_name} is what you intended.""" + ) extra_kwargs[field_name] = values.pop(field_name) invalid_model_kwargs = all_required_field_names.intersection( @@ -195,8 +212,9 @@ def build_extra_kwargs( ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly." - " Instead they were passed in as part of `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be specified" + " explicitly. Instead they were passed in as part of" + " `model_kwargs` parameter." ) return extra_kwargs @@ -273,12 +291,16 @@ class _AnthropicCommon(BaseLanguageModel): check_package_version("anthropic", gte_version="0.3") values["client"] = anthropic.Anthropic( base_url=values["anthropic_api_url"], - api_key=values["anthropic_api_key"].get_secret_value(), + api_key=values[ + "anthropic_api_key" + ].get_secret_value(), timeout=values["default_request_timeout"], ) values["async_client"] = anthropic.AsyncAnthropic( base_url=values["anthropic_api_url"], - api_key=values["anthropic_api_key"].get_secret_value(), + api_key=values[ + "anthropic_api_key" + ].get_secret_value(), timeout=values["default_request_timeout"], ) values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT @@ -316,7 +338,9 @@ class _AnthropicCommon(BaseLanguageModel): self, stop: Optional[List[str]] = None ) -> List[str]: if not self.HUMAN_PROMPT or not self.AI_PROMPT: - raise NameError("Please ensure the anthropic package is loaded") + raise NameError( + "Please ensure the anthropic package is loaded" + ) if stop is None: stop = [] @@ -375,7 +399,9 @@ class Anthropic(LLM, _AnthropicCommon): def _wrap_prompt(self, prompt: str) -> str: if not self.HUMAN_PROMPT or not self.AI_PROMPT: - raise NameError("Please ensure the anthropic package is loaded") + raise NameError( + "Please ensure the anthropic package is loaded" + ) if prompt.startswith(self.HUMAN_PROMPT): return prompt # Already wrapped. @@ -389,7 +415,8 @@ class Anthropic(LLM, _AnthropicCommon): # As a last resort, wrap the prompt ourselves to emulate instruct-style. return ( - f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here" + " you go:\n" ) def _call( @@ -419,7 +446,10 @@ class Anthropic(LLM, _AnthropicCommon): if self.streaming: completion = "" for chunk in self._stream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + prompt=prompt, + stop=stop, + run_manager=run_manager, + **kwargs, ): completion += chunk.text return completion @@ -447,7 +477,10 @@ class Anthropic(LLM, _AnthropicCommon): if self.streaming: completion = "" async for chunk in self._astream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + prompt=prompt, + stop=stop, + run_manager=run_manager, + **kwargs, ): completion += chunk.text return completion @@ -533,10 +566,14 @@ class Anthropic(LLM, _AnthropicCommon): chunk = GenerationChunk(text=token.completion) yield chunk if run_manager: - await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + await run_manager.on_llm_new_token( + chunk.text, chunk=chunk + ) def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" if not self.count_tokens: - raise NameError("Please ensure the anthropic package is loaded") + raise NameError( + "Please ensure the anthropic package is loaded" + ) return self.count_tokens(text) diff --git a/swarms/models/base.py b/swarms/models/base.py index 4e92ae45..eacbc1cf 100644 --- a/swarms/models/base.py +++ b/swarms/models/base.py @@ -27,7 +27,9 @@ class AbstractModel(ABC): def chat(self, task: str, history: str = "") -> str: """Chat with the model""" - complete_task = task + " | " + history # Delimiter for clarity + complete_task = ( + task + " | " + history + ) # Delimiter for clarity return self.run(complete_task) def __call__(self, task: str) -> str: diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index a773b12f..34c1b4b6 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -107,7 +107,9 @@ class BaseMultiModalModel: image_pil = Image.open(BytesIO(response.content)) return image_pil except requests.RequestException as error: - print(f"Error fetching image from {img} and error: {error}") + print( + f"Error fetching image from {img} and error: {error}" + ) return None def encode_img(self, img: str): @@ -142,14 +144,18 @@ class BaseMultiModalModel: """ # Instantiate the thread pool executor - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with ThreadPoolExecutor( + max_workers=self.max_workers + ) as executor: results = executor.map(self.run, tasks, imgs) # Print the results for debugging for result in results: print(result) - def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + def run_batch( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -176,7 +182,9 @@ class BaseMultiModalModel: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, self.run_with_retries, task, img) + loop.run_in_executor( + None, self.run_with_retries, task, img + ) for task, img in tasks_images ] return await asyncio.gather(*futures) @@ -194,7 +202,9 @@ class BaseMultiModalModel: print(f"Error with the request {error}") continue - def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]): + def run_batch_with_retries( + self, tasks_images: List[Tuple[str, str]] + ): """Run the model with retries""" for i in range(self.retries): try: diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py index 1c2627a6..e2d070af 100644 --- a/swarms/models/bioclip.py +++ b/swarms/models/bioclip.py @@ -112,11 +112,12 @@ class BioClip: template: str = "this is a photo of ", context_length: int = 256, ): - image = torch.stack([self.preprocess_val(Image.open(img_path))]).to( - self.device - ) + image = torch.stack( + [self.preprocess_val(Image.open(img_path))] + ).to(self.device) texts = self.tokenizer( - [template + l for l in labels], context_length=context_length + [template + l for l in labels], + context_length=context_length, ).to(self.device) with torch.no_grad(): @@ -128,7 +129,9 @@ class BioClip: .detach() .softmax(dim=-1) ) - sorted_indices = torch.argsort(logits, dim=-1, descending=True) + sorted_indices = torch.argsort( + logits, dim=-1, descending=True + ) logits = logits.cpu().numpy() sorted_indices = sorted_indices.cpu().numpy() @@ -149,7 +152,10 @@ class BioClip: metadata["filename"] + "\n" + "\n".join( - [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()] + [ + f"{k}: {v*100:.1f}" + for k, v in metadata["top_probs"].items() + ] ) ) ax.set_title(title, fontsize=14) diff --git a/swarms/models/biogpt.py b/swarms/models/biogpt.py index d5e692f2..9ee5b513 100644 --- a/swarms/models/biogpt.py +++ b/swarms/models/biogpt.py @@ -34,7 +34,12 @@ advantage of BioGPT on biomedical literature to generate fluent descriptions for """ import torch -from transformers import pipeline, set_seed, BioGptTokenizer, BioGptForCausalLM +from transformers import ( + pipeline, + set_seed, + BioGptTokenizer, + BioGptForCausalLM, +) class BioGPT: @@ -85,8 +90,12 @@ class BioGPT: self.do_sample = do_sample self.min_length = min_length - self.model = BioGptForCausalLM.from_pretrained(self.model_name) - self.tokenizer = BioGptTokenizer.from_pretrained(self.model_name) + self.model = BioGptForCausalLM.from_pretrained( + self.model_name + ) + self.tokenizer = BioGptTokenizer.from_pretrained( + self.model_name + ) def __call__(self, text: str): """ @@ -103,7 +112,9 @@ class BioGPT: """ set_seed(42) generator = pipeline( - "text-generation", model=self.model, tokenizer=self.tokenizer + "text-generation", + model=self.model, + tokenizer=self.tokenizer, ) out = generator( text, @@ -156,7 +167,9 @@ class BioGPT: num_beams=num_beams, early_stopping=early_stopping, ) - return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) + return self.tokenizer.decode( + beam_output[0], skip_special_tokens=True + ) # Feature 1: Set a new tokenizer and model def set_pretrained_model(self, model_name): @@ -167,8 +180,12 @@ class BioGPT: model_name (str): Name of the pretrained model. """ self.model_name = model_name - self.model = BioGptForCausalLM.from_pretrained(self.model_name) - self.tokenizer = BioGptTokenizer.from_pretrained(self.model_name) + self.model = BioGptForCausalLM.from_pretrained( + self.model_name + ) + self.tokenizer = BioGptTokenizer.from_pretrained( + self.model_name + ) # Feature 2: Get the model's config details def get_config(self): diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py index 508e9073..1a31d82e 100644 --- a/swarms/models/cohere_chat.py +++ b/swarms/models/cohere_chat.py @@ -32,7 +32,9 @@ def _create_retry_decorator(llm) -> Callable[[Any], Any]: return retry( reraise=True, stop=stop_after_attempt(llm.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + wait=wait_exponential( + multiplier=1, min=min_seconds, max=max_seconds + ), retry=(retry_if_exception_type(cohere.error.CohereError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -65,7 +67,9 @@ class BaseCohere(Serializable): client: Any #: :meta private: async_client: Any #: :meta private: - model: Optional[str] = Field(default=None, description="Model name to use.") + model: Optional[str] = Field( + default=None, description="Model name to use." + ) """Model name to use.""" temperature: float = 0.75 diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 3c130670..40f63418 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -116,7 +116,9 @@ class Dalle3: byte_array = byte_stream.getvalue() return byte_array - @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds) + @backoff.on_exception( + backoff.expo, Exception, max_time=max_time_seconds + ) def __call__(self, task: str): """ Text to image conversion using the Dalle3 API @@ -169,8 +171,8 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api" - " key and or try again" + f"Error running Dalle3: {error} try" + " optimizing your api key and or try again" ), "red", ) @@ -198,7 +200,9 @@ class Dalle3: with open(full_path, "wb") as file: file.write(response.content) else: - raise ValueError(f"Failed to download image from {img_url}") + raise ValueError( + f"Failed to download image from {img_url}" + ) def create_variations(self, img: str): """ @@ -234,14 +238,21 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api" - " key and or try again" + f"Error running Dalle3: {error} try" + " optimizing your api key and or try again" ), "red", ) ) - print(colored(f"Error running Dalle3: {error.http_status}", "red")) - print(colored(f"Error running Dalle3: {error.error}", "red")) + print( + colored( + f"Error running Dalle3: {error.http_status}", + "red", + ) + ) + print( + colored(f"Error running Dalle3: {error.error}", "red") + ) raise error def print_dashboard(self): @@ -300,7 +311,9 @@ class Dalle3: executor.submit(self, task): task for task in tasks } results = [] - for future in concurrent.futures.as_completed(future_to_task): + for future in concurrent.futures.as_completed( + future_to_task + ): task = future_to_task[future] try: img = future.result() @@ -311,19 +324,27 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing" - " your api key and or try again" + f"Error running Dalle3: {error} try" + " optimizing your api key and or try" + " again" ), "red", ) ) print( colored( - f"Error running Dalle3: {error.http_status}", "red" + ( + "Error running Dalle3:" + f" {error.http_status}" + ), + "red", ) ) print( - colored(f"Error running Dalle3: {error.error}", "red") + colored( + f"Error running Dalle3: {error.error}", + "red", + ) ) raise error @@ -339,7 +360,9 @@ class Dalle3: """Str method for the Dalle3 class""" return f"Dalle3(image_url={self.image_url})" - @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) + @backoff.on_exception( + backoff.expo, Exception, max_tries=max_retries + ) def rate_limited_call(self, task: str): """Rate limited call to the Dalle3 API""" return self.__call__(task) diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 2b4fb5a5..951dcd10 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -6,7 +6,11 @@ from typing import Union import torch from termcolor import colored -from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +from transformers import ( + AutoModelForSpeechSeq2Seq, + AutoProcessor, + pipeline, +) def async_retry(max_retries=3, exceptions=(Exception,), delay=1): @@ -29,8 +33,8 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): if retries <= 0: raise print( - f"Retry after exception: {e}, Attempts remaining:" - f" {retries}" + f"Retry after exception: {e}, Attempts" + f" remaining: {retries}" ) await asyncio.sleep(delay) @@ -66,7 +70,9 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.torch_dtype = ( - torch.float16 if torch.cuda.is_available() else torch.float32 + torch.float16 + if torch.cuda.is_available() + else torch.float32 ) self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -106,7 +112,9 @@ class DistilWhisperModel: :return: The transcribed text. """ loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.transcribe, inputs) + return await loop.run_in_executor( + None, self.transcribe, inputs + ) def real_time_transcribe(self, audio_file_path, chunk_duration=5): """ @@ -130,13 +138,21 @@ class DistilWhisperModel: sample_rate = audio_input.sampling_rate len(audio_input.array) / sample_rate chunks = [ - audio_input.array[i : i + sample_rate * chunk_duration] + audio_input.array[ + i : i + sample_rate * chunk_duration + ] for i in range( - 0, len(audio_input.array), sample_rate * chunk_duration + 0, + len(audio_input.array), + sample_rate * chunk_duration, ) ] - print(colored("Starting real-time transcription...", "green")) + print( + colored( + "Starting real-time transcription...", "green" + ) + ) for i, chunk in enumerate(chunks): # Process the current chunk @@ -146,8 +162,8 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = processed_inputs.input_values.to( - self.device + processed_inputs = ( + processed_inputs.input_values.to(self.device) ) # Generate transcription for the chunk @@ -158,7 +174,9 @@ class DistilWhisperModel: # Print the chunk's transcription print( - colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + colored( + f"Chunk {i+1}/{len(chunks)}: ", "yellow" + ) + transcription ) @@ -167,5 +185,8 @@ class DistilWhisperModel: except Exception as e: print( - colored(f"An error occurred during transcription: {e}", "red") + colored( + f"An error occurred during transcription: {e}", + "red", + ) ) diff --git a/swarms/models/eleven_labs.py b/swarms/models/eleven_labs.py index 42f4dae1..2d55e864 100644 --- a/swarms/models/eleven_labs.py +++ b/swarms/models/eleven_labs.py @@ -13,7 +13,8 @@ def _import_elevenlabs() -> Any: import elevenlabs except ImportError as e: raise ImportError( - "Cannot import elevenlabs, please install `pip install elevenlabs`." + "Cannot import elevenlabs, please install `pip install" + " elevenlabs`." ) from e return elevenlabs @@ -52,16 +53,18 @@ class ElevenLabsText2SpeechTool(BaseTool): name: str = "eleven_labs_text2speech" description: str = ( - "A wrapper around Eleven Labs Text2Speech. " - "Useful for when you need to convert text to speech. " - "It supports multiple languages, including English, German, Polish, " - "Spanish, Italian, French, Portuguese, and Hindi. " + "A wrapper around Eleven Labs Text2Speech. Useful for when" + " you need to convert text to speech. It supports multiple" + " languages, including English, German, Polish, Spanish," + " Italian, French, Portuguese, and Hindi. " ) @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - _ = get_from_dict_or_env(values, "eleven_api_key", "ELEVEN_API_KEY") + _ = get_from_dict_or_env( + values, "eleven_api_key", "ELEVEN_API_KEY" + ) return values @@ -102,7 +105,9 @@ class ElevenLabsText2SpeechTool(BaseTool): def save(self, speech_file: str, path: str) -> None: """Save the speech file to a path.""" - raise NotImplementedError("Saving not implemented for this tool.") + raise NotImplementedError( + "Saving not implemented for this tool." + ) def __str__(self): return "ElevenLabsText2SpeechTool" diff --git a/swarms/models/embeddings_base.py b/swarms/models/embeddings_base.py index 6dd700c4..b0f5e22e 100644 --- a/swarms/models/embeddings_base.py +++ b/swarms/models/embeddings_base.py @@ -14,7 +14,9 @@ class Embeddings(ABC): def embed_query(self, text: str) -> List[float]: """Embed query text.""" - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + async def aembed_documents( + self, texts: List[str] + ) -> List[List[float]]: """Embed search docs.""" raise NotImplementedError diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index c9a0d719..a6fc31f8 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -83,4 +83,6 @@ class FastViT: top_classes = top_classes.cpu().numpy().tolist() # top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed - return ClassificationResult(class_id=top_classes, confidence=top_probs) + return ClassificationResult( + class_id=top_classes, confidence=top_probs + ) diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index 79dc1c47..c1e51199 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -43,7 +43,9 @@ class Fuyu: self.device_map = device_map self.max_new_tokens = max_new_tokens - self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_path + ) self.image_processor = FuyuImageProcessor() self.processor = FuyuProcessor( image_processor=self.image_processor, @@ -87,5 +89,7 @@ class Fuyu: image_pil = Image.open(BytesIO(response.content)) return image_pil except requests.RequestException as error: - print(f"Error fetching image from {img} and error: {error}") + print( + f"Error fetching image from {img} and error: {error}" + ) return None diff --git a/swarms/models/gpt4_vision_api.py b/swarms/models/gpt4_vision_api.py index 5539cc4a..86ced6a2 100644 --- a/swarms/models/gpt4_vision_api.py +++ b/swarms/models/gpt4_vision_api.py @@ -15,7 +15,10 @@ from termcolor import colored try: import cv2 except ImportError: - print("OpenCV not installed. Please install OpenCV to use this model.") + print( + "OpenCV not installed. Please install OpenCV to use this" + " model." + ) raise ImportError # Load environment variables @@ -127,7 +130,10 @@ class GPT4VisionAPI: payload = { "model": self.model_name, "messages": [ - {"role": "system", "content": [self.system_prompt]}, + { + "role": "system", + "content": [self.system_prompt], + }, { "role": "user", "content": [ @@ -135,9 +141,7 @@ class GPT4VisionAPI: { "type": "image_url", "image_url": { - "url": ( - f"data:image/jpeg;base64,{base64_image}" - ) + "url": f"data:image/jpeg;base64,{base64_image}" }, }, ], @@ -241,7 +245,9 @@ class GPT4VisionAPI: if not success: break _, buffer = cv2.imencode(".jpg", frame) - base64_frames.append(base64.b64encode(buffer).decode("utf-8")) + base64_frames.append( + base64.b64encode(buffer).decode("utf-8") + ) video.release() print(len(base64_frames), "frames read.") @@ -266,7 +272,10 @@ class GPT4VisionAPI: payload = { "model": self.model_name, "messages": [ - {"role": "system", "content": [self.system_prompt]}, + { + "role": "system", + "content": [self.system_prompt], + }, { "role": "user", "content": [ @@ -274,9 +283,7 @@ class GPT4VisionAPI: { "type": "image_url", "image_url": { - "url": ( - f"data:image/jpeg;base64,{base64_image}" - ) + "url": f"data:image/jpeg;base64,{base64_image}" }, }, ], @@ -326,7 +333,9 @@ class GPT4VisionAPI: """ # Instantiate the thread pool executor - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with ThreadPoolExecutor( + max_workers=self.max_workers + ) as executor: results = executor.map(self.run, tasks, imgs) # Print the results for debugging @@ -372,9 +381,7 @@ class GPT4VisionAPI: { "type": "image_url", "image_url": { - "url": ( - f"data:image/jpeg;base64,{base64_image}" - ) + "url": f"data:image/jpeg;base64,{base64_image}" }, }, ], @@ -384,7 +391,9 @@ class GPT4VisionAPI: } async with aiohttp.ClientSession() as session: async with session.post( - self.openai_proxy, headers=headers, data=json.dumps(payload) + self.openai_proxy, + headers=headers, + data=json.dumps(payload), ) as response: out = await response.json() content = out["choices"][0]["message"]["content"] @@ -393,7 +402,9 @@ class GPT4VisionAPI: print(f"Error with the request {error}") raise error - def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + def run_batch( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -420,7 +431,9 @@ class GPT4VisionAPI: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, self.run_with_retries, task, img) + loop.run_in_executor( + None, self.run_with_retries, task, img + ) for task, img in tasks_images ] return await asyncio.gather(*futures) @@ -428,7 +441,9 @@ class GPT4VisionAPI: def health_check(self): """Health check for the GPT4Vision model""" try: - response = requests.get("https://api.openai.com/v1/engines") + response = requests.get( + "https://api.openai.com/v1/engines" + ) return response.status_code == 200 except requests.RequestException as error: print(f"Health check failed: {error}") diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py index 8f2683e0..43eff12d 100644 --- a/swarms/models/gpt4v.py +++ b/swarms/models/gpt4v.py @@ -65,7 +65,9 @@ class GPT4Vision: model: str = "gpt-4-vision-preview" backoff_factor: float = 2.0 timeout_seconds: int = 10 - openai_api_key: Optional[str] = None or os.getenv("OPENAI_API_KEY") + openai_api_key: Optional[str] = None or os.getenv( + "OPENAI_API_KEY" + ) # 'Low' or 'High' for respesctively fast or high quality, but high more token usage quality: str = "low" # Max tokens to use for the API request, the maximum might be 3,000 but we don't know @@ -131,9 +133,14 @@ class GPT4Vision: return out except openai.OpenAIError as e: # logger.error(f"OpenAI API error: {e}") - return f"OpenAI API error: Could not process the image. {e}" + return ( + f"OpenAI API error: Could not process the image. {e}" + ) except Exception as e: - return f"Unexpected error occurred while processing the image. {e}" + return ( + "Unexpected error occurred while processing the" + f" image. {e}" + ) def clean_output(self, output: str): # Regex pattern to find the Choice object representation in the output @@ -182,11 +189,18 @@ class GPT4Vision: return print(response.choices[0]) except openai.OpenAIError as e: # logger.error(f"OpenAI API error: {e}") - return f"OpenAI API error: Could not process the image. {e}" + return ( + f"OpenAI API error: Could not process the image. {e}" + ) except Exception as e: - return f"Unexpected error occurred while processing the image. {e}" + return ( + "Unexpected error occurred while processing the" + f" image. {e}" + ) - def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + def run_batch( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -213,7 +227,9 @@ class GPT4Vision: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, self.run_with_retries, task, img) + loop.run_in_executor( + None, self.run_with_retries, task, img + ) for task, img in tasks_images ] return await asyncio.gather(*futures) @@ -240,7 +256,9 @@ class GPT4Vision: def health_check(self): """Health check for the GPT4Vision model""" try: - response = requests.get("https://api.openai.com/v1/engines") + response = requests.get( + "https://api.openai.com/v1/engines" + ) return response.status_code == 200 except requests.RequestException as error: print(f"Health check failed: {error}") diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 1db435f5..295949f5 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -7,7 +7,11 @@ from typing import List, Tuple import torch from termcolor import colored from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) class HuggingfaceLLM: @@ -173,7 +177,10 @@ class HuggingfaceLLM: self.model_id, *args, **kwargs ) self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, *args, **kwargs + self.model_id, + quantization_config=bnb_config, + *args, + **kwargs, ) self.model # .to(self.device) @@ -182,7 +189,11 @@ class HuggingfaceLLM: # raise print( colored( - f"Failed to load the model and or the tokenizer: {e}", "red" + ( + "Failed to load the model and or the" + f" tokenizer: {e}" + ), + "red", ) ) @@ -198,7 +209,9 @@ class HuggingfaceLLM: """Load the model""" if not self.model or not self.tokenizer: try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id + ) bnb_config = ( BitsAndBytesConfig(**self.quantization_config) @@ -214,7 +227,8 @@ class HuggingfaceLLM: self.model = DDP(self.model) except Exception as error: self.logger.error( - f"Failed to load the model or the tokenizer: {error}" + "Failed to load the model or the tokenizer:" + f" {error}" ) raise @@ -226,7 +240,9 @@ class HuggingfaceLLM: results = list(executor.map(self.run, tasks)) return results - def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + def run_batch( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -254,9 +270,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + task, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -266,7 +282,9 @@ class HuggingfaceLLM: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -274,7 +292,8 @@ class HuggingfaceLLM: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -287,13 +306,16 @@ class HuggingfaceLLM: ) del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: print( colored( ( - "HuggingfaceLLM could not generate text because of" - f" error: {e}, try optimizing your arguments" + "HuggingfaceLLM could not generate text" + f" because of error: {e}, try optimizing your" + " arguments" ), "red", ) @@ -318,9 +340,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + task, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -330,7 +352,9 @@ class HuggingfaceLLM: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -338,7 +362,8 @@ class HuggingfaceLLM: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -352,7 +377,9 @@ class HuggingfaceLLM: del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: self.logger.error(f"Failed to generate the text: {e}") raise diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 0cfcf1af..7c505d8a 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -100,10 +100,14 @@ class Idefics: """ inputs = ( self.processor( - prompts, add_end_of_utterance_token=False, return_tensors="pt" + prompts, + add_end_of_utterance_token=False, + return_tensors="pt", ).to(self.device) if batched_mode - else self.processor(prompts[0], return_tensors="pt").to(self.device) + else self.processor(prompts[0], return_tensors="pt").to( + self.device + ) ) exit_condition = self.processor.tokenizer( @@ -111,7 +115,8 @@ class Idefics: ).input_ids bad_words_ids = self.processor.tokenizer( - ["", "", "", "", "= image_h: diff --git a/swarms/models/llama_function_caller.py b/swarms/models/llama_function_caller.py index ca5ee5d3..78169208 100644 --- a/swarms/models/llama_function_caller.py +++ b/swarms/models/llama_function_caller.py @@ -170,7 +170,9 @@ class LlamaFunctionCaller: prompt = f"{task}\n\n" # Encode and send to the model - inputs = self.tokenizer([prompt], return_tensors="pt").to(self.runtime) + inputs = self.tokenizer([prompt], return_tensors="pt").to( + self.runtime + ) streamer = TextStreamer(self.tokenizer) diff --git a/swarms/models/llava.py b/swarms/models/llava.py index 6f8019bc..605904c3 100644 --- a/swarms/models/llava.py +++ b/swarms/models/llava.py @@ -70,7 +70,10 @@ class MultiModalLlava: def chat(self): """Interactive chat in terminal""" - print("Starting chat with LlavaModel. Type 'exit' to end the session.") + print( + "Starting chat with LlavaModel. Type 'exit' to end the" + " session." + ) while True: user_input = input("You: ") if user_input.lower() == "exit": diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 056a31bb..297ecf12 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -50,7 +50,8 @@ class Mistral: # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": raise ValueError( - "CUDA is not available. Please choose a different device." + "CUDA is not available. Please choose a different" + " device." ) # Load the model and tokenizer @@ -62,19 +63,25 @@ class Mistral: def load_model(self): try: - self.model = AutoModelForCausalLM.from_pretrained(self.model_name) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name + ) self.model.to(self.device) except Exception as e: - raise ValueError(f"Error loading the Mistral model: {str(e)}") + raise ValueError( + f"Error loading the Mistral model: {str(e)}" + ) def run(self, task: str): """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to( - self.device - ) + model_inputs = self.tokenizer( + [task], return_tensors="pt" + ).to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -82,7 +89,9 @@ class Mistral: temperature=self.temperature, max_new_tokens=self.max_length, ) - output_text = self.tokenizer.batch_decode(generated_ids)[0] + output_text = self.tokenizer.batch_decode(generated_ids)[ + 0 + ] return output_text except Exception as e: raise ValueError(f"Error running the model: {str(e)}") @@ -91,9 +100,9 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to( - self.device - ) + model_inputs = self.tokenizer( + [task], return_tensors="pt" + ).to(self.device) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -101,7 +110,9 @@ class Mistral: temperature=self.temperature, max_new_tokens=self.max_length, ) - output_text = self.tokenizer.batch_decode(generated_ids)[0] + output_text = self.tokenizer.batch_decode(generated_ids)[ + 0 + ] return output_text except Exception as e: raise ValueError(f"Error running the model: {str(e)}") diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index c304355a..56f1bbdb 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -30,7 +30,10 @@ class MPT7B: """ def __init__( - self, model_name: str, tokenizer_name: str, max_tokens: int = 100 + self, + model_name: str, + tokenizer_name: str, + max_tokens: int = 100, ): # Loading model and tokenizer details self.model_name = model_name @@ -138,9 +141,13 @@ class MPT7B: """Call the model asynchronously""" "" return await self.run_async(task, *args, **kwargs) - def batch_generate(self, prompts: list, temperature: float = 1.0) -> list: + def batch_generate( + self, prompts: list, temperature: float = 1.0 + ) -> list: """Batch generate text""" - self.logger.info(f"Generating text for {len(prompts)} prompts...") + self.logger.info( + f"Generating text for {len(prompts)} prompts..." + ) results = [] with torch.autocast("cuda", dtype=torch.bfloat16): for prompt in prompts: diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index 0eceb362..453c6cae 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -61,7 +61,9 @@ class Nougat: def __call__(self, img: str): """Call the model with an image_path str as an input""" image = Image.open(img) - pixel_values = self.processor(image, return_tensors="pt").pixel_values + pixel_values = self.processor( + image, return_tensors="pt" + ).pixel_values # Generate transcriptions, here we only generate 30 tokens outputs = self.model.generate( @@ -92,7 +94,9 @@ class Nougat: # Convert the matches to a readable format cleaned_data = [ - "Date: {}, Amount: {}".format(date, amount.replace(",", "")) + "Date: {}, Amount: {}".format( + date, amount.replace(",", "") + ) for date, amount in matches ] diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 08919d45..0cbbdbee 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -55,13 +55,17 @@ def _create_retry_decorator( return retry( reraise=True, stop=stop_after_attempt(embeddings.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + wait=wait_exponential( + multiplier=1, min=min_seconds, max=max_seconds + ), retry=( retry_if_exception_type(llm.error.Timeout) | retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.APIConnectionError) | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type(llm.error.ServiceUnavailableError) + | retry_if_exception_type( + llm.error.ServiceUnavailableError + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -77,13 +81,17 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any: async_retrying = AsyncRetrying( reraise=True, stop=stop_after_attempt(embeddings.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + wait=wait_exponential( + multiplier=1, min=min_seconds, max=max_seconds + ), retry=( retry_if_exception_type(llm.error.Timeout) | retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.APIConnectionError) | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type(llm.error.ServiceUnavailableError) + | retry_if_exception_type( + llm.error.ServiceUnavailableError + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -104,11 +112,15 @@ def _check_response(response: dict) -> dict: if any(len(d["embedding"]) == 1 for d in response["data"]): import llm - raise llm.error.APIError("OpenAI API returned an empty embedding") + raise llm.error.APIError( + "OpenAI API returned an empty embedding" + ) return response -def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: +def embed_with_retry( + embeddings: OpenAIEmbeddings, **kwargs: Any +) -> Any: """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @@ -176,9 +188,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "text-embedding-ada-002" - deployment: str = ( - model # to support Azure OpenAI Service custom deployment names - ) + deployment: str = model # to support Azure OpenAI Service custom deployment names openai_api_version: Optional[str] = None # to support Azure OpenAI Service custom endpoints openai_api_base: Optional[str] = None @@ -191,12 +201,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key: Optional[str] = None openai_organization: Optional[str] = None allowed_special: Union[Literal["all"], Set[str]] = set() - disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" + disallowed_special: Union[ + Literal["all"], Set[str], Sequence[str] + ] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 """Maximum number of retries to make when generating.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = ( + None + ) """Timeout in seconds for the OpenAPI request.""" headers: Any = None tiktoken_model_name: Optional[str] = None @@ -226,7 +240,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") + raise ValueError( + f"Found {field_name} supplied twice." + ) if field_name not in all_required_field_names: warnings.warn( f"""WARNING! {field_name} is not default parameter. @@ -240,9 +256,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified" - " explicitly. Instead they were passed in as part of" - " `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be" + " specified explicitly. Instead they were passed in" + " as part of `model_kwargs` parameter." ) values["model_kwargs"] = extra @@ -272,7 +288,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "OPENAI_PROXY", default="", ) - if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): + if values["openai_api_type"] in ( + "azure", + "azure_ad", + "azuread", + ): default_api_version = "2022-12-01" else: default_api_version = "" @@ -324,9 +344,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return openai_args def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: - embeddings: List[List[float]] = [[] for _ in range(len(texts))] + embeddings: List[List[float]] = [ + [] for _ in range(len(texts)) + ] try: import tiktoken except ImportError: @@ -343,7 +369,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): encoding = tiktoken.encoding_for_model(model_name) except KeyError: logger.warning( - "Warning: model not found. Using cl100k_base encoding." + "Warning: model not found. Using cl100k_base" + " encoding." ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) @@ -358,7 +385,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j : j + self.embedding_ctx_length]) + tokens.append( + token[j : j + self.embedding_ctx_length] + ) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -380,10 +409,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings): input=tokens[i : i + _chunk_size], **self._invocation_params, ) - batched_embeddings.extend(r["embedding"] for r in response["data"]) + batched_embeddings.extend( + r["embedding"] for r in response["data"] + ) - results: List[List[List[float]]] = [[] for _ in range(len(texts))] - num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + results: List[List[List[float]]] = [ + [] for _ in range(len(texts)) + ] + num_tokens_in_batch: List[List[int]] = [ + [] for _ in range(len(texts)) + ] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) @@ -400,16 +435,24 @@ class OpenAIEmbeddings(BaseModel, Embeddings): average = np.average( _result, axis=0, weights=num_tokens_in_batch[i] ) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embeddings[i] = ( + average / np.linalg.norm(average) + ).tolist() return embeddings # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb async def _aget_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: - embeddings: List[List[float]] = [[] for _ in range(len(texts))] + embeddings: List[List[float]] = [ + [] for _ in range(len(texts)) + ] try: import tiktoken except ImportError: @@ -426,7 +469,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): encoding = tiktoken.encoding_for_model(model_name) except KeyError: logger.warning( - "Warning: model not found. Using cl100k_base encoding." + "Warning: model not found. Using cl100k_base" + " encoding." ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) @@ -441,7 +485,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append(token[j : j + self.embedding_ctx_length]) + tokens.append( + token[j : j + self.embedding_ctx_length] + ) indices.append(i) batched_embeddings: List[List[float]] = [] @@ -452,10 +498,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings): input=tokens[i : i + _chunk_size], **self._invocation_params, ) - batched_embeddings.extend(r["embedding"] for r in response["data"]) + batched_embeddings.extend( + r["embedding"] for r in response["data"] + ) - results: List[List[List[float]]] = [[] for _ in range(len(texts))] - num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + results: List[List[List[float]]] = [ + [] for _ in range(len(texts)) + ] + num_tokens_in_batch: List[List[int]] = [ + [] for _ in range(len(texts)) + ] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) @@ -474,7 +526,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): average = np.average( _result, axis=0, weights=num_tokens_in_batch[i] ) - embeddings[i] = (average / np.linalg.norm(average)).tolist() + embeddings[i] = ( + average / np.linalg.norm(average) + ).tolist() return embeddings @@ -493,7 +547,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return self._get_len_safe_embeddings(texts, engine=self.deployment) + return self._get_len_safe_embeddings( + texts, engine=self.deployment + ) async def aembed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index f0c41f2a..6542e457 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -3,7 +3,11 @@ from typing import Any, Dict, List, Optional, Union import openai import requests from pydantic import BaseModel, validator -from tenacity import retry, stop_after_attempt, wait_random_exponential +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, +) from termcolor import colored @@ -100,7 +104,9 @@ class FunctionSpecification(BaseModel): for req_param in self.required or []: if req_param not in params: - raise ValueError(f"Missing required parameter: {req_param}") + raise ValueError( + f"Missing required parameter: {req_param}" + ) class OpenAIFunctionCaller: @@ -220,7 +226,10 @@ class OpenAIFunctionCaller: elif message["role"] == "tool": print( colored( - f"function ({message['name']}): {message['content']}\n", + ( + f"function ({message['name']}):" + f" {message['content']}\n" + ), role_to_color[message["role"]], ) ) diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 6366b8b0..14332ff2 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -27,7 +27,10 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.pydantic_v1 import Field, root_validator from langchain.schema import Generation, LLMResult from langchain.schema.output import GenerationChunk -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) from langchain.utils.utils import build_extra_kwargs @@ -44,7 +47,9 @@ def is_openai_v1() -> bool: def update_token_usage( - keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] + keys: Set[str], + response: Dict[str, Any], + token_usage: Dict[str, Any], ) -> None: """Update token usage.""" _keys_to_use = keys.intersection(response["usage"]) @@ -65,7 +70,9 @@ def _stream_response_to_generation_chunk( finish_reason=stream_response["choices"][0].get( "finish_reason", None ), - logprobs=stream_response["choices"][0].get("logprobs", None), + logprobs=stream_response["choices"][0].get( + "logprobs", None + ), ), ) @@ -74,13 +81,15 @@ def _update_response( response: Dict[str, Any], stream_response: Dict[str, Any] ) -> None: """Update response from the stream response.""" - response["choices"][0]["text"] += stream_response["choices"][0]["text"] - response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( - "finish_reason", None - ) - response["choices"][0]["logprobs"] = stream_response["choices"][0][ - "logprobs" + response["choices"][0]["text"] += stream_response["choices"][0][ + "text" ] + response["choices"][0]["finish_reason"] = stream_response[ + "choices" + ][0].get("finish_reason", None) + response["choices"][0]["logprobs"] = stream_response["choices"][ + 0 + ]["logprobs"] def _streaming_response_template() -> Dict[str, Any]: @@ -111,7 +120,9 @@ def _create_retry_decorator( openai.error.ServiceUnavailableError, ] return create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + error_types=errors, + max_retries=llm.max_retries, + run_manager=run_manager, ) @@ -121,7 +132,9 @@ def completion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = _create_retry_decorator( + llm, run_manager=run_manager + ) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: @@ -136,7 +149,9 @@ async def acompletion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = _create_retry_decorator( + llm, run_manager=run_manager + ) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: @@ -160,7 +175,9 @@ class BaseOpenAI(BaseLLM): attributes["openai_api_base"] = self.openai_api_base if self.openai_organization != "": - attributes["openai_organization"] = self.openai_organization + attributes["openai_organization"] = ( + self.openai_organization + ) if self.openai_proxy != "": attributes["openai_proxy"] = self.openai_proxy @@ -199,9 +216,13 @@ class BaseOpenAI(BaseLLM): openai_proxy: Optional[str] = None batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" - request_timeout: Optional[Union[float, Tuple[float, float]]] = None + request_timeout: Optional[Union[float, Tuple[float, float]]] = ( + None + ) """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" - logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict) + logit_bias: Optional[Dict[str, float]] = Field( + default_factory=dict + ) """Adjust the probability of specific tokens being generated.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" @@ -278,7 +299,9 @@ class BaseOpenAI(BaseLLM): if values["streaming"] and values["n"] > 1: raise ValueError("Cannot stream results when n > 1.") if values["streaming"] and values["best_of"] > 1: - raise ValueError("Cannot stream results when best_of > 1.") + raise ValueError( + "Cannot stream results when best_of > 1." + ) return values @property @@ -310,7 +333,9 @@ class BaseOpenAI(BaseLLM): **kwargs: Any, ) -> Iterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} - self.get_sub_prompts(params, [prompt], stop) # this mutates params + self.get_sub_prompts( + params, [prompt], stop + ) # this mutates params for stream_resp in completion_with_retry( self, prompt=prompt, run_manager=run_manager, **params ): @@ -336,7 +361,9 @@ class BaseOpenAI(BaseLLM): **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: params = {**self._invocation_params, **kwargs, "stream": True} - self.get_sub_prompts(params, [prompt], stop) # this mutate params + self.get_sub_prompts( + params, [prompt], stop + ) # this mutate params async for stream_resp in await acompletion_with_retry( self, prompt=prompt, run_manager=run_manager, **params ): @@ -404,7 +431,9 @@ class BaseOpenAI(BaseLLM): { "text": generation.text, "finish_reason": ( - generation.generation_info.get("finish_reason") + generation.generation_info.get( + "finish_reason" + ) if generation.generation_info else None ), @@ -417,7 +446,10 @@ class BaseOpenAI(BaseLLM): ) else: response = completion_with_retry( - self, prompt=_prompts, run_manager=run_manager, **params + self, + prompt=_prompts, + run_manager=run_manager, + **params, ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) @@ -459,7 +491,9 @@ class BaseOpenAI(BaseLLM): { "text": generation.text, "finish_reason": ( - generation.generation_info.get("finish_reason") + generation.generation_info.get( + "finish_reason" + ) if generation.generation_info else None ), @@ -472,7 +506,10 @@ class BaseOpenAI(BaseLLM): ) else: response = await acompletion_with_retry( - self, prompt=_prompts, run_manager=run_manager, **params + self, + prompt=_prompts, + run_manager=run_manager, + **params, ) choices.extend(response["choices"]) update_token_usage(_keys, response, token_usage) @@ -488,15 +525,19 @@ class BaseOpenAI(BaseLLM): if stop is not None: if "stop" in params: raise ValueError( - "`stop` found in both the input and default params." + "`stop` found in both the input and default" + " params." ) params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: raise ValueError( - "max_tokens set to -1 not supported for multiple inputs." + "max_tokens set to -1 not supported for multiple" + " inputs." ) - params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) + params["max_tokens"] = self.max_tokens_for_prompt( + prompts[0] + ) sub_prompts = [ prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size) @@ -504,7 +545,10 @@ class BaseOpenAI(BaseLLM): return sub_prompts def create_llm_result( - self, choices: Any, prompts: List[str], token_usage: Dict[str, int] + self, + choices: Any, + prompts: List[str], + token_usage: Dict[str, int], ) -> LLMResult: """Create the LLMResult from the choices and prompts.""" generations = [] @@ -522,8 +566,13 @@ class BaseOpenAI(BaseLLM): for choice in sub_choices ] ) - llm_output = {"token_usage": token_usage, "model_name": self.model_name} - return LLMResult(generations=generations, llm_output=llm_output) + llm_output = { + "token_usage": token_usage, + "model_name": self.model_name, + } + return LLMResult( + generations=generations, llm_output=llm_output + ) @property def _invocation_params(self) -> Dict[str, Any]: @@ -542,7 +591,10 @@ class BaseOpenAI(BaseLLM): @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" - return {**{"model_name": self.model_name}, **self._default_params} + return { + **{"model_name": self.model_name}, + **self._default_params, + } @property def _llm_type(self) -> str: @@ -558,9 +610,9 @@ class BaseOpenAI(BaseLLM): import tiktoken except ImportError: raise ImportError( - "Could not import tiktoken python package. " - "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`." + "Could not import tiktoken python package. This is" + " needed in order to calculate get_num_tokens. Please" + " install it with `pip install tiktoken`." ) model_name = self.tiktoken_model_name or self.model_name @@ -568,7 +620,8 @@ class BaseOpenAI(BaseLLM): enc = tiktoken.encoding_for_model(model_name) except KeyError: logger.warning( - "Warning: model not found. Using cl100k_base encoding." + "Warning: model not found. Using cl100k_base" + " encoding." ) model = "cl100k_base" enc = tiktoken.get_encoding(model) @@ -630,8 +683,8 @@ class BaseOpenAI(BaseLLM): if context_size is None: raise ValueError( - f"Unknown model: {modelname}. Please provide a valid OpenAI" - " model name.Known models are: " + f"Unknown model: {modelname}. Please provide a valid" + " OpenAI model name.Known models are: " + ", ".join(model_token_mapping.keys()) ) @@ -678,7 +731,10 @@ class OpenAI(BaseOpenAI): @property def _invocation_params(self) -> Dict[str, Any]: - return {**{"model": self.model_name}, **super()._invocation_params} + return { + **{"model": self.model_name}, + **super()._invocation_params, + } class AzureOpenAI(BaseOpenAI): @@ -802,7 +858,9 @@ class OpenAIChat(BaseLLM): for field_name in list(values): if field_name not in all_required_field_names: if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") + raise ValueError( + f"Found {field_name} supplied twice." + ) extra[field_name] = values.pop(field_name) values["model_kwargs"] = extra return values @@ -826,7 +884,10 @@ class OpenAIChat(BaseLLM): default="", ) openai_organization = get_from_dict_or_env( - values, "openai_organization", "OPENAI_ORGANIZATION", default="" + values, + "openai_organization", + "OPENAI_ORGANIZATION", + default="", ) try: import openai @@ -847,9 +908,10 @@ class OpenAIChat(BaseLLM): values["client"] = openai.ChatCompletion except AttributeError: raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." + "`openai` has no `ChatCompletion` attribute, this is" + " likely due to an old version of the openai package." + " Try upgrading it with `pip install --upgrade" + " openai`." ) return values @@ -863,8 +925,8 @@ class OpenAIChat(BaseLLM): ) -> Tuple: if len(prompts) > 1: raise ValueError( - "OpenAIChat currently only supports single prompt, got" - f" {prompts}" + "OpenAIChat currently only supports single prompt," + f" got {prompts}" ) messages = self.prefix_messages + [ {"role": "user", "content": prompts[0]} @@ -876,7 +938,8 @@ class OpenAIChat(BaseLLM): if stop is not None: if "stop" in params: raise ValueError( - "`stop` found in both the input and default params." + "`stop` found in both the input and default" + " params." ) params["stop"] = stop if params.get("max_tokens") == -1: @@ -896,7 +959,9 @@ class OpenAIChat(BaseLLM): for stream_resp in completion_with_retry( self, messages=messages, run_manager=run_manager, **params ): - token = stream_resp["choices"][0]["delta"].get("content", "") + token = stream_resp["choices"][0]["delta"].get( + "content", "" + ) chunk = GenerationChunk(text=token) yield chunk if run_manager: @@ -914,7 +979,9 @@ class OpenAIChat(BaseLLM): async for stream_resp in await acompletion_with_retry( self, messages=messages, run_manager=run_manager, **params ): - token = stream_resp["choices"][0]["delta"].get("content", "") + token = stream_resp["choices"][0]["delta"].get( + "content", "" + ) chunk = GenerationChunk(text=token) yield chunk if run_manager: @@ -929,7 +996,9 @@ class OpenAIChat(BaseLLM): ) -> LLMResult: if self.streaming: generation: Optional[GenerationChunk] = None - for chunk in self._stream(prompts[0], stop, run_manager, **kwargs): + for chunk in self._stream( + prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: @@ -950,7 +1019,9 @@ class OpenAIChat(BaseLLM): generations=[ [ Generation( - text=full_response["choices"][0]["message"]["content"] + text=full_response["choices"][0]["message"][ + "content" + ] ) ] ], @@ -989,7 +1060,9 @@ class OpenAIChat(BaseLLM): generations=[ [ Generation( - text=full_response["choices"][0]["message"]["content"] + text=full_response["choices"][0]["message"][ + "content" + ] ) ] ], @@ -999,7 +1072,10 @@ class OpenAIChat(BaseLLM): @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" - return {**{"model_name": self.model_name}, **self._default_params} + return { + **{"model_name": self.model_name}, + **self._default_params, + } @property def _llm_type(self) -> str: @@ -1015,9 +1091,9 @@ class OpenAIChat(BaseLLM): import tiktoken except ImportError: raise ImportError( - "Could not import tiktoken python package. " - "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install tiktoken`." + "Could not import tiktoken python package. This is" + " needed in order to calculate get_num_tokens. Please" + " install it with `pip install tiktoken`." ) enc = tiktoken.encoding_for_model(self.model_name) diff --git a/swarms/models/palm.py b/swarms/models/palm.py index 8c9277d7..d61d4856 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -47,7 +47,9 @@ def _create_retry_decorator() -> Callable[[Any], Any]: | retry_if_exception_type( google.api_core.exceptions.ServiceUnavailable ) - | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + | retry_if_exception_type( + google.api_core.exceptions.GoogleAPIError + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -114,8 +116,9 @@ class GooglePalm(BaseLLM, BaseModel): genai.configure(api_key=google_api_key) except ImportError: raise ImportError( - "Could not import google-generativeai python package. " - "Please install it with `pip install google-generativeai`." + "Could not import google-generativeai python package." + " Please install it with `pip install" + " google-generativeai`." ) values["client"] = genai @@ -124,9 +127,14 @@ class GooglePalm(BaseLLM, BaseModel): values["temperature"] is not None and not 0 <= values["temperature"] <= 1 ): - raise ValueError("temperature must be in the range [0.0, 1.0]") + raise ValueError( + "temperature must be in the range [0.0, 1.0]" + ) - if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + if ( + values["top_p"] is not None + and not 0 <= values["top_p"] <= 1 + ): raise ValueError("top_p must be in the range [0.0, 1.0]") if values["top_k"] is not None and values["top_k"] <= 0: @@ -136,7 +144,9 @@ class GooglePalm(BaseLLM, BaseModel): values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0 ): - raise ValueError("max_output_tokens must be greater than zero") + raise ValueError( + "max_output_tokens must be greater than zero" + ) return values @@ -165,8 +175,12 @@ class GooglePalm(BaseLLM, BaseModel): prompt_generations = [] for candidate in completion.candidates: raw_text = candidate["output"] - stripped_text = _strip_erroneous_leading_spaces(raw_text) - prompt_generations.append(Generation(text=stripped_text)) + stripped_text = _strip_erroneous_leading_spaces( + raw_text + ) + prompt_generations.append( + Generation(text=stripped_text) + ) generations.append(prompt_generations) return LLMResult(generations=generations) diff --git a/swarms/models/pegasus.py b/swarms/models/pegasus.py index e388d40c..b2f43980 100644 --- a/swarms/models/pegasus.py +++ b/swarms/models/pegasus.py @@ -34,16 +34,22 @@ class PegasusEmbedding: """ def __init__( - self, modality: str, multi_process: bool = False, n_processes: int = 4 + self, + modality: str, + multi_process: bool = False, + n_processes: int = 4, ): self.modality = modality self.multi_process = multi_process self.n_processes = n_processes try: - self.pegasus = Pegasus(modality, multi_process, n_processes) + self.pegasus = Pegasus( + modality, multi_process, n_processes + ) except Exception as e: logging.error( - f"Failed to initialize Pegasus with modality: {modality}: {e}" + "Failed to initialize Pegasus with modality:" + f" {modality}: {e}" ) raise @@ -52,5 +58,7 @@ class PegasusEmbedding: try: return self.pegasus.embed(data) except Exception as e: - logging.error(f"Failed to generate embeddings. Error: {e}") + logging.error( + f"Failed to generate embeddings. Error: {e}" + ) raise diff --git a/swarms/models/petals.py b/swarms/models/petals.py index 189c2477..7abc4590 100644 --- a/swarms/models/petals.py +++ b/swarms/models/petals.py @@ -38,6 +38,8 @@ class Petals: def __call__(self, prompt): """Generate text using the Petals API.""" params = self._default_params() - inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] + inputs = self.tokenizer(prompt, return_tensors="pt")[ + "input_ids" + ] outputs = self.model.generate(inputs, **params) return self.tokenizer.decode(outputs[0]) diff --git a/swarms/models/sam.py b/swarms/models/sam.py index 7abde5ee..866c79ee 100644 --- a/swarms/models/sam.py +++ b/swarms/models/sam.py @@ -1,7 +1,12 @@ import cv2 import numpy as np from PIL import Image -from transformers import SamImageProcessor, SamModel, SamProcessor, pipeline +from transformers import ( + SamImageProcessor, + SamModel, + SamProcessor, + pipeline, +) try: import cv2 @@ -44,16 +49,18 @@ def compute_mask_iou_vectorized(masks: np.ndarray) -> np.ndarray: """ if np.any(masks.sum(axis=(1, 2)) == 0): raise ValueError( - "One or more masks are empty. Please filter out empty masks before" - " using `compute_iou_vectorized` function." + "One or more masks are empty. Please filter out empty" + " masks before using `compute_iou_vectorized` function." ) masks_bool = masks.astype(bool) masks_flat = masks_bool.reshape(masks.shape[0], -1) - intersection = np.logical_and(masks_flat[:, None], masks_flat[None, :]).sum( - axis=2 - ) - union = np.logical_or(masks_flat[:, None], masks_flat[None, :]).sum(axis=2) + intersection = np.logical_and( + masks_flat[:, None], masks_flat[None, :] + ).sum(axis=2) + union = np.logical_or( + masks_flat[:, None], masks_flat[None, :] + ).sum(axis=2) iou_matrix = intersection / union return iou_matrix @@ -96,7 +103,9 @@ def mask_non_max_suppression( def filter_masks_by_relative_area( - masks: np.ndarray, minimum_area: float = 0.01, maximum_area: float = 1.0 + masks: np.ndarray, + minimum_area: float = 0.01, + maximum_area: float = 1.0, ) -> np.ndarray: """ Filters masks based on their relative area within the total area of each mask. @@ -123,18 +132,21 @@ def filter_masks_by_relative_area( if not (0 <= minimum_area <= 1) or not (0 <= maximum_area <= 1): raise ValueError( - "`minimum_area` and `maximum_area` must be between 0 and 1." + "`minimum_area` and `maximum_area` must be between 0" + " and 1." ) if minimum_area > maximum_area: raise ValueError( - "`minimum_area` must be less than or equal to `maximum_area`." + "`minimum_area` must be less than or equal to" + " `maximum_area`." ) total_area = masks.shape[1] * masks.shape[2] relative_areas = masks.sum(axis=(1, 2)) / total_area return masks[ - (relative_areas >= minimum_area) & (relative_areas <= maximum_area) + (relative_areas >= minimum_area) + & (relative_areas <= maximum_area) ] @@ -170,7 +182,9 @@ def adjust_mask_features_by_relative_area( if feature_type == FeatureType.ISLAND else cv2.RETR_CCOMP ) - contours, _ = cv2.findContours(mask, operation, cv2.CHAIN_APPROX_SIMPLE) + contours, _ = cv2.findContours( + mask, operation, cv2.CHAIN_APPROX_SIMPLE + ) for contour in contours: area = cv2.contourArea(contour) @@ -180,7 +194,9 @@ def adjust_mask_features_by_relative_area( image=mask, contours=[contour], contourIdx=-1, - color=(0 if feature_type == FeatureType.ISLAND else 255), + color=( + 0 if feature_type == FeatureType.ISLAND else 255 + ), thickness=-1, ) return np.where(mask > 0, 1, 0).astype(bool) @@ -198,7 +214,9 @@ def masks_to_marks(masks: np.ndarray) -> sv.Detections: sv.Detections: An object containing the masks and their bounding box coordinates. """ - return sv.Detections(mask=masks, xyxy=sv.mask_to_xyxy(masks=masks)) + return sv.Detections( + mask=masks, xyxy=sv.mask_to_xyxy(masks=masks) + ) def refine_marks( @@ -262,11 +280,15 @@ class SegmentAnythingMarkGenerator: """ def __init__( - self, device: str = "cpu", model_name: str = "facebook/sam-vit-huge" + self, + device: str = "cpu", + model_name: str = "facebook/sam-vit-huge", ): self.model = SamModel.from_pretrained(model_name).to(device) self.processor = SamProcessor.from_pretrained(model_name) - self.image_processor = SamImageProcessor.from_pretrained(model_name) + self.image_processor = SamImageProcessor.from_pretrained( + model_name + ) self.pipeline = pipeline( task="mask-generation", model=self.model, @@ -285,7 +307,9 @@ class SegmentAnythingMarkGenerator: sv.Detections: An object containing the segmentation masks and their corresponding bounding box coordinates. """ - image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + image = Image.fromarray( + cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + ) outputs = self.pipeline(image, points_per_batch=64) masks = np.array(outputs["masks"]) return masks_to_marks(masks=masks) diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index a4e99fe4..e9a599d0 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -4,7 +4,9 @@ from openai import OpenAI client = OpenAI() -def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): +def get_ada_embeddings( + text: str, model: str = "text-embedding-ada-002" +): """ Simple function to get embeddings from ada @@ -16,6 +18,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): text = text.replace("\n", " ") - return client.embeddings.create(input=[text], model=model)["data"][0][ - "embedding" - ] + return client.embeddings.create(input=[text], model=model)[ + "data" + ][0]["embedding"] diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index 143a7514..cc6ef931 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -87,9 +87,15 @@ class SpeechT5: self.model_name = model_name self.vocoder_name = vocoder_name self.dataset_name = dataset_name - self.processor = SpeechT5Processor.from_pretrained(self.model_name) - self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) - self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) + self.processor = SpeechT5Processor.from_pretrained( + self.model_name + ) + self.model = SpeechT5ForTextToSpeech.from_pretrained( + self.model_name + ) + self.vocoder = SpeechT5HifiGan.from_pretrained( + self.vocoder_name + ) self.embeddings_dataset = load_dataset( self.dataset_name, split="validation" ) @@ -101,7 +107,9 @@ class SpeechT5: ).unsqueeze(0) inputs = self.processor(text=text, return_tensors="pt") speech = self.model.generate_speech( - inputs["input_ids"], speaker_embedding, vocoder=self.vocoder + inputs["input_ids"], + speaker_embedding, + vocoder=self.vocoder, ) return speech @@ -112,13 +120,19 @@ class SpeechT5: def set_model(self, model_name: str): """Set the model to a new model.""" self.model_name = model_name - self.processor = SpeechT5Processor.from_pretrained(self.model_name) - self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) + self.processor = SpeechT5Processor.from_pretrained( + self.model_name + ) + self.model = SpeechT5ForTextToSpeech.from_pretrained( + self.model_name + ) def set_vocoder(self, vocoder_name): """Set the vocoder to a new vocoder.""" self.vocoder_name = vocoder_name - self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) + self.vocoder = SpeechT5HifiGan.from_pretrained( + self.vocoder_name + ) def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" @@ -148,7 +162,9 @@ class SpeechT5: # Feature 4: Change dataset split (train, validation, test) def change_dataset_split(self, split="train"): """Change dataset split (train, validation, test).""" - self.embeddings_dataset = load_dataset(self.dataset_name, split=split) + self.embeddings_dataset = load_dataset( + self.dataset_name, split=split + ) # Feature 5: Load a custom speaker embedding (xvector) for the text def load_custom_embedding(self, xvector): diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index 406678ef..d3b9086b 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -96,7 +96,9 @@ class SSD1B: byte_array = byte_stream.getvalue() return byte_array - @backoff.on_exception(backoff.expo, Exception, max_time=max_time_seconds) + @backoff.on_exception( + backoff.expo, Exception, max_time=max_time_seconds + ) def __call__(self, task: str, neg_prompt: str): """ Text to image conversion using the SSD1B API @@ -124,7 +126,9 @@ class SSD1B: if task in self.cache: return self.cache[task] try: - img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0] + img = self.pipe( + prompt=task, neg_prompt=neg_prompt + ).images[0] # Generate a unique filename for the image img_name = f"{uuid.uuid4()}.{self.image_format}" @@ -141,8 +145,8 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing your api" - " key and or try again" + f"Error running SSD1B: {error} try optimizing" + " your api key and or try again" ), "red", ) @@ -218,7 +222,9 @@ class SSD1B: executor.submit(self, task): task for task in tasks } results = [] - for future in concurrent.futures.as_completed(future_to_task): + for future in concurrent.futures.as_completed( + future_to_task + ): task = future_to_task[future] try: img = future.result() @@ -229,18 +235,28 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing" - " your api key and or try again" + f"Error running SSD1B: {error} try" + " optimizing your api key and or try" + " again" ), "red", ) ) print( colored( - f"Error running SSD1B: {error.http_status}", "red" + ( + "Error running SSD1B:" + f" {error.http_status}" + ), + "red", + ) + ) + print( + colored( + f"Error running SSD1B: {error.error}", + "red", ) ) - print(colored(f"Error running SSD1B: {error.error}", "red")) raise error def _generate_uuid(self): @@ -255,7 +271,9 @@ class SSD1B: """Str method for the SSD1B class""" return f"SSD1B(image_url={self.image_url})" - @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) + @backoff.on_exception( + backoff.expo, Exception, max_tries=max_retries + ) def rate_limited_call(self, task: str): """Rate limited call to the SSD1B API""" return self.__call__(task) diff --git a/swarms/models/timm.py b/swarms/models/timm.py index 9947ec7b..d1c42165 100644 --- a/swarms/models/timm.py +++ b/swarms/models/timm.py @@ -34,7 +34,9 @@ class TimmModel: """Retrieve the list of supported models from timm.""" return timm.list_models() - def _create_model(self, model_info: TimmModelInfo) -> torch.nn.Module: + def _create_model( + self, model_info: TimmModelInfo + ) -> torch.nn.Module: """ Create a model instance from timm with specified parameters. diff --git a/swarms/models/whisperx_model.py b/swarms/models/whisperx_model.py index 338db6e3..a41d0430 100644 --- a/swarms/models/whisperx_model.py +++ b/swarms/models/whisperx_model.py @@ -93,7 +93,9 @@ class WhisperX: try: segments = result["segments"] - transcription = " ".join(segment["text"] for segment in segments) + transcription = " ".join( + segment["text"] for segment in segments + ) return transcription except KeyError: print("The key 'segments' is not found in the result.") @@ -128,7 +130,9 @@ class WhisperX: try: segments = result["segments"] - transcription = " ".join(segment["text"] for segment in segments) + transcription = " ".join( + segment["text"] for segment in segments + ) return transcription except KeyError: print("The key 'segments' is not found in the result.") diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index a34f6ec7..0dd6c1a1 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -2,7 +2,11 @@ import logging import torch from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) class WizardLLMStoryTeller: @@ -74,21 +78,27 @@ class WizardLLMStoryTeller: bnb_config = BitsAndBytesConfig(**quantization_config) try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id + ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=bnb_config ) self.model # .to(self.device) except Exception as e: - self.logger.error(f"Failed to load the model or the tokenizer: {e}") + self.logger.error( + f"Failed to load the model or the tokenizer: {e}" + ) raise def load_model(self): """Load the model""" if not self.model or not self.tokenizer: try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id + ) bnb_config = ( BitsAndBytesConfig(**self.quantization_config) @@ -104,7 +114,8 @@ class WizardLLMStoryTeller: self.model = DDP(self.model) except Exception as error: self.logger.error( - f"Failed to load the model or the tokenizer: {error}" + "Failed to load the model or the tokenizer:" + f" {error}" ) raise @@ -124,9 +135,9 @@ class WizardLLMStoryTeller: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -136,7 +147,9 @@ class WizardLLMStoryTeller: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -144,7 +157,8 @@ class WizardLLMStoryTeller: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -157,7 +171,9 @@ class WizardLLMStoryTeller: ) del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: self.logger.error(f"Failed to generate the text: {e}") raise @@ -178,9 +194,9 @@ class WizardLLMStoryTeller: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -190,7 +206,9 @@ class WizardLLMStoryTeller: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -198,7 +216,8 @@ class WizardLLMStoryTeller: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -212,7 +231,9 @@ class WizardLLMStoryTeller: del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: self.logger.error(f"Failed to generate the text: {e}") raise diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index 065e3140..7b5a9c02 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -2,7 +2,11 @@ import logging import torch from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) class YarnMistral128: @@ -74,7 +78,9 @@ class YarnMistral128: bnb_config = BitsAndBytesConfig(**quantization_config) try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id + ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=bnb_config, @@ -86,14 +92,18 @@ class YarnMistral128: self.model # .to(self.device) except Exception as e: - self.logger.error(f"Failed to load the model or the tokenizer: {e}") + self.logger.error( + f"Failed to load the model or the tokenizer: {e}" + ) raise def load_model(self): """Load the model""" if not self.model or not self.tokenizer: try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id + ) bnb_config = ( BitsAndBytesConfig(**self.quantization_config) @@ -109,7 +119,8 @@ class YarnMistral128: self.model = DDP(self.model) except Exception as error: self.logger.error( - f"Failed to load the model or the tokenizer: {error}" + "Failed to load the model or the tokenizer:" + f" {error}" ) raise @@ -129,9 +140,9 @@ class YarnMistral128: max_length = self.max_length try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -141,7 +152,9 @@ class YarnMistral128: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -149,7 +162,8 @@ class YarnMistral128: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -162,7 +176,9 @@ class YarnMistral128: ) del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: self.logger.error(f"Failed to generate the text: {e}") raise @@ -206,9 +222,9 @@ class YarnMistral128: max_length = self.max_ try: - inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( - self.device - ) + inputs = self.tokenizer.encode( + prompt_text, return_tensors="pt" + ).to(self.device) # self.log.start() @@ -218,7 +234,9 @@ class YarnMistral128: output_sequence = [] outputs = self.model.generate( - inputs, max_length=len(inputs) + 1, do_sample=True + inputs, + max_length=len(inputs) + 1, + do_sample=True, ) output_tokens = outputs[0][-1] output_sequence.append(output_tokens.item()) @@ -226,7 +244,8 @@ class YarnMistral128: # print token in real-time print( self.tokenizer.decode( - [output_tokens], skip_special_tokens=True + [output_tokens], + skip_special_tokens=True, ), end="", flush=True, @@ -240,7 +259,9 @@ class YarnMistral128: del inputs - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) except Exception as e: self.logger.error(f"Failed to generate the text: {e}") raise diff --git a/swarms/models/yi_200k.py b/swarms/models/yi_200k.py index 8f9f7635..1f1258aa 100644 --- a/swarms/models/yi_200k.py +++ b/swarms/models/yi_200k.py @@ -87,7 +87,9 @@ class Yi34B200k: top_k=self.top_k, top_p=self.top_p, ) - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return self.tokenizer.decode( + outputs[0], skip_special_tokens=True + ) # # Example usage diff --git a/swarms/models/zephyr.py b/swarms/models/zephyr.py index 4fca5211..c5772295 100644 --- a/swarms/models/zephyr.py +++ b/swarms/models/zephyr.py @@ -68,7 +68,9 @@ class Zephyr: tokenize=self.tokenize, add_generation_prompt=self.add_generation_prompt, ) - outputs = self.pipe(prompt) # max_new_token=self.max_new_tokens) + outputs = self.pipe( + prompt + ) # max_new_token=self.max_new_tokens) print(outputs[0]["generated_text"]) def chat(self, message: str): diff --git a/swarms/prompts/__init__.py b/swarms/prompts/__init__.py index 825bddaa..6417dc85 100644 --- a/swarms/prompts/__init__.py +++ b/swarms/prompts/__init__.py @@ -2,7 +2,9 @@ from swarms.prompts.code_interpreter import CODE_INTERPRETER from swarms.prompts.finance_agent_prompt import FINANCE_AGENT_PROMPT from swarms.prompts.growth_agent_prompt import GROWTH_AGENT_PROMPT from swarms.prompts.legal_agent_prompt import LEGAL_AGENT_PROMPT -from swarms.prompts.operations_agent_prompt import OPERATIONS_AGENT_PROMPT +from swarms.prompts.operations_agent_prompt import ( + OPERATIONS_AGENT_PROMPT, +) from swarms.prompts.product_agent_prompt import PRODUCT_AGENT_PROMPT diff --git a/swarms/prompts/agent_output_parser.py b/swarms/prompts/agent_output_parser.py index 27f8ac24..0810d2ad 100644 --- a/swarms/prompts/agent_output_parser.py +++ b/swarms/prompts/agent_output_parser.py @@ -25,7 +25,9 @@ class AgentOutputParser(BaseAgentOutputParser): @staticmethod def _preprocess_json_input(input_str: str) -> str: corrected_str = re.sub( - r'(? None: @@ -66,13 +70,16 @@ class PromptGenerator: Returns: str: The generated prompt string. """ - formatted_response_format = json.dumps(self.response_format, indent=4) + formatted_response_format = json.dumps( + self.response_format, indent=4 + ) prompt_string = ( f"Constraints:\n{''.join(self.constraints)}\n\nCommands:\n{''.join(self.commands)}\n\nResources:\n{''.join(self.resources)}\n\nPerformance" f" Evaluation:\n{''.join(self.performance_evaluation)}\n\nYou" - " should only respond in JSON format as described below \nResponse" - f" Format: \n{formatted_response_format} \nEnsure the response can" - " be parsed by Python json.loads" + " should only respond in JSON format as described below" + " \nResponse Format:" + f" \n{formatted_response_format} \nEnsure the response" + " can be parsed by Python json.loads" ) return prompt_string diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index a8c3fca7..88853b09 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -5,26 +5,30 @@ def generate_agent_role_prompt(agent): """ prompts = { "Finance Agent": ( - "You are a seasoned finance analyst AI assistant. Your primary goal" - " is to compose comprehensive, astute, impartial, and methodically" - " arranged financial reports based on provided data and trends." + "You are a seasoned finance analyst AI assistant. Your" + " primary goal is to compose comprehensive, astute," + " impartial, and methodically arranged financial reports" + " based on provided data and trends." ), "Travel Agent": ( - "You are a world-travelled AI tour guide assistant. Your main" - " purpose is to draft engaging, insightful, unbiased, and" - " well-structured travel reports on given locations, including" - " history, attractions, and cultural insights." + "You are a world-travelled AI tour guide assistant. Your" + " main purpose is to draft engaging, insightful," + " unbiased, and well-structured travel reports on given" + " locations, including history, attractions, and cultural" + " insights." ), "Academic Research Agent": ( "You are an AI academic research assistant. Your primary" - " responsibility is to create thorough, academically rigorous," - " unbiased, and systematically organized reports on a given" - " research topic, following the standards of scholarly work." + " responsibility is to create thorough, academically" + " rigorous, unbiased, and systematically organized" + " reports on a given research topic, following the" + " standards of scholarly work." ), "Default Agent": ( - "You are an AI critical thinker research assistant. Your sole" - " purpose is to write well written, critically acclaimed, objective" - " and structured reports on given text." + "You are an AI critical thinker research assistant. Your" + " sole purpose is to write well written, critically" + " acclaimed, objective and structured reports on given" + " text." ), } @@ -39,12 +43,14 @@ def generate_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, answer the' - f' following question or topic: "{question}" in a detailed report --' - " The report should focus on the answer to the question, should be" - " well structured, informative, in depth, with facts and numbers if" - " available, a minimum of 1,200 words and with markdown syntax and apa" - " format. Write all source urls at the end of the report in apa format" + f'"""{research_summary}""" Using the above information,' + f' answer the following question or topic: "{question}" in a' + " detailed report -- The report should focus on the answer" + " to the question, should be well structured, informative," + " in depth, with facts and numbers if available, a minimum" + " of 1,200 words and with markdown syntax and apa format." + " Write all source urls at the end of the report in apa" + " format" ) @@ -55,10 +61,10 @@ def generate_search_queries_prompt(question): """ return ( - "Write 4 google search queries to search online that form an objective" - f' opinion from the following: "{question}"You must respond with a list' - ' of strings in the following format: ["query 1", "query 2", "query' - ' 3", "query 4"]' + "Write 4 google search queries to search online that form an" + f' objective opinion from the following: "{question}"You must' + " respond with a list of strings in the following format:" + ' ["query 1", "query 2", "query 3", "query 4"]' ) @@ -73,16 +79,17 @@ def generate_resource_report_prompt(question, research_summary): str: The resource report prompt for the given question and research summary. """ return ( - f'"""{research_summary}""" Based on the above information, generate a' - " bibliography recommendation report for the following question or" - f' topic: "{question}". The report should provide a detailed analysis' - " of each recommended resource, explaining how each source can" - " contribute to finding answers to the research question. Focus on the" - " relevance, reliability, and significance of each source. Ensure that" - " the report is well-structured, informative, in-depth, and follows" - " Markdown syntax. Include relevant facts, figures, and numbers" - " whenever available. The report should have a minimum length of 1,200" - " words." + f'"""{research_summary}""" Based on the above information,' + " generate a bibliography recommendation report for the" + f' following question or topic: "{question}". The report' + " should provide a detailed analysis of each recommended" + " resource, explaining how each source can contribute to" + " finding answers to the research question. Focus on the" + " relevance, reliability, and significance of each source." + " Ensure that the report is well-structured, informative," + " in-depth, and follows Markdown syntax. Include relevant" + " facts, figures, and numbers whenever available. The report" + " should have a minimum length of 1,200 words." ) @@ -94,14 +101,15 @@ def generate_outline_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate an' - " outline for a research report in Markdown syntax for the following" - f' question or topic: "{question}". The outline should provide a' - " well-structured framework for the research report, including the" - " main sections, subsections, and key points to be covered. The" - " research report should be detailed, informative, in-depth, and a" - " minimum of 1,200 words. Use appropriate Markdown syntax to format" - " the outline and ensure readability." + f'"""{research_summary}""" Using the above information,' + " generate an outline for a research report in Markdown" + f' syntax for the following question or topic: "{question}".' + " The outline should provide a well-structured framework for" + " the research report, including the main sections," + " subsections, and key points to be covered. The research" + " report should be detailed, informative, in-depth, and a" + " minimum of 1,200 words. Use appropriate Markdown syntax to" + " format the outline and ensure readability." ) @@ -113,11 +121,12 @@ def generate_concepts_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate a list' - " of 5 main concepts to learn for a research report on the following" - f' question or topic: "{question}". The outline should provide a' - " well-structured frameworkYou must respond with a list of strings in" - ' the following format: ["concepts 1", "concepts 2", "concepts 3",' + f'"""{research_summary}""" Using the above information,' + " generate a list of 5 main concepts to learn for a research" + f' report on the following question or topic: "{question}".' + " The outline should provide a well-structured frameworkYou" + " must respond with a list of strings in the following" + ' format: ["concepts 1", "concepts 2", "concepts 3",' ' "concepts 4, concepts 5"]' ) @@ -132,10 +141,11 @@ def generate_lesson_prompt(concept): """ prompt = ( - f"generate a comprehensive lesson about {concept} in Markdown syntax." - f" This should include the definitionof {concept}, its historical" - " background and development, its applications or uses in" - f" differentfields, and notable events or facts related to {concept}." + f"generate a comprehensive lesson about {concept} in Markdown" + f" syntax. This should include the definitionof {concept}," + " its historical background and development, its" + " applications or uses in differentfields, and notable" + f" events or facts related to {concept}." ) return prompt diff --git a/swarms/prompts/base.py b/swarms/prompts/base.py index 369063e6..a0e28c71 100644 --- a/swarms/prompts/base.py +++ b/swarms/prompts/base.py @@ -53,7 +53,10 @@ def get_buffer_string( else: raise ValueError(f"Got unsupported message type: {m}") message = f"{role}: {m.content}" - if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: + if ( + isinstance(m, AIMessage) + and "function_call" in m.additional_kwargs + ): message += f"{m.additional_kwargs['function_call']}" string_messages.append(message) @@ -100,8 +103,8 @@ class BaseMessageChunk(BaseMessage): merged[k] = v elif not isinstance(merged[k], type(v)): raise ValueError( - f'additional_kwargs["{k}"] already exists in this message,' - " but with a different type." + f'additional_kwargs["{k}"] already exists in this' + " message, but with a different type." ) elif isinstance(merged[k], str): merged[k] += v @@ -109,7 +112,8 @@ class BaseMessageChunk(BaseMessage): merged[k] = self._merge_kwargs_dict(merged[k], v) else: raise ValueError( - f"Additional kwargs key {k} already exists in this message." + f"Additional kwargs key {k} already exists in" + " this message." ) return merged diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index bbdaa9c7..013aee28 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -10,10 +10,14 @@ class Message: Messages are the inputs and outputs of ChatModels. """ - def __init__(self, content: str, role: str, additional_kwargs: Dict = None): + def __init__( + self, content: str, role: str, additional_kwargs: Dict = None + ): self.content = content self.role = role - self.additional_kwargs = additional_kwargs if additional_kwargs else {} + self.additional_kwargs = ( + additional_kwargs if additional_kwargs else {} + ) @abstractmethod def get_type(self) -> str: @@ -65,7 +69,10 @@ class SystemMessage(Message): """ def __init__( - self, content: str, role: str = "System", additional_kwargs: Dict = None + self, + content: str, + role: str = "System", + additional_kwargs: Dict = None, ): super().__init__(content, role, additional_kwargs) @@ -97,7 +104,9 @@ class ChatMessage(Message): A Message that can be assigned an arbitrary speaker (i.e. role). """ - def __init__(self, content: str, role: str, additional_kwargs: Dict = None): + def __init__( + self, content: str, role: str, additional_kwargs: Dict = None + ): super().__init__(content, role, additional_kwargs) def get_type(self) -> str: @@ -112,7 +121,10 @@ def get_buffer_string( string_messages = [] for m in messages: message = f"{m.role}: {m.content}" - if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: + if ( + isinstance(m, AIMessage) + and "function_call" in m.additional_kwargs + ): message += f"{m.additional_kwargs['function_call']}" string_messages.append(message) diff --git a/swarms/prompts/multi_modal_prompts.py b/swarms/prompts/multi_modal_prompts.py index 1c0830d6..83e9800c 100644 --- a/swarms/prompts/multi_modal_prompts.py +++ b/swarms/prompts/multi_modal_prompts.py @@ -1,6 +1,6 @@ ERROR_PROMPT = ( - "An error has occurred for the following text: \n{promptedQuery} Please" - " explain this error.\n {e}" + "An error has occurred for the following text: \n{promptedQuery}" + " Please explain this error.\n {e}" ) IMAGE_PROMPT = """ diff --git a/swarms/prompts/python.py b/swarms/prompts/python.py index 46df5cdc..a6210024 100644 --- a/swarms/prompts/python.py +++ b/swarms/prompts/python.py @@ -1,17 +1,19 @@ -PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." +PY_SIMPLE_COMPLETION_INSTRUCTION = ( + "# Write the body of this function only." +) PY_REFLEXION_COMPLETION_INSTRUCTION = ( - "You are a Python writing assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the" - " implementation appropriately. Write your full implementation (restate the" - " function signature).\n\n-----" + "You are a Python writing assistant. You will be given your past" + " function implementation, a series of unit tests, and a hint to" + " change the implementation appropriately. Write your full" + " implementation (restate the function signature).\n\n-----" ) PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given a function" - " implementation and a series of unit tests. Your goal is to write a few" - " sentences to explain why your implementation is wrong as indicated by the" - " tests. You will need this as a hint when you try again later. Only" - " provide the few sentence description in your answer, not the" - " implementation.\n\n-----" + " implementation and a series of unit tests. Your goal is to" + " write a few sentences to explain why your implementation is" + " wrong as indicated by the tests. You will need this as a hint" + " when you try again later. Only provide the few sentence" + " description in your answer, not the implementation.\n\n-----" ) USE_PYTHON_CODEBLOCK_INSTRUCTION = ( "Use a Python code block to write your response. For" @@ -19,26 +21,28 @@ USE_PYTHON_CODEBLOCK_INSTRUCTION = ( ) PY_SIMPLE_CHAT_INSTRUCTION = ( - "You are an AI that only responds with python code, NOT ENGLISH. You will" - " be given a function signature and its docstring by the user. Write your" - " full implementation (restate the function signature)." + "You are an AI that only responds with python code, NOT ENGLISH." + " You will be given a function signature and its docstring by the" + " user. Write your full implementation (restate the function" + " signature)." ) PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( - "You are an AI that only responds with only python code. You will be given" - " a function signature and its docstring by the user. Write your full" - " implementation (restate the function signature)." + "You are an AI that only responds with only python code. You will" + " be given a function signature and its docstring by the user." + " Write your full implementation (restate the function" + " signature)." ) PY_REFLEXION_CHAT_INSTRUCTION = ( - "You are an AI Python assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the" - " implementation appropriately. Write your full implementation (restate the" - " function signature)." + "You are an AI Python assistant. You will be given your past" + " function implementation, a series of unit tests, and a hint to" + " change the implementation appropriately. Write your full" + " implementation (restate the function signature)." ) PY_REFLEXION_CHAT_INSTRUCTION_V2 = ( "You are an AI Python assistant. You will be given your previous" - " implementation of a function, a series of unit tests results, and your" - " self-reflection on your previous implementation. Write your full" - " implementation (restate the function signature)." + " implementation of a function, a series of unit tests results," + " and your self-reflection on your previous implementation. Write" + " your full implementation (restate the function signature)." ) PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: [previous impl]: @@ -173,19 +177,20 @@ END EXAMPLES ''' PY_SELF_REFLECTION_CHAT_INSTRUCTION = ( - "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit tests. Your goal is to write a few" - " sentences to explain why your implementation is wrong as indicated by the" - " tests. You will need this as a hint when you try again later. Only" - " provide the few sentence description in your answer, not the" - " implementation." + "You are a Python programming assistant. You will be given a" + " function implementation and a series of unit tests. Your goal" + " is to write a few sentences to explain why your implementation" + " is wrong as indicated by the tests. You will need this as a" + " hint when you try again later. Only provide the few sentence" + " description in your answer, not the implementation." ) PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( - "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit test results. Your goal is to write a" - " few sentences to explain why your implementation is wrong as indicated by" - " the tests. You will need this as guidance when you try again later. Only" - " provide the few sentence description in your answer, not the" + "You are a Python programming assistant. You will be given a" + " function implementation and a series of unit test results. Your" + " goal is to write a few sentences to explain why your" + " implementation is wrong as indicated by the tests. You will" + " need this as guidance when you try again later. Only provide" + " the few sentence description in your answer, not the" " implementation. You will be given a few examples by the user." ) PY_SELF_REFLECTION_FEW_SHOT = """Example 1: diff --git a/swarms/prompts/sales.py b/swarms/prompts/sales.py index 3a362174..d69f9086 100644 --- a/swarms/prompts/sales.py +++ b/swarms/prompts/sales.py @@ -1,40 +1,43 @@ conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your" - " company. Be polite and respectful while keeping the tone of the" - " conversation professional. Your greeting should be welcoming. Always" - " clarify in your greeting the reason why you are contacting the" - " prospect." + "Introduction: Start the conversation by introducing yourself" + " and your company. Be polite and respectful while keeping" + " the tone of the conversation professional. Your greeting" + " should be welcoming. Always clarify in your greeting the" + " reason why you are contacting the prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the" - " right person to talk to regarding your product/service. Ensure that" - " they have the authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they" + " are the right person to talk to regarding your" + " product/service. Ensure that they have the authority to" + " make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can" - " benefit the prospect. Focus on the unique selling points and value" - " proposition of your product/service that sets it apart from" - " competitors." + "Value proposition: Briefly explain how your product/service" + " can benefit the prospect. Focus on the unique selling" + " points and value proposition of your product/service that" + " sets it apart from competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's" - " needs and pain points. Listen carefully to their responses and take" - " notes." + "Needs analysis: Ask open-ended questions to uncover the" + " prospect's needs and pain points. Listen carefully to their" + " responses and take notes." ), "5": ( - "Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." + "Solution presentation: Based on the prospect's needs," + " present your product/service as the solution that can" + " address their pain points." ), "6": ( - "Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims." + "Objection handling: Address any objections that the prospect" + " may have regarding your product/service. Be prepared to" + " provide evidence or testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a" - " demo, a trial or a meeting with decision-makers. Ensure to summarize" - " what has been discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could" + " be a demo, a trial or a meeting with decision-makers." + " Ensure to summarize what has been discussed and reiterate" + " the benefits." ), } diff --git a/swarms/prompts/sales_prompts.py b/swarms/prompts/sales_prompts.py index 7c1f50ed..dbc2b40e 100644 --- a/swarms/prompts/sales_prompts.py +++ b/swarms/prompts/sales_prompts.py @@ -46,40 +46,43 @@ Conversation history: conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your" - " company. Be polite and respectful while keeping the tone of the" - " conversation professional. Your greeting should be welcoming. Always" - " clarify in your greeting the reason why you are contacting the" - " prospect." + "Introduction: Start the conversation by introducing yourself" + " and your company. Be polite and respectful while keeping" + " the tone of the conversation professional. Your greeting" + " should be welcoming. Always clarify in your greeting the" + " reason why you are contacting the prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the" - " right person to talk to regarding your product/service. Ensure that" - " they have the authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they" + " are the right person to talk to regarding your" + " product/service. Ensure that they have the authority to" + " make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can" - " benefit the prospect. Focus on the unique selling points and value" - " proposition of your product/service that sets it apart from" - " competitors." + "Value proposition: Briefly explain how your product/service" + " can benefit the prospect. Focus on the unique selling" + " points and value proposition of your product/service that" + " sets it apart from competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's" - " needs and pain points. Listen carefully to their responses and take" - " notes." + "Needs analysis: Ask open-ended questions to uncover the" + " prospect's needs and pain points. Listen carefully to their" + " responses and take notes." ), "5": ( - "Solution presentation: Based on the prospect's needs, present your" - " product/service as the solution that can address their pain points." + "Solution presentation: Based on the prospect's needs," + " present your product/service as the solution that can" + " address their pain points." ), "6": ( - "Objection handling: Address any objections that the prospect may have" - " regarding your product/service. Be prepared to provide evidence or" - " testimonials to support your claims." + "Objection handling: Address any objections that the prospect" + " may have regarding your product/service. Be prepared to" + " provide evidence or testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a" - " demo, a trial or a meeting with decision-makers. Ensure to summarize" - " what has been discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could" + " be a demo, a trial or a meeting with decision-makers." + " Ensure to summarize what has been discussed and reiterate" + " the benefits." ), } diff --git a/swarms/prompts/self_operating_prompt.py b/swarms/prompts/self_operating_prompt.py index ce058d7b..bb4856e0 100644 --- a/swarms/prompts/self_operating_prompt.py +++ b/swarms/prompts/self_operating_prompt.py @@ -54,7 +54,9 @@ IMPORTANT: Avoid repeating actions such as doing the same CLICK event twice in a Objective: {objective} """ -USER_QUESTION = "Hello, I can help you with anything. What would you like done?" +USER_QUESTION = ( + "Hello, I can help you with anything. What would you like done?" +) SUMMARY_PROMPT = """ You are a Self-Operating Computer. You just completed a request from a user by operating the computer. Now you need to share the results. @@ -89,7 +91,8 @@ def format_vision_prompt(objective, previous_action): """ if previous_action: previous_action = ( - f"Here was the previous action you took: {previous_action}" + "Here was the previous action you took:" + f" {previous_action}" ) else: previous_action = "" diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index e8d6f196..3544fe68 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -21,7 +21,9 @@ from swarms.prompts.tools import ( ) from swarms.tools.tool import BaseTool from swarms.utils.code_interpreter import SubprocessCodeInterpreter -from swarms.utils.parse_code import extract_code_in_backticks_in_string +from swarms.utils.parse_code import ( + extract_code_in_backticks_in_string, +) from swarms.utils.pdf_to_text import pdf_to_text @@ -140,7 +142,7 @@ class Agent: get_llm_init_params(): Get LLM init params get_tool_description(): Get the tool description find_tool_by_name(name: str): Find a tool by name - + Example: >>> from swarms.models import OpenAIChat @@ -180,6 +182,7 @@ class Agent: dynamic_temperature_enabled: Optional[bool] = False, sop: Optional[str] = None, sop_list: Optional[List[str]] = None, + # memory: Optional[Vectorstore] = None, saved_state_path: Optional[str] = "flow_state.json", autosave: Optional[bool] = False, context_length: Optional[int] = 8192, @@ -189,7 +192,7 @@ class Agent: multi_modal: Optional[bool] = None, pdf_path: Optional[str] = None, list_of_pdf: Optional[str] = None, - tokenizer: Optional[str] = None, + tokenizer: Optional[Any] = None, *args, **kwargs: Any, ): @@ -305,8 +308,9 @@ class Agent: return "\n".join(tool_descriptions) except Exception as error: print( - f"Error getting tool description: {error} try adding a" - " description to the tool or removing the tool" + f"Error getting tool description: {error} try" + " adding a description to the tool or removing" + " the tool" ) else: return "No tools available" @@ -322,7 +326,9 @@ class Agent: """Construct the dynamic prompt""" tools_description = self.get_tool_description() - tool_prompt = self.tool_prompt_prep(tools_description, SCENARIOS) + tool_prompt = self.tool_prompt_prep( + tools_description, SCENARIOS + ) return tool_prompt @@ -435,27 +441,36 @@ class Agent: def activate_autonomous_agent(self): """Print the autonomous agent activation message""" try: - print(colored("Initializing Autonomous Agent...", "yellow")) + print( + colored("Initializing Autonomous Agent...", "yellow") + ) # print(colored("Loading modules...", "yellow")) # print(colored("Modules loaded successfully.", "green")) print( - colored("Autonomous Agent Activated.", "cyan", attrs=["bold"]) + colored( + "Autonomous Agent Activated.", + "cyan", + attrs=["bold"], + ) ) print( - colored("All systems operational. Executing task...", "green") + colored( + "All systems operational. Executing task...", + "green", + ) ) except Exception as error: print( colored( ( - "Error activating autonomous agent. Try optimizing your" - " parameters..." + "Error activating autonomous agent. Try" + " optimizing your parameters..." ), "red", ) ) print(error) - + def loop_count_print(self, loop_count, max_loops): """loop_count_print summary @@ -463,11 +478,9 @@ class Agent: loop_count (_type_): _description_ max_loops (_type_): _description_ """ - print( - colored(f"\nLoop {loop_count} of {max_loops}", "cyan") - ) + print(colored(f"\nLoop {loop_count} of {max_loops}", "cyan")) print("\n") - + def _history(self, user_name: str, task: str) -> str: """Generate the history for the history prompt @@ -480,8 +493,10 @@ class Agent: """ history = [f"{user_name}: {task}"] return history - - def _dynamic_prompt_setup(self, dynamic_prompt: str, task: str) -> str: + + def _dynamic_prompt_setup( + self, dynamic_prompt: str, task: str + ) -> str: """_dynamic_prompt_setup summary Args: @@ -491,11 +506,15 @@ class Agent: Returns: str: _description_ """ - dynamic_prompt = dynamic_prompt or self.construct_dynamic_prompt() + dynamic_prompt = ( + dynamic_prompt or self.construct_dynamic_prompt() + ) combined_prompt = f"{dynamic_prompt}\n{task}" return combined_prompt - def run(self, task: Optional[str], img: Optional[str] = None, **kwargs): + def run( + self, task: Optional[str], img: Optional[str] = None, **kwargs + ): """ Run the autonomous agent loop @@ -524,7 +543,10 @@ class Agent: loop_count = 0 # While the max_loops is auto or the loop count is less than the max_loops - while self.max_loops == "auto" or loop_count < self.max_loops: + while ( + self.max_loops == "auto" + or loop_count < self.max_loops + ): # Loop count loop_count += 1 self.loop_count_print(loop_count, self.max_loops) @@ -542,7 +564,9 @@ class Agent: self.dynamic_temperature() # Preparing the prompt - task = self.agent_history_prompt(FLOW_SYSTEM_PROMPT, response) + task = self.agent_history_prompt( + FLOW_SYSTEM_PROMPT, response + ) attempt = 0 while attempt < self.retry_attempts: @@ -581,7 +605,9 @@ class Agent: # print(response) break except Exception as e: - logging.error(f"Error generating response: {e}") + logging.error( + f"Error generating response: {e}" + ) attempt += 1 time.sleep(self.retry_interval) # Add the response to the history @@ -595,7 +621,10 @@ class Agent: if self.autosave: save_path = self.saved_state_path or "flow_state.json" print( - colored(f"Autosaving agent state to {save_path}", "green") + colored( + f"Autosaving agent state to {save_path}", + "green", + ) ) self.save_state(save_path) @@ -637,12 +666,16 @@ class Agent: # for i in range(self.max_loops): while self.max_loops == "auto" or loop_count < self.max_loops: loop_count += 1 - print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) + print( + colored( + f"\nLoop {loop_count} of {self.max_loops}", "blue" + ) + ) print("\n") - if self._check_stopping_condition(response) or parse_done_token( + if self._check_stopping_condition( response - ): + ) or parse_done_token(response): break # Adjust temperature, comment if no work @@ -650,7 +683,9 @@ class Agent: self.dynamic_temperature() # Preparing the prompt - task = self.agent_history_prompt(FLOW_SYSTEM_PROMPT, response) + task = self.agent_history_prompt( + FLOW_SYSTEM_PROMPT, response + ) attempt = 0 while attempt < self.retry_attempts: @@ -678,7 +713,11 @@ class Agent: if self.autosave: save_path = self.saved_state_path or "flow_state.json" - print(colored(f"Autosaving agent state to {save_path}", "green")) + print( + colored( + f"Autosaving agent state to {save_path}", "green" + ) + ) self.save_state(save_path) if self.return_history: @@ -737,7 +776,9 @@ class Agent: Args: tasks (List[str]): A list of tasks to run. """ - task_coroutines = [self.run_async(task, **kwargs) for task in tasks] + task_coroutines = [ + self.run_async(task, **kwargs) for task in tasks + ] completed_tasks = await asyncio.gather(*task_coroutines) return completed_tasks @@ -751,7 +792,9 @@ class Agent: return Agent(llm=llm, template=template) @staticmethod - def from_llm_and_template_file(llm: Any, template_file: str) -> "Agent": + def from_llm_and_template_file( + llm: Any, template_file: str + ) -> "Agent": """Create AgentStream from LLM and a template file.""" with open(template_file, "r") as f: template = f.read() @@ -785,16 +828,34 @@ class Agent: Prints the entire history and memory of the agent. Each message is colored and formatted for better readability. """ - print(colored("Agent History and Memory", "cyan", attrs=["bold"])) - print(colored("========================", "cyan", attrs=["bold"])) + print( + colored( + "Agent History and Memory", "cyan", attrs=["bold"] + ) + ) + print( + colored( + "========================", "cyan", attrs=["bold"] + ) + ) for loop_index, history in enumerate(self.memory, start=1): - print(colored(f"\nLoop {loop_index}:", "yellow", attrs=["bold"])) + print( + colored( + f"\nLoop {loop_index}:", "yellow", attrs=["bold"] + ) + ) for message in history: speaker, _, message_text = message.partition(": ") if "Human" in speaker: - print(colored(f"{speaker}:", "green") + f" {message_text}") + print( + colored(f"{speaker}:", "green") + + f" {message_text}" + ) else: - print(colored(f"{speaker}:", "blue") + f" {message_text}") + print( + colored(f"{speaker}:", "blue") + + f" {message_text}" + ) print(colored("------------------------", "cyan")) print(colored("End of Agent History", "cyan", attrs=["bold"])) @@ -963,7 +1024,16 @@ class Agent: value = getattr(self.llm, name) if isinstance( value, - (str, int, float, bool, list, dict, tuple, type(None)), + ( + str, + int, + float, + bool, + list, + dict, + tuple, + type(None), + ), ): llm_params[name] = value else: @@ -1110,7 +1180,9 @@ class Agent: text = text or self.pdf_connector() pass - def tools_prompt_prep(self, docs: str = None, scenarios: str = None): + def tools_prompt_prep( + self, docs: str = None, scenarios: str = None + ): """ Prepare the tool prompt """ diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index 16c7892b..7d4894ad 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -62,7 +62,9 @@ class AutoScaler: agent=None, ): self.agent = agent or Agent - self.agents_pool = [self.agent() for _ in range(initial_agents)] + self.agents_pool = [ + self.agent() for _ in range(initial_agents) + ] self.task_queue = queue.Queue() self.scale_up_factor = scale_up_factor self.idle_threshold = idle_threshold @@ -74,7 +76,8 @@ class AutoScaler: self.tasks_queue.put(task) except Exception as error: print( - f"Error adding task to queue: {error} try again with a new task" + f"Error adding task to queue: {error} try again with" + " a new task" ) @log_decorator @@ -84,20 +87,29 @@ class AutoScaler: """Add more agents""" try: with self.lock: - new_agents_counts = len(self.agents_pool) * self.scale_up_factor + new_agents_counts = ( + len(self.agents_pool) * self.scale_up_factor + ) for _ in range(new_agents_counts): self.agents_pool.append(Agent()) except Exception as error: - print(f"Error scaling up: {error} try again with a new task") + print( + f"Error scaling up: {error} try again with a new task" + ) def scale_down(self): """scale down""" try: with self.lock: - if len(self.agents_pool) > 10: # ensure minmum of 10 agents + if ( + len(self.agents_pool) > 10 + ): # ensure minmum of 10 agents del self.agents_pool[-1] # remove last agent except Exception as error: - print(f"Error scaling down: {error} try again with a new task") + print( + f"Error scaling down: {error} try again with a new" + " task" + ) @log_decorator @error_decorator @@ -109,19 +121,27 @@ class AutoScaler: sleep(60) # check minute pending_tasks = self.task_queue.qsize() active_agents = sum( - [1 for agent in self.agents_pool if agent.is_busy()] + [ + 1 + for agent in self.agents_pool + if agent.is_busy() + ] ) - if pending_tasks / len(self.agents_pool) > self.busy_threshold: + if ( + pending_tasks / len(self.agents_pool) + > self.busy_threshold + ): self.scale_up() elif ( - active_agents / len(self.agents_pool) < self.idle_threshold + active_agents / len(self.agents_pool) + < self.idle_threshold ): self.scale_down() except Exception as error: print( - f"Error monitoring and scaling: {error} try again with a new" - " task" + f"Error monitoring and scaling: {error} try again" + " with a new task" ) @log_decorator @@ -130,7 +150,9 @@ class AutoScaler: def start(self): """Start scaling""" try: - monitor_thread = threading.Thread(target=self.monitor_and_scale) + monitor_thread = threading.Thread( + target=self.monitor_and_scale + ) monitor_thread.start() while True: @@ -142,13 +164,17 @@ class AutoScaler: if available_agent: available_agent.run(task) except Exception as error: - print(f"Error starting: {error} try again with a new task") + print( + f"Error starting: {error} try again with a new task" + ) def check_agent_health(self): """Checks the health of each agent and replaces unhealthy agents.""" for i, agent in enumerate(self.agents_pool): if not agent.is_healthy(): - logging.warning(f"Replacing unhealthy agent at index {i}") + logging.warning( + f"Replacing unhealthy agent at index {i}" + ) self.agents_pool[i] = self.agent() def balance_load(self): @@ -159,7 +185,9 @@ class AutoScaler: task = self.task_queue.get() agent.run(task) - def set_scaling_strategy(self, strategy: Callable[[int, int], int]): + def set_scaling_strategy( + self, strategy: Callable[[int, int], int] + ): """Set a custom scaling strategy.""" self.custom_scale_strategy = strategy @@ -179,7 +207,11 @@ class AutoScaler: def report_agent_metrics(self) -> Dict[str, List[float]]: """Collects and reports metrics from each agent.""" - metrics = {"completion_time": [], "success_rate": [], "error_rate": []} + metrics = { + "completion_time": [], + "success_rate": [], + "error_rate": [], + } for agent in self.agents_pool: agent_metrics = agent.get_metrics() for key in metrics.keys(): diff --git a/swarms/structs/document.py b/swarms/structs/document.py index b87d3d91..7b99721f 100644 --- a/swarms/structs/document.py +++ b/swarms/structs/document.py @@ -87,5 +87,7 @@ class BaseDocumentTransformer(ABC): A list of transformed Documents. """ return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents + None, + partial(self.transform_documents, **kwargs), + documents, ) diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index 96ed6859..22e1236c 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -69,11 +69,18 @@ class Task: # Add a prompt to notify the Agent of the sequential workflow if "prompt" in self.kwargs: self.kwargs["prompt"] += ( - f"\n\nPrevious output: {self.result}" if self.result else "" + f"\n\nPrevious output: {self.result}" + if self.result + else "" ) else: - self.kwargs["prompt"] = f"Main task: {self.description}" + ( - f"\n\nPrevious output: {self.result}" if self.result else "" + self.kwargs["prompt"] = ( + f"Main task: {self.description}" + + ( + f"\n\nPrevious output: {self.result}" + if self.result + else "" + ) ) self.result = self.agent.run(*self.args, **self.kwargs) else: @@ -116,7 +123,9 @@ class SequentialWorkflow: autosave: bool = False name: str = (None,) description: str = (None,) - saved_state_filepath: Optional[str] = "sequential_workflow_state.json" + saved_state_filepath: Optional[str] = ( + "sequential_workflow_state.json" + ) restore_state_filepath: Optional[str] = None dashboard: bool = False @@ -181,7 +190,9 @@ class SequentialWorkflow: def remove_task(self, task: str) -> None: """Remove tasks from sequential workflow""" - self.tasks = [task for task in self.tasks if task.description != task] + self.tasks = [ + task for task in self.tasks if task.description != task + ] def update_task(self, task: str, **updates) -> None: """ @@ -330,7 +341,9 @@ class SequentialWorkflow: ) self.tasks.append(task) - def load_workflow_state(self, filepath: str = None, **kwargs) -> None: + def load_workflow_state( + self, filepath: str = None, **kwargs + ) -> None: """ Loads the workflow state from a json file and restores the workflow state. @@ -384,18 +397,22 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the" - " Agent agent execution in" - f" '{task.description}'" + "The 'task' argument is required" + " for the Agent agent execution" + f" in '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") task.result = task.agent.run( - flow_task_arg, *task.args, **task.kwargs + flow_task_arg, + *task.args, + **task.kwargs, ) else: # If it's not a Agent instance, call the agent directly - task.result = task.agent(*task.args, **task.kwargs) + task.result = task.agent( + *task.args, **task.kwargs + ) # Pass the result as an argument to the next task if it exists next_task_index = self.tasks.index(task) + 1 @@ -417,9 +434,9 @@ class SequentialWorkflow: print( colored( ( - f"Error initializing the Sequential workflow: {e} try" - " optimizing your inputs like the agent class and task" - " description" + "Error initializing the Sequential workflow:" + f" {e} try optimizing your inputs like the" + " agent class and task description" ), "red", attrs=["bold", "underline"], @@ -443,8 +460,9 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the Agent" - f" agent execution in '{task.description}'" + "The 'task' argument is required for" + " the Agent agent execution in" + f" '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") diff --git a/swarms/swarms/base.py b/swarms/swarms/base.py index 1ccc819c..15238a8a 100644 --- a/swarms/swarms/base.py +++ b/swarms/swarms/base.py @@ -144,7 +144,9 @@ class AbstractSwarm(ABC): pass @abstractmethod - def autoscaler(self, num_workers: int, worker: ["AbstractWorker"]): + def autoscaler( + self, num_workers: int, worker: ["AbstractWorker"] + ): """Autoscaler that acts like kubernetes for autonomous agents""" pass @@ -159,7 +161,9 @@ class AbstractSwarm(ABC): pass @abstractmethod - def assign_task(self, worker: "AbstractWorker", task: Any) -> Dict: + def assign_task( + self, worker: "AbstractWorker", task: Any + ) -> Dict: """Assign a task to a worker""" pass diff --git a/swarms/swarms/dialogue_simulator.py b/swarms/swarms/dialogue_simulator.py index c5257ef4..b5a07d7b 100644 --- a/swarms/swarms/dialogue_simulator.py +++ b/swarms/swarms/dialogue_simulator.py @@ -24,7 +24,10 @@ class DialogueSimulator: """ def __init__( - self, agents: List[Callable], max_iters: int = 10, name: str = None + self, + agents: List[Callable], + max_iters: int = 10, + name: str = None, ): self.agents = agents self.max_iters = max_iters @@ -60,7 +63,8 @@ class DialogueSimulator: def __repr__(self): return ( - f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})" + f"DialogueSimulator({self.agents}, {self.max_iters}," + f" {self.name})" ) def save_state(self): diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index 65377308..29178b2c 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -64,7 +64,11 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + tabulate( + table, + headers=["LLM", "Response"], + tablefmt="pretty", + ), "cyan", ) ) @@ -84,7 +88,11 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + tabulate( + table, + headers=["LLM", "Response"], + tablefmt="pretty", + ), "cyan", ) ) @@ -122,7 +130,11 @@ class GodMode: ] print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + tabulate( + table, + headers=["LLM", "Response"], + tablefmt="pretty", + ), "cyan", ) ) @@ -159,8 +171,8 @@ class GodMode: responses.append(future.result()) except Exception as error: print( - f"{future_to_llm[future]} generated an exception:" - f" {error}" + f"{future_to_llm[future]} generated an" + f" exception: {error}" ) self.last_responses = responses self.task_history.append(task) diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index 38b692e7..76f287bc 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -43,12 +43,15 @@ class GroupChat: for agent in self.agents: if agent.name in name: return agent - raise ValueError(f"No agent found with a name contained in '{name}'.") + raise ValueError( + f"No agent found with a name contained in '{name}'." + ) def next_agent(self, agent: Agent) -> Agent: """Return the next agent in the list.""" return self.agents[ - (self.agent_names.index(agent.name) + 1) % len(self.agents) + (self.agent_names.index(agent.name) + 1) + % len(self.agents) ] def select_speaker_msg(self): @@ -69,8 +72,8 @@ class GroupChat: n_agents = len(self.agent_names) if n_agents < 3: logger.warning( - f"GroupChat is underpopulated with {n_agents} agents. Direct" - " communication would be more efficient." + f"GroupChat is underpopulated with {n_agents} agents." + " Direct communication would be more efficient." ) name = selector.generate_reply( @@ -80,9 +83,10 @@ class GroupChat: { "role": "system", "content": ( - "Read the above conversation. Then select the next" - f" most suitable role from {self.agent_names} to" - " play. Only return the role." + "Read the above conversation. Then" + " select the next most suitable role" + f" from {self.agent_names} to play. Only" + " return the role." ), } ] @@ -95,13 +99,18 @@ class GroupChat: def _participant_roles(self): return "\n".join( - [f"{agent.name}: {agent.system_message}" for agent in self.agents] + [ + f"{agent.name}: {agent.system_message}" + for agent in self.agents + ] ) def format_history(self, messages: List[Dict]) -> str: formatted_messages = [] for message in messages: - formatted_message = f"'{message['role']}:{message['content']}" + formatted_message = ( + f"'{message['role']}:{message['content']}" + ) formatted_messages.append(formatted_message) return "\n".join(formatted_messages) diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/swarms/multi_agent_collab.py index 19075d33..64b030d0 100644 --- a/swarms/swarms/multi_agent_collab.py +++ b/swarms/swarms/multi_agent_collab.py @@ -13,8 +13,8 @@ from swarms.utils.logger import logger class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: return ( - "Your response should be an integrater delimited by angled brackets" - " like this: " + "Your response should be an integrater delimited by" + " angled brackets like this: " ) @@ -123,7 +123,9 @@ class MultiAgentCollaboration: def step(self) -> tuple[str, str]: """Steps through the multi-agent collaboration.""" - speaker_idx = self.select_next_speaker(self._step, self.agents) + speaker_idx = self.select_next_speaker( + self._step, self.agents + ) speaker = self.agents[speaker_idx] message = speaker.send() message = speaker.send() @@ -146,7 +148,8 @@ class MultiAgentCollaboration: wait=tenacity.wait_none(), retry=tenacity.retry_if_exception_type(ValueError), before_sleep=lambda retry_state: print( - f"ValueError occured: {retry_state.outcome.exception()}, retying..." + f"ValueError occured: {retry_state.outcome.exception()}," + " retying..." ), retry_error_callback=lambda retry_state: 0, ) @@ -167,7 +170,9 @@ class MultiAgentCollaboration: bid = self.ask_for_bid(agent) bids.append(bid) max_value = max(bids) - max_indices = [i for i, x in enumerate(bids) if x == max_value] + max_indices = [ + i for i, x in enumerate(bids) if x == max_value + ] idx = random.choice(max_indices) return idx @@ -176,7 +181,8 @@ class MultiAgentCollaboration: wait=tenacity.wait_none(), retry=tenacity.retry_if_exception_type(ValueError), before_sleep=lambda retry_state: print( - f"ValueError occured: {retry_state.outcome.exception()}, retying..." + f"ValueError occured: {retry_state.outcome.exception()}," + " retying..." ), retry_error_callback=lambda retry_state: 0, ) @@ -256,7 +262,9 @@ class MultiAgentCollaboration: for _ in range(self.max_iters): for agent in self.agents: result = agent.run(conversation) - self.results.append({"agent": agent, "response": result}) + self.results.append( + {"agent": agent, "response": result} + ) conversation += result if self.autosave: @@ -309,7 +317,9 @@ class MultiAgentCollaboration: """Tracks and reports the performance of each agent""" performance_data = {} for agent in self.agents: - performance_data[agent.name] = agent.get_performance_metrics() + performance_data[agent.name] = ( + agent.get_performance_metrics() + ) return performance_data def set_interaction_rules(self, rules): diff --git a/swarms/swarms/orchestrate.py b/swarms/swarms/orchestrate.py index b7a7d0e0..387c32e4 100644 --- a/swarms/swarms/orchestrate.py +++ b/swarms/swarms/orchestrate.py @@ -119,13 +119,17 @@ class Orchestrator: self.lock = threading.Lock() self.condition = threading.Condition(self.lock) - self.executor = ThreadPoolExecutor(max_workers=len(agent_list)) + self.executor = ThreadPoolExecutor( + max_workers=len(agent_list) + ) self.embed_func = embed_func if embed_func else self.embed # @abstractmethod - def assign_task(self, agent_id: int, task: Dict[str, Any]) -> None: + def assign_task( + self, agent_id: int, task: Dict[str, Any] + ) -> None: """Assign a task to a specific agent""" while True: @@ -156,8 +160,8 @@ class Orchestrator: except Exception as error: logging.error( - f"Failed to process task {id(task)} by agent {id(agent)}." - f" Error: {error}" + f"Failed to process task {id(task)} by agent" + f" {id(agent)}. Error: {error}" ) finally: with self.condition: @@ -185,7 +189,8 @@ class Orchestrator: return results except Exception as e: logging.error( - f"Failed to retrieve results from agent {agent_id}. Error {e}" + f"Failed to retrieve results from agent {agent_id}." + f" Error {e}" ) raise @@ -201,7 +206,9 @@ class Orchestrator: ) except Exception as e: - logging.error(f"Failed to update the vector database. Error: {e}") + logging.error( + f"Failed to update the vector database. Error: {e}" + ) raise # @abstractmethod @@ -214,11 +221,14 @@ class Orchestrator: """append the result of the swarm to a specifici collection in the database""" try: - self.collection.add(documents=[result], ids=[str(id(result))]) + self.collection.add( + documents=[result], ids=[str(id(result))] + ) except Exception as e: logging.error( - f"Failed to append the agent output to database. Error: {e}" + "Failed to append the agent output to database." + f" Error: {e}" ) raise @@ -241,7 +251,9 @@ class Orchestrator: for result in results: self.append_to_db(result) - logging.info(f"Successfully ran swarms with results: {results}") + logging.info( + f"Successfully ran swarms with results: {results}" + ) return results except Exception as e: logging.error(f"An error occured in swarm: {e}") @@ -264,7 +276,9 @@ class Orchestrator: """ - message_vector = self.embed(message, self.api_key, self.model_name) + message_vector = self.embed( + message, self.api_key, self.model_name + ) # store the mesage in the vector database self.collection.add( @@ -273,15 +287,21 @@ class Orchestrator: ids=[f"{sender_id}_to_{receiver_id}"], ) - self.run(objective=f"chat with agent {receiver_id} about {message}") + self.run( + objective=f"chat with agent {receiver_id} about {message}" + ) def add_agents(self, num_agents: int): for _ in range(num_agents): self.agents.put(self.agent()) - self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize()) + self.executor = ThreadPoolExecutor( + max_workers=self.agents.qsize() + ) def remove_agents(self, num_agents): for _ in range(num_agents): if not self.agents.empty(): self.agents.get() - self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize()) + self.executor = ThreadPoolExecutor( + max_workers=self.agents.qsize() + ) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index 105a2541..1029a183 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -124,10 +124,15 @@ class BaseTool(RunnableSerializable[Union[str, Dict], Any]): """Create the definition of the new tool class.""" super().__init_subclass__(**kwargs) - args_schema_type = cls.__annotations__.get("args_schema", None) + args_schema_type = cls.__annotations__.get( + "args_schema", None + ) if args_schema_type is not None: - if args_schema_type is None or args_schema_type == BaseModel: + if ( + args_schema_type is None + or args_schema_type == BaseModel + ): # Throw errors for common mis-annotations. # TODO: Use get_args / get_origin and fully # specify valid annotations. @@ -138,10 +143,11 @@ class ChildTool(BaseTool): ...""" name = cls.__name__ raise SchemaAnnotationError( - f"Tool definition for {name} must include valid type" - " annotations for argument 'args_schema' to behave as" - " expected.\nExpected annotation of 'Type[BaseModel]' but" - f" got '{args_schema_type}'.\nExpected class looks" + f"Tool definition for {name} must include valid" + " type annotations for argument 'args_schema' to" + " behave as expected.\nExpected annotation of" + " 'Type[BaseModel]' but got" + f" '{args_schema_type}'.\nExpected class looks" f" like:\n{typehint_mandate}" ) @@ -264,7 +270,9 @@ class ChildTool(BaseTool): if input_args is not None: result = input_args.parse_obj(tool_input) return { - k: v for k, v in result.dict().items() if k in tool_input + k: v + for k, v in result.dict().items() + if k in tool_input } return tool_input @@ -273,7 +281,10 @@ class ChildTool(BaseTool): """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: warnings.warn( - "callback_manager is deprecated. Please use callbacks instead.", + ( + "callback_manager is deprecated. Please use" + " callbacks instead." + ), DeprecationWarning, ) values["callbacks"] = values.pop("callback_manager", None) @@ -346,18 +357,28 @@ class ChildTool(BaseTool): self.metadata, ) # TODO: maybe also pass through run_manager is _run supports kwargs - new_arg_supported = signature(self._run).parameters.get("run_manager") + new_arg_supported = signature(self._run).parameters.get( + "run_manager" + ) run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), + ( + tool_input + if isinstance(tool_input, str) + else str(tool_input) + ), color=start_color, name=run_name, **kwargs, ) try: - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + tool_args, tool_kwargs = self._to_args_and_kwargs( + parsed_input + ) observation = ( - self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + self._run( + *tool_args, run_manager=run_manager, **tool_kwargs + ) if new_arg_supported else self._run(*tool_args, **tool_kwargs) ) @@ -376,12 +397,15 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected" - " bool, str or callable. Received:" + "Got unexpected type of `handle_tool_error`." + " Expected bool, str or callable. Received:" f" {self.handle_tool_error}" ) run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs + str(observation), + color="red", + name=self.name, + **kwargs, ) return observation except (Exception, KeyboardInterrupt) as e: @@ -389,7 +413,10 @@ class ChildTool(BaseTool): raise e else: run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs + str(observation), + color=color, + name=self.name, + **kwargs, ) return observation @@ -421,17 +448,25 @@ class ChildTool(BaseTool): metadata, self.metadata, ) - new_arg_supported = signature(self._arun).parameters.get("run_manager") + new_arg_supported = signature(self._arun).parameters.get( + "run_manager" + ) run_manager = await callback_manager.on_tool_start( {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), + ( + tool_input + if isinstance(tool_input, str) + else str(tool_input) + ), color=start_color, name=run_name, **kwargs, ) try: # We then call the tool on the tool input to get an observation - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + tool_args, tool_kwargs = self._to_args_and_kwargs( + parsed_input + ) observation = ( await self._arun( *tool_args, run_manager=run_manager, **tool_kwargs @@ -454,12 +489,15 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected" - " bool, str or callable. Received:" + "Got unexpected type of `handle_tool_error`." + " Expected bool, str or callable. Received:" f" {self.handle_tool_error}" ) await run_manager.on_tool_end( - str(observation), color="red", name=self.name, **kwargs + str(observation), + color="red", + name=self.name, + **kwargs, ) return observation except (Exception, KeyboardInterrupt) as e: @@ -467,11 +505,16 @@ class ChildTool(BaseTool): raise e else: await run_manager.on_tool_end( - str(observation), color=color, name=self.name, **kwargs + str(observation), + color=color, + name=self.name, + **kwargs, ) return observation - def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: + def __call__( + self, tool_input: str, callbacks: Callbacks = None + ) -> str: """Make tool callable.""" return self.run(tool_input, callbacks=callbacks) @@ -520,8 +563,8 @@ class Tool(BaseTool): all_args = list(args) + list(kwargs.values()) if len(all_args) != 1: raise ToolException( - f"Too many arguments to single-input tool {self.name}. Args:" - f" {all_args}" + "Too many arguments to single-input tool" + f" {self.name}. Args: {all_args}" ) return tuple(all_args), {} @@ -533,13 +576,17 @@ class Tool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get( - "callbacks" - ) + new_argument_supported = signature( + self.func + ).parameters.get("callbacks") return ( self.func( *args, - callbacks=run_manager.get_child() if run_manager else None, + callbacks=( + run_manager.get_child() + if run_manager + else None + ), **kwargs, ) if new_argument_supported @@ -555,13 +602,17 @@ class Tool(BaseTool): ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) + new_argument_supported = signature( + self.coroutine + ).parameters.get("callbacks") return ( await self.coroutine( *args, - callbacks=run_manager.get_child() if run_manager else None, + callbacks=( + run_manager.get_child() + if run_manager + else None + ), **kwargs, ) if new_argument_supported @@ -602,7 +653,9 @@ class Tool(BaseTool): ) -> Tool: """Initialize tool from a function.""" if func is None and coroutine is None: - raise ValueError("Function and/or coroutine must be provided") + raise ValueError( + "Function and/or coroutine must be provided" + ) return cls( name=name, func=func, @@ -618,7 +671,9 @@ class StructuredTool(BaseTool): """Tool that can operate on any number of inputs.""" description: str = "" - args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + args_schema: Type[BaseModel] = Field( + ..., description="The tool schema." + ) """The input arguments' schema.""" func: Optional[Callable[..., Any]] """The function to run when the tool is called.""" @@ -655,13 +710,17 @@ class StructuredTool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get( - "callbacks" - ) + new_argument_supported = signature( + self.func + ).parameters.get("callbacks") return ( self.func( *args, - callbacks=run_manager.get_child() if run_manager else None, + callbacks=( + run_manager.get_child() + if run_manager + else None + ), **kwargs, ) if new_argument_supported @@ -677,13 +736,17 @@ class StructuredTool(BaseTool): ) -> str: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) + new_argument_supported = signature( + self.coroutine + ).parameters.get("callbacks") return ( await self.coroutine( *args, - callbacks=run_manager.get_child() if run_manager else None, + callbacks=( + run_manager.get_child() + if run_manager + else None + ), **kwargs, ) if new_argument_supported @@ -740,12 +803,15 @@ class StructuredTool(BaseTool): elif coroutine is not None: source_function = coroutine else: - raise ValueError("Function and/or coroutine must be provided") + raise ValueError( + "Function and/or coroutine must be provided" + ) name = name or source_function.__name__ description = description or source_function.__doc__ if description is None: raise ValueError( - "Function must have a docstring if description not provided." + "Function must have a docstring if description not" + " provided." ) # Description example: @@ -804,28 +870,41 @@ def tool( """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: + def _make_tool( + dec_func: Union[Callable, Runnable] + ) -> BaseTool: if isinstance(dec_func, Runnable): runnable = dec_func - if runnable.input_schema.schema().get("type") != "object": - raise ValueError("Runnable must have an object schema.") + if ( + runnable.input_schema.schema().get("type") + != "object" + ): + raise ValueError( + "Runnable must have an object schema." + ) async def ainvoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any + callbacks: Optional[Callbacks] = None, + **kwargs: Any, ) -> Any: return await runnable.ainvoke( kwargs, {"callbacks": callbacks} ) def invoke_wrapper( - callbacks: Optional[Callbacks] = None, **kwargs: Any + callbacks: Optional[Callbacks] = None, + **kwargs: Any, ) -> Any: - return runnable.invoke(kwargs, {"callbacks": callbacks}) + return runnable.invoke( + kwargs, {"callbacks": callbacks} + ) coroutine = ainvoke_wrapper func = invoke_wrapper - schema: Optional[Type[BaseModel]] = runnable.input_schema + schema: Optional[Type[BaseModel]] = ( + runnable.input_schema + ) description = repr(runnable) elif inspect.iscoroutinefunction(dec_func): coroutine = dec_func @@ -852,8 +931,8 @@ def tool( # a simple string->string function if func.__doc__ is None: raise ValueError( - "Function must have a docstring if " - "description not provided and infer_schema is False." + "Function must have a docstring if description" + " not provided and infer_schema is False." ) return Tool( name=tool_name, diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index b8aca925..494f182f 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -1,7 +1,9 @@ from swarms.utils.markdown_message import display_markdown_message from swarms.utils.futures import execute_futures_dict from swarms.utils.code_interpreter import SubprocessCodeInterpreter -from swarms.utils.parse_code import extract_code_in_backticks_in_string +from swarms.utils.parse_code import ( + extract_code_in_backticks_in_string, +) from swarms.utils.pdf_to_text import pdf_to_text __all__ = [ diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index 4adcb5cf..f2e1bb38 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -102,7 +102,9 @@ class Action: tool_name: str = "" tool_input: dict = field(default_factory=lambda: {}) - tool_output_status: ToolCallStatus = ToolCallStatus.ToolCallSuccess + tool_output_status: ToolCallStatus = ( + ToolCallStatus.ToolCallSuccess + ) tool_output: str = "" def to_json(self): @@ -124,7 +126,9 @@ class Action: @dataclass class userQuery: task: str - additional_information: List[str] = field(default_factory=lambda: []) + additional_information: List[str] = field( + default_factory=lambda: [] + ) refine_prompt: str = field(default_factory=lambda: "") def print_self(self): diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index fc2f95f7..98fbab70 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -117,7 +117,10 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): # applescript yield {"output": traceback.format_exc()} yield { - "output": f"Retrying... ({retry_count}/{max_retries})" + "output": ( + "Retrying..." + f" ({retry_count}/{max_retries})" + ) } yield {"output": "Restarting process."} @@ -127,7 +130,8 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): if retry_count > max_retries: yield { "output": ( - "Maximum retries reached. Could not execute code." + "Maximum retries reached. Could not" + " execute code." ) } return diff --git a/swarms/utils/decorators.py b/swarms/utils/decorators.py index cf4a774c..e4c11574 100644 --- a/swarms/utils/decorators.py +++ b/swarms/utils/decorators.py @@ -32,7 +32,8 @@ def timing_decorator(func): result = func(*args, **kwargs) end_time = time.time() logging.info( - f"{func.__name__} executed in {end_time - start_time} seconds" + f"{func.__name__} executed in" + f" {end_time - start_time} seconds" ) return result @@ -48,7 +49,8 @@ def retry_decorator(max_retries=5): return func(*args, **kwargs) except Exception as error: logging.error( - f" Error in {func.__name__}: {str(error)} Retrying ...." + f" Error in {func.__name__}:" + f" {str(error)} Retrying ...." ) return func(*args, **kwargs) @@ -82,7 +84,8 @@ def deprecated_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): warnings.warn( - f"{func.__name__} is deprecated", category=DeprecationWarning + f"{func.__name__} is deprecated", + category=DeprecationWarning, ) return func(*args, **kwargs) diff --git a/swarms/utils/disable_logging.py b/swarms/utils/disable_logging.py index 93e59bb5..3b6884d2 100644 --- a/swarms/utils/disable_logging.py +++ b/swarms/utils/disable_logging.py @@ -3,6 +3,7 @@ import os import warnings import sys + def disable_logging(): log_file = open("errors.txt", "w") sys.stderr = log_file @@ -30,4 +31,6 @@ def disable_logging(): "wandb.docker.auth", ]: logger = logging.getLogger(logger_name) - logger.setLevel(logging.WARNING) # Supress DEBUG and info logs + logger.setLevel( + logging.WARNING + ) # Supress DEBUG and info logs diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py index a5ffdf51..744b44e0 100644 --- a/swarms/utils/futures.py +++ b/swarms/utils/futures.py @@ -4,9 +4,13 @@ from typing import TypeVar T = TypeVar("T") -def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: +def execute_futures_dict( + fs_dict: dict[str, futures.Future[T]] +) -> dict[str, T]: futures.wait( - fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED + fs_dict.values(), + timeout=None, + return_when=futures.ALL_COMPLETED, ) return {key: future.result() for key, future in fs_dict.items()} diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index d9845543..a0dec94d 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -14,7 +14,9 @@ from swarms.utils.apa import Action, ToolCallStatus # from autogpt.speech import say_text class JsonFileHandler(logging.FileHandler): - def __init__(self, filename, mode="a", encoding=None, delay=False): + def __init__( + self, filename, mode="a", encoding=None, delay=False + ): """ Initializes a new instance of the class with the specified file name, mode, encoding, and delay settings. @@ -84,7 +86,9 @@ class Logger: log_file = "activity.log" error_file = "error.log" - console_formatter = AutoGptFormatter("%(title_color)s %(message)s") + console_formatter = AutoGptFormatter( + "%(title_color)s %(message)s" + ) # Create a handler for console which simulate typing self.typing_console_handler = TypingConsoleHandler() @@ -113,8 +117,9 @@ class Logger: ) error_handler.setLevel(logging.ERROR) error_formatter = AutoGptFormatter( - "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d" - " %(title)s %(message_no_color)s" + "%(asctime)s %(levelname)s" + " %(module)s:%(funcName)s:%(lineno)d %(title)s" + " %(message_no_color)s" ) error_handler.setFormatter(error_formatter) @@ -170,7 +175,9 @@ class Logger: content = "" self.typing_logger.log( - level, content, extra={"title": title, "color": title_color} + level, + content, + extra={"title": title, "color": title_color}, ) def debug( @@ -292,9 +299,9 @@ class Logger: additionalText = ( "Please ensure you've setup and configured everything" " correctly. Read" - " https://github.com/Torantulino/Auto-GPT#readme to double" - " check. You can also create a github issue or join the discord" - " and ask there!" + " https://github.com/Torantulino/Auto-GPT#readme to" + " double check. You can also create a github issue or" + " join the discord and ask there!" ) self.typewriter_log( @@ -368,10 +375,16 @@ class TypingConsoleHandler(logging.StreamHandler): transfer_enter = "" msg_transfered = str(msg).replace("\n", transfer_enter) transfer_space = "<4SPACE>" - msg_transfered = str(msg_transfered).replace(" ", transfer_space) + msg_transfered = str(msg_transfered).replace( + " ", transfer_space + ) words = msg_transfered.split() - words = [word.replace(transfer_enter, "\n") for word in words] - words = [word.replace(transfer_space, " ") for word in words] + words = [ + word.replace(transfer_enter, "\n") for word in words + ] + words = [ + word.replace(transfer_space, " ") for word in words + ] for i, word in enumerate(words): print(word, end="", flush=True) @@ -437,7 +450,9 @@ class AutoGptFormatter(logging.Formatter): record.title = getattr(record, "title", "") if hasattr(record, "msg"): - record.message_no_color = remove_color_codes(getattr(record, "msg")) + record.message_no_color = remove_color_codes( + getattr(record, "msg") + ) else: record.message_no_color = "" return super().format(record) @@ -471,8 +486,12 @@ def print_action_base(action: Action): None """ if action.content != "": - logger.typewriter_log(f"content:", Fore.YELLOW, f"{action.content}") - logger.typewriter_log(f"Thought:", Fore.YELLOW, f"{action.thought}") + logger.typewriter_log( + f"content:", Fore.YELLOW, f"{action.content}" + ) + logger.typewriter_log( + f"Thought:", Fore.YELLOW, f"{action.thought}" + ) if len(action.plan) > 0: logger.typewriter_log( f"Plan:", @@ -481,7 +500,9 @@ def print_action_base(action: Action): for line in action.plan: line = line.lstrip("- ") logger.typewriter_log("- ", Fore.GREEN, line.strip()) - logger.typewriter_log(f"Criticism:", Fore.YELLOW, f"{action.criticism}") + logger.typewriter_log( + f"Criticism:", Fore.YELLOW, f"{action.criticism}" + ) def print_action_tool(action: Action): @@ -495,15 +516,21 @@ def print_action_tool(action: Action): None """ logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}") - logger.typewriter_log(f"Tool Input:", Fore.BLUE, f"{action.tool_input}") + logger.typewriter_log( + f"Tool Input:", Fore.BLUE, f"{action.tool_input}" + ) - output = action.tool_output if action.tool_output != "" else "None" + output = ( + action.tool_output if action.tool_output != "" else "None" + ) logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}") color = Fore.RED if action.tool_output_status == ToolCallStatus.ToolCallSuccess: color = Fore.GREEN - elif action.tool_output_status == ToolCallStatus.InputCannotParsed: + elif ( + action.tool_output_status == ToolCallStatus.InputCannotParsed + ): color = Fore.YELLOW logger.typewriter_log( diff --git a/swarms/utils/main.py b/swarms/utils/main.py index 73704552..c9c0f380 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -36,7 +36,9 @@ def cut_dialogue_history(history_memory, keep_last_n_words=500): paragraphs = history_memory.split("\n") last_n_tokens = n_tokens while last_n_tokens >= keep_last_n_words: - last_n_tokens = last_n_tokens - len(paragraphs[0].split(" ")) + last_n_tokens = last_n_tokens - len( + paragraphs[0].split(" ") + ) paragraphs = paragraphs[1:] return "\n" + "\n".join(paragraphs) @@ -51,14 +53,20 @@ def get_new_image_name(org_img_name, func_name="update"): most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + this_new_uuid, + func_name, + recent_prev_file_name, + most_org_file_name, ) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + this_new_uuid, + func_name, + recent_prev_file_name, + most_org_file_name, ) return os.path.join(head, new_file_name) @@ -73,14 +81,20 @@ def get_new_dataframe_name(org_img_name, func_name="update"): most_org_file_name = name_split[0] recent_prev_file_name = name_split[0] new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + this_new_uuid, + func_name, + recent_prev_file_name, + most_org_file_name, ) else: assert len(name_split) == 4 most_org_file_name = name_split[3] recent_prev_file_name = name_split[0] new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, func_name, recent_prev_file_name, most_org_file_name + this_new_uuid, + func_name, + recent_prev_file_name, + most_org_file_name, ) return os.path.join(head, new_file_name) @@ -187,7 +201,11 @@ class ANSI: self.args = [] def join(self) -> str: - return ANSI.ESCAPE + ";".join([str(a) for a in self.args]) + ANSI.CLOSE + return ( + ANSI.ESCAPE + + ";".join([str(a) for a in self.args]) + + ANSI.CLOSE + ) def wrap(self, text: str) -> str: return self.join() + text + ANSI(Style.reset()).join() @@ -338,7 +356,9 @@ class BaseHandler: class FileHandler: - def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): + def __init__( + self, handlers: Dict[FileType, BaseHandler], path: Path + ): self.handlers = handlers self.path = path @@ -366,9 +386,16 @@ class FileHandler: os.environ.get("SERVER", "http://localhost:8000") ): local_filepath = url[ - len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : + len( + os.environ.get( + "SERVER", "http://localhost:8000" + ) + ) + + 1 : ] - local_filename = Path("file") / local_filepath.split("/")[-1] + local_filename = ( + Path("file") / local_filepath.split("/")[-1] + ) src = self.path / local_filepath dst = ( self.path @@ -383,11 +410,14 @@ class FileHandler: if handler is None: if FileType.from_url(url) == FileType.IMAGE: raise Exception( - f"No handler for {FileType.from_url(url)}. " - "Please set USE_GPU to True in env/settings.py" + f"No handler for {FileType.from_url(url)}." + " Please set USE_GPU to True in" + " env/settings.py" ) else: - raise Exception(f"No handler for {FileType.from_url(url)}") + raise Exception( + f"No handler for {FileType.from_url(url)}" + ) return handler.handle(local_filename) except Exception as e: raise e diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index 9e3b8cb4..747bd0b6 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -6,7 +6,9 @@ def extract_code_in_backticks_in_string(message: str) -> str: To extract code from a string in markdown and return a string """ - pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks + pattern = ( # Non-greedy match between six backticks + r"`` ``(.*?)`` " + ) match = re.search( pattern, message, re.DOTALL ) # re.DOTALL to match newline chars diff --git a/swarms/utils/pdf_to_text.py b/swarms/utils/pdf_to_text.py index b8778841..35309dd3 100644 --- a/swarms/utils/pdf_to_text.py +++ b/swarms/utils/pdf_to_text.py @@ -4,7 +4,10 @@ import os try: import PyPDF2 except ImportError: - print("PyPDF2 not installed. Please install it using: pip install PyPDF2") + print( + "PyPDF2 not installed. Please install it using: pip install" + " PyPDF2" + ) sys.exit(1) @@ -34,9 +37,13 @@ def pdf_to_text(pdf_path): return text except FileNotFoundError: - raise FileNotFoundError(f"The file at {pdf_path} was not found.") + raise FileNotFoundError( + f"The file at {pdf_path} was not found." + ) except Exception as e: - raise Exception(f"An error occurred while reading the PDF file: {e}") + raise Exception( + f"An error occurred while reading the PDF file: {e}" + ) # Example usage diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index c7f9bc2c..de9444ef 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -74,7 +74,9 @@ class Serializable(BaseModel, ABC): super().__init__(**kwargs) self._lc_kwargs = kwargs - def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: + def to_json( + self, + ) -> Union[SerializedConstructor, SerializedNotImplemented]: if not self.lc_serializable: return self.to_json_not_implemented() @@ -93,7 +95,10 @@ class Serializable(BaseModel, ABC): break # Get a reference to self bound to each class in the MRO - this = cast(Serializable, self if cls is None else super(cls, self)) + this = cast( + Serializable, + self if cls is None else super(cls, self), + ) secrets.update(this.lc_secrets) lc_kwargs.update(this.lc_attributes) @@ -101,7 +106,9 @@ class Serializable(BaseModel, ABC): # include all secrets, even if not specified in kwargs # as these secrets may be passed as an environment variable instead for key in secrets.keys(): - secret_value = getattr(self, key, None) or lc_kwargs.get(key) + secret_value = getattr(self, key, None) or lc_kwargs.get( + key + ) if secret_value is not None: lc_kwargs.update({key: secret_value}) @@ -155,7 +162,10 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented: if hasattr(obj, "__name__"): _id = [*obj.__module__.split("."), obj.__name__] elif hasattr(obj, "__class__"): - _id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] + _id = [ + *obj.__class__.__module__.split("."), + obj.__class__.__name__, + ] except Exception: pass return { diff --git a/tests/embeddings/test_pegasus.py b/tests/embeddings/test_pegasus.py index e9632eae..64909d3b 100644 --- a/tests/embeddings/test_pegasus.py +++ b/tests/embeddings/test_pegasus.py @@ -11,7 +11,9 @@ def test_init(): def test_init_exception(): - with patch("your_module.Pegasus", side_effect=Exception("Test exception")): + with patch( + "your_module.Pegasus", side_effect=Exception("Test exception") + ): with pytest.raises(Exception) as e: PegasusEmbedding(modality="text") assert str(e.value) == "Test exception" @@ -26,7 +28,9 @@ def test_embed(): def test_embed_exception(): with patch("your_module.Pegasus") as MockPegasus: - MockPegasus.return_value.embed.side_effect = Exception("Test exception") + MockPegasus.return_value.embed.side_effect = Exception( + "Test exception" + ) embedder = PegasusEmbedding(modality="text") with pytest.raises(Exception) as e: embedder.embed("Hello world") diff --git a/tests/memory/qdrant.py b/tests/memory/qdrant.py index 76711420..12a6af84 100644 --- a/tests/memory/qdrant.py +++ b/tests/memory/qdrant.py @@ -28,7 +28,9 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client): assert qdrant_client.client is not None -def test_load_embedding_model(qdrant_client, mock_sentence_transformer): +def test_load_embedding_model( + qdrant_client, mock_sentence_transformer +): qdrant_client._load_embedding_model("model_name") mock_sentence_transformer.assert_called_once_with("model_name") diff --git a/tests/memory/test_main.py b/tests/memory/test_main.py index 851de26a..63d56907 100644 --- a/tests/memory/test_main.py +++ b/tests/memory/test_main.py @@ -24,7 +24,9 @@ def test_init(ocean_db, mock_ocean_client): assert ocean_db.client.heartbeat() == "OK" -def test_create_collection(ocean_db, mock_ocean_client, mock_collection): +def test_create_collection( + ocean_db, mock_ocean_client, mock_collection +): mock_ocean_client.create_collection.return_value = mock_collection collection = ocean_db.create_collection("test", "text") assert collection == mock_collection @@ -34,14 +36,18 @@ def test_append_document(ocean_db, mock_collection): document = "test_document" id = "test_id" ocean_db.append_document(mock_collection, document, id) - mock_collection.add.assert_called_once_with(documents=[document], ids=[id]) + mock_collection.add.assert_called_once_with( + documents=[document], ids=[id] + ) def test_add_documents(ocean_db, mock_collection): documents = ["test_document1", "test_document2"] ids = ["test_id1", "test_id2"] ocean_db.add_documents(mock_collection, documents, ids) - mock_collection.add.assert_called_once_with(documents=documents, ids=ids) + mock_collection.add.assert_called_once_with( + documents=documents, ids=ids + ) def test_query(ocean_db, mock_collection): diff --git a/tests/memory/test_oceandb.py b/tests/memory/test_oceandb.py index c74b7c15..e760dc61 100644 --- a/tests/memory/test_oceandb.py +++ b/tests/memory/test_oceandb.py @@ -44,14 +44,18 @@ def test_append_document(): db = OceanDB(MockClient) collection = Mock() db.append_document(collection, "doc", "id") - collection.add.assert_called_once_with(documents=["doc"], ids=["id"]) + collection.add.assert_called_once_with( + documents=["doc"], ids=["id"] + ) def test_append_document_exception(): with patch("oceandb.Client") as MockClient: db = OceanDB(MockClient) collection = Mock() - collection.add.side_effect = Exception("Append document error") + collection.add.side_effect = Exception( + "Append document error" + ) with pytest.raises(Exception) as e: db.append_document(collection, "doc", "id") assert str(e.value) == "Append document error" @@ -73,7 +77,9 @@ def test_add_documents_exception(): collection = Mock() collection.add.side_effect = Exception("Add documents error") with pytest.raises(Exception) as e: - db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"]) + db.add_documents( + collection, ["doc1", "doc2"], ["id1", "id2"] + ) assert str(e.value) == "Add documents error" diff --git a/tests/memory/test_pg.py b/tests/memory/test_pg.py index e7b0587d..ba564586 100644 --- a/tests/memory/test_pg.py +++ b/tests/memory/test_pg.py @@ -23,7 +23,9 @@ def test_init(): def test_init_exception(): with pytest.raises(ValueError): PgVectorVectorStore( - connection_string="mysql://root:password@localhost:3306/test", + connection_string=( + "mysql://root:password@localhost:3306/test" + ), table_name="test", ) @@ -47,7 +49,10 @@ def test_upsert_vector(): table_name="test", ) store.upsert_vector( - [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"} + [1.0, 2.0, 3.0], + "test_id", + "test_namespace", + {"meta": "data"}, ) MockSession.assert_called() MockSession.return_value.merge.assert_called() diff --git a/tests/memory/test_pinecone.py b/tests/memory/test_pinecone.py index 106a6e81..9cc99781 100644 --- a/tests/memory/test_pinecone.py +++ b/tests/memory/test_pinecone.py @@ -10,7 +10,9 @@ def test_init(): "pinecone.Index" ) as MockIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) MockInit.assert_called_once() MockIndex.assert_called_once() @@ -20,10 +22,15 @@ def test_init(): def test_upsert_vector(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) store.upsert_vector( - [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"} + [1.0, 2.0, 3.0], + "test_id", + "test_namespace", + {"meta": "data"}, ) MockIndex.return_value.upsert.assert_called() @@ -31,7 +38,9 @@ def test_upsert_vector(): def test_load_entry(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) store.load_entry("test_id", "test_namespace") MockIndex.return_value.fetch.assert_called() @@ -40,7 +49,9 @@ def test_load_entry(): def test_load_entries(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) store.load_entries("test_namespace") MockIndex.return_value.query.assert_called() @@ -49,7 +60,9 @@ def test_load_entries(): def test_query(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) store.query("test_query", 10, "test_namespace") MockIndex.return_value.query.assert_called() @@ -60,7 +73,9 @@ def test_create_index(): "pinecone.create_index" ) as MockCreateIndex: store = PineconeVectorStore( - api_key=api_key, index_name="test_index", environment="test_env" + api_key=api_key, + index_name="test_index", + environment="test_env", ) store.create_index("test_index") MockCreateIndex.assert_called() diff --git a/tests/models/test_LLM.py b/tests/models/test_LLM.py index a7ca149f..04d6a5f2 100644 --- a/tests/models/test_LLM.py +++ b/tests/models/test_LLM.py @@ -17,7 +17,9 @@ class TestLLM(unittest.TestCase): self.prompt = "Who won the FIFA World Cup in 1998?" def test_init(self): - self.assertEqual(self.llm_openai.openai_api_key, "mock_openai_key") + self.assertEqual( + self.llm_openai.openai_api_key, "mock_openai_key" + ) self.assertEqual(self.llm_hf.hf_repo_id, "mock_repo_id") self.assertEqual(self.llm_hf.hf_api_token, "mock_hf_token") @@ -41,7 +43,9 @@ class TestLLM(unittest.TestCase): with self.assertRaises(ValueError): LLM(hf_repo_id="mock_repo_id") - @patch.dict(os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"}) + @patch.dict( + os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"} + ) def test_hf_token_from_env(self): llm = LLM(hf_repo_id="mock_repo_id") self.assertEqual(llm.hf_api_token, "mock_hf_token") diff --git a/tests/models/test_ada.py b/tests/models/test_ada.py index e65e1470..43895e79 100644 --- a/tests/models/test_ada.py +++ b/tests/models/test_ada.py @@ -26,7 +26,9 @@ def test_texts(): def test_get_ada_embeddings_basic(test_texts): with patch("openai.resources.Embeddings.create") as mock_create: # Mocking the OpenAI API call - mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} + mock_create.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } for text in test_texts: embedding = get_ada_embeddings(text) @@ -36,7 +38,8 @@ def test_get_ada_embeddings_basic(test_texts): 0.3, ], "Embedding does not match expected output" mock_create.assert_called_with( - input=[text.replace("\n", " ")], model="text-embedding-ada-002" + input=[text.replace("\n", " ")], + model="text-embedding-ada-002", ) @@ -44,16 +47,28 @@ def test_get_ada_embeddings_basic(test_texts): @pytest.mark.parametrize( "text, model, expected_call_model", [ - ("Hello World", "text-embedding-ada-002", "text-embedding-ada-002"), - ("Hello World", "text-embedding-ada-001", "text-embedding-ada-001"), + ( + "Hello World", + "text-embedding-ada-002", + "text-embedding-ada-002", + ), + ( + "Hello World", + "text-embedding-ada-001", + "text-embedding-ada-001", + ), ], ) def test_get_ada_embeddings_models(text, model, expected_call_model): with patch("openai.resources.Embeddings.create") as mock_create: - mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} + mock_create.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } _ = get_ada_embeddings(text, model=model) - mock_create.assert_called_with(input=[text], model=expected_call_model) + mock_create.assert_called_with( + input=[text], model=expected_call_model + ) # Exception Test diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index fecd3585..cc48479a 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -11,7 +11,9 @@ class MockAnthropicClient: def __init__(self, *args, **kwargs): pass - def completions_create(self, prompt, stop_sequences, stream, **kwargs): + def completions_create( + self, prompt, stop_sequences, stream, **kwargs + ): return MockAnthropicResponse() @@ -43,7 +45,10 @@ def test_anthropic_init_default_values(anthropic_instance): assert anthropic_instance.top_p is None assert anthropic_instance.streaming is False assert anthropic_instance.default_request_timeout == 600 - assert anthropic_instance.anthropic_api_url == "https://test.anthropic.com" + assert ( + anthropic_instance.anthropic_api_url + == "https://test.anthropic.com" + ) assert anthropic_instance.anthropic_api_key == "test_api_key" @@ -168,7 +173,9 @@ def test_anthropic_async_call_method(anthropic_instance): def test_anthropic_async_stream_method(anthropic_instance): - async_generator = anthropic_instance.async_stream("Translate to French.") + async_generator = anthropic_instance.async_stream( + "Translate to French." + ) for token in async_generator: assert isinstance(token, str) @@ -192,7 +199,9 @@ def test_anthropic_wrap_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance): prompt = "What is the meaning of life?" converted_prompt = anthropic_instance.convert_prompt(prompt) - assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) + assert converted_prompt.startswith( + anthropic_instance.HUMAN_PROMPT + ) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) @@ -226,21 +235,27 @@ def test_anthropic_async_stream_with_stop(anthropic_instance): assert isinstance(token, str) -def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance): +def test_anthropic_get_num_tokens_with_count_tokens( + anthropic_instance, +): anthropic_instance.count_tokens = Mock(return_value=10) text = "This is a test sentence." num_tokens = anthropic_instance.get_num_tokens(text) assert num_tokens == 10 -def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance): +def test_anthropic_get_num_tokens_without_count_tokens( + anthropic_instance, +): del anthropic_instance.count_tokens with pytest.raises(NameError): text = "This is a test sentence." anthropic_instance.get_num_tokens(text) -def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance): +def test_anthropic_wrap_prompt_without_human_ai_prompt( + anthropic_instance, +): del anthropic_instance.HUMAN_PROMPT del anthropic_instance.AI_PROMPT prompt = "What is the meaning of life?" diff --git a/tests/models/test_auto_temp.py b/tests/models/test_auto_temp.py index 76cdc7c3..7937d0dc 100644 --- a/tests/models/test_auto_temp.py +++ b/tests/models/test_auto_temp.py @@ -51,7 +51,9 @@ def test_run_no_scores(auto_temp_agent): max_workers=auto_temp_agent.max_workers ) as executor: with patch.object( - executor, "submit", side_effect=[None, None, None, None, None, None] + executor, + "submit", + side_effect=[None, None, None, None, None, None], ): result = auto_temp_agent.run(task, temperature_string) assert result == "No valid outputs generated." diff --git a/tests/models/test_bingchat.py b/tests/models/test_bingchat.py index 8f29f905..c87237e2 100644 --- a/tests/models/test_bingchat.py +++ b/tests/models/test_bingchat.py @@ -24,7 +24,9 @@ class TestBingChat(unittest.TestCase): def test_call(self): # Mocking the asynchronous behavior for the purpose of the test - self.chat.bot.ask = lambda *args, **kwargs: {"text": "Hello, Test!"} + self.chat.bot.ask = lambda *args, **kwargs: { + "text": "Hello, Test!" + } response = self.chat("Test prompt") self.assertEqual(response, "Hello, Test!") diff --git a/tests/models/test_bioclip.py b/tests/models/test_bioclip.py index 54ab5bb9..99e1e343 100644 --- a/tests/models/test_bioclip.py +++ b/tests/models/test_bioclip.py @@ -14,7 +14,9 @@ def sample_image_path(): @pytest.fixture def clip_instance(): - return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") + return BioClip( + "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" + ) # Basic tests for the BioClip class @@ -44,12 +46,16 @@ def test_clip_call_method(clip_instance, sample_image_path): assert len(result) == len(labels) -def test_clip_plot_image_with_metadata(clip_instance, sample_image_path): +def test_clip_plot_image_with_metadata( + clip_instance, sample_image_path +): metadata = { "filename": "sample_image.jpg", "top_probs": {"label1": 0.75, "label2": 0.65}, } - clip_instance.plot_image_with_metadata(sample_image_path, metadata) + clip_instance.plot_image_with_metadata( + sample_image_path, metadata + ) # More test cases can be added to cover additional functionality and edge cases @@ -147,7 +153,9 @@ def test_clip_inference_performance( # Test different preprocessing pipelines -def test_clip_preprocessing_pipelines(clip_instance, sample_image_path): +def test_clip_preprocessing_pipelines( + clip_instance, sample_image_path +): labels = ["label1", "label2"] image = Image.open(sample_image_path) diff --git a/tests/models/test_biogpt.py b/tests/models/test_biogpt.py index e1daa14e..38be125d 100644 --- a/tests/models/test_biogpt.py +++ b/tests/models/test_biogpt.py @@ -47,8 +47,8 @@ def test_cell_biology_response(biogpt_instance): # 40. Test for a question about protein structure def test_protein_structure_response(biogpt_instance): question = ( - "What's the difference between alpha helix and beta sheet structures in" - " proteins?" + "What's the difference between alpha helix and beta sheet" + " structures in proteins?" ) response = biogpt_instance(question) assert response and isinstance(response, str) @@ -77,7 +77,9 @@ def test_bioinformatics_response(biogpt_instance): # 44. Test for a neuroscience question def test_neuroscience_response(biogpt_instance): - question = "Explain the function of synapses in the nervous system." + question = ( + "Explain the function of synapses in the nervous system." + ) response = biogpt_instance(question) assert response and isinstance(response, str) @@ -157,7 +159,9 @@ def test_get_config_return_type(biogpt_instance): # 28. Test saving model functionality by checking if files are created @patch.object(BioGptForCausalLM, "save_pretrained") @patch.object(BioGptTokenizer, "save_pretrained") -def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): +def test_save_model( + mock_save_model, mock_save_tokenizer, biogpt_instance +): path = "test_path" biogpt_instance.save_model(path) mock_save_model.assert_called_once_with(path) @@ -167,7 +171,9 @@ def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): # 29. Test loading model from path @patch.object(BioGptForCausalLM, "from_pretrained") @patch.object(BioGptTokenizer, "from_pretrained") -def test_load_from_path(mock_load_model, mock_load_tokenizer, biogpt_instance): +def test_load_from_path( + mock_load_model, mock_load_tokenizer, biogpt_instance +): path = "test_path" biogpt_instance.load_from_path(path) mock_load_model.assert_called_once_with(path) @@ -184,7 +190,9 @@ def test_print_model_metadata(biogpt_instance): # 31. Test that beam_search_decoding uses the correct number of beams @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): +def test_beam_search_decoding_num_beams( + mock_generate, biogpt_instance +): biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) _, kwargs = mock_generate.call_args assert kwargs["num_beams"] == 7 @@ -192,8 +200,12 @@ def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): # 32. Test if beam_search_decoding handles early_stopping @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_early_stopping(mock_generate, biogpt_instance): - biogpt_instance.beam_search_decoding("test_sentence", early_stopping=False) +def test_beam_search_decoding_early_stopping( + mock_generate, biogpt_instance +): + biogpt_instance.beam_search_decoding( + "test_sentence", early_stopping=False + ) _, kwargs = mock_generate.call_args assert kwargs["early_stopping"] is False diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 08a0e39d..5e6fc948 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -41,7 +41,9 @@ def test_cohere_async_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - cohere_instance.async_call("Error handling with invalid API key.") + cohere_instance.async_call( + "Error handling with invalid API key." + ) def test_cohere_stream_api_error_handling(cohere_instance): @@ -91,7 +93,9 @@ def test_cohere_convert_prompt(cohere_instance): def test_cohere_call_with_stop(cohere_instance): - response = cohere_instance("Translate to French.", stop=["stop1", "stop2"]) + response = cohere_instance( + "Translate to French.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Cohere" @@ -147,14 +151,20 @@ def test_base_cohere_import(): def test_base_cohere_validate_environment(): - values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"} + values = { + "cohere_api_key": "my-api-key", + "user_agent": "langchain", + } validated_values = BaseCohere.validate_environment(values) assert "client" in validated_values assert "async_client" in validated_values def test_base_cohere_validate_environment_without_cohere(): - values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"} + values = { + "cohere_api_key": "my-api-key", + "user_agent": "langchain", + } with patch.dict("sys.modules", {"cohere": None}): with pytest.raises(ImportError): BaseCohere.validate_environment(values) @@ -163,8 +173,12 @@ def test_base_cohere_validate_environment_without_cohere(): # Test cases for benchmarking generations with various models def test_cohere_generate_with_command_light(cohere_instance): cohere_instance.model = "command-light" - response = cohere_instance("Generate text with Command Light model.") - assert response.startswith("Generated text with Command Light model") + response = cohere_instance( + "Generate text with Command Light model." + ) + assert response.startswith( + "Generated text with Command Light model" + ) def test_cohere_generate_with_command(cohere_instance): @@ -187,8 +201,12 @@ def test_cohere_generate_with_base(cohere_instance): def test_cohere_generate_with_embed_english_v2(cohere_instance): cohere_instance.model = "embed-english-v2.0" - response = cohere_instance("Generate embeddings with English v2.0 model.") - assert response.startswith("Generated embeddings with English v2.0 model") + response = cohere_instance( + "Generate embeddings with English v2.0 model." + ) + assert response.startswith( + "Generated embeddings with English v2.0 model" + ) def test_cohere_generate_with_embed_english_light_v2(cohere_instance): @@ -213,8 +231,12 @@ def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): def test_cohere_generate_with_embed_english_v3(cohere_instance): cohere_instance.model = "embed-english-v3.0" - response = cohere_instance("Generate embeddings with English v3.0 model.") - assert response.startswith("Generated embeddings with English v3.0 model") + response = cohere_instance( + "Generate embeddings with English v3.0 model." + ) + assert response.startswith( + "Generated embeddings with English v3.0 model" + ) def test_cohere_generate_with_embed_english_light_v3(cohere_instance): @@ -237,7 +259,9 @@ def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): ) -def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance): +def test_cohere_generate_with_embed_multilingual_light_v3( + cohere_instance, +): cohere_instance.model = "embed-multilingual-light-v3.0" response = cohere_instance( "Generate embeddings with Multilingual Light v3.0 model." @@ -274,13 +298,17 @@ def test_cohere_call_with_embed_english_v3_model(cohere_instance): assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance): +def test_cohere_call_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) -def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance): +def test_cohere_call_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance("Translate to French.") assert isinstance(response, str) @@ -300,7 +328,9 @@ def test_cohere_call_with_long_prompt(cohere_instance): def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit." + prompt = ( + "This is a test prompt that will exceed the max tokens limit." + ) with pytest.raises(ValueError): cohere_instance(prompt) @@ -333,14 +363,18 @@ def test_cohere_stream_with_embed_english_v3_model(cohere_instance): assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance): +def test_cohere_stream_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" generator = cohere_instance.stream("Write a story.") for token in generator: assert isinstance(token, str) -def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance): +def test_cohere_stream_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" generator = cohere_instance.stream("Write a story.") for token in generator: @@ -359,25 +393,33 @@ def test_cohere_async_call_with_base_model(cohere_instance): assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v2_model(cohere_instance): +def test_cohere_async_call_with_embed_english_v2_model( + cohere_instance, +): cohere_instance.model = "embed-english-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_english_v3_model(cohere_instance): +def test_cohere_async_call_with_embed_english_v3_model( + cohere_instance, +): cohere_instance.model = "embed-english-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance): +def test_cohere_async_call_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) -def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance): +def test_cohere_async_call_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" response = cohere_instance.async_call("Translate to French.") assert isinstance(response, str) @@ -397,28 +439,36 @@ def test_cohere_async_stream_with_base_model(cohere_instance): assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance): +def test_cohere_async_stream_with_embed_english_v2_model( + cohere_instance, +): cohere_instance.model = "embed-english-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance): +def test_cohere_async_stream_with_embed_english_v3_model( + cohere_instance, +): cohere_instance.model = "embed-english-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance): +def test_cohere_async_stream_with_embed_multilingual_v2_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v2.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: assert isinstance(token, str) -def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance): +def test_cohere_async_stream_with_embed_multilingual_v3_model( + cohere_instance, +): cohere_instance.model = "embed-multilingual-v3.0" async_generator = cohere_instance.async_stream("Write a story.") for token in async_generator: @@ -428,7 +478,9 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance): def test_cohere_representation_model_embedding(cohere_instance): # Test using the Representation model for text embedding cohere_instance.model = "embed-english-v3.0" - embedding = cohere_instance.embed("Generate an embedding for this text.") + embedding = cohere_instance.embed( + "Generate an embedding for this text." + ) assert isinstance(embedding, list) assert len(embedding) > 0 @@ -442,7 +494,9 @@ def test_cohere_representation_model_classification(cohere_instance): assert "score" in classification -def test_cohere_representation_model_language_detection(cohere_instance): +def test_cohere_representation_model_language_detection( + cohere_instance, +): # Test using the Representation model for language detection cohere_instance.model = "embed-english-v3.0" language = cohere_instance.detect_language( @@ -451,11 +505,15 @@ def test_cohere_representation_model_language_detection(cohere_instance): assert isinstance(language, str) -def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance): +def test_cohere_representation_model_max_tokens_limit_exceeded( + cohere_instance, +): # Test handling max tokens limit exceeded error cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit." + prompt = ( + "This is a test prompt that will exceed the max tokens limit." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -463,10 +521,14 @@ def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance): # Add more production-grade test cases based on real-world scenarios -def test_cohere_representation_model_multilingual_embedding(cohere_instance): +def test_cohere_representation_model_multilingual_embedding( + cohere_instance, +): # Test using the Representation model for multilingual text embedding cohere_instance.model = "embed-multilingual-v3.0" - embedding = cohere_instance.embed("Generate multilingual embeddings.") + embedding = cohere_instance.embed( + "Generate multilingual embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 @@ -476,7 +538,9 @@ def test_cohere_representation_model_multilingual_classification( ): # Test using the Representation model for multilingual text classification cohere_instance.model = "embed-multilingual-v3.0" - classification = cohere_instance.classify("Classify multilingual text.") + classification = cohere_instance.classify( + "Classify multilingual text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -500,8 +564,8 @@ def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for" - " multilingual model." + "This is a test prompt that will exceed the max tokens limit" + " for multilingual model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -512,7 +576,9 @@ def test_cohere_representation_model_multilingual_light_embedding( ): # Test using the Representation model for multilingual light text embedding cohere_instance.model = "embed-multilingual-light-v3.0" - embedding = cohere_instance.embed("Generate multilingual light embeddings.") + embedding = cohere_instance.embed( + "Generate multilingual light embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 @@ -548,8 +614,8 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for" - " multilingual light model." + "This is a test prompt that will exceed the max tokens limit" + " for multilingual light model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -558,14 +624,18 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede def test_cohere_command_light_model(cohere_instance): # Test using the Command Light model for text generation cohere_instance.model = "command-light" - response = cohere_instance("Generate text using Command Light model.") + response = cohere_instance( + "Generate text using Command Light model." + ) assert isinstance(response, str) def test_cohere_base_light_model(cohere_instance): # Test using the Base Light model for text generation cohere_instance.model = "base-light" - response = cohere_instance("Generate text using Base Light model.") + response = cohere_instance( + "Generate text using Base Light model." + ) assert isinstance(response, str) @@ -576,7 +646,9 @@ def test_cohere_generate_summarize_endpoint(cohere_instance): assert isinstance(response, str) -def test_cohere_representation_model_english_embedding(cohere_instance): +def test_cohere_representation_model_english_embedding( + cohere_instance, +): # Test using the Representation model for English text embedding cohere_instance.model = "embed-english-v3.0" embedding = cohere_instance.embed("Generate English embeddings.") @@ -584,10 +656,14 @@ def test_cohere_representation_model_english_embedding(cohere_instance): assert len(embedding) > 0 -def test_cohere_representation_model_english_classification(cohere_instance): +def test_cohere_representation_model_english_classification( + cohere_instance, +): # Test using the Representation model for English text classification cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify("Classify English text.") + classification = cohere_instance.classify( + "Classify English text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -611,17 +687,21 @@ def test_cohere_representation_model_english_max_tokens_limit_exceeded( cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for" - " English model." + "This is a test prompt that will exceed the max tokens limit" + " for English model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) -def test_cohere_representation_model_english_light_embedding(cohere_instance): +def test_cohere_representation_model_english_light_embedding( + cohere_instance, +): # Test using the Representation model for English light text embedding cohere_instance.model = "embed-english-light-v3.0" - embedding = cohere_instance.embed("Generate English light embeddings.") + embedding = cohere_instance.embed( + "Generate English light embeddings." + ) assert isinstance(embedding, list) assert len(embedding) > 0 @@ -631,7 +711,9 @@ def test_cohere_representation_model_english_light_classification( ): # Test using the Representation model for English light text classification cohere_instance.model = "embed-english-light-v3.0" - classification = cohere_instance.classify("Classify English light text.") + classification = cohere_instance.classify( + "Classify English light text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -655,8 +737,8 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( cohere_instance.model = "embed-english-light-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for" - " English light model." + "This is a test prompt that will exceed the max tokens limit" + " for English light model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -665,7 +747,9 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( def test_cohere_command_model(cohere_instance): # Test using the Command model for text generation cohere_instance.model = "command" - response = cohere_instance("Generate text using the Command model.") + response = cohere_instance( + "Generate text using the Command model." + ) assert isinstance(response, str) @@ -679,7 +763,9 @@ def test_cohere_invalid_model(cohere_instance): cohere_instance("Generate text using an invalid model.") -def test_cohere_base_model_generation_with_max_tokens(cohere_instance): +def test_cohere_base_model_generation_with_max_tokens( + cohere_instance, +): # Test generating text using the base model with a specified max_tokens limit cohere_instance.model = "base" cohere_instance.max_tokens = 20 diff --git a/tests/models/test_dalle3.py b/tests/models/test_dalle3.py index 9b7cf0e1..00ba7bc9 100644 --- a/tests/models/test_dalle3.py +++ b/tests/models/test_dalle3.py @@ -23,9 +23,7 @@ def dalle3(mock_openai_client): def test_dalle3_call_success(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" mock_openai_client.images.generate.return_value = Mock( data=[Mock(url=expected_img_url)] ) @@ -35,7 +33,9 @@ def test_dalle3_call_success(dalle3, mock_openai_client): # Assert assert img_url == expected_img_url - mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4) + mock_openai_client.images.generate.assert_called_once_with( + prompt=task, n=4 + ) def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): @@ -45,7 +45,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): # Mocking OpenAIError mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, http_status=500, error="Internal Server Error" + expected_error_message, + http_status=500, + error="Internal Server Error", ) # Act and assert @@ -53,7 +55,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): dalle3(task) assert str(excinfo.value) == expected_error_message - mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4) + mock_openai_client.images.generate.assert_called_once_with( + prompt=task, n=4 + ) # Ensure the error message is printed in red captured = capsys.readouterr() @@ -62,12 +66,8 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): def test_dalle3_create_variations_success(dalle3, mock_openai_client): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) - expected_variation_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" mock_openai_client.images.create_variation.return_value = Mock( data=[Mock(url=expected_variation_url)] ) @@ -84,16 +84,20 @@ def test_dalle3_create_variations_success(dalle3, mock_openai_client): assert kwargs["size"] == "1024x1024" -def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): +def test_dalle3_create_variations_failure( + dalle3, mock_openai_client, capsys +): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = OpenAIError( - expected_error_message, http_status=500, error="Internal Server Error" + mock_openai_client.images.create_variation.side_effect = ( + OpenAIError( + expected_error_message, + http_status=500, + error="Internal Server Error", + ) ) # Act and assert @@ -158,9 +162,7 @@ def test_dalle3_convert_to_bytesio(): def test_dalle3_call_multiple_times(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" mock_openai_client.images.generate.return_value = Mock( data=[Mock(url=expected_img_url)] ) @@ -180,7 +182,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client): task = "A" * 2048 # Input longer than API's limit expected_error_message = "Error running Dalle3: API Error" mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, http_status=500, error="Internal Server Error" + expected_error_message, + http_status=500, + error="Internal Server Error", ) # Act and assert @@ -228,14 +232,14 @@ def test_dalle3_convert_to_bytesio_invalid_format(dalle3): def test_dalle3_call_with_retry(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" # Simulate a retry scenario mock_openai_client.images.generate.side_effect = [ OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + "Temporary error", + http_status=500, + error="Internal Server Error", ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -248,19 +252,19 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): assert mock_openai_client.images.generate.call_count == 2 -def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): +def test_dalle3_create_variations_with_retry( + dalle3, mock_openai_client +): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) - expected_variation_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" # Simulate a retry scenario mock_openai_client.images.create_variation.side_effect = [ OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + "Temporary error", + http_status=500, + error="Internal Server Error", ), Mock(data=[Mock(url=expected_variation_url)]), ] @@ -273,14 +277,18 @@ def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): assert mock_openai_client.images.create_variation.call_count == 2 -def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys): +def test_dalle3_call_exception_logging( + dalle3, mock_openai_client, capsys +): # Arrange task = "A painting of a dog" expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, http_status=500, error="Internal Server Error" + expected_error_message, + http_status=500, + error="Internal Server Error", ) # Act @@ -296,14 +304,16 @@ def test_dalle3_create_variations_exception_logging( dalle3, mock_openai_client, capsys ): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = OpenAIError( - expected_error_message, http_status=500, error="Internal Server Error" + mock_openai_client.images.create_variation.side_effect = ( + OpenAIError( + expected_error_message, + http_status=500, + error="Internal Server Error", + ) ) # Act @@ -328,7 +338,9 @@ def test_dalle3_call_no_api_key(): # Arrange task = "A painting of a dog" dalle3 = Dalle3(api_key=None) - expected_error_message = "Error running Dalle3: API Key is missing" + expected_error_message = ( + "Error running Dalle3: API Key is missing" + ) # Act and assert with pytest.raises(ValueError) as excinfo: @@ -339,11 +351,11 @@ def test_dalle3_call_no_api_key(): def test_dalle3_create_variations_no_api_key(): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" dalle3 = Dalle3(api_key=None) - expected_error_message = "Error running Dalle3: API Key is missing" + expected_error_message = ( + "Error running Dalle3: API Key is missing" + ) # Act and assert with pytest.raises(ValueError) as excinfo: @@ -360,7 +372,9 @@ def test_dalle3_call_with_retry_max_retries_exceeded( # Simulate max retries exceeded mock_openai_client.images.generate.side_effect = OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + "Temporary error", + http_status=500, + error="Internal Server Error", ) # Act and assert @@ -374,13 +388,15 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded( dalle3, mock_openai_client ): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" # Simulate max retries exceeded - mock_openai_client.images.create_variation.side_effect = OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + mock_openai_client.images.create_variation.side_effect = ( + OpenAIError( + "Temporary error", + http_status=500, + error="Internal Server Error", + ) ) # Act and assert @@ -393,14 +409,14 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded( def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" # Simulate success after a retry mock_openai_client.images.generate.side_effect = [ OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + "Temporary error", + http_status=500, + error="Internal Server Error", ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -417,17 +433,15 @@ def test_dalle3_create_variations_retry_with_success( dalle3, mock_openai_client ): # Arrange - img_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - ) - expected_variation_url = ( - "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - ) + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" # Simulate success after a retry mock_openai_client.images.create_variation.side_effect = [ OpenAIError( - "Temporary error", http_status=500, error="Internal Server Error" + "Temporary error", + http_status=500, + error="Internal Server Error", ), Mock(data=[Mock(url=expected_variation_url)]), ] diff --git a/tests/models/test_distill_whisper.py b/tests/models/test_distill_whisper.py index 6f95a0e3..775bb896 100644 --- a/tests/models/test_distill_whisper.py +++ b/tests/models/test_distill_whisper.py @@ -8,7 +8,10 @@ import pytest import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor -from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry +from swarms.models.distilled_whisperx import ( + DistilWhisperModel, + async_retry, +) @pytest.fixture @@ -16,7 +19,9 @@ def distil_whisper_model(): return DistilWhisperModel() -def create_audio_file(data: np.ndarray, sample_rate: int, file_path: str): +def create_audio_file( + data: np.ndarray, sample_rate: int, file_path: str +): data.tofile(file_path) return file_path @@ -29,10 +34,18 @@ def test_initialization(distil_whisper_model): def test_transcribe_audio_file(distil_whisper_model): - test_data = np.random.rand(16000) # Simulated audio data (1 second) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: - audio_file_path = create_audio_file(test_data, 16000, audio_file.name) - transcription = distil_whisper_model.transcribe(audio_file_path) + test_data = np.random.rand( + 16000 + ) # Simulated audio data (1 second) + with tempfile.NamedTemporaryFile( + suffix=".wav", delete=False + ) as audio_file: + audio_file_path = create_audio_file( + test_data, 16000, audio_file.name + ) + transcription = distil_whisper_model.transcribe( + audio_file_path + ) os.remove(audio_file_path) assert isinstance(transcription, str) @@ -41,9 +54,15 @@ def test_transcribe_audio_file(distil_whisper_model): @pytest.mark.asyncio async def test_async_transcribe_audio_file(distil_whisper_model): - test_data = np.random.rand(16000) # Simulated audio data (1 second) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: - audio_file_path = create_audio_file(test_data, 16000, audio_file.name) + test_data = np.random.rand( + 16000 + ) # Simulated audio data (1 second) + with tempfile.NamedTemporaryFile( + suffix=".wav", delete=False + ) as audio_file: + audio_file_path = create_audio_file( + test_data, 16000, audio_file.name + ) transcription = await distil_whisper_model.async_transcribe( audio_file_path ) @@ -54,8 +73,12 @@ async def test_async_transcribe_audio_file(distil_whisper_model): def test_transcribe_audio_data(distil_whisper_model): - test_data = np.random.rand(16000) # Simulated audio data (1 second) - transcription = distil_whisper_model.transcribe(test_data.tobytes()) + test_data = np.random.rand( + 16000 + ) # Simulated audio data (1 second) + transcription = distil_whisper_model.transcribe( + test_data.tobytes() + ) assert isinstance(transcription, str) assert transcription.strip() != "" @@ -63,7 +86,9 @@ def test_transcribe_audio_data(distil_whisper_model): @pytest.mark.asyncio async def test_async_transcribe_audio_data(distil_whisper_model): - test_data = np.random.rand(16000) # Simulated audio data (1 second) + test_data = np.random.rand( + 16000 + ) # Simulated audio data (1 second) transcription = await distil_whisper_model.async_transcribe( test_data.tobytes() ) @@ -73,9 +98,15 @@ async def test_async_transcribe_audio_data(distil_whisper_model): def test_real_time_transcribe(distil_whisper_model, capsys): - test_data = np.random.rand(16000 * 5) # Simulated audio data (5 seconds) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: - audio_file_path = create_audio_file(test_data, 16000, audio_file.name) + test_data = np.random.rand( + 16000 * 5 + ) # Simulated audio data (5 seconds) + with tempfile.NamedTemporaryFile( + suffix=".wav", delete=False + ) as audio_file: + audio_file_path = create_audio_file( + test_data, 16000, audio_file.name + ) distil_whisper_model.real_time_transcribe( audio_file_path, chunk_duration=1 @@ -92,7 +123,9 @@ def test_real_time_transcribe_audio_file_not_found( distil_whisper_model, capsys ): audio_file_path = "non_existent_audio.wav" - distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) + distil_whisper_model.real_time_transcribe( + audio_file_path, chunk_duration=1 + ) captured = capsys.readouterr() assert "The audio file was not found." in captured.out @@ -100,7 +133,9 @@ def test_real_time_transcribe_audio_file_not_found( @pytest.fixture def mock_async_retry(): - def _mock_async_retry(retries=3, exceptions=(Exception,), delay=1): + def _mock_async_retry( + retries=3, exceptions=(Exception,), delay=1 + ): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): @@ -110,7 +145,9 @@ def mock_async_retry(): return decorator - with patch("distil_whisper_model.async_retry", new=_mock_async_retry()): + with patch( + "distil_whisper_model.async_retry", new=_mock_async_retry() + ): yield @@ -144,15 +181,21 @@ async def test_async_retry_decorator_multiple_attempts(): return "Success" mock_async_function.attempts = 0 - decorated_function = async_retry(max_retries=2)(mock_async_function) + decorated_function = async_retry(max_retries=2)( + mock_async_function + ) result = await decorated_function() assert result == "Success" def test_create_audio_file(): - test_data = np.random.rand(16000) # Simulated audio data (1 second) + test_data = np.random.rand( + 16000 + ) # Simulated audio data (1 second) sample_rate = 16000 - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: + with tempfile.NamedTemporaryFile( + suffix=".wav", delete=False + ) as audio_file: audio_file_path = create_audio_file( test_data, sample_rate, audio_file.name ) @@ -219,8 +262,12 @@ def test_file_not_found(whisper_model, invalid_audio_file_path): # Asynchronous tests @pytest.mark.asyncio -async def test_async_transcription_success(whisper_model, audio_file_path): - transcription = await whisper_model.async_transcribe(audio_file_path) +async def test_async_transcription_success( + whisper_model, audio_file_path +): + transcription = await whisper_model.async_transcribe( + audio_file_path + ) assert isinstance(transcription, str) @@ -233,8 +280,12 @@ async def test_async_transcription_failure( # Testing real-time transcription simulation -def test_real_time_transcription(whisper_model, audio_file_path, capsys): - whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) +def test_real_time_transcription( + whisper_model, audio_file_path, capsys +): + whisper_model.real_time_transcribe( + audio_file_path, chunk_duration=1 + ) captured = capsys.readouterr() assert "Starting real-time transcription..." in captured.out @@ -272,10 +323,14 @@ async def test_async_transcribe_with_mocked_model( ): model_mock, processor_mock = mocked_model # Set up what the mock should return when it's called - model_mock.return_value.generate.return_value = torch.tensor([[0]]) + model_mock.return_value.generate.return_value = torch.tensor( + [[0]] + ) processor_mock.return_value.batch_decode.return_value = [ "mocked transcription" ] model_wrapper = DistilWhisperModel() - transcription = await model_wrapper.async_transcribe(audio_file_path) + transcription = await model_wrapper.async_transcribe( + audio_file_path + ) assert transcription == "mocked transcription" diff --git a/tests/models/test_elevenlab.py b/tests/models/test_elevenlab.py index 986ce937..b28ecb31 100644 --- a/tests/models/test_elevenlab.py +++ b/tests/models/test_elevenlab.py @@ -1,6 +1,9 @@ import pytest from unittest.mock import patch, mock_open -from swarms.models.eleven_labs import ElevenLabsText2SpeechTool, ElevenLabsModel +from swarms.models.eleven_labs import ( + ElevenLabsText2SpeechTool, + ElevenLabsModel, +) import os from dotenv import load_dotenv @@ -26,31 +29,45 @@ def test_run_text_to_speech(eleven_labs_tool): def test_play_speech(eleven_labs_tool): - with patch("builtins.open", mock_open(read_data="fake_audio_data")): + with patch( + "builtins.open", mock_open(read_data="fake_audio_data") + ): eleven_labs_tool.play(EXPECTED_SPEECH_FILE) def test_stream_speech(eleven_labs_tool): - with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file: + with patch( + "tempfile.NamedTemporaryFile", mock_open() + ) as mock_file: eleven_labs_tool.stream_speech(SAMPLE_TEXT) - mock_file.assert_called_with(mode="bx", suffix=".wav", delete=False) + mock_file.assert_called_with( + mode="bx", suffix=".wav", delete=False + ) # Testing fixture and environment variables def test_api_key_validation(eleven_labs_tool): - with patch("langchain.utils.get_from_dict_or_env", return_value=API_KEY): + with patch( + "langchain.utils.get_from_dict_or_env", return_value=API_KEY + ): values = {"eleven_api_key": None} - validated_values = eleven_labs_tool.validate_environment(values) + validated_values = eleven_labs_tool.validate_environment( + values + ) assert "eleven_api_key" in validated_values # Mocking the external library def test_run_text_to_speech_with_mock(eleven_labs_tool): - with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file, patch( + with patch( + "tempfile.NamedTemporaryFile", mock_open() + ) as mock_file, patch( "your_module._import_elevenlabs" ) as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.return_value = b"fake_audio_data" + mock_elevenlabs_instance.generate.return_value = ( + b"fake_audio_data" + ) eleven_labs_tool.run(SAMPLE_TEXT) assert mock_file.call_args[1]["suffix"] == ".wav" assert mock_file.call_args[1]["delete"] is False @@ -67,7 +84,8 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): with pytest.raises( RuntimeError, match=( - "Error while running ElevenLabsText2SpeechTool: Test Exception" + "Error while running ElevenLabsText2SpeechTool: Test" + " Exception" ), ): eleven_labs_tool.run(SAMPLE_TEXT) @@ -75,9 +93,12 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): # Parameterized testing @pytest.mark.parametrize( - "model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL] + "model", + [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL], ) -def test_run_text_to_speech_with_different_models(eleven_labs_tool, model): +def test_run_text_to_speech_with_different_models( + eleven_labs_tool, model +): eleven_labs_tool.model = model speech_file = eleven_labs_tool.run(SAMPLE_TEXT) assert isinstance(speech_file, str) diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index a70cb42a..0fc74035 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -36,7 +36,9 @@ def fuyu_instance(): # Test using the fixture. def test_fuyu_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) - assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) + assert isinstance( + fuyu_instance.image_processor, FuyuImageProcessor + ) # Test exception when providing an invalid image path. @@ -76,9 +78,12 @@ def test_tokenizer_type(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance): assert ( - fuyu_instance.processor.image_processor == fuyu_instance.image_processor + fuyu_instance.processor.image_processor + == fuyu_instance.image_processor + ) + assert ( + fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer ) - assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer def test_model_device_map(fuyu_instance): diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index bca3b5f6..c716bb7c 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -26,16 +26,21 @@ def test_init(vision_api): def test_encode_image(vision_api): with patch( - "builtins.open", mock_open(read_data=b"test_image_data"), create=True + "builtins.open", + mock_open(read_data=b"test_image_data"), + create=True, ): encoded_image = vision_api.encode_image(img) assert encoded_image == "dGVzdF9pbWFnZV9kYXRh" def test_run_success(vision_api): - expected_response = {"choices": [{"text": "This is the model's response."}]} + expected_response = { + "choices": [{"text": "This is the model's response."}] + } with patch( - "requests.post", return_value=Mock(json=lambda: expected_response) + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api.run("What is this?", img) mock_post.assert_called_once() @@ -53,16 +58,20 @@ def test_run_request_error(vision_api): def test_run_response_error(vision_api): expected_response = {"error": "Model Error"} with patch( - "requests.post", return_value=Mock(json=lambda: expected_response) + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: with pytest.raises(RuntimeError): vision_api.run("What is this?", img) def test_call(vision_api): - expected_response = {"choices": [{"text": "This is the model's response."}]} + expected_response = { + "choices": [{"text": "This is the model's response."}] + } with patch( - "requests.post", return_value=Mock(json=lambda: expected_response) + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: result = vision_api("What is this?", img) mock_post.assert_called_once() @@ -88,10 +97,14 @@ def test_initialization_with_custom_key(): def test_run_successful_response(gpt_api): task = "What is in the image?" img_url = img - response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} + response_json = { + "choices": [{"text": "Answer from GPT-4 Vision"}] + } mock_response = Mock() mock_response.json.return_value = response_json - with patch("requests.post", return_value=mock_response) as mock_post: + with patch( + "requests.post", return_value=mock_response + ) as mock_post: result = gpt_api.run(task, img_url) mock_post.assert_called_once() assert result == response_json["choices"][0]["text"] @@ -100,7 +113,9 @@ def test_run_successful_response(gpt_api): def test_run_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch("requests.post", side_effect=Exception("Test Exception")): + with patch( + "requests.post", side_effect=Exception("Test Exception") + ): with pytest.raises(Exception): gpt_api.run(task, img_url) @@ -108,10 +123,14 @@ def test_run_with_exception(gpt_api): def test_call_method_successful_response(gpt_api): task = "What is in the image?" img_url = img - response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} + response_json = { + "choices": [{"text": "Answer from GPT-4 Vision"}] + } mock_response = Mock() mock_response.json.return_value = response_json - with patch("requests.post", return_value=mock_response) as mock_post: + with patch( + "requests.post", return_value=mock_response + ) as mock_post: result = gpt_api(task, img_url) mock_post.assert_called_once() assert result == response_json @@ -120,7 +139,9 @@ def test_call_method_successful_response(gpt_api): def test_call_method_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch("requests.post", side_effect=Exception("Test Exception")): + with patch( + "requests.post", side_effect=Exception("Test Exception") + ): with pytest.raises(Exception): gpt_api(task, img_url) @@ -128,12 +149,16 @@ def test_call_method_with_exception(gpt_api): @pytest.mark.asyncio async def test_arun_success(vision_api): expected_response = { - "choices": [{"message": {"content": "This is the model's response."}}] + "choices": [ + {"message": {"content": "This is the model's response."}} + ] } with patch( "aiohttp.ClientSession.post", new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock(return_value=expected_response)), + return_value=AsyncMock( + json=AsyncMock(return_value=expected_response) + ), ) as mock_post: result = await vision_api.arun("What is this?", img) mock_post.assert_called_once() @@ -153,10 +178,13 @@ async def test_arun_request_error(vision_api): def test_run_many_success(vision_api): expected_response = { - "choices": [{"message": {"content": "This is the model's response."}}] + "choices": [ + {"message": {"content": "This is the model's response."}} + ] } with patch( - "requests.post", return_value=Mock(json=lambda: expected_response) + "requests.post", + return_value=Mock(json=lambda: expected_response), ) as mock_post: tasks = ["What is this?", "What is that?"] imgs = [img, img] @@ -183,7 +211,9 @@ async def test_arun_json_decode_error(vision_api): with patch( "aiohttp.ClientSession.post", new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)), + return_value=AsyncMock( + json=AsyncMock(side_effect=ValueError) + ), ) as mock_post: with pytest.raises(ValueError): await vision_api.arun("What is this?", img) @@ -195,7 +225,9 @@ async def test_arun_api_error(vision_api): with patch( "aiohttp.ClientSession.post", new_callable=AsyncMock, - return_value=AsyncMock(json=AsyncMock(return_value=error_response)), + return_value=AsyncMock( + json=AsyncMock(return_value=error_response) + ), ) as mock_post: with pytest.raises(Exception, match="API Error"): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_gpt4v.py b/tests/models/test_gpt4v.py index 8532d313..cd0ee6d5 100644 --- a/tests/models/test_gpt4v.py +++ b/tests/models/test_gpt4v.py @@ -104,7 +104,9 @@ def test_gpt4vision_process_img_nonexistent_file(): gpt4vision.process_img(img_path) -def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision): +def test_gpt4vision_call_single_task_single_image_no_openai_client( + gpt4vision, +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." @@ -121,7 +123,9 @@ def test_gpt4vision_call_single_task_single_image_empty_response( img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.return_value.choices = [] + mock_openai_client.chat.completions.create.return_value.choices = ( + [] + ) # Act response = gpt4vision(img_url, [task]) @@ -138,7 +142,9 @@ def test_gpt4vision_call_multiple_tasks_single_image_empty_responses( img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" tasks = ["Describe this image.", "What's in this picture?"] - mock_openai_client.chat.completions.create.return_value.choices = [] + mock_openai_client.chat.completions.create.return_value.choices = ( + [] + ) # Act responses = gpt4vision(img_url, tasks) @@ -180,7 +186,9 @@ def test_gpt4vision_call_retry_with_success_after_timeout( "choices": [ { "message": { - "content": {"text": "A description of the image."} + "content": { + "text": "A description of the image." + } } } ], @@ -216,12 +224,14 @@ def test_gpt4vision_call_single_task_single_image( img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - expected_response = GPT4VisionResponse(answer="A description of the image.") - - mock_openai_client.chat.completions.create.return_value.choices[0].text = ( - expected_response.answer + expected_response = GPT4VisionResponse( + answer="A description of the image." ) + mock_openai_client.chat.completions.create.return_value.choices[ + 0 + ].text = expected_response.answer + # Act response = gpt4vision(img_url, [task]) @@ -240,12 +250,14 @@ def test_gpt4vision_call_single_task_multiple_images( ] task = "Describe these images." - expected_response = GPT4VisionResponse(answer="Descriptions of the images.") - - mock_openai_client.chat.completions.create.return_value.choices[0].text = ( - expected_response.answer + expected_response = GPT4VisionResponse( + answer="Descriptions of the images." ) + mock_openai_client.chat.completions.create.return_value.choices[ + 0 + ].text = expected_response.answer + # Act response = gpt4vision(img_urls, [task]) @@ -268,11 +280,14 @@ def test_gpt4vision_call_multiple_tasks_single_image( def create_mock_response(response): return { - "choices": [{"message": {"content": {"text": response.answer}}}] + "choices": [ + {"message": {"content": {"text": response.answer}}} + ] } mock_openai_client.chat.completions.create.side_effect = [ - create_mock_response(response) for response in expected_responses + create_mock_response(response) + for response in expected_responses ] # Act @@ -301,7 +316,9 @@ def test_gpt4vision_call_multiple_tasks_single_image( "choices": [ { "message": { - "content": {"text": expected_responses[i].answer} + "content": { + "text": expected_responses[i].answer + } } } ] @@ -335,7 +352,11 @@ def test_gpt4vision_call_multiple_tasks_multiple_images( ] mock_openai_client.chat.completions.create.side_effect = [ - {"choices": [{"message": {"content": {"text": response.answer}}}]} + { + "choices": [ + {"message": {"content": {"text": response.answer}}} + ] + } for response in expected_responses ] @@ -354,8 +375,8 @@ def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = HTTPError( - "HTTP Error" + mock_openai_client.chat.completions.create.side_effect = ( + HTTPError("HTTP Error") ) # Act and Assert @@ -363,13 +384,15 @@ def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): gpt4vision(img_url, [task]) -def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client): +def test_gpt4vision_call_request_error( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = RequestException( - "Request Error" + mock_openai_client.chat.completions.create.side_effect = ( + RequestException("Request Error") ) # Act and Assert @@ -377,13 +400,15 @@ def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client): gpt4vision(img_url, [task]) -def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client): +def test_gpt4vision_call_connection_error( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = ConnectionError( - "Connection Error" + mock_openai_client.chat.completions.create.side_effect = ( + ConnectionError("Connection Error") ) # Act and Assert @@ -391,7 +416,9 @@ def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client): gpt4vision(img_url, [task]) -def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client): +def test_gpt4vision_call_retry_with_success( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index d3ff9a04..dce13338 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -39,7 +39,9 @@ def hugging_face_llm( return HuggingFaceLLM(model_id="test") -def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm): +def test_init( + hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm +): assert hugging_face_llm.model_id == "test" mock_autotokenizer.from_pretrained.assert_called_once_with("test") mock_automodelforcausallm.from_pretrained.assert_called_once_with( @@ -63,7 +65,9 @@ def test_init_with_quantize( HuggingFaceLLM(model_id="test", quantize=True) - mock_bitsandbytesconfig.assert_called_once_with(**quantization_config) + mock_bitsandbytesconfig.assert_called_once_with( + **quantization_config + ) mock_autotokenizer.from_pretrained.assert_called_once_with("test") mock_automodelforcausallm.from_pretrained.assert_called_once_with( "test", quantization_config=quantization_config diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 62261b9c..8d53b8e0 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -38,7 +38,9 @@ def test_llm_bad_model_initialization(): # Mocking the tokenizer and model to test run method @patch("swarms.models.huggingface.AutoTokenizer.from_pretrained") -@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained") +@patch( + "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" +) def test_llm_run(mock_model, mock_tokenizer, llm_instance): mock_model.return_value.generate.return_value = "mocked output" mock_tokenizer.return_value.encode.return_value = "mocked input" @@ -80,7 +82,9 @@ def test_llm_memory_consumption(llm_instance): ) def test_llm_initialization_params(model_id, max_length): if max_length: - instance = HuggingfaceLLM(model_id=model_id, max_length=max_length) + instance = HuggingfaceLLM( + model_id=model_id, max_length=max_length + ) assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) @@ -141,7 +145,9 @@ def test_llm_run_output_length(mock_run, llm_instance): # Test the tokenizer handling special tokens correctly @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") -def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance): +def test_llm_tokenizer_special_tokens( + mock_decode, mock_encode, llm_instance +): mock_encode.return_value = "encoded input with special tokens" mock_decode.return_value = "decoded output with special tokens" result = llm_instance.run("test task with special tokens") @@ -192,7 +198,9 @@ def test_llm_run_model_exception(mock_generate, llm_instance): # Test the behavior when GPU is forced but not available @patch("torch.cuda.is_available", return_value=False) -def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): +def test_llm_force_gpu_when_unavailable( + mock_is_available, llm_instance +): with pytest.raises(EnvironmentError): llm_instance.set_device( "cuda" @@ -251,7 +259,9 @@ def test_llm_caching_mechanism(mock_run, llm_instance): @patch("swarms.models.huggingface.HuggingfaceLLM._download_model") def test_llm_force_download(mock_download, llm_instance): llm_instance.download_model_with_progress(force_download=True) - mock_download.assert_called_once_with(llm_instance.model_id, force=True) + mock_download.assert_called_once_with( + llm_instance.model_id, force=True + ) # These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class. diff --git a/tests/models/test_idefics.py b/tests/models/test_idefics.py index 2ee9f010..bb443533 100644 --- a/tests/models/test_idefics.py +++ b/tests/models/test_idefics.py @@ -46,8 +46,12 @@ def test_run(idefics_instance): prompts = [["User: Test"]] with patch.object( idefics_instance, "processor" - ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + ) as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -61,8 +65,12 @@ def test_call(idefics_instance): prompts = [["User: Test"]] with patch.object( idefics_instance, "processor" - ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: - mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} + ) as mock_processor, patch.object( + idefics_instance, "model" + ) as mock_model: + mock_processor.return_value = { + "input_ids": torch.tensor([1, 2, 3]) + } mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -75,7 +83,9 @@ def test_call(idefics_instance): def test_chat(idefics_instance): user_input = "User: Hello" response = "Model: Hi there!" - with patch.object(idefics_instance, "run", return_value=[response]): + with patch.object( + idefics_instance, "run", return_value=[response] + ): result = idefics_instance.chat(user_input) assert result == response @@ -87,7 +97,9 @@ def test_set_checkpoint(idefics_instance): new_checkpoint = "new_checkpoint" with patch.object( IdeficsForVisionText2Text, "from_pretrained" - ) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"): + ) as mock_from_pretrained, patch.object( + AutoProcessor, "from_pretrained" + ): idefics_instance.set_checkpoint(new_checkpoint) mock_from_pretrained.assert_called_with( diff --git a/tests/models/test_kosmos.py b/tests/models/test_kosmos.py index aaa756a3..1219f895 100644 --- a/tests/models/test_kosmos.py +++ b/tests/models/test_kosmos.py @@ -16,7 +16,9 @@ def mock_image_request(): img_data = open(TEST_IMAGE_URL, "rb").read() mock_resp = Mock() mock_resp.raw = img_data - with patch.object(requests, "get", return_value=mock_resp) as _fixture: + with patch.object( + requests, "get", return_value=mock_resp + ) as _fixture: yield _fixture @@ -109,12 +111,16 @@ def kosmos(): # Mocking the requests.get() method @pytest.fixture def mock_request_get(monkeypatch): - monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse()) + monkeypatch.setattr( + requests, "get", lambda url, **kwargs: MockResponse() + ) @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding(kosmos): - kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) + kosmos.multimodal_grounding( + "Find the red apple in the image.", IMG_URL1 + ) @pytest.mark.usefixtures("mock_request_get") @@ -126,7 +132,9 @@ def test_referring_expression_comprehension(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_generation(kosmos): - kosmos.referring_expression_generation("It is on the table.", IMG_URL3) + kosmos.referring_expression_generation( + "It is on the table.", IMG_URL3 + ) @pytest.mark.usefixtures("mock_request_get") @@ -146,7 +154,9 @@ def test_grounded_image_captioning_detailed(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding_2(kosmos): - kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) + kosmos.multimodal_grounding( + "Find the yellow fruit in the image.", IMG_URL2 + ) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/test_kosmos2.py b/tests/models/test_kosmos2.py index 1ad824cc..7e4f0e5f 100644 --- a/tests/models/test_kosmos2.py +++ b/tests/models/test_kosmos2.py @@ -55,7 +55,9 @@ def mock_process_entities_to_detections(entities, image): def test_kosmos2_with_mocked_extraction_and_detection( kosmos2, sample_image, monkeypatch ): - monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + monkeypatch.setattr( + kosmos2, "extract_entities", mock_extract_entities + ) monkeypatch.setattr( kosmos2, "process_entities_to_detections", @@ -73,7 +75,9 @@ def test_kosmos2_with_mocked_extraction_and_detection( # Test Kosmos2 with empty entity extraction -def test_kosmos2_with_empty_extraction(kosmos2, sample_image, monkeypatch): +def test_kosmos2_with_empty_extraction( + kosmos2, sample_image, monkeypatch +): monkeypatch.setattr(kosmos2, "extract_entities", lambda x: []) detections = kosmos2(img=sample_image) assert isinstance(detections, Detections) @@ -219,7 +223,9 @@ def test_kosmos2_with_invalid_hf_api_key(kosmos2, sample_image): # Test Kosmos2 with a very long generated text -def test_kosmos2_with_long_generated_text(kosmos2, sample_image, monkeypatch): +def test_kosmos2_with_long_generated_text( + kosmos2, sample_image, monkeypatch +): def mock_generate_text(*args, **kwargs): return "A" * 10000 @@ -246,7 +252,9 @@ def test_kosmos2_with_entities_containing_special_characters( ) ] - monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + monkeypatch.setattr( + kosmos2, "extract_entities", mock_extract_entities + ) detections = kosmos2(img=sample_image) assert isinstance(detections, Detections) assert ( @@ -267,7 +275,9 @@ def test_kosmos2_with_image_containing_multiple_objects( ("entity2", (0.5, 0.6, 0.7, 0.8)), ] - monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + monkeypatch.setattr( + kosmos2, "extract_entities", mock_extract_entities + ) detections = kosmos2(img=sample_image) assert isinstance(detections, Detections) assert ( @@ -285,7 +295,9 @@ def test_kosmos2_with_image_containing_no_objects( def mock_extract_entities(text): return [] - monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) + monkeypatch.setattr( + kosmos2, "extract_entities", mock_extract_entities + ) detections = kosmos2(img=sample_image) assert isinstance(detections, Detections) assert ( @@ -311,7 +323,9 @@ def test_kosmos2_with_valid_youtube_video_url(kosmos2): # Test Kosmos2 with an invalid YouTube video URL def test_kosmos2_with_invalid_youtube_video_url(kosmos2): - invalid_youtube_video_url = "https://www.youtube.com/invalid_video" + invalid_youtube_video_url = ( + "https://www.youtube.com/invalid_video" + ) with pytest.raises(Exception): kosmos2(video_url=invalid_youtube_video_url) diff --git a/tests/models/test_llama_function_caller.py b/tests/models/test_llama_function_caller.py index c38b2267..56ad481d 100644 --- a/tests/models/test_llama_function_caller.py +++ b/tests/models/test_llama_function_caller.py @@ -25,15 +25,26 @@ def test_llama_custom_function(llama_caller): function=sample_function, description="Sample custom function", arguments=[ - {"name": "arg1", "type": "string", "description": "Argument 1"}, - {"name": "arg2", "type": "string", "description": "Argument 2"}, + { + "name": "arg1", + "type": "string", + "description": "Argument 1", + }, + { + "name": "arg2", + "type": "string", + "description": "Argument 2", + }, ], ) result = llama_caller.call_function( "sample_function", arg1="arg1_value", arg2="arg2_value" ) - assert result == "Sample function called with args: arg1_value, arg2_value" + assert ( + result + == "Sample function called with args: arg1_value, arg2_value" + ) # Test streaming user prompts @@ -60,13 +71,23 @@ def test_llama_custom_function_invalid_arguments(llama_caller): function=sample_function, description="Sample custom function", arguments=[ - {"name": "arg1", "type": "string", "description": "Argument 1"}, - {"name": "arg2", "type": "string", "description": "Argument 2"}, + { + "name": "arg1", + "type": "string", + "description": "Argument 1", + }, + { + "name": "arg2", + "type": "string", + "description": "Argument 2", + }, ], ) with pytest.raises(TypeError): - llama_caller.call_function("sample_function", arg1="arg1_value") + llama_caller.call_function( + "sample_function", arg1="arg1_value" + ) # Test streaming with custom runtime diff --git a/tests/models/test_mpt7b.py b/tests/models/test_mpt7b.py index dfde578d..92b6c254 100644 --- a/tests/models/test_mpt7b.py +++ b/tests/models/test_mpt7b.py @@ -6,7 +6,9 @@ from swarms.models.mpt import MPT7B def test_mpt7b_init(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, ) assert isinstance(mpt, MPT7B) @@ -19,36 +21,55 @@ def test_mpt7b_init(): def test_mpt7b_run(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, + ) + output = mpt.run( + "generate", "Once upon a time in a land far, far away..." ) - output = mpt.run("generate", "Once upon a time in a land far, far away...") assert isinstance(output, str) - assert output.startswith("Once upon a time in a land far, far away...") + assert output.startswith( + "Once upon a time in a land far, far away..." + ) def test_mpt7b_run_invalid_task(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, ) with pytest.raises(ValueError): - mpt.run("invalid_task", "Once upon a time in a land far, far away...") + mpt.run( + "invalid_task", + "Once upon a time in a land far, far away...", + ) def test_mpt7b_generate(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, + ) + output = mpt.generate( + "Once upon a time in a land far, far away..." ) - output = mpt.generate("Once upon a time in a land far, far away...") assert isinstance(output, str) - assert output.startswith("Once upon a time in a land far, far away...") + assert output.startswith( + "Once upon a time in a land far, far away..." + ) def test_mpt7b_batch_generate(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, ) prompts = ["In the deep jungles,", "At the heart of the city,"] outputs = mpt.batch_generate(prompts, temperature=0.7) @@ -61,7 +82,9 @@ def test_mpt7b_batch_generate(): def test_mpt7b_unfreeze_model(): mpt = MPT7B( - "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150 + "mosaicml/mpt-7b-storywriter", + "EleutherAI/gpt-neox-20b", + max_tokens=150, ) mpt.unfreeze_model() diff --git a/tests/models/test_nougat.py b/tests/models/test_nougat.py index ac972e07..858845a6 100644 --- a/tests/models/test_nougat.py +++ b/tests/models/test_nougat.py @@ -22,7 +22,9 @@ def test_nougat_default_initialization(setup_nougat): def test_nougat_custom_initialization(): nougat = Nougat( - model_name_or_path="custom_path", min_length=10, max_new_tokens=50 + model_name_or_path="custom_path", + min_length=10, + max_new_tokens=50, ) assert nougat.model_name_or_path == "custom_path" assert nougat.min_length == 10 @@ -38,11 +40,16 @@ def test_model_initialization(setup_nougat): @pytest.mark.parametrize( - "cuda_available, expected_device", [(True, "cuda"), (False, "cpu")] + "cuda_available, expected_device", + [(True, "cuda"), (False, "cpu")], ) -def test_device_initialization(cuda_available, expected_device, monkeypatch): +def test_device_initialization( + cuda_available, expected_device, monkeypatch +): monkeypatch.setattr( - torch, "cuda", Mock(is_available=Mock(return_value=cuda_available)) + torch, + "cuda", + Mock(is_available=Mock(return_value=cuda_available)), ) nougat = Nougat() assert nougat.device == expected_device @@ -67,7 +74,9 @@ def test_get_image_invalid_path(setup_nougat): (10, 50), ], ) -def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): +def test_model_call_with_diff_params( + setup_nougat, min_len, max_tokens +): setup_nougat.min_length = min_len setup_nougat.max_new_tokens = max_tokens @@ -98,7 +107,8 @@ def test_model_call_mocked_output(setup_nougat): def mock_processor_and_model(): """Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior.""" with patch( - "transformers.NougatProcessor.from_pretrained", return_value=Mock() + "transformers.NougatProcessor.from_pretrained", + return_value=Mock(), ), patch( "transformers.VisionEncoderDecoderModel.from_pretrained", return_value=Mock(), @@ -161,7 +171,9 @@ def test_nougat_different_model_path(setup_nougat): @pytest.mark.usefixtures("mock_processor_and_model") def test_nougat_bad_image_path(setup_nougat): - with pytest.raises(Exception): # Adjust the exception type accordingly. + with pytest.raises( + Exception + ): # Adjust the exception type accordingly. setup_nougat("bad_image_path.png") diff --git a/tests/models/test_revgptv1.py b/tests/models/test_revgptv1.py index 5908b64e..cb539da6 100644 --- a/tests/models/test_revgptv1.py +++ b/tests/models/test_revgptv1.py @@ -16,12 +16,14 @@ class TestRevChatGPT(unittest.TestCase): def test_run_time(self): prompt = "Generate a 300 word essay about technology." self.model.run(prompt) - self.assertLess(self.model.end_time - self.model.start_time, 60) + self.assertLess( + self.model.end_time - self.model.start_time, 60 + ) def test_generate_summary(self): text = ( - "This is a sample text to summarize. It has multiple sentences and" - " details. The summary should be concise." + "This is a sample text to summarize. It has multiple" + " sentences and details. The summary should be concise." ) summary = self.model.generate_summary(text) self.assertLess(len(summary), len(text) / 2) @@ -64,7 +66,8 @@ class TestRevChatGPT(unittest.TestCase): title = "New Title" self.model.chatbot.change_title(convo_id, title) self.assertEqual( - self.model.chatbot.get_msg_history(convo_id)["title"], title + self.model.chatbot.get_msg_history(convo_id)["title"], + title, ) def test_delete_conversation(self): diff --git a/tests/models/test_speech_t5.py b/tests/models/test_speech_t5.py index f4d21a30..a33272fc 100644 --- a/tests/models/test_speech_t5.py +++ b/tests/models/test_speech_t5.py @@ -14,9 +14,13 @@ def speecht5_model(): def test_speecht5_init(speecht5_model): - assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) + assert isinstance( + speecht5_model.processor, SpeechT5.processor.__class__ + ) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) - assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) + assert isinstance( + speecht5_model.vocoder, SpeechT5.vocoder.__class__ + ) assert isinstance( speecht5_model.embeddings_dataset, torch.utils.data.Dataset ) @@ -43,7 +47,10 @@ def test_speecht5_set_model(speecht5_model): speecht5_model.set_model(new_model_name) assert speecht5_model.model_name == new_model_name assert speecht5_model.processor.model_name == new_model_name - assert speecht5_model.model.config.model_name_or_path == new_model_name + assert ( + speecht5_model.model.config.model_name_or_path + == new_model_name + ) speecht5_model.set_model(old_model_name) # Restore original model @@ -52,8 +59,13 @@ def test_speecht5_set_vocoder(speecht5_model): new_vocoder_name = "facebook/speecht5-hifigan" speecht5_model.set_vocoder(new_vocoder_name) assert speecht5_model.vocoder_name == new_vocoder_name - assert speecht5_model.vocoder.config.model_name_or_path == new_vocoder_name - speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder + assert ( + speecht5_model.vocoder.config.model_name_or_path + == new_vocoder_name + ) + speecht5_model.set_vocoder( + old_vocoder_name + ) # Restore original vocoder def test_speecht5_set_embeddings_dataset(speecht5_model): @@ -98,7 +110,9 @@ def test_speecht5_change_dataset_split(speecht5_model): def test_speecht5_load_custom_embedding(speecht5_model): xvector = [0.1, 0.2, 0.3, 0.4, 0.5] embedding = speecht5_model.load_custom_embedding(xvector) - assert torch.all(torch.eq(embedding, torch.tensor(xvector).unsqueeze(0))) + assert torch.all( + torch.eq(embedding, torch.tensor(xvector).unsqueeze(0)) + ) def test_speecht5_with_different_speakers(speecht5_model): @@ -109,7 +123,9 @@ def test_speecht5_with_different_speakers(speecht5_model): assert isinstance(speech, torch.Tensor) -def test_speecht5_save_speech_with_different_extensions(speecht5_model): +def test_speecht5_save_speech_with_different_extensions( + speecht5_model, +): text = "Hello, how are you?" speech = speecht5_model(text) extensions = [".wav", ".flac"] @@ -122,7 +138,9 @@ def test_speecht5_save_speech_with_different_extensions(speecht5_model): def test_speecht5_invalid_speaker_id(speecht5_model): text = "Hello, how are you?" - invalid_speaker_id = 9999 # Speaker ID that does not exist in the dataset + invalid_speaker_id = ( + 9999 # Speaker ID that does not exist in the dataset + ) with pytest.raises(IndexError): speecht5_model(text, speaker_id=invalid_speaker_id) @@ -142,4 +160,6 @@ def test_speecht5_change_vocoder_model(speecht5_model): speecht5_model.set_vocoder(new_vocoder_name) speech = speecht5_model(text) assert isinstance(speech, torch.Tensor) - speecht5_model.set_vocoder(old_vocoder_name) # Restore original vocoder + speecht5_model.set_vocoder( + old_vocoder_name + ) # Restore original vocoder diff --git a/tests/models/test_ssd_1b.py b/tests/models/test_ssd_1b.py index 7a7a897f..35cc4864 100644 --- a/tests/models/test_ssd_1b.py +++ b/tests/models/test_ssd_1b.py @@ -41,7 +41,9 @@ def test_ssd1b_parameterized_task(ssd1b_model, task): # Example of a test using mocks to isolate units of code def test_ssd1b_with_mock(ssd1b_model, mocker): - mocker.patch("your_module.StableDiffusionXLPipeline") # Mock the pipeline + mocker.patch( + "your_module.StableDiffusionXLPipeline" + ) # Mock the pipeline task = "A painting of a cat" image_url = ssd1b_model(task) assert isinstance(image_url, str) @@ -225,7 +227,9 @@ def test_ssd1b_repr_str(ssd1b_model): def test_ssd1b_rate_limited_call(ssd1b_model, mocker): task = "A painting of a dog" mocker.patch.object( - ssd1b_model, "__call__", side_effect=Exception("Rate limit exceeded") + ssd1b_model, + "__call__", + side_effect=Exception("Rate limit exceeded"), ) with pytest.raises(Exception, match="Rate limit exceeded"): ssd1b_model.rate_limited_call(task) diff --git a/tests/models/test_timm_model.py b/tests/models/test_timm_model.py index 07f68b05..97499c6a 100644 --- a/tests/models/test_timm_model.py +++ b/tests/models/test_timm_model.py @@ -6,7 +6,9 @@ from swarms.models.timm import TimmModel, TimmModelInfo @pytest.fixture def sample_model_info(): - return TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + return TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) def test_get_supported_models(): @@ -25,7 +27,9 @@ def test_create_model(sample_model_info): def test_call(sample_model_info): model_handler = TimmModel() input_tensor = torch.randn(1, 3, 224, 224) - output_shape = model_handler.__call__(sample_model_info, input_tensor) + output_shape = model_handler.__call__( + sample_model_info, input_tensor + ) assert isinstance(output_shape, torch.Size) @@ -39,7 +43,9 @@ def test_call(sample_model_info): ) def test_create_model_parameterized(model_name, pretrained, in_chans): model_info = TimmModelInfo( - model_name=model_name, pretrained=pretrained, in_chans=in_chans + model_name=model_name, + pretrained=pretrained, + in_chans=in_chans, ) model_handler = TimmModel() model = model_handler._create_model(model_info) @@ -56,7 +62,9 @@ def test_create_model_parameterized(model_name, pretrained, in_chans): ) def test_call_parameterized(model_name, pretrained, in_chans): model_info = TimmModelInfo( - model_name=model_name, pretrained=pretrained, in_chans=in_chans + model_name=model_name, + pretrained=pretrained, + in_chans=in_chans, ) model_handler = TimmModel() input_tensor = torch.randn(1, in_chans, 224, 224) @@ -133,7 +141,9 @@ def test_marked_slow(): ) def test_marked_parameterized(model_name, pretrained, in_chans): model_info = TimmModelInfo( - model_name=model_name, pretrained=pretrained, in_chans=in_chans + model_name=model_name, + pretrained=pretrained, + in_chans=in_chans, ) model_handler = TimmModel() model = model_handler._create_model(model_info) diff --git a/tests/models/test_vilt.py b/tests/models/test_vilt.py index 8dcdce88..99e6848e 100644 --- a/tests/models/test_vilt.py +++ b/tests/models/test_vilt.py @@ -19,13 +19,17 @@ def test_vilt_initialization(vilt_instance): # 2. Test Model Predictions @patch.object(requests, "get") @patch.object(Image, "open") -def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): +def test_vilt_prediction( + mock_image_open, mock_requests_get, vilt_instance +): mock_image = Mock() mock_image_open.return_value = mock_image mock_requests_get.return_value.raw = Mock() # It's a mock response, so no real answer expected - with pytest.raises(Exception): # Ensure exception is more specific + with pytest.raises( + Exception + ): # Ensure exception is more specific vilt_instance( "What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", @@ -34,7 +38,9 @@ def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): # 3. Test Exception Handling for network @patch.object( - requests, "get", side_effect=requests.RequestException("Network error") + requests, + "get", + side_effect=requests.RequestException("Network error"), ) def test_vilt_network_exception(vilt_instance): with pytest.raises(requests.RequestException): @@ -50,12 +56,17 @@ def test_vilt_network_exception(vilt_instance): [ ("What is this?", "http://example.com/image1.jpg"), ("Who is in the image?", "http://example.com/image2.jpg"), - ("Where was this picture taken?", "http://example.com/image3.jpg"), + ( + "Where was this picture taken?", + "http://example.com/image3.jpg", + ), # ... Add more scenarios ], ) def test_vilt_various_inputs(text, image_url, vilt_instance): - with pytest.raises(Exception): # Again, ensure exception is more specific + with pytest.raises( + Exception + ): # Again, ensure exception is more specific vilt_instance(text, image_url) diff --git a/tests/models/test_whisperx.py b/tests/models/test_whisperx.py index ed671cb2..4b0e4120 100644 --- a/tests/models/test_whisperx.py +++ b/tests/models/test_whisperx.py @@ -34,7 +34,9 @@ def test_speech_to_text_download_youtube_video( # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") + mock_stream.download.return_value = os.path.join( + temp_dir, "video.mp4" + ) mock_youtube.return_value = mock_youtube mock_youtube.streams = mock_streams @@ -68,7 +70,9 @@ def test_speech_to_text_transcribe_youtube_video( mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} + mock_align.return_value = { + "segments": [{"text": "Hello, World!"}] + } # Mock diarization pipeline mock_diarization.return_value = None @@ -193,14 +197,18 @@ def test_speech_to_text_transcribe_diarization_failure( # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join(temp_dir, "video.mp4") + mock_stream.download.return_value = os.path.join( + temp_dir, "video.mp4" + ) mock_youtube.return_value = mock_youtube mock_youtube.streams = mock_streams # Mock whisperx functions mock_load_audio.return_value = "audio_path" mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = {"segments": [{"text": "Hello, World!"}]} + mock_align.return_value = { + "segments": [{"text": "Hello, World!"}] + } # Mock diarization pipeline to raise an exception mock_diarization.side_effect = Exception("Diarization failed") diff --git a/tests/models/test_yi_200k.py b/tests/models/test_yi_200k.py index 6b179ca1..9f3c236f 100644 --- a/tests/models/test_yi_200k.py +++ b/tests/models/test_yi_200k.py @@ -31,7 +31,9 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length): @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): +def test_yi34b_generate_text_with_temperature( + yi34b_model, temperature +): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, temperature=temperature) assert isinstance(generated_text, str) @@ -89,7 +91,9 @@ def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): prompt = "There's a place where time stands still." top_p = 1.5 # Invalid top_p - with pytest.raises(ValueError, match="top_p must be between 0.0 and 1.0"): + with pytest.raises( + ValueError, match="top_p must be between 0.0 and 1.0" + ): yi34b_model(prompt, top_p=top_p) @@ -98,15 +102,20 @@ def test_yi34b_generate_text_with_repitition_penalty( yi34b_model, repitition_penalty ): prompt = "There's a place where time stands still." - generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) + generated_text = yi34b_model( + prompt, repitition_penalty=repitition_penalty + ) assert isinstance(generated_text, str) -def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model): +def test_yi34b_generate_text_with_invalid_repitition_penalty( + yi34b_model, +): prompt = "There's a place where time stands still." repitition_penalty = 0.0 # Invalid repitition_penalty with pytest.raises( - ValueError, match="repitition_penalty must be a positive float" + ValueError, + match="repitition_penalty must be a positive float", ): yi34b_model(prompt, repitition_penalty=repitition_penalty) diff --git a/tests/structs/test_flow.py b/tests/structs/test_flow.py index 056a3047..a8e1cf92 100644 --- a/tests/structs/test_flow.py +++ b/tests/structs/test_flow.py @@ -31,7 +31,9 @@ def basic_flow(mocked_llm): @pytest.fixture def flow_with_condition(mocked_llm): return Agent( - llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats + llm=mocked_llm, + max_loops=5, + stopping_condition=stop_when_repeats, ) @@ -69,7 +71,9 @@ def test_run_without_stopping_condition(mocked_sleep, basic_flow): @patch("time.sleep", return_value=None) # to speed up tests -def test_run_with_stopping_condition(mocked_sleep, flow_with_condition): +def test_run_with_stopping_condition( + mocked_sleep, flow_with_condition +): response = flow_with_condition.run("Stop") assert response == "Stop" @@ -113,7 +117,9 @@ def test_flow_with_custom_stopping_condition(mocked_llm): return "terminate" in x.lower() agent = Agent( - llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition + llm=mocked_llm, + max_loops=5, + stopping_condition=stopping_condition, ) assert agent.stopping_condition("Please terminate now") assert not agent.stopping_condition("Continue the process") @@ -127,7 +133,9 @@ def test_flow_call(basic_flow): # Test formatting the prompt def test_format_prompt(basic_flow): - formatted_prompt = basic_flow.format_prompt("Hello {name}", name="John") + formatted_prompt = basic_flow.format_prompt( + "Hello {name}", name="John" + ) assert formatted_prompt == "Hello John" @@ -155,7 +163,11 @@ def test_interactive_mode(basic_flow): # Test bulk run with varied inputs def test_bulk_run_varied_inputs(basic_flow): - inputs = [{"task": "Test1"}, {"task": "Test2"}, {"task": "Stop now"}] + inputs = [ + {"task": "Test1"}, + {"task": "Test2"}, + {"task": "Stop now"}, + ] responses = basic_flow.bulk_run(inputs) assert responses == ["Test1", "Test2", "Stop now"] @@ -179,7 +191,9 @@ def test_save_different_memory(basic_flow, tmp_path): # Test the stopping condition check def test_check_stopping_condition(flow_with_condition): - assert flow_with_condition._check_stopping_condition("Stop this process") + assert flow_with_condition._check_stopping_condition( + "Stop this process" + ) assert not flow_with_condition._check_stopping_condition( "Continue the task" ) @@ -211,7 +225,10 @@ def test_mocked_openai_chat(MockedOpenAIChat): @patch("time.sleep", return_value=None) def test_retry_attempts(mocked_sleep, basic_flow): basic_flow.retry_attempts = 2 - basic_flow.llm.side_effect = [Exception("Test Exception"), "Valid response"] + basic_flow.llm.side_effect = [ + Exception("Test Exception"), + "Valid response", + ] response = basic_flow.run("Test retry") assert response == "Valid response" @@ -235,7 +252,9 @@ def test_different_retry_intervals(mocked_sleep, basic_flow): # Test invoking the agent with additional kwargs @patch("time.sleep", return_value=None) def test_flow_call_with_kwargs(mocked_sleep, basic_flow): - response = basic_flow("Test call", param1="value1", param2="value2") + response = basic_flow( + "Test call", param1="value1", param2="value2" + ) assert response == "Test call" @@ -319,7 +338,9 @@ def test_flow_autosave(flow_instance): def test_flow_response_filtering(flow_instance): # Test the response filtering functionality flow_instance.add_response_filter("filter_this") - response = flow_instance.filtered_run("This message should filter_this") + response = flow_instance.filtered_run( + "This message should filter_this" + ) assert "filter_this" not in response @@ -378,10 +399,12 @@ def test_flow_autosave_path(flow_instance): def test_flow_response_length(flow_instance): # Test checking the length of the response response = flow_instance.run( - "Generate a 10,000 word long blog on mental clarity and the benefits of" - " meditation." + "Generate a 10,000 word long blog on mental clarity and the" + " benefits of meditation." + ) + assert ( + len(response) > flow_instance.get_response_length_threshold() ) - assert len(response) > flow_instance.get_response_length_threshold() def test_flow_set_response_length_threshold(flow_instance): @@ -470,7 +493,9 @@ def test_flow_get_conversation_log(flow_instance): flow_instance.run("Message 1") flow_instance.run("Message 2") conversation_log = flow_instance.get_conversation_log() - assert len(conversation_log) == 4 # Including system and user messages + assert ( + len(conversation_log) == 4 + ) # Including system and user messages def test_flow_clear_conversation_log(flow_instance): @@ -554,11 +579,21 @@ def test_flow_rollback(flow_instance): flow_instance.change_prompt("New prompt") state2 = flow_instance.get_state() flow_instance.rollback_to_state(state1) - assert flow_instance.get_current_prompt() == state1["current_prompt"] + assert ( + flow_instance.get_current_prompt() == state1["current_prompt"] + ) assert flow_instance.get_instructions() == state1["instructions"] - assert flow_instance.get_user_messages() == state1["user_messages"] - assert flow_instance.get_response_history() == state1["response_history"] - assert flow_instance.get_conversation_log() == state1["conversation_log"] + assert ( + flow_instance.get_user_messages() == state1["user_messages"] + ) + assert ( + flow_instance.get_response_history() + == state1["response_history"] + ) + assert ( + flow_instance.get_conversation_log() + == state1["conversation_log"] + ) assert ( flow_instance.is_dynamic_pacing_enabled() == state1["dynamic_pacing_enabled"] @@ -567,9 +602,14 @@ def test_flow_rollback(flow_instance): flow_instance.get_response_length_threshold() == state1["response_length_threshold"] ) - assert flow_instance.get_response_filters() == state1["response_filters"] + assert ( + flow_instance.get_response_filters() + == state1["response_filters"] + ) assert flow_instance.get_max_loops() == state1["max_loops"] - assert flow_instance.get_autosave_path() == state1["autosave_path"] + assert ( + flow_instance.get_autosave_path() == state1["autosave_path"] + ) assert flow_instance.get_state() == state1 @@ -587,9 +627,13 @@ def test_flow_contextual_intent(flow_instance): def test_flow_contextual_intent_override(flow_instance): # Test contextual intent override flow_instance.add_context("location", "New York") - response1 = flow_instance.run("What's the weather like in {location}?") + response1 = flow_instance.run( + "What's the weather like in {location}?" + ) flow_instance.add_context("location", "Los Angeles") - response2 = flow_instance.run("What's the weather like in {location}?") + response2 = flow_instance.run( + "What's the weather like in {location}?" + ) assert "New York" in response1 assert "Los Angeles" in response2 @@ -597,9 +641,13 @@ def test_flow_contextual_intent_override(flow_instance): def test_flow_contextual_intent_reset(flow_instance): # Test resetting contextual intent flow_instance.add_context("location", "New York") - response1 = flow_instance.run("What's the weather like in {location}?") + response1 = flow_instance.run( + "What's the weather like in {location}?" + ) flow_instance.reset_context() - response2 = flow_instance.run("What's the weather like in {location}?") + response2 = flow_instance.run( + "What's the weather like in {location}?" + ) assert "New York" in response1 assert "New York" in response2 @@ -624,7 +672,9 @@ def test_flow_non_interruptible(flow_instance): def test_flow_timeout(flow_instance): # Test conversation timeout flow_instance.timeout = 60 # Set a timeout of 60 seconds - response = flow_instance.run("This should take some time to respond.") + response = flow_instance.run( + "This should take some time to respond." + ) assert "Timed out" in response assert flow_instance.is_timed_out() is True @@ -673,14 +723,20 @@ def test_flow_save_and_load_conversation(flow_instance): def test_flow_inject_custom_system_message(flow_instance): # Test injecting a custom system message into the conversation - flow_instance.inject_custom_system_message("Custom system message") - assert "Custom system message" in flow_instance.get_message_history() + flow_instance.inject_custom_system_message( + "Custom system message" + ) + assert ( + "Custom system message" in flow_instance.get_message_history() + ) def test_flow_inject_custom_user_message(flow_instance): # Test injecting a custom user message into the conversation flow_instance.inject_custom_user_message("Custom user message") - assert "Custom user message" in flow_instance.get_message_history() + assert ( + "Custom user message" in flow_instance.get_message_history() + ) def test_flow_inject_custom_response(flow_instance): @@ -691,13 +747,23 @@ def test_flow_inject_custom_response(flow_instance): def test_flow_clear_injected_messages(flow_instance): # Test clearing injected messages from the conversation - flow_instance.inject_custom_system_message("Custom system message") + flow_instance.inject_custom_system_message( + "Custom system message" + ) flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_response("Custom response") flow_instance.clear_injected_messages() - assert "Custom system message" not in flow_instance.get_message_history() - assert "Custom user message" not in flow_instance.get_message_history() - assert "Custom response" not in flow_instance.get_message_history() + assert ( + "Custom system message" + not in flow_instance.get_message_history() + ) + assert ( + "Custom user message" + not in flow_instance.get_message_history() + ) + assert ( + "Custom response" not in flow_instance.get_message_history() + ) def test_flow_disable_message_history(flow_instance): @@ -706,14 +772,20 @@ def test_flow_disable_message_history(flow_instance): response = flow_instance.run( "This message should not be recorded in history." ) - assert "This message should not be recorded in history." in response - assert len(flow_instance.get_message_history()) == 0 # History is empty + assert ( + "This message should not be recorded in history." in response + ) + assert ( + len(flow_instance.get_message_history()) == 0 + ) # History is empty def test_flow_enable_message_history(flow_instance): # Test enabling message history recording flow_instance.enable_message_history() - response = flow_instance.run("This message should be recorded in history.") + response = flow_instance.run( + "This message should be recorded in history." + ) assert "This message should be recorded in history." in response assert len(flow_instance.get_message_history()) == 1 @@ -723,7 +795,9 @@ def test_flow_custom_logger(flow_instance): custom_logger = logger # Replace with your custom logger class flow_instance.set_logger(custom_logger) response = flow_instance.run("Custom logger test") - assert "Logged using custom logger" in response # Verify logging message + assert ( + "Logged using custom logger" in response + ) # Verify logging message def test_flow_batch_processing(flow_instance): @@ -991,8 +1065,14 @@ def test_flow_custom_response(flow_instance): flow_instance.set_response_generator(custom_response_generator) assert flow_instance.run("Hello") == "Hi there!" - assert flow_instance.run("How are you?") == "I'm doing well, thank you." - assert flow_instance.run("What's your name?") == "I don't understand." + assert ( + flow_instance.run("How are you?") + == "I'm doing well, thank you." + ) + assert ( + flow_instance.run("What's your name?") + == "I don't understand." + ) def test_flow_message_validation(flow_instance): @@ -1003,8 +1083,12 @@ def test_flow_message_validation(flow_instance): flow_instance.set_message_validator(custom_message_validator) assert flow_instance.run("Valid message") is not None - assert flow_instance.run("") is None # Empty message should be rejected - assert flow_instance.run(None) is None # None message should be rejected + assert ( + flow_instance.run("") is None + ) # Empty message should be rejected + assert ( + flow_instance.run(None) is None + ) # None message should be rejected def test_flow_custom_logging(flow_instance): @@ -1029,10 +1113,15 @@ def test_flow_complex_use_case(flow_instance): flow_instance.add_context("user_id", "12345") flow_instance.run("Hello") flow_instance.run("How can I help you?") - assert flow_instance.get_response() == "Please provide more details." + assert ( + flow_instance.get_response() == "Please provide more details." + ) flow_instance.update_context("user_id", "54321") flow_instance.run("I need help with my order") - assert flow_instance.get_response() == "Sure, I can assist with that." + assert ( + flow_instance.get_response() + == "Sure, I can assist with that." + ) flow_instance.reset_conversation() assert len(flow_instance.get_message_history()) == 0 assert flow_instance.get_context("user_id") is None @@ -1071,7 +1160,9 @@ def test_flow_concurrent_requests(flow_instance): def test_flow_custom_timeout(flow_instance): # Test custom timeout handling - flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds + flow_instance.set_timeout( + 10 + ) # Set a custom timeout of 10 seconds assert flow_instance.get_timeout() == 10 import time @@ -1125,8 +1216,13 @@ def test_flow_agent_history_prompt(flow_instance): system_prompt, history ) - assert "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt - assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt + assert ( + "SYSTEM_PROMPT: This is the system prompt." + in agent_history_prompt + ) + assert ( + "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt + ) async def test_flow_run_concurrent(flow_instance): @@ -1158,7 +1254,9 @@ def test_flow_from_llm_and_template(): llm_instance = mocked_llm # Replace with your LLM class template = "This is a template for testing." - flow_instance = Agent.from_llm_and_template(llm_instance, template) + flow_instance = Agent.from_llm_and_template( + llm_instance, template + ) assert isinstance(flow_instance, Agent) @@ -1166,7 +1264,9 @@ def test_flow_from_llm_and_template(): def test_flow_from_llm_and_template_file(): # Test creating Agent instance from an LLM and a template file llm_instance = mocked_llm # Replace with your LLM class - template_file = "template.txt" # Create a template file for testing + template_file = ( # Create a template file for testing + "template.txt" + ) flow_instance = Agent.from_llm_and_template_file( llm_instance, template_file diff --git a/tests/structs/test_sequential_workflow.py b/tests/structs/test_sequential_workflow.py index 405cce2d..0d12991a 100644 --- a/tests/structs/test_sequential_workflow.py +++ b/tests/structs/test_sequential_workflow.py @@ -6,7 +6,10 @@ import pytest from swarms.models import OpenAIChat from swarms.structs.agent import Agent -from swarms.structs.sequential_workflow import SequentialWorkflow, Task +from swarms.structs.sequential_workflow import ( + SequentialWorkflow, + Task, +) # Mock the OpenAI API key using environment variables os.environ["OPENAI_API_KEY"] = "mocked_api_key" @@ -66,7 +69,10 @@ def test_sequential_workflow_initialization(): assert len(workflow.tasks) == 0 assert workflow.max_loops == 1 assert workflow.autosave is False - assert workflow.saved_state_filepath == "sequential_workflow_state.json" + assert ( + workflow.saved_state_filepath + == "sequential_workflow_state.json" + ) assert workflow.restore_state_filepath is None assert workflow.dashboard is False diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index cc6be26f..5db822d4 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -21,10 +21,11 @@ def llm(): def test_agent_run_task(llm): task = ( - "Analyze this image of an assembly line and identify any issues such as" - " misaligned parts, defects, or deviations from the standard assembly" - " process. IF there is anything unsafe in the image, explain why it is" - " unsafe and how it could be improved." + "Analyze this image of an assembly line and identify any" + " issues such as misaligned parts, defects, or deviations" + " from the standard assembly process. IF there is anything" + " unsafe in the image, explain why it is unsafe and how it" + " could be improved." ) img = "assembly_line.jpg" @@ -47,7 +48,9 @@ def test_agent_run_task(llm): @pytest.fixture def task(): agents = [Agent(llm=llm, id=f"Agent_{i}") for i in range(5)] - return Task(id="Task_1", task="Task_Name", agents=agents, dependencies=[]) + return Task( + id="Task_1", task="Task_Name", agents=agents, dependencies=[] + ) # Basic tests diff --git a/tests/swarms/test_autoscaler.py b/tests/swarms/test_autoscaler.py index 85955f00..fbf63637 100644 --- a/tests/swarms/test_autoscaler.py +++ b/tests/swarms/test_autoscaler.py @@ -34,7 +34,9 @@ def test_autoscaler_add_task(): def test_autoscaler_scale_up(): - autoscaler = AutoScaler(initial_agents=5, scale_up_factor=2, agent=agent) + autoscaler = AutoScaler( + initial_agents=5, scale_up_factor=2, agent=agent + ) autoscaler.scale_up() assert len(autoscaler.agents_pool) == 10 diff --git a/tests/swarms/test_dialogue_simulator.py b/tests/swarms/test_dialogue_simulator.py index 52cd6367..40665201 100644 --- a/tests/swarms/test_dialogue_simulator.py +++ b/tests/swarms/test_dialogue_simulator.py @@ -11,7 +11,9 @@ def test_dialoguesimulator_initialization(): @patch("swarms.workers.worker.Worker.run") def test_dialoguesimulator_run(mock_run): dialoguesimulator = DialogueSimulator(agents=[Worker] * 5) - dialoguesimulator.run(max_iters=5, name="Agent 1", message="Hello, world!") + dialoguesimulator.run( + max_iters=5, name="Agent 1", message="Hello, world!" + ) assert mock_run.call_count == 30 diff --git a/tests/swarms/test_multi_agent_collab.py b/tests/swarms/test_multi_agent_collab.py index e08979ca..e30358aa 100644 --- a/tests/swarms/test_multi_agent_collab.py +++ b/tests/swarms/test_multi_agent_collab.py @@ -75,8 +75,12 @@ def test_run(collaboration): def test_format_results(collaboration): - collaboration.results = [{"agent": "Agent1", "response": "Response1"}] - formatted_results = collaboration.format_results(collaboration.results) + collaboration.results = [ + {"agent": "Agent1", "response": "Response1"} + ] + formatted_results = collaboration.format_results( + collaboration.results + ) assert "Agent1 responded: Response1" in formatted_results diff --git a/tests/swarms/test_multi_agent_debate.py b/tests/swarms/test_multi_agent_debate.py index 25e15ae5..656ee9fb 100644 --- a/tests/swarms/test_multi_agent_debate.py +++ b/tests/swarms/test_multi_agent_debate.py @@ -61,6 +61,6 @@ def test_multiagentdebate_format_results(): formatted_results = multiagentdebate.format_results(results) assert ( formatted_results - == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded:" - " Goodbye, world!" + == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2" + " responded: Goodbye, world!" ) diff --git a/tests/swarms/test_orchestrate.py b/tests/swarms/test_orchestrate.py index 7a73d92d..4136ad94 100644 --- a/tests/swarms/test_orchestrate.py +++ b/tests/swarms/test_orchestrate.py @@ -22,10 +22,14 @@ def mock_vector_db(): def orchestrator(mock_agent, mock_vector_db): agent_list = [mock_agent for _ in range(5)] task_queue = [] - return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db) + return Orchestrator( + mock_agent, agent_list, task_queue, mock_vector_db + ) -def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db): +def test_assign_task( + orchestrator, mock_agent, mock_task, mock_vector_db +): orchestrator.task_queue.append(mock_task) orchestrator.assign_task(0, mock_task) diff --git a/tests/tools/test_base.py b/tests/tools/test_base.py index 007719b2..9f9c700f 100644 --- a/tests/tools/test_base.py +++ b/tests/tools/test_base.py @@ -3,7 +3,13 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from swarms.tools.tool import BaseTool, Runnable, StructuredTool, Tool, tool +from swarms.tools.tool import ( + BaseTool, + Runnable, + StructuredTool, + Tool, + tool, +) # Define test data test_input = {"key1": "value1", "key2": "value2"} @@ -59,14 +65,18 @@ def test_structured_tool_invoke(): def test_tool_creation(): - tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + tool = Tool( + name="test_tool", func=lambda x: x, description="Test tool" + ) assert tool.name == "test_tool" assert tool.func is not None assert tool.description == "Test tool" def test_tool_ainvoke(): - tool = Tool(name="test_tool", func=lambda x: x, description="Test tool") + tool = Tool( + name="test_tool", func=lambda x: x, description="Test tool" + ) result = tool.ainvoke("input_data") assert result == "input_data" @@ -76,7 +86,9 @@ def test_tool_ainvoke_with_coroutine(): return input_data tool = Tool( - name="test_tool", coroutine=async_function, description="Test tool" + name="test_tool", + coroutine=async_function, + description="Test tool", ) result = tool.ainvoke("input_data") assert result == "input_data" @@ -86,7 +98,11 @@ def test_tool_args(): def sample_function(input_data): return input_data - tool = Tool(name="test_tool", func=sample_function, description="Test tool") + tool = Tool( + name="test_tool", + func=sample_function, + description="Test tool", + ) assert tool.args == {"tool_input": {"type": "string"}} @@ -166,7 +182,9 @@ def test_tool_ainvoke_exception(): def test_tool_ainvoke_with_coroutine_exception(): - tool = Tool(name="test_tool", coroutine=None, description="Test tool") + tool = Tool( + name="test_tool", coroutine=None, description="Test tool" + ) with pytest.raises(NotImplementedError): tool.ainvoke("input_data") @@ -289,7 +307,9 @@ def test_structured_tool_ainvoke_with_callbacks(): args_schema=SampleArgsSchema, ) callbacks = MagicMock() - result = tool.ainvoke({"tool_input": "input_data"}, callbacks=callbacks) + result = tool.ainvoke( + {"tool_input": "input_data"}, callbacks=callbacks + ) assert result == "input_data" callbacks.on_start.assert_called_once() callbacks.on_finish.assert_called_once() @@ -349,7 +369,9 @@ def test_structured_tool_ainvoke_with_new_argument(): func=sample_function, args_schema=SampleArgsSchema, ) - result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) + result = tool.ainvoke( + {"tool_input": "input_data"}, callbacks=None + ) assert result == "input_data" @@ -461,7 +483,9 @@ def test_tool_with_runnable(mock_runnable): def test_tool_with_invalid_argument(): # Test passing an invalid argument type with pytest.raises(ValueError): - tool(123) # Using an integer instead of a string/callable/Runnable + tool( + 123 + ) # Using an integer instead of a string/callable/Runnable def test_tool_with_multiple_arguments(mock_func): @@ -523,7 +547,9 @@ class MockSchema(BaseModel): # Test suite starts here class TestTool: # Basic Functionality Tests - def test_tool_with_valid_callable_creates_base_tool(self, mock_func): + def test_tool_with_valid_callable_creates_base_tool( + self, mock_func + ): result = tool(mock_func) assert isinstance(result, BaseTool) @@ -783,7 +809,9 @@ class TestTool: def thread_target(): results.append(threaded_function(5)) - threads = [threading.Thread(target=thread_target) for _ in range(10)] + threads = [ + threading.Thread(target=thread_target) for _ in range(10) + ] for t in threads: t.start() for t in threads: diff --git a/tests/utils/test_subprocess_code_interpreter.py b/tests/utils/test_subprocess_code_interpreter.py index c15c0e16..2c7f7e47 100644 --- a/tests/utils/test_subprocess_code_interpreter.py +++ b/tests/utils/test_subprocess_code_interpreter.py @@ -35,25 +35,38 @@ def test_base_code_interpreter_terminate_not_implemented(): interpreter.terminate() -def test_subprocess_code_interpreter_init(subprocess_code_interpreter): - assert isinstance(subprocess_code_interpreter, SubprocessCodeInterpreter) +def test_subprocess_code_interpreter_init( + subprocess_code_interpreter, +): + assert isinstance( + subprocess_code_interpreter, SubprocessCodeInterpreter + ) -def test_subprocess_code_interpreter_start_process(subprocess_code_interpreter): +def test_subprocess_code_interpreter_start_process( + subprocess_code_interpreter, +): subprocess_code_interpreter.start_process() assert subprocess_code_interpreter.process is not None -def test_subprocess_code_interpreter_terminate(subprocess_code_interpreter): +def test_subprocess_code_interpreter_terminate( + subprocess_code_interpreter, +): subprocess_code_interpreter.start_process() subprocess_code_interpreter.terminate() assert subprocess_code_interpreter.process.poll() is not None -def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter): +def test_subprocess_code_interpreter_run_success( + subprocess_code_interpreter, +): code = 'print("Hello, World!")' result = list(subprocess_code_interpreter.run(code)) - assert any("Hello, World!" in output.get("output", "") for output in result) + assert any( + "Hello, World!" in output.get("output", "") + for output in result + ) def test_subprocess_code_interpreter_run_with_error( @@ -61,7 +74,9 @@ def test_subprocess_code_interpreter_run_with_error( ): code = 'print("Hello, World")\nraise ValueError("Error!")' result = list(subprocess_code_interpreter.run(code)) - assert any("Error!" in output.get("output", "") for output in result) + assert any( + "Error!" in output.get("output", "") for output in result + ) def test_subprocess_code_interpreter_run_with_keyboard_interrupt( @@ -73,7 +88,8 @@ def test_subprocess_code_interpreter_run_with_keyboard_interrupt( ) result = list(subprocess_code_interpreter.run(code)) assert any( - "KeyboardInterrupt" in output.get("output", "") for output in result + "KeyboardInterrupt" in output.get("output", "") + for output in result ) @@ -115,7 +131,10 @@ def test_subprocess_code_interpreter_run_retry_on_error( code = 'print("Hello, World!")' result = list(subprocess_code_interpreter.run(code)) - assert any("Hello, World!" in output.get("output", "") for output in result) + assert any( + "Hello, World!" in output.get("output", "") + for output in result + ) # Add more tests to cover other aspects of the code and edge cases as needed @@ -127,16 +146,24 @@ def test_subprocess_code_interpreter_line_postprocessor( subprocess_code_interpreter, ): line = "This is a test line" - processed_line = subprocess_code_interpreter.line_postprocessor(line) - assert processed_line == line # No processing, should remain the same + processed_line = subprocess_code_interpreter.line_postprocessor( + line + ) + assert ( + processed_line == line + ) # No processing, should remain the same def test_subprocess_code_interpreter_preprocess_code( subprocess_code_interpreter, ): code = 'print("Hello, World!")' - preprocessed_code = subprocess_code_interpreter.preprocess_code(code) - assert preprocessed_code == code # No preprocessing, should remain the same + preprocessed_code = subprocess_code_interpreter.preprocess_code( + code + ) + assert ( + preprocessed_code == code + ) # No preprocessing, should remain the same def test_subprocess_code_interpreter_detect_active_line( @@ -151,7 +178,9 @@ def test_subprocess_code_interpreter_detect_end_of_execution( subprocess_code_interpreter, ): line = "Execution completed." - end_of_execution = subprocess_code_interpreter.detect_end_of_execution(line) + end_of_execution = ( + subprocess_code_interpreter.detect_end_of_execution(line) + ) assert end_of_execution is True @@ -221,7 +250,10 @@ def test_subprocess_code_interpreter_run_with_preprocess_code( lambda x: x.upper() ) # Modify code in preprocess_code result = list(subprocess_code_interpreter.run(code)) - assert any("Hello, World!" in output.get("output", "") for output in result) + assert any( + "Hello, World!" in output.get("output", "") + for output in result + ) def test_subprocess_code_interpreter_run_with_exception( @@ -249,7 +281,9 @@ def test_subprocess_code_interpreter_run_with_active_line( def test_subprocess_code_interpreter_run_with_end_of_execution( subprocess_code_interpreter, capsys ): - code = 'print("Hello, World!")' # Simple code without active line marker + code = ( # Simple code without active line marker + 'print("Hello, World!")' + ) result = list(subprocess_code_interpreter.run(code)) assert any(output.get("active_line") is None for output in result) @@ -268,5 +302,6 @@ def test_subprocess_code_interpreter_run_with_unicode_characters( code = 'print("こんにちは、世界")' # Contains unicode characters result = list(subprocess_code_interpreter.run(code)) assert any( - "こんにちは、世界" in output.get("output", "") for output in result + "こんにちは、世界" in output.get("output", "") + for output in result )