diff --git a/example.py b/example.py index 4d90b7c9..35413816 100644 --- a/example.py +++ b/example.py @@ -1,12 +1,11 @@ from swarms import Agent, Anthropic -## Initialize the workflow +# Initialize the agemt agent = Agent( agent_name="Transcript Generator", agent_description=( - "Generate a transcript for a youtube video on what swarms" - " are!" + "Generate a transcript for a youtube video on what swarms" " are!" ), llm=Anthropic(), max_loops=3, @@ -18,5 +17,5 @@ agent = Agent( interactive=True, ) -# Run the workflow on a task +# Run the Agent on a task agent("Generate a transcript for a youtube video on what swarms are!") diff --git a/playground/agents/command_r_tool_agent.py b/playground/agents/command_r_tool_agent.py index 9cbd73ad..e6fe075a 100644 --- a/playground/agents/command_r_tool_agent.py +++ b/playground/agents/command_r_tool_agent.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent -from swarms.utils.json_utils import base_model_to_json +from swarms.tools.json_utils import base_model_to_json # Model name model_name = "CohereForAI/c4ai-command-r-v01-4bit" @@ -28,9 +28,7 @@ class APIExampleRequestSchema(BaseModel): headers: dict = Field( ..., description="The headers for the example request" ) - body: dict = Field( - ..., description="The body of the example request" - ) + body: dict = Field(..., description="The body of the example request") response: dict = Field( ..., description="The expected response of the example request", diff --git a/playground/agents/full_stack_agent.py b/playground/agents/full_stack_agent.py index 510f5c98..0db12ad3 100644 --- a/playground/agents/full_stack_agent.py +++ b/playground/agents/full_stack_agent.py @@ -14,8 +14,7 @@ def search_api(query: str, max_results: int = 10): agent = Agent( agent_name="Youtube Transcript Generator", agent_description=( - "Generate a transcript for a youtube video on what swarms" - " are!" + "Generate a transcript for a youtube video on what swarms" " are!" ), llm=Anthropic(), max_loops="auto", diff --git a/playground/agents/jamba_tool_agent.py b/playground/agents/jamba_tool_agent.py index 3ca293cd..032272a3 100644 --- a/playground/agents/jamba_tool_agent.py +++ b/playground/agents/jamba_tool_agent.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent -from swarms.utils.json_utils import base_model_to_json +from swarms.tools.json_utils import base_model_to_json # Model name model_name = "ai21labs/Jamba-v0.1" @@ -28,9 +28,7 @@ class APIExampleRequestSchema(BaseModel): headers: dict = Field( ..., description="The headers for the example request" ) - body: dict = Field( - ..., description="The body of the example request" - ) + body: dict = Field(..., description="The body of the example request") response: dict = Field( ..., description="The expected response of the example request", diff --git a/playground/agents/mm_agent_example.py b/playground/agents/mm_agent_example.py index 6cedcb29..d564fc02 100644 --- a/playground/agents/mm_agent_example.py +++ b/playground/agents/mm_agent_example.py @@ -4,9 +4,7 @@ load_dict = {"ImageCaptioning": "cuda"} node = MultiModalAgent(load_dict) -text = node.run_text( - "What is your name? Generate a picture of yourself" -) +text = node.run_text("What is your name? Generate a picture of yourself") img = node.run_img("/image1", "What is this image about?") diff --git a/playground/agents/tool_agent_pydantic.py b/playground/agents/tool_agent_pydantic.py index da5f4825..cd564480 100644 --- a/playground/agents/tool_agent_pydantic.py +++ b/playground/agents/tool_agent_pydantic.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent -from swarms.utils.json_utils import base_model_to_json +from swarms.tools.json_utils import base_model_to_json # Load the pre-trained model and tokenizer model = AutoModelForCausalLM.from_pretrained( @@ -17,9 +17,7 @@ tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") class Schema(BaseModel): name: str = Field(..., title="Name of the person") agent: int = Field(..., title="Age of the person") - is_student: bool = Field( - ..., title="Whether the person is a student" - ) + is_student: bool = Field(..., title="Whether the person is a student") courses: list[str] = Field( ..., title="List of courses the person is taking" ) @@ -29,9 +27,7 @@ class Schema(BaseModel): tool_schema = base_model_to_json(Schema) # Define the task to generate a person's information -task = ( - "Generate a person's information based on the following schema:" -) +task = "Generate a person's information based on the following schema:" # Create an instance of the ToolAgent class agent = ToolAgent( diff --git a/playground/agents/tool_agent_with_llm.py b/playground/agents/tool_agent_with_llm.py index 3582be21..5babf461 100644 --- a/playground/agents/tool_agent_with_llm.py +++ b/playground/agents/tool_agent_with_llm.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field from swarms import OpenAIChat, ToolAgent -from swarms.utils.json_utils import base_model_to_json +from swarms.tools.json_utils import base_model_to_json # Load the environment variables load_dotenv() @@ -19,9 +19,7 @@ chat = OpenAIChat( class Schema(BaseModel): name: str = Field(..., title="Name of the person") agent: int = Field(..., title="Age of the person") - is_student: bool = Field( - ..., title="Whether the person is a student" - ) + is_student: bool = Field(..., title="Whether the person is a student") courses: list[str] = Field( ..., title="List of courses the person is taking" ) @@ -31,9 +29,7 @@ class Schema(BaseModel): tool_schema = base_model_to_json(Schema) # Define the task to generate a person's information -task = ( - "Generate a person's information based on the following schema:" -) +task = "Generate a person's information based on the following schema:" # Create an instance of the ToolAgent class agent = ToolAgent( diff --git a/playground/creation_engine/omni_model_agent.py b/playground/creation_engine/omni_model_agent.py index b261c2f7..03428ef5 100644 --- a/playground/creation_engine/omni_model_agent.py +++ b/playground/creation_engine/omni_model_agent.py @@ -34,9 +34,7 @@ def text_to_video(task: str): step = 4 # Options: [1,2,4,8] repo = "ByteDance/AnimateDiff-Lightning" ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" - base = ( # Choose to your favorite base model. - "emilianJR/epiCRealism" - ) + base = "emilianJR/epiCRealism" # Choose to your favorite base model. adapter = MotionAdapter().to(device, dtype) adapter.load_state_dict( diff --git a/playground/demos/ad_gen/ad_gen_example.py b/playground/demos/ad_gen/ad_gen_example.py index 978ab502..21d9f315 100644 --- a/playground/demos/ad_gen/ad_gen_example.py +++ b/playground/demos/ad_gen/ad_gen_example.py @@ -61,16 +61,12 @@ class ProductAdConceptGenerator: "in an ice cave setting", "in a serene and calm landscape", ] - self.contexts = [ - "high realism product ad (extremely creative)" - ] + self.contexts = ["high realism product ad (extremely creative)"] def generate_concept(self): theme = random.choice(self.themes) context = random.choice(self.contexts) - return ( - f"{theme} inside a {style} {self.product_name}, {context}" - ) + return f"{theme} inside a {style} {self.product_name}, {context}" # User input diff --git a/playground/demos/ai_acceleerated_learning/Podgraph .py b/playground/demos/ai_acceleerated_learning/Podgraph .py index 70944b31..d632b7de 100644 --- a/playground/demos/ai_acceleerated_learning/Podgraph .py +++ b/playground/demos/ai_acceleerated_learning/Podgraph .py @@ -31,9 +31,7 @@ def test_find_most_similar_podcasts(): graph = create_graph() weight_edges(graph) user_list = create_user_list() - most_similar_podcasts = find_most_similar_podcasts( - graph, user_list - ) + most_similar_podcasts = find_most_similar_podcasts(graph, user_list) assert isinstance(most_similar_podcasts, list) diff --git a/playground/demos/ai_acceleerated_learning/main.py b/playground/demos/ai_acceleerated_learning/main.py index 44eba542..6366c005 100644 --- a/playground/demos/ai_acceleerated_learning/main.py +++ b/playground/demos/ai_acceleerated_learning/main.py @@ -45,9 +45,7 @@ def execute_concurrently(callable_functions: callable, max_workers=5): ) as executor: futures = [] for i, (fn, args, kwargs) in enumerate(callable_functions): - futures.append( - executor.submit(worker, fn, args, kwargs, i) - ) + futures.append(executor.submit(worker, fn, args, kwargs, i)) # Wait for all threads to complete concurrent.futures.wait(futures) @@ -56,9 +54,7 @@ def execute_concurrently(callable_functions: callable, max_workers=5): # Adjusting the function to extract specific column values -def extract_and_create_agents( - csv_file_path: str, target_columns: list -): +def extract_and_create_agents(csv_file_path: str, target_columns: list): """ Reads a CSV file, extracts "Project Name" and "Lightning Proposal" for each row, creates an Agent for each, and adds it to the swarm network. @@ -138,8 +134,7 @@ def extract_and_create_agents( # Log the agent logger.info( - f"Agent created: {agent_name} with long term" - " memory" + f"Agent created: {agent_name} with long term" " memory" ) agents.append(agent) diff --git a/playground/demos/ai_acceleerated_learning/test_Vocal.py b/playground/demos/ai_acceleerated_learning/test_Vocal.py index b8e1e14f..41433b87 100644 --- a/playground/demos/ai_acceleerated_learning/test_Vocal.py +++ b/playground/demos/ai_acceleerated_learning/test_Vocal.py @@ -16,9 +16,7 @@ def test_pass(): def test_invalid_sports(): assert ( - vocal.generate_video( - "I just ate some delicious tacos", "tacos" - ) + vocal.generate_video("I just ate some delicious tacos", "tacos") == "Invalid sports entered!! Please enter a valid sport." ) diff --git a/playground/demos/ai_research_team/main_example.py b/playground/demos/ai_research_team/main_example.py index bda9e0de..dc6e54ae 100644 --- a/playground/demos/ai_research_team/main_example.py +++ b/playground/demos/ai_research_team/main_example.py @@ -51,6 +51,4 @@ algorithmic_psuedocode_agent = paper_summarizer_agent.run( "Focus on creating the algorithmic pseudocode for the novel" f" method in this paper: {paper}" ) -pytorch_code = paper_implementor_agent.run( - algorithmic_psuedocode_agent -) +pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent) diff --git a/playground/demos/autobloggen_example.py b/playground/demos/autobloggen_example.py index 09b02674..ae6cdc60 100644 --- a/playground/demos/autobloggen_example.py +++ b/playground/demos/autobloggen_example.py @@ -55,9 +55,7 @@ class AutoBlogGenSwarm: ): self.llm = llm() self.topic_selection_task = topic_selection_task - self.topic_selection_agent_prompt = ( - topic_selection_agent_prompt - ) + self.topic_selection_agent_prompt = topic_selection_agent_prompt self.objective = objective self.iterations = iterations self.max_retries = max_retries @@ -93,9 +91,7 @@ class AutoBlogGenSwarm: def step(self): """Steps through the task""" - topic_selection_agent = self.llm( - self.topic_selection_agent_prompt - ) + topic_selection_agent = self.llm(self.topic_selection_agent_prompt) topic_selection_agent = self.print_beautifully( "Topic Selection Agent", topic_selection_agent ) @@ -105,9 +101,7 @@ class AutoBlogGenSwarm: # Agent that reviews the draft review_agent = self.llm(self.get_review_prompt(draft_blog)) - review_agent = self.print_beautifully( - "Review Agent", review_agent - ) + review_agent = self.print_beautifully("Review Agent", review_agent) # Agent that publishes on social media distribution_agent = self.llm( diff --git a/playground/demos/autotemp/autotemp_example.py b/playground/demos/autotemp/autotemp_example.py index f086f112..f77d46c2 100644 --- a/playground/demos/autotemp/autotemp_example.py +++ b/playground/demos/autotemp/autotemp_example.py @@ -48,11 +48,7 @@ class AutoTemp: """ score_text = self.llm(eval_prompt, temperature=0.5) score_match = re.search(r"\b\d+(\.\d)?\b", score_text) - return ( - round(float(score_match.group()), 1) - if score_match - else 0.0 - ) + return round(float(score_match.group()), 1) if score_match else 0.0 def run(self, prompt, temperature_string): print("Starting generation process...") diff --git a/playground/demos/autotemp/blog_gen_example.py b/playground/demos/autotemp/blog_gen_example.py index fe2a2317..40f5d0e7 100644 --- a/playground/demos/autotemp/blog_gen_example.py +++ b/playground/demos/autotemp/blog_gen_example.py @@ -56,15 +56,11 @@ class BlogGen: ) chosen_topic = topic_output.split("\n")[0] - print( - colored("Selected topic: " + chosen_topic, "yellow") - ) + print(colored("Selected topic: " + chosen_topic, "yellow")) # Initial draft generation with AutoTemp - initial_draft_prompt = ( - self.DRAFT_WRITER_SYSTEM_PROMPT.replace( - "{{CHOSEN_TOPIC}}", chosen_topic - ) + initial_draft_prompt = self.DRAFT_WRITER_SYSTEM_PROMPT.replace( + "{{CHOSEN_TOPIC}}", chosen_topic ) auto_temp_output = self.auto_temp.run( initial_draft_prompt, self.temperature_range diff --git a/playground/demos/education/education_example.py b/playground/demos/education/education_example.py index 31c08f0d..670a7b29 100644 --- a/playground/demos/education/education_example.py +++ b/playground/demos/education/education_example.py @@ -12,9 +12,7 @@ api_key = os.getenv("OPENAI_API_KEY") stability_api_key = os.getenv("STABILITY_API_KEY") # Initialize language model -llm = OpenAIChat( - openai_api_key=api_key, temperature=0.5, max_tokens=3000 -) +llm = OpenAIChat(openai_api_key=api_key, temperature=0.5, max_tokens=3000) # User preferences (can be dynamically set in a real application) user_preferences = { @@ -30,9 +28,7 @@ curriculum_prompt = edu_prompts.CURRICULUM_DESIGN_PROMPT.format( interactive_prompt = edu_prompts.INTERACTIVE_LEARNING_PROMPT.format( **user_preferences ) -sample_prompt = edu_prompts.SAMPLE_TEST_PROMPT.format( - **user_preferences -) +sample_prompt = edu_prompts.SAMPLE_TEST_PROMPT.format(**user_preferences) image_prompt = edu_prompts.IMAGE_GENERATION_PROMPT.format( **user_preferences ) @@ -49,9 +45,7 @@ workflow = SequentialWorkflow(max_loops=1) # Add tasks to workflow with personalized prompts workflow.add(curriculum_agent, "Generate a curriculum") -workflow.add( - interactive_learning_agent, "Generate an interactive lesson" -) +workflow.add(interactive_learning_agent, "Generate an interactive lesson") workflow.add(sample_lesson_agent, "Generate a practice test") # Execute the workflow for text-based tasks diff --git a/playground/demos/grupa/app_example.py b/playground/demos/grupa/app_example.py index ff5fc27d..8fafc0cd 100644 --- a/playground/demos/grupa/app_example.py +++ b/playground/demos/grupa/app_example.py @@ -11,9 +11,7 @@ from swarms.structs import Agent load_dotenv() -FEATURE = ( - "Implement an all-new signup system in typescript using supabase" -) +FEATURE = "Implement an all-new signup system in typescript using supabase" CODEBASE = """ import React, { useState } from 'react'; @@ -68,9 +66,7 @@ feature_implementer_backend = Agent( ) # Create another agent for a different task -tester_agent = Agent( - llm=llm, max_loops=1, sop=TEST_SOP, autosave=True -) +tester_agent = Agent(llm=llm, max_loops=1, sop=TEST_SOP, autosave=True) # Create another agent for a different task documenting_agent = Agent( diff --git a/playground/demos/multimodal_tot/idea2img_example.py b/playground/demos/multimodal_tot/idea2img_example.py index 4a6c1da3..94863ae4 100644 --- a/playground/demos/multimodal_tot/idea2img_example.py +++ b/playground/demos/multimodal_tot/idea2img_example.py @@ -44,9 +44,7 @@ class Idea2Image(Agent): print(f"Generated image at: {img}") analysis = ( - self.vision_api.run(img, current_prompt) - if img - else None + self.vision_api.run(img, current_prompt) if img else None ) if analysis: current_prompt += ( @@ -147,9 +145,7 @@ gpt_api = OpenAIChat(openai_api_key=openai_api_key) # Define the modified Idea2Image class here # Streamlit UI layout -st.title( - "Explore the infinite Multi-Modal Idea Space with Idea2Image" -) +st.title("Explore the infinite Multi-Modal Idea Space with Idea2Image") user_prompt = st.text_input("Prompt for image generation:") num_iterations = st.number_input( "Enter the number of iterations for image improvement:", @@ -168,9 +164,7 @@ if st.button("Generate Image"): user_prompt, num_iterations, run_folder ) - for i, (enriched_prompt, img_path, analysis) in enumerate( - results - ): + for i, (enriched_prompt, img_path, analysis) in enumerate(results): st.write(f"Iteration {i+1}:") st.write("Enriched Prompt:", enriched_prompt) if img_path: diff --git a/playground/demos/multimodal_tot/main_example.py b/playground/demos/multimodal_tot/main_example.py index 2a0494dc..59062425 100644 --- a/playground/demos/multimodal_tot/main_example.py +++ b/playground/demos/multimodal_tot/main_example.py @@ -96,9 +96,7 @@ for _ in range(max_iterations): # Evaluate the image by passing the file path score = evaluate_img(llm, task, img_path) print( - colored( - f"Evaluated Image Score: {score} for {img_path}", "cyan" - ) + colored(f"Evaluated Image Score: {score} for {img_path}", "cyan") ) # Update the best score and image path if necessary diff --git a/playground/demos/nutrition/nutrition_example.py b/playground/demos/nutrition/nutrition_example.py index b4331db6..19f500b8 100644 --- a/playground/demos/nutrition/nutrition_example.py +++ b/playground/demos/nutrition/nutrition_example.py @@ -77,9 +77,7 @@ def generate_integrated_shopping_list( meal_plan_output, image_analysis, user_preferences ): # Prepare the prompt for the LLM - fridge_contents = image_analysis["choices"][0]["message"][ - "content" - ] + fridge_contents = image_analysis["choices"][0]["message"]["content"] prompt = ( f"Based on this meal plan: {meal_plan_output}, and the" f" following items in the fridge: {fridge_contents}," @@ -131,9 +129,7 @@ print("Integrated Shopping List:", integrated_shopping_list) with open("nutrition_output.txt", "w") as file: file.write("Meal Plan:\n" + meal_plan_output + "\n\n") file.write( - "Integrated Shopping List:\n" - + integrated_shopping_list - + "\n" + "Integrated Shopping List:\n" + integrated_shopping_list + "\n" ) print("Outputs have been saved to nutrition_output.txt") diff --git a/playground/demos/positive_med/positive_med_example.py b/playground/demos/positive_med/positive_med_example.py index 09cbb411..288dc40f 100644 --- a/playground/demos/positive_med/positive_med_example.py +++ b/playground/demos/positive_med/positive_med_example.py @@ -42,9 +42,7 @@ def get_review_prompt(article): return prompt -def social_media_prompt( - article: str, goal: str = "Clicks and engagement" -): +def social_media_prompt(article: str, goal: str = "Clicks and engagement"): prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace( "{{ARTICLE}}", article ).replace("{{GOAL}}", goal) diff --git a/playground/demos/swarm_hackathon/Bants.py b/playground/demos/swarm_hackathon/Bants.py index 8efca381..b8544254 100644 --- a/playground/demos/swarm_hackathon/Bants.py +++ b/playground/demos/swarm_hackathon/Bants.py @@ -24,9 +24,7 @@ async def handle_websocket(websocket, path): # Broadcast the message to all other users in the public group chats. for other_websocket in public_group_chats: if other_websocket != websocket: - await other_websocket.send( - f"{username}: {message}" - ) + await other_websocket.send(f"{username}: {message}") finally: # Remove the user from the list of public group chats. public_group_chats.remove(websocket) diff --git a/playground/demos/swarm_hackathon/Ego.py b/playground/demos/swarm_hackathon/Ego.py index dceb5a76..17f656cf 100644 --- a/playground/demos/swarm_hackathon/Ego.py +++ b/playground/demos/swarm_hackathon/Ego.py @@ -48,9 +48,7 @@ def generate_conversation(characters, topic): # Generate the conversation -conversation = generate_conversation( - character_names, conversation_topic -) +conversation = generate_conversation(character_names, conversation_topic) # Play the conversation for line in conversation: diff --git a/playground/demos/swarm_hackathon/main.py b/playground/demos/swarm_hackathon/main.py index 2e8eed8c..62e69cbf 100644 --- a/playground/demos/swarm_hackathon/main.py +++ b/playground/demos/swarm_hackathon/main.py @@ -48,9 +48,7 @@ def execute_concurrently(callable_functions: callable, max_workers=5): ) as executor: futures = [] for i, (fn, args, kwargs) in enumerate(callable_functions): - futures.append( - executor.submit(worker, fn, args, kwargs, i) - ) + futures.append(executor.submit(worker, fn, args, kwargs, i)) # Wait for all threads to complete concurrent.futures.wait(futures) @@ -59,9 +57,7 @@ def execute_concurrently(callable_functions: callable, max_workers=5): # Adjusting the function to extract specific column values -def extract_and_create_agents( - csv_file_path: str, target_columns: list -): +def extract_and_create_agents(csv_file_path: str, target_columns: list): """ Reads a CSV file, extracts "Project Name" and "Lightning Proposal" for each row, creates an Agent for each, and adds it to the swarm network. diff --git a/playground/demos/swarm_of_mma_manufacturing/main_example.py b/playground/demos/swarm_of_mma_manufacturing/main_example.py index 02a3cc1a..ce2d9514 100644 --- a/playground/demos/swarm_of_mma_manufacturing/main_example.py +++ b/playground/demos/swarm_of_mma_manufacturing/main_example.py @@ -31,9 +31,7 @@ llm = GPT4VisionAPI(openai_api_key=api_key, max_tokens=2000) assembly_line = ( "playground/demos/swarm_of_mma_manufacturing/assembly_line.jpg" ) -red_robots = ( - "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg" -) +red_robots = "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg" robots = "playground/demos/swarm_of_mma_manufacturing/robots.jpg" tesla_assembly_line = ( "playground/demos/swarm_of_mma_manufacturing/tesla_assembly.jpg" @@ -127,31 +125,19 @@ health_check = health_security_agent.run( print( - colored( - "--------------- Productivity agents initializing...", "green" - ) + colored("--------------- Productivity agents initializing...", "green") ) # Add the third task to the productivity_check_agent productivity_check = productivity_check_agent.run( health_check, assembly_line ) -print( - colored( - "--------------- Security agents initializing...", "green" - ) -) +print(colored("--------------- Security agents initializing...", "green")) # Add the fourth task to the security_check_agent -security_check = security_check_agent.run( - productivity_check, red_robots -) +security_check = security_check_agent.run(productivity_check, red_robots) -print( - colored( - "--------------- Efficiency agents initializing...", "cyan" - ) -) +print(colored("--------------- Efficiency agents initializing...", "cyan")) # Add the fifth task to the efficiency_check_agent efficiency_check = efficiency_check_agent.run( security_check, tesla_assembly_line diff --git a/playground/demos/urban_planning/urban_planning_example.py b/playground/demos/urban_planning/urban_planning_example.py index 2a52ced7..59b33bb8 100644 --- a/playground/demos/urban_planning/urban_planning_example.py +++ b/playground/demos/urban_planning/urban_planning_example.py @@ -12,9 +12,7 @@ api_key = os.getenv("OPENAI_API_KEY") stability_api_key = os.getenv("STABILITY_API_KEY") # Initialize language model -llm = OpenAIChat( - openai_api_key=api_key, temperature=0.5, max_tokens=3000 -) +llm = OpenAIChat(openai_api_key=api_key, temperature=0.5, max_tokens=3000) # Initialize Vision model vision_api = GPT4VisionAPI(api_key=api_key) @@ -51,17 +49,13 @@ workflow = SequentialWorkflow(max_loops=1) # Add tasks to workflow with personalized prompts workflow.add(architecture_analysis_agent, "Architecture Analysis") -workflow.add( - infrastructure_evaluation_agent, "Infrastructure Evaluation" -) +workflow.add(infrastructure_evaluation_agent, "Infrastructure Evaluation") workflow.add(traffic_flow_analysis_agent, "Traffic Flow Analysis") workflow.add( environmental_impact_assessment_agent, "Environmental Impact Assessment", ) -workflow.add( - public_space_utilization_agent, "Public Space Utilization" -) +workflow.add(public_space_utilization_agent, "Public Space Utilization") workflow.add( socioeconomic_impact_analysis_agent, "Socioeconomic Impact Analysis", diff --git a/playground/examples/example_qwenvlmultimodal.py b/playground/examples/example_qwenvlmultimodal.py index 561b6f88..f338a508 100644 --- a/playground/examples/example_qwenvlmultimodal.py +++ b/playground/examples/example_qwenvlmultimodal.py @@ -8,9 +8,7 @@ model = QwenVLMultiModal( ) # Run the model -response = model( - "Hello, how are you?", "https://example.com/image.jpg" -) +response = model("Hello, how are you?", "https://example.com/image.jpg") # Print the response print(response) diff --git a/playground/examples/example_toolagent.py b/playground/examples/example_toolagent.py index 93e07ff3..c6adf00f 100644 --- a/playground/examples/example_toolagent.py +++ b/playground/examples/example_toolagent.py @@ -3,9 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent # Load the pre-trained model and tokenizer -model = AutoModelForCausalLM.from_pretrained( - "databricks/dolly-v2-12b" -) +model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b") tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") # Define a JSON schema for person's information @@ -20,9 +18,7 @@ json_schema = { } # Define the task to generate a person's information -task = ( - "Generate a person's information based on the following schema:" -) +task = "Generate a person's information based on the following schema:" # Create an instance of the ToolAgent class agent = ToolAgent( diff --git a/playground/models/llava_example.py b/playground/models/llava_example.py index 561b6f88..f338a508 100644 --- a/playground/models/llava_example.py +++ b/playground/models/llava_example.py @@ -8,9 +8,7 @@ model = QwenVLMultiModal( ) # Run the model -response = model( - "Hello, how are you?", "https://example.com/image.jpg" -) +response = model("Hello, how are you?", "https://example.com/image.jpg") # Print the response print(response) diff --git a/playground/models/together_example.py b/playground/models/together_example.py index f730f72f..273cfa97 100644 --- a/playground/models/together_example.py +++ b/playground/models/together_example.py @@ -7,6 +7,4 @@ model = TogetherLLM( ) # Run the model -model.run( - "Generate a blog post about the best way to make money online." -) +model.run("Generate a blog post about the best way to make money online.") diff --git a/playground/structs/debate_example.py b/playground/structs/debate_example.py index 7cf0290b..6758dace 100644 --- a/playground/structs/debate_example.py +++ b/playground/structs/debate_example.py @@ -35,9 +35,7 @@ class DialogueAgent: [ self.system_message, HumanMessage( - content="\n".join( - self.message_history + [self.prefix] - ) + content="\n".join(self.message_history + [self.prefix]) ), ] ) @@ -76,9 +74,7 @@ class DialogueSimulator: def step(self) -> tuple[str, str]: # 1. choose the next speaker - speaker_idx = self.select_next_speaker( - self._step, self.agents - ) + speaker_idx = self.select_next_speaker(self._step, self.agents) speaker = self.agents[speaker_idx] # 2. next speaker sends message @@ -116,9 +112,7 @@ class BiddingDialogueAgent(DialogueAgent): message_history="\n".join(self.message_history), recent_message=self.message_history[-1], ) - bid_string = self.model( - [SystemMessage(content=prompt)] - ).content + bid_string = self.model([SystemMessage(content=prompt)]).content return bid_string @@ -140,10 +134,12 @@ player_descriptor_system_message = SystemMessage( def generate_character_description(character_name): character_specifier_prompt = [ player_descriptor_system_message, - HumanMessage(content=f"""{game_description} + HumanMessage( + content=f"""{game_description} Please reply with a creative description of the presidential candidate, {character_name}, in {word_limit} words or less, that emphasizes their personalities. Speak directly to {character_name}. - Do not add anything else."""), + Do not add anything else.""" + ), ] character_description = ChatOpenAI(temperature=1.0)( character_specifier_prompt @@ -161,10 +157,9 @@ Your goal is to be as creative as possible and make the voters think you are the """ -def generate_character_system_message( - character_name, character_header -): - return SystemMessage(content=f"""{character_header} +def generate_character_system_message(character_name, character_header): + return SystemMessage( + content=f"""{character_header} You will speak in the style of {character_name}, and exaggerate their personality. You will come up with creative ideas related to {topic}. Do not say the same things over and over again. @@ -176,7 +171,8 @@ Speak only from the perspective of {character_name}. Stop speaking the moment you finish speaking from your perspective. Never forget to keep your response to {word_limit} words! Do not add anything else. - """) + """ + ) character_descriptions = [ @@ -190,9 +186,7 @@ character_headers = [ ) ] character_system_messages = [ - generate_character_system_message( - character_name, character_headers - ) + generate_character_system_message(character_name, character_headers) for character_name, character_headers in zip( character_names, character_headers ) @@ -261,7 +255,8 @@ for character_name, bidding_template in zip( topic_specifier_prompt = [ SystemMessage(content="You can make a task more specific."), - HumanMessage(content=f"""{game_description} + HumanMessage( + content=f"""{game_description} You are the debate moderator. Please make the debate topic more specific. @@ -269,7 +264,8 @@ topic_specifier_prompt = [ Be creative and imaginative. Please reply with the specified topic in {word_limit} words or less. Speak directly to the presidential candidates: {*character_names,}. - Do not add anything else."""), + Do not add anything else.""" + ), ] specified_topic = ChatOpenAI(temperature=1.0)( topic_specifier_prompt @@ -298,9 +294,7 @@ def ask_for_bid(agent) -> str: return bid -def select_next_speaker( - step: int, agents: List[DialogueAgent] -) -> int: +def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: bids = [] for agent in agents: bid = ask_for_bid(agent) diff --git a/playground/structs/groupchat_example.py b/playground/structs/groupchat_example.py index 5c9d1a7c..c55aea69 100644 --- a/playground/structs/groupchat_example.py +++ b/playground/structs/groupchat_example.py @@ -44,7 +44,5 @@ manager = Agent( agents = [flow1, flow2, flow3] group_chat = GroupChat(agents=agents, messages=[], max_round=10) -chat_manager = GroupChatManager( - groupchat=group_chat, selector=manager -) +chat_manager = GroupChatManager(groupchat=group_chat, selector=manager) chat_history = chat_manager("Write me a riddle") diff --git a/playground/structs/kyle_hackathon.py b/playground/structs/kyle_hackathon.py index 1de48f1b..137ebf70 100644 --- a/playground/structs/kyle_hackathon.py +++ b/playground/structs/kyle_hackathon.py @@ -6,7 +6,7 @@ from swarms import Agent, OpenAIChat from swarms.agents.multion_agent import MultiOnAgent from swarms.memory.chroma_db import ChromaDB from swarms.tools.tool import tool -from swarms.utils.code_interpreter import SubprocessCodeInterpreter +from swarms.tools.code_interpreter import SubprocessCodeInterpreter # Load the environment variables load_dotenv() diff --git a/playground/structs/message_pool_example.py b/playground/structs/message_pool_example.py index 6dbad128..da815157 100644 --- a/playground/structs/message_pool_example.py +++ b/playground/structs/message_pool_example.py @@ -8,9 +8,7 @@ agent3 = Agent(llm=OpenAIChat(), agent_name="agent3") moderator = Agent(agent_name="moderator") agents = [agent1, agent2, agent3] -message_pool = MessagePool( - agents=agents, moderator=moderator, turns=5 -) +message_pool = MessagePool(agents=agents, moderator=moderator, turns=5) message_pool.add(agent=agent1, content="Hello, agent2!", turn=1) message_pool.add(agent=agent2, content="Hello, agent1!", turn=1) message_pool.add(agent=agent3, content="Hello, agent1!", turn=1) diff --git a/playground/structs/orchestrate_example.py b/playground/structs/orchestrate_example.py index 6b91b74f..33825b2f 100644 --- a/playground/structs/orchestrate_example.py +++ b/playground/structs/orchestrate_example.py @@ -7,9 +7,7 @@ node = Worker( # Instantiate the Orchestrator with 10 agents -orchestrator = Orchestrator( - node, agent_list=[node] * 10, task_queue=[] -) +orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) # Agent 7 sends a message to Agent 9 orchestrator.chat( diff --git a/playground/structs/tool_agent.py b/playground/structs/tool_agent.py index ae10a168..02783ff3 100644 --- a/playground/structs/tool_agent.py +++ b/playground/structs/tool_agent.py @@ -21,9 +21,7 @@ json_schema = { } # Define the task to generate a person's information -task = ( - "Generate a person's information based on the following schema:" -) +task = "Generate a person's information based on the following schema:" # Create an instance of the ToolAgent class agent = ToolAgent( diff --git a/playground/swarms/automate_docs.py b/playground/swarms/automate_docs.py index f3268fdb..322ed577 100644 --- a/playground/swarms/automate_docs.py +++ b/playground/swarms/automate_docs.py @@ -100,9 +100,7 @@ class PythonDocumentationSwarm: with open(file_path, "w") as file: file.write(doc_content) - logger.info( - f"Documentation generated for {item.__name__}." - ) + logger.info(f"Documentation generated for {item.__name__}.") except Exception as e: logger.error( f"Error processing documentation for {item.__name__}." @@ -130,8 +128,7 @@ class PythonDocumentationSwarm: thread.join() logger.info( - "Documentation generated in 'swarms.structs'" - " directory." + "Documentation generated in 'swarms.structs'" " directory." ) except Exception as e: logger.error("Error running documentation process.") @@ -143,8 +140,7 @@ class PythonDocumentationSwarm: executor.map(self.process_documentation, python_items) logger.info( - "Documentation generated in 'swarms.structs'" - " directory." + "Documentation generated in 'swarms.structs'" " directory." ) except Exception as e: logger.error("Error running documentation process.") diff --git a/playground/swarms/hierarchical_swarm.py b/playground/swarms/hierarchical_swarm.py index f0357711..58397e04 100644 --- a/playground/swarms/hierarchical_swarm.py +++ b/playground/swarms/hierarchical_swarm.py @@ -4,7 +4,7 @@ B -> W1, W2, W3 """ from typing import List, Optional from pydantic import BaseModel, Field -from swarms.utils.json_utils import str_to_json +from swarms.tools.json_utils import str_to_json class HierarchicalSwarm(BaseModel): diff --git a/playground/tools/agent_with_tools_example.py b/playground/tools/agent_with_tools_example.py index 35b61703..87471899 100644 --- a/playground/tools/agent_with_tools_example.py +++ b/playground/tools/agent_with_tools_example.py @@ -35,7 +35,6 @@ agent = Agent( ) out = agent.run( - "Use the search api to find the best restaurants in New York" - " City." + "Use the search api to find the best restaurants in New York" " City." ) print(out) diff --git a/pyproject.toml b/pyproject.toml index 5e58c2e7..7154b0fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "4.8.2" +version = "4.8.4" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -35,25 +35,22 @@ python = ">=3.9,<4.0" torch = ">=2.1.1,<3.0" transformers = ">= 4.39.0, <5.0.0" asyncio = ">=3.4.3,<4.0" -einops = "0.7.0" -langchain-core = "0.1.33" langchain-community = "0.0.29" langchain-experimental = "0.0.55" backoff = "2.2.1" toml = "*" pypdf = "4.1.0" -httpx = "0.24.1" ratelimit = "2.2.1" loguru = "0.7.2" pydantic = "2.6.4" tenacity = "8.2.3" Pillow = "10.2.0" -rich = "13.5.2" psutil = "*" sentry-sdk = "*" python-dotenv = "*" accelerate = "0.28.0" opencv-python = "^4.9.0.80" +yaml = "*" [tool.poetry.group.lint.dependencies] black = "^23.1.0" @@ -71,7 +68,7 @@ pandas = "^2.2.2" fastapi = "^0.110.1" [tool.ruff] -line-length = 128 +line-length = 75 [tool.ruff.lint] select = ["E", "F", "W", "I", "UP"] @@ -84,6 +81,21 @@ preview = true "swarms/prompts/**.py" = ["E501"] [tool.black] -line-length = 70 -target-version = ['py38'] -preview = true +target-version = ["py38"] +line-length = 75 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | docs +)/ +''' + diff --git a/requirements.txt b/requirements.txt index 3814ccf0..9b7843ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,6 @@ torch>=2.1.1,<3.0 transformers==4.39.0 asyncio>=3.4.3,<4.0 -einops==0.7.0 -langchain-core==0.1.33 langchain-community==0.0.29 langchain-experimental==0.0.55 backoff==2.2.1 diff --git a/scripts/auto_tests_docs/auto_docs_functions.py b/scripts/auto_tests_docs/auto_docs_functions.py index 37bf376d..e5d55fc7 100644 --- a/scripts/auto_tests_docs/auto_docs_functions.py +++ b/scripts/auto_tests_docs/auto_docs_functions.py @@ -52,9 +52,7 @@ def main(): # Gathering all functions from the swarms.utils module functions = [ obj - for name, obj in inspect.getmembers( - sys.modules["swarms.utils"] - ) + for name, obj in inspect.getmembers(sys.modules["swarms.utils"]) if inspect.isfunction(obj) ] diff --git a/scripts/auto_tests_docs/auto_docs_omni.py b/scripts/auto_tests_docs/auto_docs_omni.py index 7fd3cde6..f60650c4 100644 --- a/scripts/auto_tests_docs/auto_docs_omni.py +++ b/scripts/auto_tests_docs/auto_docs_omni.py @@ -57,9 +57,7 @@ def process_documentation( with open(file_path, "w") as file: file.write(doc_content) - print( - f"Processed documentation for {item.__name__}. at {file_path}" - ) + print(f"Processed documentation for {item.__name__}. at {file_path}") def main(module: str = "docs/swarms/structs"): diff --git a/scripts/auto_tests_docs/auto_tests.py b/scripts/auto_tests_docs/auto_tests.py index c9d7c95e..38f00482 100644 --- a/scripts/auto_tests_docs/auto_tests.py +++ b/scripts/auto_tests_docs/auto_tests.py @@ -68,9 +68,7 @@ def create_test(cls): # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content) processed_content = model( - TEST_WRITER_SOP_PROMPT( - input_content, "swarms", "swarms.memory" - ) + TEST_WRITER_SOP_PROMPT(input_content, "swarms", "swarms.memory") ) processed_content = extract_code_from_markdown(processed_content) diff --git a/scripts/auto_tests_docs/auto_tests_functions.py b/scripts/auto_tests_docs/auto_tests_functions.py index 4fa2fafd..109ab46c 100644 --- a/scripts/auto_tests_docs/auto_tests_functions.py +++ b/scripts/auto_tests_docs/auto_tests_functions.py @@ -57,9 +57,7 @@ def main(): # Gathering all functions from the swarms.utils module functions = [ obj - for name, obj in inspect.getmembers( - sys.modules["swarms.utils"] - ) + for name, obj in inspect.getmembers(sys.modules["swarms.utils"]) if inspect.isfunction(obj) ] diff --git a/scripts/auto_tests_docs/mkdocs_handler.py b/scripts/auto_tests_docs/mkdocs_handler.py index 8b1dc0a0..e5a6410b 100644 --- a/scripts/auto_tests_docs/mkdocs_handler.py +++ b/scripts/auto_tests_docs/mkdocs_handler.py @@ -22,9 +22,7 @@ def generate_file_list(directory, output_file): # Remove the file extension file_name, _ = os.path.splitext(file) # Write the file name and path to the output file - f.write( - f'- {file_name}: "swarms/utils/{file_path}"\n' - ) + f.write(f'- {file_name}: "swarms/utils/{file_path}"\n') # Use the function to generate the file list diff --git a/scripts/get_package_requirements.py b/scripts/get_package_requirements.py index 99e139da..5f48ba31 100644 --- a/scripts/get_package_requirements.py +++ b/scripts/get_package_requirements.py @@ -13,18 +13,15 @@ def get_package_versions(requirements_path, output_path): for requirement in requirements: # Skip empty lines and comments - if ( - requirement.strip() == "" - or requirement.strip().startswith("#") + if requirement.strip() == "" or requirement.strip().startswith( + "#" ): continue # Extract package name package_name = requirement.split("==")[0].strip() try: - version = pkg_resources.get_distribution( - package_name - ).version + version = pkg_resources.get_distribution(package_name).version package_versions.append(f"{package_name}=={version}") except pkg_resources.DistributionNotFound: package_versions.append(f"{package_name}: not installed") diff --git a/swarms/agents/tool_agent.py b/swarms/agents/tool_agent.py index 0de72778..001a96d3 100644 --- a/swarms/agents/tool_agent.py +++ b/swarms/agents/tool_agent.py @@ -147,7 +147,5 @@ class ToolAgent(Agent): ) except Exception as error: - logger.error( - f"Error running {self.name} for task: {task}" - ) + logger.error(f"Error running {self.name} for task: {task}") raise error diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index 94c5e585..11115720 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -69,15 +69,11 @@ def cosine_similarity_top_k( score_array = cosine_similarity(X, Y) score_threshold = score_threshold or -1.0 score_array[score_array < score_threshold] = 0 - top_k = min( - top_k or len(score_array), np.count_nonzero(score_array) - ) - top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[ - -top_k: + top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) + top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] + top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][ + ::-1 ] - top_k_idxs = top_k_idxs[ - np.argsort(score_array.ravel()[top_k_idxs]) - ][::-1] ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) scores = score_array.ravel()[top_k_idxs].tolist() return list(zip(*ret_idxs)), scores # type: ignore diff --git a/swarms/memory/dict_shared_memory.py b/swarms/memory/dict_shared_memory.py index f81e2fd4..8ac92d17 100644 --- a/swarms/memory/dict_shared_memory.py +++ b/swarms/memory/dict_shared_memory.py @@ -44,9 +44,7 @@ class DictSharedMemory: entry_id = str(uuid.uuid4()) data = {} epoch = datetime.datetime.utcfromtimestamp(0) - epoch = ( - datetime.datetime.utcnow() - epoch - ).total_seconds() + epoch = (datetime.datetime.utcnow() - epoch).total_seconds() data[entry_id] = { "agent": agent_id, "epoch": epoch, diff --git a/swarms/memory/lanchain_chroma.py b/swarms/memory/lanchain_chroma.py index cd5d832a..f3dd3d20 100644 --- a/swarms/memory/lanchain_chroma.py +++ b/swarms/memory/lanchain_chroma.py @@ -170,9 +170,7 @@ class LangchainChromaVectorMemory(AbstractVectorDatabase): ) texts = [text.page_content for text in texts] elif type == "cos": - texts = self.db.similarity_search_with_score( - query=query, k=k - ) + texts = self.db.similarity_search_with_score(query=query, k=k) texts = [ text[0].page_content for text in texts diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index e0bc72d2..a1b2605f 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -34,9 +34,7 @@ class PostgresDB(AbstractVectorDatabase): table_name (str): The name of the table in the database. """ - self.engine = create_engine( - connection_string, *args, **kwargs - ) + self.engine = create_engine(connection_string, *args, **kwargs) self.table_name = table_name self.VectorModel = self._create_vector_model() diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index d33cb9cd..fb9d32ba 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -123,9 +123,7 @@ class PineconeDB(AbstractVectorDatabase): Returns: str: _description_ """ - vector_id = ( - vector_id if vector_id else str_to_hash(str(vector)) - ) + vector_id = vector_id if vector_id else str_to_hash(str(vector)) params = {"namespace": namespace} | kwargs diff --git a/swarms/memory/short_term_memory.py b/swarms/memory/short_term_memory.py index 370d4e6a..82af9680 100644 --- a/swarms/memory/short_term_memory.py +++ b/swarms/memory/short_term_memory.py @@ -40,9 +40,7 @@ class ShortTermMemory(BaseStructure): self.medium_term_memory = [] self.lock = threading.Lock() - def add( - self, role: str = None, message: str = None, *args, **kwargs - ): + def add(self, role: str = None, message: str = None, *args, **kwargs): """Add a message to the short term memory. Args: @@ -160,9 +158,7 @@ class ShortTermMemory(BaseStructure): with open(filename, "w") as f: json.dump( { - "short_term_memory": ( - self.short_term_memory - ), + "short_term_memory": (self.short_term_memory), "medium_term_memory": ( self.medium_term_memory ), @@ -184,9 +180,7 @@ class ShortTermMemory(BaseStructure): with self.lock: with open(filename) as f: data = json.load(f) - self.short_term_memory = data.get( - "short_term_memory", [] - ) + self.short_term_memory = data.get("short_term_memory", []) self.medium_term_memory = data.get( "medium_term_memory", [] ) diff --git a/swarms/memory/sqlite.py b/swarms/memory/sqlite.py index 7a391303..7922e274 100644 --- a/swarms/memory/sqlite.py +++ b/swarms/memory/sqlite.py @@ -5,9 +5,7 @@ from swarms.memory.base_vectordb import AbstractVectorDatabase try: import sqlite3 except ImportError: - raise ImportError( - "Please install sqlite3 to use the SQLiteDB class." - ) + raise ImportError("Please install sqlite3 to use the SQLiteDB class.") class SQLiteDB(AbstractVectorDatabase): diff --git a/swarms/memory/weaviate_db.py b/swarms/memory/weaviate_db.py index 05ad5388..b5b89e04 100644 --- a/swarms/memory/weaviate_db.py +++ b/swarms/memory/weaviate_db.py @@ -126,9 +126,7 @@ class WeaviateDB(AbstractVectorDatabase): print(f"Error adding object: {error}") raise - def query( - self, collection_name: str, query: str, limit: int = 10 - ): + def query(self, collection_name: str, query: str, limit: int = 10): """Query objects from a specified collection. Args: diff --git a/swarms/models/base_embedding_model.py b/swarms/models/base_embedding_model.py index bb244c6c..cb5e4d28 100644 --- a/swarms/models/base_embedding_model.py +++ b/swarms/models/base_embedding_model.py @@ -25,9 +25,7 @@ class BaseEmbeddingModel( tokenizer: Callable = None chunker: Callable = None - def embed_text_artifact( - self, artifact: TextArtifact - ) -> list[float]: + def embed_text_artifact(self, artifact: TextArtifact) -> list[float]: return self.embed_string(artifact.to_text()) def embed_string(self, string: str) -> list[float]: diff --git a/swarms/models/base_llm.py b/swarms/models/base_llm.py index d69f21a8..6d8ae898 100644 --- a/swarms/models/base_llm.py +++ b/swarms/models/base_llm.py @@ -154,15 +154,11 @@ class AbstractLLM(ABC): Returns: _type_: _description_ """ - return await asyncio.gather( - *(self.arun(task) for task in tasks) - ) + return await asyncio.gather(*(self.arun(task) for task in tasks)) def chat(self, task: str, history: str = "") -> str: """Chat with the model""" - complete_task = ( - task + " | " + history - ) # Delimiter for clarity + complete_task = task + " | " + history # Delimiter for clarity return self.run(complete_task) def __call__(self, task: str) -> str: @@ -209,9 +205,7 @@ class AbstractLLM(ABC): def log_event(self, message: str): """Log an event.""" - logging.info( - f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {message}" - ) + logging.info(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {message}") def save_checkpoint(self, checkpoint_dir: str = "checkpoints"): """Save the model state.""" diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index 25975eaa..0f1d38b1 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -135,9 +135,7 @@ class BaseMultiModalModel: image_pil = Image.open(BytesIO(response.content)) return image_pil except requests.RequestException as error: - print( - f"Error fetching image from {img} and error: {error}" - ) + print(f"Error fetching image from {img} and error: {error}") return None def encode_img(self, img: str): @@ -190,9 +188,7 @@ class BaseMultiModalModel: """Clear the chat history""" self.chat_history = [] - def run_many( - self, tasks: List[str], imgs: List[str], *args, **kwargs - ): + def run_many(self, tasks: List[str], imgs: List[str], *args, **kwargs): """ Run the model on multiple tasks and images all at once using concurrent @@ -206,18 +202,14 @@ class BaseMultiModalModel: """ # Instantiate the thread pool executor - with ThreadPoolExecutor( - max_workers=self.max_workers - ) as executor: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: results = executor.map(self.run, tasks, imgs) # Print the results for debugging for result in results: print(result) - def run_batch( - self, tasks_images: List[Tuple[str, str]] - ) -> List[str]: + def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -244,9 +236,7 @@ class BaseMultiModalModel: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor( - None, self.run_with_retries, task, img - ) + loop.run_in_executor(None, self.run_with_retries, task, img) for task, img in tasks_images ] return await asyncio.gather(*futures) @@ -264,9 +254,7 @@ class BaseMultiModalModel: print(f"Error with the request {error}") continue - def run_batch_with_retries( - self, tasks_images: List[Tuple[str, str]] - ): + def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]): """Run the model with retries""" for i in range(self.retries): try: diff --git a/swarms/models/cog_vlm.py b/swarms/models/cog_vlm.py index a3f820c5..740f8c22 100644 --- a/swarms/models/cog_vlm.py +++ b/swarms/models/cog_vlm.py @@ -299,9 +299,7 @@ class CogVLMMultiModal(BaseMultiModalModel): """ messages = params["messages"] temperature = float(params.get("temperature", 1.0)) - repetition_penalty = float( - params.get("repetition_penalty", 1.0) - ) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) max_new_tokens = int(params.get("max_tokens", 256)) query, history, image_list = self.process_history_and_images( @@ -318,9 +316,7 @@ class CogVLMMultiModal(BaseMultiModalModel): ) inputs = { "input_ids": ( - input_by_model["input_ids"] - .unsqueeze(0) - .to(self.device) + input_by_model["input_ids"].unsqueeze(0).to(self.device) ), "token_type_ids": ( input_by_model["token_type_ids"] @@ -379,9 +375,7 @@ class CogVLMMultiModal(BaseMultiModalModel): "text": generated_text, "usage": { "prompt_tokens": input_echo_len, - "completion_tokens": ( - total_len - input_echo_len - ), + "completion_tokens": (total_len - input_echo_len), "total_tokens": total_len, }, } @@ -437,9 +431,7 @@ class CogVLMMultiModal(BaseMultiModalModel): for item in content: if isinstance(item, ImageUrlContent): image_url = item.image_url.url - if image_url.startswith( - "data:image/jpeg;base64," - ): + if image_url.startswith("data:image/jpeg;base64,"): base64_encoded_image = image_url.split( "data:image/jpeg;base64," )[1] @@ -471,9 +463,7 @@ class CogVLMMultiModal(BaseMultiModalModel): text_content, ) else: - raise AssertionError( - "assistant reply before user" - ) + raise AssertionError("assistant reply before user") else: raise AssertionError(f"unrecognized role: {role}") diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 0e02c3d6..786807dc 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -199,9 +199,7 @@ class Dalle3: with open(full_path, "wb") as file: file.write(response.content) else: - raise ValueError( - f"Failed to download image from {img_url}" - ) + raise ValueError(f"Failed to download image from {img_url}") def create_variations(self, img: str): """ @@ -249,9 +247,7 @@ class Dalle3: "red", ) ) - print( - colored(f"Error running Dalle3: {error.error}", "red") - ) + print(colored(f"Error running Dalle3: {error.error}", "red")) raise error def print_dashboard(self): @@ -310,9 +306,7 @@ class Dalle3: executor.submit(self, task): task for task in tasks } results = [] - for future in concurrent.futures.as_completed( - future_to_task - ): + for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] try: img = future.result() @@ -359,9 +353,7 @@ class Dalle3: """Str method for the Dalle3 class""" return f"Dalle3(image_url={self.image_url})" - @backoff.on_exception( - backoff.expo, Exception, max_tries=max_retries - ) + @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) def rate_limited_call(self, task: str): """Rate limited call to the Dalle3 API""" return self.__call__(task) diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 951dcd10..bd2bbcf3 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -70,9 +70,7 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.torch_dtype = ( - torch.float16 - if torch.cuda.is_available() - else torch.float32 + torch.float16 if torch.cuda.is_available() else torch.float32 ) self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -112,9 +110,7 @@ class DistilWhisperModel: :return: The transcribed text. """ loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, self.transcribe, inputs - ) + return await loop.run_in_executor(None, self.transcribe, inputs) def real_time_transcribe(self, audio_file_path, chunk_duration=5): """ @@ -138,9 +134,7 @@ class DistilWhisperModel: sample_rate = audio_input.sampling_rate len(audio_input.array) / sample_rate chunks = [ - audio_input.array[ - i : i + sample_rate * chunk_duration - ] + audio_input.array[i : i + sample_rate * chunk_duration] for i in range( 0, len(audio_input.array), @@ -149,9 +143,7 @@ class DistilWhisperModel: ] print( - colored( - "Starting real-time transcription...", "green" - ) + colored("Starting real-time transcription...", "green") ) for i, chunk in enumerate(chunks): @@ -162,8 +154,8 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = ( - processed_inputs.input_values.to(self.device) + processed_inputs = processed_inputs.input_values.to( + self.device ) # Generate transcription for the chunk @@ -174,9 +166,7 @@ class DistilWhisperModel: # Print the chunk's transcription print( - colored( - f"Chunk {i+1}/{len(chunks)}: ", "yellow" - ) + colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + transcription ) diff --git a/swarms/models/gemini.py b/swarms/models/gemini.py index 276cd05d..9c100814 100644 --- a/swarms/models/gemini.py +++ b/swarms/models/gemini.py @@ -112,9 +112,7 @@ class Gemini(BaseMultiModalModel): ) # Initialize the model - self.model = genai.GenerativeModel( - model_name, *args, **kwargs - ) + self.model = genai.GenerativeModel(model_name, *args, **kwargs) # Check for the key if self.gemini_api_key is None: @@ -211,9 +209,7 @@ class Gemini(BaseMultiModalModel): raise ValueError("Please provide a Gemini API key") # Load the image - img = [ - {"mime_type": type, "data": Path(img).read_bytes()} - ] + img = [{"mime_type": type, "data": Path(img).read_bytes()}] except Exception as error: print(f"Error processing image: {error}") diff --git a/swarms/models/gpt4_sam.py b/swarms/models/gpt4_sam.py index 37dde6a0..1fda68c5 100644 --- a/swarms/models/gpt4_sam.py +++ b/swarms/models/gpt4_sam.py @@ -42,9 +42,7 @@ class GPT4VSAM(BaseMultiModalModel): self.device = device self.return_related_marks = return_related_marks - self.sam = SegmentAnythingMarkGenerator( - device, *args, **kwargs - ) + self.sam = SegmentAnythingMarkGenerator(device, *args, **kwargs) self.visualizer = MarkVisualizer(*args, **kwargs) def load_img(self, img: str) -> Any: diff --git a/swarms/models/gpt4_vision_api.py b/swarms/models/gpt4_vision_api.py index 5966a0b6..724a2def 100644 --- a/swarms/models/gpt4_vision_api.py +++ b/swarms/models/gpt4_vision_api.py @@ -15,8 +15,7 @@ try: import cv2 except ImportError: print( - "OpenCV not installed. Please install OpenCV to use this" - " model." + "OpenCV not installed. Please install OpenCV to use this" " model." ) raise ImportError @@ -248,9 +247,7 @@ class GPT4VisionAPI(BaseMultiModalModel): if not success: break _, buffer = cv2.imencode(".jpg", frame) - base64_frames.append( - base64.b64encode(buffer).decode("utf-8") - ) + base64_frames.append(base64.b64encode(buffer).decode("utf-8")) video.release() print(len(base64_frames), "frames read.") @@ -433,9 +430,7 @@ class GPT4VisionAPI(BaseMultiModalModel): def health_check(self): """Health check for the GPT4Vision model""" try: - response = requests.get( - "https://api.openai.com/v1/engines" - ) + response = requests.get("https://api.openai.com/v1/engines") return response.status_code == 200 except requests.RequestException as error: print(f"Health check failed: {error}") diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 957b398f..e459b251 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -203,9 +203,7 @@ class HuggingfaceLLM(AbstractLLM): results = list(executor.map(self.run, tasks)) return results - def run_batch( - self, tasks_images: List[Tuple[str, str]] - ) -> List[str]: + def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index cc654221..8a9c501f 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -77,9 +77,7 @@ class Idefics(BaseMultiModalModel): def __init__( self, - model_name: Optional[ - str - ] = "HuggingFaceM4/idefics-9b-instruct", + model_name: Optional[str] = "HuggingFaceM4/idefics-9b-instruct", device: Callable = autodetect_device, torch_dtype=torch.bfloat16, max_length: int = 100, diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index ce19c37d..59f1f98e 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -87,8 +87,8 @@ class Kosmos(BaseMultiModalModel): skip_special_tokens=True, )[0] - processed_text, entities = ( - self.processor.post_process_generation(generated_texts) + processed_text, entities = self.processor.post_process_generation( + generated_texts ) return processed_text, entities @@ -189,9 +189,7 @@ class Kosmos(BaseMultiModalModel): ) # draw bbox # random color - color = tuple( - np.random.randint(0, 255, size=3).tolist() - ) + color = tuple(np.random.randint(0, 255, size=3).tolist()) new_image = cv2.rectangle( new_image, (orig_x1, orig_y1), @@ -210,9 +208,7 @@ class Kosmos(BaseMultiModalModel): if ( y1 - < text_height - + text_offset_original - + 2 * text_spaces + < text_height + text_offset_original + 2 * text_spaces ): y1 = ( orig_y1 diff --git a/swarms/models/medical_sam.py b/swarms/models/medical_sam.py index 8d096ba5..8a183e79 100644 --- a/swarms/models/medical_sam.py +++ b/swarms/models/medical_sam.py @@ -115,12 +115,10 @@ class MedicalSAM: if len(box_torch.shape) == 2: box_torch = box_torch[:, None, :] - sparse_embeddings, dense_embeddings = ( - self.model.prompt_encoder( - points=None, - boxes=box_torch, - masks=None, - ) + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=None, + boxes=box_torch, + masks=None, ) low_res_logits, _ = self.model.mask_decoder( diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index dc7ba462..e2003c97 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -74,9 +74,9 @@ class Mistral(AbstractLLM): """Run the model on a given task.""" try: - model_inputs = self.tokenizer( - [task], return_tensors="pt" - ).to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to( + self.device + ) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -85,9 +85,7 @@ class Mistral(AbstractLLM): max_new_tokens=self.max_length, **kwargs, ) - output_text = self.tokenizer.batch_decode(generated_ids)[ - 0 - ] + output_text = self.tokenizer.batch_decode(generated_ids)[0] return output_text except Exception as e: raise ValueError(f"Error running the model: {str(e)}") diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index 543e3f41..6b5c2ace 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -146,9 +146,7 @@ class MPT7B: self, prompts: list, temperature: float = 1.0 ) -> list: """Batch generate text""" - self.logger.info( - f"Generating text for {len(prompts)} prompts..." - ) + self.logger.info(f"Generating text for {len(prompts)} prompts...") results = [] with torch.autocast("cuda", dtype=torch.bfloat16): for prompt in prompts: diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index f352ee17..e356f204 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -53,9 +53,7 @@ def _create_retry_decorator( | retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.APIConnectionError) | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type( - llm.error.ServiceUnavailableError - ) + | retry_if_exception_type(llm.error.ServiceUnavailableError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -79,9 +77,7 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any: | retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.APIConnectionError) | retry_if_exception_type(llm.error.RateLimitError) - | retry_if_exception_type( - llm.error.ServiceUnavailableError - ) + | retry_if_exception_type(llm.error.ServiceUnavailableError) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -102,15 +98,11 @@ def _check_response(response: dict) -> dict: if any(len(d["embedding"]) == 1 for d in response["data"]): import llm - raise llm.error.APIError( - "OpenAI API returned an empty embedding" - ) + raise llm.error.APIError("OpenAI API returned an empty embedding") return response -def embed_with_retry( - embeddings: OpenAIEmbeddings, **kwargs: Any -) -> Any: +def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @@ -181,7 +173,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any = None #: :meta private: model: str = "text-embedding-ada-002" - deployment: str = model # to support Azure OpenAI Service custom deployment names + deployment: str = ( + model # to support Azure OpenAI Service custom deployment names + ) openai_api_version: str | None = None # to support Azure OpenAI Service custom endpoints openai_api_base: str | None = None @@ -194,9 +188,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key: str | None = None openai_organization: str | None = None allowed_special: Literal["all"] | set[str] = set() - disallowed_special: Literal["all"] | set[str] | Sequence[ - str - ] = "all" + disallowed_special: Literal["all"] | set[str] | Sequence[str] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 @@ -228,9 +220,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: - raise ValueError( - f"Found {field_name} supplied twice." - ) + raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: warnings.warn( f"""WARNING! {field_name} is not default parameter. @@ -339,9 +329,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): engine: str, chunk_size: int | None = None, ) -> list[list[float]]: - embeddings: list[list[float]] = [ - [] for _ in range(len(texts)) - ] + embeddings: list[list[float]] = [[] for _ in range(len(texts))] try: import tiktoken except ImportError: @@ -358,8 +346,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): encoding = tiktoken.encoding_for_model(model_name) except KeyError: logger.warning( - "Warning: model not found. Using cl100k_base" - " encoding." + "Warning: model not found. Using cl100k_base" " encoding." ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) @@ -374,9 +361,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append( - token[j : j + self.embedding_ctx_length] - ) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: list[list[float]] = [] @@ -402,9 +387,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): r["embedding"] for r in response["data"] ) - results: list[list[list[float]]] = [ - [] for _ in range(len(texts)) - ] + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [ [] for _ in range(len(texts)) ] @@ -424,9 +407,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): average = np.average( _result, axis=0, weights=num_tokens_in_batch[i] ) - embeddings[i] = ( - average / np.linalg.norm(average) - ).tolist() + embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -439,9 +420,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): engine: str, chunk_size: int | None = None, ) -> list[list[float]]: - embeddings: list[list[float]] = [ - [] for _ in range(len(texts)) - ] + embeddings: list[list[float]] = [[] for _ in range(len(texts))] try: import tiktoken except ImportError: @@ -458,8 +437,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): encoding = tiktoken.encoding_for_model(model_name) except KeyError: logger.warning( - "Warning: model not found. Using cl100k_base" - " encoding." + "Warning: model not found. Using cl100k_base" " encoding." ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) @@ -474,9 +452,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): disallowed_special=self.disallowed_special, ) for j in range(0, len(token), self.embedding_ctx_length): - tokens.append( - token[j : j + self.embedding_ctx_length] - ) + tokens.append(token[j : j + self.embedding_ctx_length]) indices.append(i) batched_embeddings: list[list[float]] = [] @@ -491,9 +467,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): r["embedding"] for r in response["data"] ) - results: list[list[list[float]]] = [ - [] for _ in range(len(texts)) - ] + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [ [] for _ in range(len(texts)) ] @@ -515,9 +489,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): average = np.average( _result, axis=0, weights=num_tokens_in_batch[i] ) - embeddings[i] = ( - average / np.linalg.norm(average) - ).tolist() + embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -536,9 +508,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return self._get_len_safe_embeddings( - texts, engine=self.deployment - ) + return self._get_len_safe_embeddings(texts, engine=self.deployment) async def aembed_documents( self, texts: list[str], chunk_size: int | None = 0 diff --git a/swarms/models/palm.py b/swarms/models/palm.py index 1d7f71d6..715715c5 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -129,14 +129,9 @@ class GooglePalm(BaseLLM, BaseModel): values["temperature"] is not None and not 0 <= values["temperature"] <= 1 ): - raise ValueError( - "temperature must be in the range [0.0, 1.0]" - ) + raise ValueError("temperature must be in the range [0.0, 1.0]") - if ( - values["top_p"] is not None - and not 0 <= values["top_p"] <= 1 - ): + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: raise ValueError("top_p must be in the range [0.0, 1.0]") if values["top_k"] is not None and values["top_k"] <= 0: @@ -146,9 +141,7 @@ class GooglePalm(BaseLLM, BaseModel): values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0 ): - raise ValueError( - "max_output_tokens must be greater than zero" - ) + raise ValueError("max_output_tokens must be greater than zero") return values @@ -177,12 +170,8 @@ class GooglePalm(BaseLLM, BaseModel): prompt_generations = [] for candidate in completion.candidates: raw_text = candidate["output"] - stripped_text = _strip_erroneous_leading_spaces( - raw_text - ) - prompt_generations.append( - Generation(text=stripped_text) - ) + stripped_text = _strip_erroneous_leading_spaces(raw_text) + prompt_generations.append(Generation(text=stripped_text)) generations.append(prompt_generations) return LLMResult(generations=generations) diff --git a/swarms/models/qwen.py b/swarms/models/qwen.py index b5a4ed1a..9d8f8b9c 100644 --- a/swarms/models/qwen.py +++ b/swarms/models/qwen.py @@ -139,6 +139,4 @@ class QwenVLMultiModal(BaseMultiModalModel): ) return response, history except Exception as e: - raise Exception( - "An error occurred during the chat." - ) from e + raise Exception("An error occurred during the chat.") from e diff --git a/swarms/models/sampling_params.py b/swarms/models/sampling_params.py index d231c295..5f33ac82 100644 --- a/swarms/models/sampling_params.py +++ b/swarms/models/sampling_params.py @@ -143,9 +143,7 @@ class SamplingParams: self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens - self.spaces_between_special_tokens = ( - spaces_between_special_tokens - ) + self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self._verify_args() @@ -189,31 +187,23 @@ class SamplingParams: f" {self.temperature}." ) if not 0.0 < self.top_p <= 1.0: - raise ValueError( - f"top_p must be in (0, 1], got {self.top_p}." - ) + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if self.top_k < -1 or self.top_k == 0: raise ValueError( "top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError( - f"min_p must be in [0, 1], got {self.min_p}." - ) + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( - "max_tokens must be at least 1, got" - f" {self.max_tokens}." + "max_tokens must be at least 1, got" f" {self.max_tokens}." ) if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}." ) - if ( - self.prompt_logprobs is not None - and self.prompt_logprobs < 0 - ): + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError( "prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}." @@ -230,13 +220,9 @@ class SamplingParams: "temperature must be 0 when using beam search." ) if self.top_p < 1.0 - _SAMPLING_EPS: - raise ValueError( - "top_p must be 1 when using beam search." - ) + raise ValueError("top_p must be 1 when using beam search.") if self.top_k != -1: - raise ValueError( - "top_k must be -1 when using beam search." - ) + raise ValueError("top_k must be -1 when using beam search.") if self.early_stopping not in [True, False, "never"]: raise ValueError( "early_stopping must be True, False, or 'never', " diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index b9f2653b..5cd9bc9e 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -88,15 +88,11 @@ class SpeechT5: self.model_name = model_name self.vocoder_name = vocoder_name self.dataset_name = dataset_name - self.processor = SpeechT5Processor.from_pretrained( - self.model_name - ) + self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained( self.model_name ) - self.vocoder = SpeechT5HifiGan.from_pretrained( - self.vocoder_name - ) + self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) self.embeddings_dataset = load_dataset( self.dataset_name, split="validation" ) @@ -121,9 +117,7 @@ class SpeechT5: def set_model(self, model_name: str): """Set the model to a new model.""" self.model_name = model_name - self.processor = SpeechT5Processor.from_pretrained( - self.model_name - ) + self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained( self.model_name ) @@ -131,9 +125,7 @@ class SpeechT5: def set_vocoder(self, vocoder_name): """Set the vocoder to a new vocoder.""" self.vocoder_name = vocoder_name - self.vocoder = SpeechT5HifiGan.from_pretrained( - self.vocoder_name - ) + self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index 3042d1ab..fab419a7 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -127,9 +127,7 @@ class SSD1B: if task in self.cache: return self.cache[task] try: - img = self.pipe( - prompt=task, neg_prompt=neg_prompt - ).images[0] + img = self.pipe(prompt=task, neg_prompt=neg_prompt).images[0] # Generate a unique filename for the image img_name = f"{uuid.uuid4()}.{self.image_format}" @@ -223,9 +221,7 @@ class SSD1B: executor.submit(self, task): task for task in tasks } results = [] - for future in concurrent.futures.as_completed( - future_to_task - ): + for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] try: img = future.result() @@ -272,9 +268,7 @@ class SSD1B: """Str method for the SSD1B class""" return f"SSD1B(image_url={self.image_url})" - @backoff.on_exception( - backoff.expo, Exception, max_tries=max_retries - ) + @backoff.on_exception(backoff.expo, Exception, max_tries=max_retries) def rate_limited_call(self, task: str): """Rate limited call to the SSD1B API""" return self.__call__(task) diff --git a/swarms/models/together.py b/swarms/models/together.py index 37d9d0e5..06cc18ba 100644 --- a/swarms/models/together.py +++ b/swarms/models/together.py @@ -120,9 +120,7 @@ class TogetherLLM(AbstractLLM): out = response.json() content = ( - out["choices"][0] - .get("message", {}) - .get("content", None) + out["choices"][0].get("message", {}).get("content", None) ) if self.streaming_enabled: content = self.stream_response(content) diff --git a/swarms/models/ultralytics_model.py b/swarms/models/ultralytics_model.py index 3cb9c956..fcaf319d 100644 --- a/swarms/models/ultralytics_model.py +++ b/swarms/models/ultralytics_model.py @@ -15,9 +15,7 @@ class UltralyticsModel(BaseMultiModalModel): **kwargs: Arbitrary keyword arguments. """ - def __init__( - self, model_name: str = "yolov8n.pt", *args, **kwargs - ): + def __init__(self, model_name: str = "yolov8n.pt", *args, **kwargs): super().__init__(*args, **kwargs) self.model_name = model_name diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index 0dd6c1a1..1052d88b 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -78,9 +78,7 @@ class WizardLLMStoryTeller: bnb_config = BitsAndBytesConfig(**quantization_config) try: - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id - ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=bnb_config ) diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index ff65b856..f55b9996 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -78,9 +78,7 @@ class YarnMistral128: bnb_config = BitsAndBytesConfig(**quantization_config) try: - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_id - ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=bnb_config, diff --git a/swarms/models/yi_200k.py b/swarms/models/yi_200k.py index 1f1258aa..8f9f7635 100644 --- a/swarms/models/yi_200k.py +++ b/swarms/models/yi_200k.py @@ -87,9 +87,7 @@ class Yi34B200k: top_k=self.top_k, top_p=self.top_p, ) - return self.tokenizer.decode( - outputs[0], skip_special_tokens=True - ) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) # # Example usage diff --git a/swarms/prompts/self_operating_prompt.py b/swarms/prompts/self_operating_prompt.py index bb4856e0..50e13eb4 100644 --- a/swarms/prompts/self_operating_prompt.py +++ b/swarms/prompts/self_operating_prompt.py @@ -91,8 +91,7 @@ def format_vision_prompt(objective, previous_action): """ if previous_action: previous_action = ( - "Here was the previous action you took:" - f" {previous_action}" + "Here was the previous action you took:" f" {previous_action}" ) else: previous_action = "" diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index fc7a023e..2b079c9d 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -19,7 +19,7 @@ from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( ) from swarms.structs.conversation import Conversation from swarms.tools.tool import BaseTool -from swarms.utils.code_interpreter import SubprocessCodeInterpreter +from swarms.tools.code_interpreter import SubprocessCodeInterpreter from swarms.utils.data_to_text import data_to_text from swarms.utils.parse_code import extract_code_from_markdown from swarms.utils.pdf_to_text import pdf_to_text @@ -326,9 +326,7 @@ class Agent: ) else: - tools_prompt = tool_usage_worker_prompt( - tools=self.tools - ) + tools_prompt = tool_usage_worker_prompt(tools=self.tools) # Append the tools prompt to the short_term_memory self.short_memory.add( @@ -354,9 +352,7 @@ class Agent: f"{self.agent_name}.log", level="INFO", colorize=True, - format=( - "{time} {message}" - ), + format=("{time} {message}"), backtrace=True, diagnose=True, ) @@ -403,9 +399,7 @@ class Agent: self.llm.temperature = 0.7 except Exception as error: print( - colored( - f"Error dynamically changing temperature: {error}" - ) + colored(f"Error dynamically changing temperature: {error}") ) def format_prompt(self, template, **kwargs: Any) -> str: @@ -418,24 +412,16 @@ class Agent: logger.info(f"Adding task to memory: {task}") self.short_memory.add(f"{self.user_name}: {task}") except Exception as error: - print( - colored( - f"Error adding task to memory: {error}", "red" - ) - ) + print(colored(f"Error adding task to memory: {error}", "red")) def add_message_to_memory(self, message: str): """Add the message to the memory""" try: logger.info(f"Adding message to memory: {message}") - self.short_memory.add( - role=self.agent_name, content=message - ) + self.short_memory.add(role=self.agent_name, content=message) except Exception as error: print( - colored( - f"Error adding message to memory: {error}", "red" - ) + colored(f"Error adding message to memory: {error}", "red") ) def add_message_to_memory_and_truncate(self, message: str): @@ -549,9 +535,7 @@ class Agent: history = [f"{user_name}: {task}"] return history - def _dynamic_prompt_setup( - self, dynamic_prompt: str, task: str - ) -> str: + def _dynamic_prompt_setup(self, dynamic_prompt: str, task: str) -> str: """_dynamic_prompt_setup summary Args: @@ -561,9 +545,7 @@ class Agent: Returns: str: _description_ """ - dynamic_prompt = ( - dynamic_prompt or self.construct_dynamic_prompt() - ) + dynamic_prompt = dynamic_prompt or self.construct_dynamic_prompt() combined_prompt = f"{dynamic_prompt}\n{task}" return combined_prompt @@ -581,9 +563,7 @@ class Agent: self.activate_autonomous_agent() if task: - self.short_memory.add( - role=self.user_name, content=task - ) + self.short_memory.add(role=self.user_name, content=task) loop_count = 0 response = None @@ -600,9 +580,7 @@ class Agent: if self.dynamic_temperature_enabled: self.dynamic_temperature() - task_prompt = ( - self.short_memory.return_history_as_string() - ) + task_prompt = self.short_memory.return_history_as_string() attempt = 0 success = False @@ -621,9 +599,7 @@ class Agent: if self.tools: # Extract code from markdown - response = extract_code_from_markdown( - response - ) + response = extract_code_from_markdown(response) # Execute the tool by name execute_tool_by_name( @@ -634,15 +610,13 @@ class Agent: if self.code_interpreter: # Extract code from markdown - extracted_code = ( - extract_code_from_markdown(response) + extracted_code = extract_code_from_markdown( + response ) # Execute the code # execution = execute_command(extracted_code) - execution = CodeExecutor().run( - extracted_code - ) + execution = CodeExecutor().run(extracted_code) # Add the execution to the memory self.short_memory.add( @@ -658,9 +632,7 @@ class Agent: ) if self.evaluator: - evaluated_response = self.evaluator( - response - ) + evaluated_response = self.evaluator(response) print( "Evaluated Response:" f" {evaluated_response}" @@ -672,9 +644,7 @@ class Agent: # Sentiment analysis if self.sentiment_analyzer: - sentiment = self.sentiment_analyzer( - response - ) + sentiment = self.sentiment_analyzer(response) print(f"Sentiment: {sentiment}") if sentiment > self.sentiment_threshold: @@ -726,9 +696,8 @@ class Agent: and self._check_stopping_condition(response) ): break - elif ( - self.stopping_func is not None - and self.stopping_func(response) + elif self.stopping_func is not None and self.stopping_func( + response ): break @@ -826,9 +795,7 @@ class Agent: context = f""" System: This reminds you of these events from your past: [{ltr}] """ - return self.short_memory.add( - role=self.agent_name, content=context - ) + return self.short_memory.add(role=self.agent_name, content=context) def add_memory(self, message: str): """Add a memory to the agent @@ -840,9 +807,7 @@ class Agent: _type_: _description_ """ logger.info(f"Adding memory: {message}") - return self.short_memory.add( - role=self.agent_name, content=message - ) + return self.short_memory.add(role=self.agent_name, content=message) async def run_concurrent(self, tasks: List[str], **kwargs): """ @@ -889,9 +854,7 @@ class Agent: json.dump(self.short_memory, f) # print(f"Saved agent history to {file_path}") except Exception as error: - print( - colored(f"Error saving agent history: {error}", "red") - ) + print(colored(f"Error saving agent history: {error}", "red")) def load(self, file_path: str): """ @@ -916,23 +879,11 @@ class Agent: Prints the entire history and memory of the agent. Each message is colored and formatted for better readability. """ - print( - colored( - "Agent History and Memory", "cyan", attrs=["bold"] - ) - ) - print( - colored( - "========================", "cyan", attrs=["bold"] - ) - ) - for loop_index, history in enumerate( - self.short_memory, start=1 - ): + print(colored("Agent History and Memory", "cyan", attrs=["bold"])) + print(colored("========================", "cyan", attrs=["bold"])) + for loop_index, history in enumerate(self.short_memory, start=1): print( - colored( - f"\nLoop {loop_index}:", "yellow", attrs=["bold"] - ) + colored(f"\nLoop {loop_index}:", "yellow", attrs=["bold"]) ) for message in history: speaker, _, message_text = message.partition(": ") @@ -943,8 +894,7 @@ class Agent: ) else: print( - colored(f"{speaker}:", "blue") - + f" {message_text}" + colored(f"{speaker}:", "blue") + f" {message_text}" ) print(colored("------------------------", "cyan")) print(colored("End of Agent History", "cyan", attrs=["bold"])) @@ -975,9 +925,7 @@ class Agent: self.short_memory.add( role=self.agent_name, content=response ) - self.short_memory.add( - role=self.user_name, content=task - ) + self.short_memory.add(role=self.user_name, content=task) else: self.short_memory.add( role=self.agent_name, content=response @@ -1054,9 +1002,7 @@ class Agent: Apply the response filters to the response """ - logger.info( - f"Applying response filters to response: {response}" - ) + logger.info(f"Applying response filters to response: {response}") for word in self.response_filters: response = response.replace(word, "[FILTERED]") return response @@ -1096,9 +1042,7 @@ class Agent: with open(file_path, "w") as f: yaml.dump(self.__dict__, f) except Exception as error: - print( - colored(f"Error saving agent to YAML: {error}", "red") - ) + print(colored(f"Error saving agent to YAML: {error}", "red")) def save_state(self, file_path: str) -> None: """ @@ -1126,9 +1070,7 @@ class Agent: "retry_interval": self.retry_interval, "interactive": self.interactive, "dashboard": self.dashboard, - "dynamic_temperature": ( - self.dynamic_temperature_enabled - ), + "dynamic_temperature": (self.dynamic_temperature_enabled), "autosave": self.autosave, "saved_state_path": self.saved_state_path, "max_loops": self.max_loops, @@ -1137,14 +1079,10 @@ class Agent: with open(file_path, "w") as f: json.dump(state, f, indent=4) - saved = colored( - f"Saved agent state to: {file_path}", "green" - ) + saved = colored(f"Saved agent state to: {file_path}", "green") print(saved) except Exception as error: - print( - colored(f"Error saving agent state: {error}", "red") - ) + print(colored(f"Error saving agent state: {error}", "red")) def state_to_str(self): """Transform the JSON into a string""" @@ -1163,9 +1101,7 @@ class Agent: "retry_interval": self.retry_interval, "interactive": self.interactive, "dashboard": self.dashboard, - "dynamic_temperature": ( - self.dynamic_temperature_enabled - ), + "dynamic_temperature": (self.dynamic_temperature_enabled), "autosave": self.autosave, "saved_state_path": self.saved_state_path, "max_loops": self.max_loops, @@ -1214,9 +1150,7 @@ class Agent: print(f"Agent state loaded from {file_path}") except Exception as error: - print( - colored(f"Error loading agent state: {error}", "red") - ) + print(colored(f"Error loading agent state: {error}", "red")) def retry_on_failure( self, @@ -1232,9 +1166,7 @@ class Agent: try: return function() except Exception as error: - logging.error( - f"Error generating response: {error}" - ) + logging.error(f"Error generating response: {error}") attempt += 1 time.sleep(retry_delay) raise Exception("All retry attempts failed") @@ -1320,9 +1252,7 @@ class Agent: for doc in docs: data = data_to_text(doc) - return self.short_memory.add( - role=self.user_name, content=data - ) + return self.short_memory.add(role=self.user_name, content=data) except Exception as error: print(colored(f"Error ingesting docs: {error}", "red")) @@ -1338,9 +1268,7 @@ class Agent: try: logger.info(f"Ingesting pdf: {pdf}") text = pdf_to_text(pdf) - return self.short_memory.add( - role=self.user_name, content=text - ) + return self.short_memory.add(role=self.user_name, content=text) except Exception as error: print(colored(f"Error ingesting pdf: {error}", "red")) @@ -1361,11 +1289,7 @@ class Agent: message = f"{agent_name}: {message}" return self.run(message, *args, **kwargs) except Exception as error: - print( - colored( - f"Error sending agent message: {error}", "red" - ) - ) + print(colored(f"Error sending agent message: {error}", "red")) def truncate_history(self): """ @@ -1407,9 +1331,7 @@ class Agent: for file in files: text = data_to_text(file) - return self.short_memory.add( - role=self.user_name, content=text - ) + return self.short_memory.add(role=self.user_name, content=text) except Exception as error: print( colored( diff --git a/swarms/structs/agent_rearrange.py b/swarms/structs/agent_rearrange.py index 4c56d0df..1a980985 100644 --- a/swarms/structs/agent_rearrange.py +++ b/swarms/structs/agent_rearrange.py @@ -117,9 +117,7 @@ class AgentRearrange(BaseSwarm): return None task_to_run = specific_tasks.get(dest_agent_name, task) if self.custom_prompt: - out = dest_agent.run( - f"{task_to_run} {self.custom_prompt}" - ) + out = dest_agent.run(f"{task_to_run} {self.custom_prompt}") else: out = dest_agent.run(f"{task_to_run} (from {source})") return out @@ -138,9 +136,7 @@ class AgentRearrange(BaseSwarm): results.append(result) else: for destination in destinations: - task = specific_tasks.get( - destination, default_task - ) + task = specific_tasks.get(destination, default_task) destination_agent = self.self_find_agent_by_name( destination ) @@ -156,9 +152,7 @@ class AgentRearrange(BaseSwarm): **specific_tasks, ): self.flows.clear() # Reset previous flows - results = self.process_flows( - pattern, default_task, specific_tasks - ) + results = self.process_flows(pattern, default_task, specific_tasks) return results diff --git a/swarms/structs/async_workflow.py b/swarms/structs/async_workflow.py index 6cf9e312..9ac9018a 100644 --- a/swarms/structs/async_workflow.py +++ b/swarms/structs/async_workflow.py @@ -67,9 +67,7 @@ class AsyncWorkflow: except Exception as error: logger.error(f"[ERROR][AsyncWorkflow] {error}") - async def delete( - self, task: Any = None, tasks: List[Task] = None - ): + async def delete(self, task: Any = None, tasks: List[Task] = None): """Delete a task from the workflow""" try: if task: diff --git a/swarms/structs/auto_swarm.py b/swarms/structs/auto_swarm.py index 10946331..717829dc 100644 --- a/swarms/structs/auto_swarm.py +++ b/swarms/structs/auto_swarm.py @@ -140,9 +140,7 @@ class AutoSwarmRouter(BaseSwarm): if self.name in self.swarm_dict: # If a match is found then send the task to the swarm - out = self.swarm_dict[self.name].run( - task, *args, **kwargs - ) + out = self.swarm_dict[self.name].run(task, *args, **kwargs) if self.custom_postprocess: # If custom postprocess function is provided then run it @@ -151,9 +149,7 @@ class AutoSwarmRouter(BaseSwarm): return out # If no match is found then return None - raise ValueError( - f"Swarm with name {self.name} not found." - ) + raise ValueError(f"Swarm with name {self.name} not found.") except Exception as e: logger.error(f"Error: {e}") raise e diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index 860b6423..c3a384e3 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -155,9 +155,7 @@ class AutoScaler(BaseStructure): for _ in range(new_agents_counts): self.agents_pool.append(self.agents[0]()) except Exception as error: - print( - f"Error scaling up: {error} try again with a new task" - ) + print(f"Error scaling up: {error} try again with a new task") def scale_down(self): """scale down""" @@ -169,13 +167,10 @@ class AutoScaler(BaseStructure): del self.agents_pool[-1] # remove last agent except Exception as error: print( - f"Error scaling down: {error} try again with a new" - " task" + f"Error scaling down: {error} try again with a new" " task" ) - def run( - self, agent_id, task: Optional[str] = None, *args, **kwargs - ): + def run(self, agent_id, task: Optional[str] = None, *args, **kwargs): """Run agent the task on the agent id Args: @@ -203,11 +198,7 @@ class AutoScaler(BaseStructure): sleep(60) # check minute pending_tasks = self.task_queue.qsize() active_agents = sum( - [ - 1 - for agent in self.agents_pool - if agent.is_busy() - ] + [1 for agent in self.agents_pool if agent.is_busy()] ) if ( @@ -246,17 +237,13 @@ class AutoScaler(BaseStructure): if available_agent: available_agent.run(task) except Exception as error: - print( - f"Error starting: {error} try again with a new task" - ) + print(f"Error starting: {error} try again with a new task") def check_agent_health(self): """Checks the health of each agent and replaces unhealthy agents.""" for i, agent in enumerate(self.agents_pool): if not agent.is_healthy(): - logging.warning( - f"Replacing unhealthy agent at index {i}" - ) + logging.warning(f"Replacing unhealthy agent at index {i}") self.agents_pool[i] = self.agent() def balance_load(self): @@ -273,9 +260,7 @@ class AutoScaler(BaseStructure): " task" ) - def set_scaling_strategy( - self, strategy: Callable[[int, int], int] - ): + def set_scaling_strategy(self, strategy: Callable[[int, int], int]): """Set a custom scaling strategy.""" self.custom_scale_strategy = strategy diff --git a/swarms/structs/base_structure.py b/swarms/structs/base_structure.py index 170960cf..5cc7f57e 100644 --- a/swarms/structs/base_structure.py +++ b/swarms/structs/base_structure.py @@ -187,9 +187,7 @@ class BaseStructure(BaseModel): async def run_async(self, *args, **kwargs): """Run the structure asynchronously.""" loop = asyncio.get_event_loop() - return await loop.run_in_executor( - None, self.run, *args, **kwargs - ) + return await loop.run_in_executor(None, self.run, *args, **kwargs) async def save_metadata_async(self, metadata: Dict[str, Any]): """Save metadata to file asynchronously. @@ -222,9 +220,7 @@ class BaseStructure(BaseModel): None, self.log_error, error_message ) - async def save_artifact_async( - self, artifact: Any, artifact_name: str - ): + async def save_artifact_async(self, artifact: Any, artifact_name: str): """Save artifact to file asynchronously. Args: @@ -266,9 +262,7 @@ class BaseStructure(BaseModel): None, self.log_event, event, event_type ) - async def asave_to_file( - self, data: Any, file: str, *args, **kwargs - ): + async def asave_to_file(self, data: Any, file: str, *args, **kwargs): """Save data to file asynchronously. Args: @@ -357,8 +351,7 @@ class BaseStructure(BaseModel): """ with ThreadPoolExecutor(max_workers=batch_size) as executor: futures = [ - executor.submit(self.run, data) - for data in batched_data + executor.submit(self.run, data) for data in batched_data ] return [future.result() for future in futures] @@ -418,9 +411,7 @@ class BaseStructure(BaseModel): _type_: _description_ """ self.monitor_resources() - return self.run_batched( - batched_data, batch_size, *args, **kwargs - ) + return self.run_batched(batched_data, batch_size, *args, **kwargs) # x = BaseStructure() diff --git a/swarms/structs/base_swarm.py b/swarms/structs/base_swarm.py index 30012d80..95f2b57f 100644 --- a/swarms/structs/base_swarm.py +++ b/swarms/structs/base_swarm.py @@ -361,9 +361,7 @@ class BaseSwarm(ABC): task (Optional[str], optional): _description_. Defaults to None. """ loop = asyncio.get_event_loop() - result = loop.run_until_complete( - self.arun(task, *args, **kwargs) - ) + result = loop.run_until_complete(self.arun(task, *args, **kwargs)) return result def run_batch_async(self, tasks: List[str], *args, **kwargs): @@ -533,9 +531,7 @@ class BaseSwarm(ABC): Agent: Instance of Agent representing the retrieved Agent, or None if not found. """ - def join_swarm( - self, from_entity: Agent | Agent, to_entity: Agent - ): + def join_swarm(self, from_entity: Agent | Agent, to_entity: Agent): """ Add a relationship between a Swarm and an Agent or other Swarm to the registry. diff --git a/swarms/structs/base_workflow.py b/swarms/structs/base_workflow.py index e5ac4811..e2f0b0c7 100644 --- a/swarms/structs/base_workflow.py +++ b/swarms/structs/base_workflow.py @@ -68,9 +68,7 @@ class BaseWorkflow(BaseStructure): elif tasks: self.task_pool.extend(tasks) else: - raise ValueError( - "You must provide a task or a list of tasks" - ) + raise ValueError("You must provide a task or a list of tasks") def add_agent(self, agent: Agent, *args, **kwargs): return self.agent_pool(agent) @@ -122,23 +120,17 @@ class BaseWorkflow(BaseStructure): Dict[str, Any]: The results of each task in the workflow """ try: - return { - task.description: task.result for task in self.tasks - } + return {task.description: task.result for task in self.tasks} except Exception as error: print( - colored( - f"Error getting task results: {error}", "red" - ), + colored(f"Error getting task results: {error}", "red"), ) def remove_task(self, task: str) -> None: """Remove tasks from sequential workflow""" try: self.tasks = [ - task - for task in self.tasks - if task.description != task + task for task in self.tasks if task.description != task ] except Exception as error: print( @@ -177,9 +169,7 @@ class BaseWorkflow(BaseStructure): task.kwargs.update(updates) break else: - raise ValueError( - f"Task {task} not found in workflow." - ) + raise ValueError(f"Task {task} not found in workflow.") except Exception as error: print( colored( @@ -214,9 +204,7 @@ class BaseWorkflow(BaseStructure): self.tasks.remove(task) break else: - raise ValueError( - f"Task {task} not found in workflow." - ) + raise ValueError(f"Task {task} not found in workflow.") except Exception as error: print( colored( @@ -299,9 +287,7 @@ class BaseWorkflow(BaseStructure): ) ) - def load_workflow_state( - self, filepath: str = None, **kwargs - ) -> None: + def load_workflow_state(self, filepath: str = None, **kwargs) -> None: """ Loads the workflow state from a json file and restores the workflow state. diff --git a/swarms/structs/blockslist.py b/swarms/structs/blockslist.py index f1d739be..79dbdb89 100644 --- a/swarms/structs/blockslist.py +++ b/swarms/structs/blockslist.py @@ -102,15 +102,11 @@ class BlocksList(BaseStructure): return [block for block in self.blocks if block.id == id] def get_by_parent(self, parent: str): - return [ - block for block in self.blocks if block.parent == parent - ] + return [block for block in self.blocks if block.parent == parent] def get_by_parent_id(self, parent_id: str): return [ - block - for block in self.blocks - if block.parent_id == parent_id + block for block in self.blocks if block.parent_id == parent_id ] def get_by_parent_name(self, parent_name: str): diff --git a/swarms/structs/company.py b/swarms/structs/company.py index 06b7bdfe..3f304891 100644 --- a/swarms/structs/company.py +++ b/swarms/structs/company.py @@ -16,9 +16,7 @@ class Company: shared_instructions: str = None ceo: Optional[Agent] = None agents: List[Agent] = field(default_factory=list) - agent_interactions: Dict[str, List[str]] = field( - default_factory=dict - ) + agent_interactions: Dict[str, List[str]] = field(default_factory=dict) history: Conversation = field(default_factory=Conversation) def __post_init__(self): @@ -46,9 +44,7 @@ class Company: self.agents.append(agent) except Exception as error: - logger.error( - f"[ERROR][CLASS: Company][METHOD: add] {error}" - ) + logger.error(f"[ERROR][CLASS: Company][METHOD: add] {error}") raise error def get(self, agent_name: str) -> Agent: @@ -73,9 +69,7 @@ class Company: " company." ) except Exception as error: - logger.error( - f"[ERROR][CLASS: Company][METHOD: get] {error}" - ) + logger.error(f"[ERROR][CLASS: Company][METHOD: get] {error}") raise error def remove(self, agent: Agent) -> None: @@ -118,9 +112,7 @@ class Company: elif isinstance(node, list): for agent in node: if not isinstance(agent, Agent): - raise ValueError( - "Invalid agent in org chart" - ) + raise ValueError("Invalid agent in org chart") self.add(agent) for i, agent in enumerate(node): @@ -153,9 +145,7 @@ class Company: """ if agent1.ai_name not in self.agents_interactions: self.agents_interactions[agent1.ai_name] = [] - self.agents_interactions[agent1.ai_name].append( - agent2.ai_name - ) + self.agents_interactions[agent1.ai_name].append(agent2.ai_name) def run(self): """ diff --git a/swarms/structs/concurrent_workflow.py b/swarms/structs/concurrent_workflow.py index eccc5ea5..2893fb9e 100644 --- a/swarms/structs/concurrent_workflow.py +++ b/swarms/structs/concurrent_workflow.py @@ -35,9 +35,7 @@ class ConcurrentWorkflow(BaseStructure): max_loops: int = 1 max_workers: int = 5 autosave: bool = False - saved_state_filepath: Optional[str] = ( - "runs/concurrent_workflow.json" - ) + saved_state_filepath: Optional[str] = "runs/concurrent_workflow.json" print_results: bool = False return_results: bool = False use_processes: bool = False @@ -89,9 +87,7 @@ class ConcurrentWorkflow(BaseStructure): } results = [] - for future in concurrent.futures.as_completed( - futures - ): + for future in concurrent.futures.as_completed(futures): task = futures[future] try: result = future.result() diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 003b023e..2a63ac13 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -339,9 +339,7 @@ class Conversation(BaseStructure): def update_from_database(self, *args, **kwargs): """Update the conversation history from the database""" - self.database.update( - "conversation", self.conversation_history - ) + self.database.update("conversation", self.conversation_history) def get_from_database(self, *args, **kwargs): """Get the conversation history from the database""" diff --git a/swarms/structs/debate.py b/swarms/structs/debate.py index 95c889d3..5a80265a 100644 --- a/swarms/structs/debate.py +++ b/swarms/structs/debate.py @@ -140,9 +140,7 @@ class Debate: self.affirmative.system_prompt( self.save_file["player_meta_prompt"] ) - self.negative.system_prompt( - self.save_file["player_meta_prompt"] - ) + self.negative.system_prompt(self.save_file["player_meta_prompt"]) self.moderator.system_prompt( self.save_file["moderator_meta_prompt"] ) @@ -191,14 +189,10 @@ class Debate: def save_file_to_json(self, id): now = datetime.now() current_time = now.strftime("%Y-%m-%d_%H:%M:%S") - save_file_path = os.path.join( - self.save_file_dir, f"{id}.json" - ) + save_file_path = os.path.join(self.save_file_dir, f"{id}.json") self.save_file["end_time"] = current_time - json_str = json.dumps( - self.save_file, ensure_ascii=False, indent=4 - ) + json_str = json.dumps(self.save_file, ensure_ascii=False, indent=4) with open(save_file_path, "w") as f: f.write(json_str) diff --git a/swarms/structs/graph_workflow.py b/swarms/structs/graph_workflow.py index 69ce0002..8e1320d0 100644 --- a/swarms/structs/graph_workflow.py +++ b/swarms/structs/graph_workflow.py @@ -126,15 +126,11 @@ class GraphWorkflow(BaseStructure): if from_node in self.graph: for condition_value, to_node in edge_dict.items(): if to_node in self.graph: - self.graph[from_node]["edges"][ - to_node - ] = condition + self.graph[from_node]["edges"][to_node] = condition else: raise ValueError("Node does not exist in graph") else: - raise ValueError( - f"Node {from_node} does not exist in graph" - ) + raise ValueError(f"Node {from_node} does not exist in graph") def run(self): """ @@ -160,9 +156,7 @@ class GraphWorkflow(BaseStructure): ValueError: _description_ """ if node_name not in self.graph: - raise ValueError( - f"Node {node_name} does not exist in graph" - ) + raise ValueError(f"Node {node_name} does not exist in graph") def _check_nodes_exist(self, from_node, to_node): """ diff --git a/swarms/structs/groupchat.py b/swarms/structs/groupchat.py index 57cb6472..fb2adc07 100644 --- a/swarms/structs/groupchat.py +++ b/swarms/structs/groupchat.py @@ -51,8 +51,7 @@ class GroupChat: def next_agent(self, agent: Agent) -> Agent: """Return the next agent in the list.""" return self.agents[ - (self.agent_names.index(agent.name) + 1) - % len(self.agents) + (self.agent_names.index(agent.name) + 1) % len(self.agents) ] def select_speaker_msg(self): @@ -122,9 +121,7 @@ class GroupChat: """ formatted_messages = [] for message in messages: - formatted_message = ( - f"'{message['role']}:{message['content']}" - ) + formatted_message = f"'{message['role']}:{message['content']}" formatted_messages.append(formatted_message) return "\n".join(formatted_messages) diff --git a/swarms/structs/majority_voting.py b/swarms/structs/majority_voting.py index 536b0787..af230f64 100644 --- a/swarms/structs/majority_voting.py +++ b/swarms/structs/majority_voting.py @@ -165,9 +165,7 @@ class MajorityVoting: # If autosave is enabled, save the conversation to a file if self.autosave: - create_file( - str(self.conversation), "majority_voting.json" - ) + create_file(str(self.conversation), "majority_voting.json") # Log the agents logger.info("Initializing majority voting system") @@ -224,9 +222,7 @@ class MajorityVoting: # If an output parser is provided, parse the responses if self.output_parser is not None: - majority_vote = self.output_parser( - responses, *args, **kwargs - ) + majority_vote = self.output_parser(responses, *args, **kwargs) else: majority_vote = majority_voting(responses) diff --git a/swarms/structs/message_pool.py b/swarms/structs/message_pool.py index 88766d06..40010bee 100644 --- a/swarms/structs/message_pool.py +++ b/swarms/structs/message_pool.py @@ -98,9 +98,7 @@ class MessagePool: logger.info("MessagePool initialized") logger.info(f"Number of agents: {len(agents)}") - logger.info( - f"Agents: {[agent.agent_name for agent in agents]}" - ) + logger.info(f"Agents: {[agent.agent_name for agent in agents]}") logger.info(f"moderator: {moderator.agent_name} is available") logger.info(f"Number of turns: {turns}") @@ -188,9 +186,7 @@ class MessagePool: """ # Get the messages before the current turn prev_messages = [ - message - for message in self.messages - if message["turn"] < turn + message for message in self.messages if message["turn"] < turn ] visible_messages = [] diff --git a/swarms/structs/model_parallizer.py b/swarms/structs/model_parallizer.py index 9d27f14c..b3c75b09 100644 --- a/swarms/structs/model_parallizer.py +++ b/swarms/structs/model_parallizer.py @@ -136,8 +136,7 @@ class ModelParallelizer: try: with ThreadPoolExecutor() as executor: future_to_llm = { - executor.submit(llm, task): llm - for llm in self.llms + executor.submit(llm, task): llm for llm in self.llms } responses = [] for future in as_completed(future_to_llm): diff --git a/swarms/structs/multi_agent_collab.py b/swarms/structs/multi_agent_collab.py index 8359068d..5010378b 100644 --- a/swarms/structs/multi_agent_collab.py +++ b/swarms/structs/multi_agent_collab.py @@ -123,9 +123,7 @@ class MultiAgentCollaboration: def step(self) -> tuple[str, str]: """Steps through the multi-agent collaboration.""" - speaker_idx = self.select_next_speaker( - self._step, self.agents - ) + speaker_idx = self.select_next_speaker(self._step, self.agents) speaker = self.agents[speaker_idx] message = speaker.send() message = speaker.send() @@ -170,9 +168,7 @@ class MultiAgentCollaboration: bid = self.ask_for_bid(agent) bids.append(bid) max_value = max(bids) - max_indices = [ - i for i, x in enumerate(bids) if x == max_value - ] + max_indices = [i for i, x in enumerate(bids) if x == max_value] idx = random.choice(max_indices) return idx @@ -262,9 +258,7 @@ class MultiAgentCollaboration: for _ in range(self.max_iters): for agent in self.agents: result = agent.run(conversation) - self.results.append( - {"agent": agent, "response": result} - ) + self.results.append({"agent": agent, "response": result}) conversation += result if self.autosave: @@ -317,9 +311,7 @@ class MultiAgentCollaboration: """Tracks and reports the performance of each agent""" performance_data = {} for agent in self.agents: - performance_data[agent.name] = ( - agent.get_performance_metrics() - ) + performance_data[agent.name] = agent.get_performance_metrics() return performance_data def set_interaction_rules(self, rules): diff --git a/swarms/structs/multi_threaded_workflow.py b/swarms/structs/multi_threaded_workflow.py index 475251ba..2617433e 100644 --- a/swarms/structs/multi_threaded_workflow.py +++ b/swarms/structs/multi_threaded_workflow.py @@ -88,9 +88,7 @@ class MultiThreadedWorkflow(BaseWorkflow): """ results = [] - with ThreadPoolExecutor( - max_workers=self.max_workers - ) as executor: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: future_to_task = {} for _ in range(self.tasks_queue.qsize()): priority_task = self.tasks_queue.get_nowait() @@ -127,9 +125,7 @@ class MultiThreadedWorkflow(BaseWorkflow): ) if attempt + 1 < self.retry_attempts: # Retry the task - retry_future = executor.submit( - task.execute - ) + retry_future = executor.submit(task.execute) future_to_task[retry_future] = ( task, attempt + 1, @@ -152,7 +148,5 @@ class MultiThreadedWorkflow(BaseWorkflow): """ with self.lock: - logging.info( - f"Autosaving result for task {task}: {result}" - ) + logging.info(f"Autosaving result for task {task}: {result}") # Actual autosave logic goes here diff --git a/swarms/structs/rearrange.py b/swarms/structs/rearrange.py index 71b77e82..b4614b33 100644 --- a/swarms/structs/rearrange.py +++ b/swarms/structs/rearrange.py @@ -59,9 +59,7 @@ class AgentRearrange: source_name, destinations_str = parts source = self.find_agent_by_name(source_name) if source is None: - logging.error( - f"Source agent {source_name} not found." - ) + logging.error(f"Source agent {source_name} not found.") return False destinations_names = destinations_str.split() @@ -69,13 +67,10 @@ class AgentRearrange: dest = self.find_agent_by_name(dest_name) if dest is None: logging.error( - f"Destination agent {dest_name} not" - " found." + f"Destination agent {dest_name} not" " found." ) return False - self.flows[source.agent_name].append( - dest.agent_name - ) + self.flows[source.agent_name].append(dest.agent_name) return True except Exception as e: logger.error(f"Error: {e}") @@ -124,9 +119,7 @@ class AgentRearrange: task = tasks.get(dest, task) if self.custom_prompt: - dest_agent.run( - f"{task} {self.custom_prompt}" - ) + dest_agent.run(f"{task} {self.custom_prompt}") else: dest_agent.run(f"{task} (from {source})") # else: @@ -136,8 +129,7 @@ class AgentRearrange: # ) except Exception as e: logger.error( - f"Error: {e} try again by providing agents and" - " pattern" + f"Error: {e} try again by providing agents and" " pattern" ) raise e diff --git a/swarms/structs/recursive_workflow.py b/swarms/structs/recursive_workflow.py index cc0d25b5..60c471a5 100644 --- a/swarms/structs/recursive_workflow.py +++ b/swarms/structs/recursive_workflow.py @@ -39,9 +39,7 @@ class RecursiveWorkflow(BaseStructure): self.stopping_conditions = stopping_conditions self.task_pool = [] - assert ( - self.stop_token is not None - ), "stop_token cannot be None" + assert self.stop_token is not None, "stop_token cannot be None" def add(self, task: Task = None, tasks: List[Task] = None): """Adds a task to the workflow. @@ -80,10 +78,7 @@ class RecursiveWorkflow(BaseStructure): for task in self.task_pool: while True: result = task.run() - if ( - result is not None - and self.stop_token in result - ): + if result is not None and self.stop_token in result: break print(f"{result}") except Exception as error: diff --git a/swarms/structs/schemas.py b/swarms/structs/schemas.py index a370334b..c465b4df 100644 --- a/swarms/structs/schemas.py +++ b/swarms/structs/schemas.py @@ -60,8 +60,7 @@ class StepInput(BaseModel): step: Any = Field( ..., description=( - "Input parameters for the task step. Any value is" - " allowed." + "Input parameters for the task step. Any value is" " allowed." ), examples=['{\n"file_to_refactor": "models.py"\n}'], ) @@ -82,9 +81,7 @@ class TaskRequestBody(BaseModel): input: str | None = Field( None, description="Input prompt for the task.", - examples=[ - "Write the words you receive to the file 'output.txt'." - ], + examples=["Write the words you receive to the file 'output.txt'."], ) additional_input: TaskInput | None = None @@ -138,9 +135,7 @@ class Step(StepRequestBody): description="The name of the task step.", examples=["Write to file"], ) - status: Status = Field( - ..., description="The status of the task step." - ) + status: Status = Field(..., description="The status of the task step.") output: str | None = Field( None, description="Output of the task step.", diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index 7c94f426..90082f0e 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -44,9 +44,7 @@ class SequentialWorkflow: task_pool: List[Task] = None max_loops: int = 1 autosave: bool = False - saved_state_filepath: Optional[str] = ( - "sequential_workflow_state.json" - ) + saved_state_filepath: Optional[str] = "sequential_workflow_state.json" restore_state_filepath: Optional[str] = None dashboard: bool = False agents: List[Agent] = None @@ -148,14 +146,11 @@ class SequentialWorkflow: """ try: return { - task.description: task.result - for task in self.task_pool + task.description: task.result for task in self.task_pool } except Exception as error: logger.error( - colored( - f"Error getting task results: {error}", "red" - ), + colored(f"Error getting task results: {error}", "red"), ) def remove_task(self, task: Task) -> None: diff --git a/swarms/structs/swarm_net.py b/swarms/structs/swarm_net.py index b5fe49fd..60e5ec7f 100644 --- a/swarms/structs/swarm_net.py +++ b/swarms/structs/swarm_net.py @@ -132,15 +132,11 @@ class SwarmNetwork(BaseStructure): >>> swarm.add_task("task") """ - self.logger.info( - f"Adding task {task} to queue asynchronously" - ) + self.logger.info(f"Adding task {task} to queue asynchronously") try: # Add task to queue asynchronously with asyncio loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.task_queue.put, task - ) + await loop.run_in_executor(None, self.task_queue.put, task) self.logger.info(f"Task {task} added to queue") except Exception as error: print( @@ -271,9 +267,7 @@ class SwarmNetwork(BaseStructure): try: # Remove agent from pool asynchronously with asyncio loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.remove_agent, agent_id - ) + await loop.run_in_executor(None, self.remove_agent, agent_id) except Exception as error: print(f"Error removing agent from pool: {error}") raise error diff --git a/swarms/structs/swarm_redis_registry.py b/swarms/structs/swarm_redis_registry.py index a15ce55f..45e2c1a9 100644 --- a/swarms/structs/swarm_redis_registry.py +++ b/swarms/structs/swarm_redis_registry.py @@ -84,7 +84,9 @@ class RedisSwarmRegistry(BaseSwarm): query = f""" {match_query} CREATE (a)-[r:joined]->(b) RETURN r - """.replace("\n", "") + """.replace( + "\n", "" + ) self.redis_graph.query(query) @@ -126,9 +128,7 @@ class RedisSwarmRegistry(BaseSwarm): from_node = self._entity_to_node(from_entity) to_node = self._entity_to_node(to_entity) - return self._add_edge( - from_node, to_node, SwarmRelationship.JOINED - ) + return self._add_edge(from_node, to_node, SwarmRelationship.JOINED) def _persist_node(self, node: Node): """ diff --git a/swarms/structs/swarming_architectures.py b/swarms/structs/swarming_architectures.py index 0f0ff885..80866d8f 100644 --- a/swarms/structs/swarming_architectures.py +++ b/swarms/structs/swarming_architectures.py @@ -222,9 +222,7 @@ async def broadcast( await asyncio.gather(*receive_tasks) except Exception as error: - logger.error( - f"[ERROR][CLASS: Agent][METHOD: broadcast] {error}" - ) + logger.error(f"[ERROR][CLASS: Agent][METHOD: broadcast] {error}") raise error @@ -250,7 +248,5 @@ async def one_to_one( try: await receiver.receive_message(sender.ai_name, task) except Exception as error: - logger.error( - f"[ERROR][CLASS: Agent][METHOD: one_to_one] {error}" - ) + logger.error(f"[ERROR][CLASS: Agent][METHOD: one_to_one] {error}") raise error diff --git a/swarms/structs/task.py b/swarms/structs/task.py index 5f25eedb..bcdf42c6 100644 --- a/swarms/structs/task.py +++ b/swarms/structs/task.py @@ -81,9 +81,7 @@ class Task: >>> task.result """ - logger.info( - f"[INFO][Task] Executing task: {self.description}" - ) + logger.info(f"[INFO][Task] Executing task: {self.description}") task = self.description try: if isinstance(self.agent, Agent): @@ -98,9 +96,7 @@ class Task: if self.action is not None: self.action() else: - self.result = self.agent.run( - *self.args, **self.kwargs - ) + self.result = self.agent.run(*self.args, **self.kwargs) self.history.append(self.result) except Exception as error: @@ -228,9 +224,7 @@ class Task: else "" ) - result = ( - task.result if task.result is not None else "" - ) + result = task.result if task.result is not None else "" # Add the context of the task to the conversation new_context.add( @@ -239,9 +233,7 @@ class Task: elif task: description = ( - task.description - if task.description is not None - else "" + task.description if task.description is not None else "" ) result = task.result if task.result is not None else "" new_context.add( diff --git a/swarms/structs/team.py b/swarms/structs/team.py index c3abfe1b..d3ee418d 100644 --- a/swarms/structs/team.py +++ b/swarms/structs/team.py @@ -20,9 +20,7 @@ class Team(BaseModel): config (Optional[Json]): Configuration of the Team. Default is None. """ - tasks: Optional[List[Task]] = Field( - None, description="List of tasks" - ) + tasks: Optional[List[Task]] = Field(None, description="List of tasks") agents: Optional[List[Agent]] = Field( None, description="List of agents in this Team." ) @@ -51,9 +49,7 @@ class Team(BaseModel): if values.get("config"): config = json.loads(values.get("config")) if not config.get("agents") or not config.get("tasks"): - raise ValueError( - "Config should have agents and tasks." - ) + raise ValueError("Config should have agents and tasks.") values["agents"] = [ Agent(**agent) for agent in config["agents"] diff --git a/swarms/structs/utils.py b/swarms/structs/utils.py index dd32b3df..dcebc7ed 100644 --- a/swarms/structs/utils.py +++ b/swarms/structs/utils.py @@ -19,9 +19,7 @@ def parse_tasks( """ tasks = {} for line in task.split("\n"): - if line.startswith("") and line.endwith( - "" - ): + if line.startswith("") and line.endwith(""): agent_id, task = line[10:-11].split("><") tasks[agent_id] = task return tasks @@ -89,9 +87,7 @@ def find_token_in_text(text: str, token: str = "") -> bool: return False -def extract_key_from_json( - json_response: str, key: str -) -> Optional[str]: +def extract_key_from_json(json_response: str, key: str) -> Optional[str]: """ Extract a specific key from a JSON response. @@ -106,9 +102,7 @@ def extract_key_from_json( return response_dict.get(key) -def extract_tokens_from_text( - text: str, tokens: List[str] -) -> List[str]: +def extract_tokens_from_text(text: str, tokens: List[str]) -> List[str]: """ Extract a list of tokens from a text response. diff --git a/swarms/structs/yaml_model.py b/swarms/structs/yaml_model.py index 5ec104ed..5e242867 100644 --- a/swarms/structs/yaml_model.py +++ b/swarms/structs/yaml_model.py @@ -53,8 +53,7 @@ def create_yaml_schema_from_dict( "type": get_type_name(model_field.outer_type_), "default": field_info.default, "description": ( - field_info.description - or "No description provided." + field_info.description or "No description provided." ), } else: @@ -127,7 +126,7 @@ class YamlModel(BaseModel): """ return yaml.safe_dump(self.dict(), sort_keys=False) - def from_yaml(cls, yaml_str: str): + def from_yaml(self, cls, yaml_str: str): """ Create an instance of the class from a YAML string. @@ -146,13 +145,11 @@ class YamlModel(BaseModel): return None @staticmethod - def json_to_yaml(json_str: str): + def json_to_yaml(self, json_str: str): """ Convert a JSON string to a YAML string. """ - data = json.loads( - json_str - ) # Convert JSON string to dictionary + data = json.loads(json_str) # Convert JSON string to dictionary return yaml.dump(data) def save_to_yaml(self, filename: str): @@ -189,7 +186,7 @@ class YamlModel(BaseModel): # return yaml.safe_dump(schema, sort_keys=False) def create_yaml_schema_from_dict( - data: Dict[str, Any], model_class: Type + self, data: Dict[str, Any], model_class: Type ) -> str: """ Generate a YAML schema based on a dictionary and a class (can be a Pydantic model, regular class, or dataclass). @@ -205,3 +202,22 @@ class YamlModel(BaseModel): >>> data = {'name': 'Alice', 'age: 30, 'is_active': True} """ return create_yaml_schema_from_dict(data, model_class) + + def yaml_to_dict(self, yaml_str: str): + """ + Convert a YAML string to a Python dictionary. + """ + return yaml.safe_load(yaml_str) + + def dict_to_yaml(self, data: Dict[str, Any]): + """ + Convert a Python dictionary to a YAML string. + """ + return yaml.safe_dump(data, sort_keys=False) + + +# dict = {'name': 'Alice', 'age': 30, 'is_active': True} + +# # Comvert the dictionary to a YAML schema dict to yaml +# yaml_model = YamlModel().dict_to_yaml(dict) +# print(yaml_model) diff --git a/swarms/telemetry/check_update.py b/swarms/telemetry/check_update.py index 2a5df8a9..f4dd9956 100644 --- a/swarms/telemetry/check_update.py +++ b/swarms/telemetry/check_update.py @@ -37,6 +37,4 @@ def check_for_update(): # Get the current version using pkg_resources current_version = pkg_resources.get_distribution("swarms").version - return version.parse(latest_version) > version.parse( - current_version - ) + return version.parse(latest_version) > version.parse(current_version) diff --git a/swarms/telemetry/sys_info.py b/swarms/telemetry/sys_info.py index 3669fbbd..ae59e792 100644 --- a/swarms/telemetry/sys_info.py +++ b/swarms/telemetry/sys_info.py @@ -31,9 +31,7 @@ def get_swarms_verison(): ) except Exception as e: swarms_verison_cmd = str(e) - swarms_verison_pkg = pkg_resources.get_distribution( - "swarms" - ).version + swarms_verison_pkg = pkg_resources.get_distribution("swarms").version swarms_verison = swarms_verison_cmd, swarms_verison_pkg return swarms_verison diff --git a/swarms/utils/code_interpreter.py b/swarms/tools/code_interpreter.py similarity index 98% rename from swarms/utils/code_interpreter.py rename to swarms/tools/code_interpreter.py index a586a1eb..dba44f6a 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/tools/code_interpreter.py @@ -22,7 +22,7 @@ class SubprocessCodeInterpreter: def __init__( self, - start_cmd: str = "", + start_cmd: str = "python3", debug_mode: bool = False, ): self.process = None @@ -139,8 +139,7 @@ class SubprocessCodeInterpreter: yield {"output": traceback.format_exc()} yield { "output": ( - "Retrying..." - f" ({retry_count}/{max_retries})" + "Retrying..." f" ({retry_count}/{max_retries})" ) } yield {"output": "Restarting process."} diff --git a/swarms/tools/exec_tool.py b/swarms/tools/exec_tool.py index 558cb9b5..5a351a94 100644 --- a/swarms/tools/exec_tool.py +++ b/swarms/tools/exec_tool.py @@ -61,9 +61,7 @@ class AgentOutputParser(BaseAgentOutputParser): return AgentAction( name="ERROR", args={ - "error": ( - f"Could not parse invalid json: {text}" - ) + "error": (f"Could not parse invalid json: {text}") }, ) try: diff --git a/swarms/utils/execution_sandbox.py b/swarms/tools/execution_sandbox.py similarity index 97% rename from swarms/utils/execution_sandbox.py rename to swarms/tools/execution_sandbox.py index af6c3840..8396aba6 100644 --- a/swarms/utils/execution_sandbox.py +++ b/swarms/tools/execution_sandbox.py @@ -45,9 +45,7 @@ async def execute_code_async(code: str) -> Tuple[str, str]: # logging.info("Code executed successfully.") except Exception: error_message = traceback.format_exc() - logging.error( - "Code execution failed. Error: %s", error_message - ) + logging.error("Code execution failed. Error: %s", error_message) # Return the new code and the error message return out, error_message diff --git a/swarms/tools/format_tools.py b/swarms/tools/format_tools.py index ce760d14..13724d3b 100644 --- a/swarms/tools/format_tools.py +++ b/swarms/tools/format_tools.py @@ -124,9 +124,7 @@ class Jsonformer: return float(response) except ValueError: if iterations > 3: - raise ValueError( - "Failed to generate a valid number" - ) + raise ValueError("Failed to generate a valid number") return self.generate_number( temperature=self.temperature * 1.3, @@ -143,9 +141,7 @@ class Jsonformer: return float(response) except ValueError: if iterations > 3: - raise ValueError( - "Failed to generate a valid number" - ) + raise ValueError("Failed to generate a valid number") return self.generate_number( temperature=self.temperature * 1.3, @@ -169,20 +165,14 @@ class Jsonformer: input_tensor = self.tokenizer.encode( prompt, return_tensors="pt" ) - output = self.model.forward( - input_tensor.to(self.model.device) - ) + output = self.model.forward(input_tensor.to(self.model.device)) logits = output.logits[0, -1] # todo: this assumes that "true" and "false" are both tokenized to a single token # this is probably not true for all tokenizers # this can be fixed by looking at only the first token of both "true" and "false" - true_token_id = self.tokenizer.convert_tokens_to_ids( - "true" - ) - false_token_id = self.tokenizer.convert_tokens_to_ids( - "false" - ) + true_token_id = self.tokenizer.convert_tokens_to_ids("true") + false_token_id = self.tokenizer.convert_tokens_to_ids("false") result = logits[true_token_id] > logits[false_token_id] @@ -227,8 +217,7 @@ class Jsonformer: if ( len(response[0]) >= len(input_tokens[0]) and ( - response[0][: len(input_tokens[0])] - == input_tokens + response[0][: len(input_tokens[0])] == input_tokens ).all() ): response = response[0][len(input_tokens[0]) :] @@ -257,8 +246,7 @@ class Jsonformer: if ( len(response[0]) >= len(input_tokens[0]) and ( - response[0][: len(input_tokens[0])] - == input_tokens + response[0][: len(input_tokens[0])] == input_tokens ).all() ): response = response[0][len(input_tokens[0]) :] @@ -320,9 +308,7 @@ class Jsonformer: obj.append(new_obj) return self.generate_object(schema["properties"], new_obj) else: - raise ValueError( - f"Unsupported schema type: {schema_type}" - ) + raise ValueError(f"Unsupported schema type: {schema_type}") def generate_array( self, item_schema: Dict[str, Any], obj: Dict[str, Any] @@ -397,9 +383,7 @@ class Jsonformer: def get_prompt(self): template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}""" progress = json.dumps(self.value) - gen_marker_index = progress.find( - f'"{self.generation_marker}"' - ) + gen_marker_index = progress.find(f'"{self.generation_marker}"') if gen_marker_index != -1: progress = progress[:gen_marker_index] else: diff --git a/swarms/utils/function_calling_utils.py b/swarms/tools/function_calling_utils.py similarity index 94% rename from swarms/utils/function_calling_utils.py rename to swarms/tools/function_calling_utils.py index 72aa487b..1bd29460 100644 --- a/swarms/utils/function_calling_utils.py +++ b/swarms/tools/function_calling_utils.py @@ -5,9 +5,7 @@ import asyncio # Helper function to run an asynchronous function in a synchronous way -def run_async_function_in_sync( - func: Callable, *args, **kwargs -) -> Any: +def run_async_function_in_sync(func: Callable, *args, **kwargs) -> Any: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) coroutine = func(*args, **kwargs) diff --git a/swarms/tools/interpreter.py b/swarms/tools/interpreter.py deleted file mode 100644 index b0661be2..00000000 --- a/swarms/tools/interpreter.py +++ /dev/null @@ -1,40 +0,0 @@ -import io -import sys - -from swarms.utils.loguru_logger import logger - - -def execute_command(code): - """ - Executes Python code and returns the output. - - Args: - code (str): The Python code to execute. - - Returns: - str: The output of the code. - """ - # Create a string buffer to capture stdout and stderr - buffer = io.StringIO() - - # Redirect stdout and stderr to the buffer - sys.stdout = buffer - sys.stderr = buffer - - try: - # Execute the code - exec(code) - except Exception as e: - # Log the error - logger.error(f"Error executing code: {code}\n{str(e)}") - return str(e) - finally: - # Restore stdout and stderr - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Get the output from the buffer - output = buffer.getvalue() - - # Return the output - return output diff --git a/swarms/utils/json_utils.py b/swarms/tools/json_utils.py similarity index 100% rename from swarms/utils/json_utils.py rename to swarms/tools/json_utils.py diff --git a/swarms/tools/logits_processor.py b/swarms/tools/logits_processor.py index f67ff451..256b77ea 100644 --- a/swarms/tools/logits_processor.py +++ b/swarms/tools/logits_processor.py @@ -7,9 +7,7 @@ from transformers import ( class StringStoppingCriteria(StoppingCriteria): - def __init__( - self, tokenizer: PreTrainedTokenizer, prompt_length: int - ): + def __init__(self, tokenizer: PreTrainedTokenizer, prompt_length: int): self.tokenizer = tokenizer self.prompt_length = prompt_length diff --git a/swarms/utils/math_eval.py b/swarms/tools/math_eval.py similarity index 100% rename from swarms/utils/math_eval.py rename to swarms/tools/math_eval.py diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index 329d95ec..50d08bab 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -1,20 +1,16 @@ from swarms.utils.class_args_wrapper import print_class_parameters -from swarms.utils.code_interpreter import SubprocessCodeInterpreter -from swarms.utils.csv_and_pandas import ( - csv_to_dataframe, - dataframe_to_strings, -) +from swarms.tools.code_interpreter import SubprocessCodeInterpreter +# from swarms.utils.csv_and_pandas import ( +# csv_to_dataframe, +# dataframe_to_strings, +# ) from swarms.utils.data_to_text import ( csv_to_text, data_to_text, json_to_text, txt_to_text, ) -from swarms.utils.device_checker_cuda import check_device from swarms.utils.download_img import download_img_from_url -from swarms.utils.download_weights_from_url import ( - download_weights_from_url, -) from swarms.utils.exponential_backoff import ExponentialBackoffMixin from swarms.utils.file_processing import ( load_json, @@ -26,59 +22,45 @@ from swarms.utils.file_processing import ( from swarms.utils.find_img_path import find_image_path from swarms.utils.json_output_parser import JsonOutputParser from swarms.utils.llm_metrics_decorator import metrics_decorator -from swarms.utils.load_model_torch import load_model_torch from swarms.utils.markdown_message import display_markdown_message -from swarms.utils.math_eval import math_eval -from swarms.utils.pandas_to_str import dataframe_to_text +from swarms.tools.math_eval import math_eval +# from swarms.utils.pandas_to_str import dataframe_to_text from swarms.utils.parse_code import extract_code_from_markdown from swarms.utils.pdf_to_text import pdf_to_text -from swarms.utils.prep_torch_model_inference import ( - prep_torch_inference, -) from swarms.utils.remove_json_whitespace import ( remove_whitespace_from_json, remove_whitespace_from_yaml, ) from swarms.utils.save_logs import parse_log_file - -# from swarms.utils.supervision_visualizer import MarkVisualizer from swarms.utils.try_except_wrapper import try_except_wrapper from swarms.utils.yaml_output_parser import YamlOutputParser from swarms.utils.concurrent_utils import execute_concurrently - __all__ = [ - "print_class_parameters", + "download_img_from_url", + "ExponentialBackoffMixin", + "find_image_path", + "JsonOutputParser", + "metrics_decorator", + "display_markdown_message", + "math_eval", + "parse_log_file", "SubprocessCodeInterpreter", - "csv_to_dataframe", - "dataframe_to_strings", + "try_except_wrapper", + "YamlOutputParser", "csv_to_text", "data_to_text", "json_to_text", "txt_to_text", - "check_device", - "download_img_from_url", - "download_weights_from_url", - "ExponentialBackoffMixin", "load_json", "sanitize_file_path", "zip_workspace", "create_file_in_folder", "zip_folders", - "find_image_path", - "JsonOutputParser", - "metrics_decorator", - "load_model_torch", - "display_markdown_message", - "math_eval", - "dataframe_to_text", - "extract_code_from_markdown", - "pdf_to_text", - "prep_torch_inference", "remove_whitespace_from_json", "remove_whitespace_from_yaml", - "parse_log_file", - "try_except_wrapper", - "YamlOutputParser", + "extract_code_from_markdown", + "pdf_to_text", "execute_concurrently", -] + "print_class_parameters", +] \ No newline at end of file diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index 05b25c5c..ca036e99 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -100,9 +100,7 @@ class Action: tool_name: str = "" tool_input: dict = field(default_factory=lambda: {}) - tool_output_status: ToolCallStatus = ( - ToolCallStatus.ToolCallSuccess - ) + tool_output_status: ToolCallStatus = ToolCallStatus.ToolCallSuccess tool_output: str = "" def to_json(self): @@ -124,9 +122,7 @@ class Action: @dataclass class userQuery: task: str - additional_information: List[str] = field( - default_factory=lambda: [] - ) + additional_information: List[str] = field(default_factory=lambda: []) refine_prompt: str = field(default_factory=lambda: "") def print_self(self): diff --git a/swarms/utils/check_function_result.py b/swarms/utils/check_function_result.py index b3c88491..cb39d370 100644 --- a/swarms/utils/check_function_result.py +++ b/swarms/utils/check_function_result.py @@ -98,9 +98,7 @@ def time_limit(seconds: float): signal.setitimer(signal.ITIMER_REAL, 0) -def check_function_result( - python_code: str, timeout: float = 5.0 -) -> Dict: +def check_function_result(python_code: str, timeout: float = 5.0) -> Dict: """ Evaluates the functional correctness of a completion by running the test suite provided in the problem. diff --git a/swarms/utils/concurrent_utils.py b/swarms/utils/concurrent_utils.py index a7bb5fe2..3843c0af 100644 --- a/swarms/utils/concurrent_utils.py +++ b/swarms/utils/concurrent_utils.py @@ -28,9 +28,7 @@ def execute_concurrently(callable_functions, max_workers=5): ) as executor: futures = [] for i, (fn, args, kwargs) in enumerate(callable_functions): - futures.append( - executor.submit(worker, fn, args, kwargs, i) - ) + futures.append(executor.submit(worker, fn, args, kwargs, i)) # Wait for all threads to complete concurrent.futures.wait(futures) diff --git a/swarms/utils/data_to_text.py b/swarms/utils/data_to_text.py index d8d72986..64ab2e0d 100644 --- a/swarms/utils/data_to_text.py +++ b/swarms/utils/data_to_text.py @@ -26,9 +26,7 @@ def txt_to_text(file): def md_to_text(file): if not os.path.exists(file): - raise FileNotFoundError( - f"No such file or directory: '{file}'" - ) + raise FileNotFoundError(f"No such file or directory: '{file}'") with open(file) as file: data = file.read() return data diff --git a/swarms/utils/device_checker_cuda.py b/swarms/utils/device_checker_cuda.py deleted file mode 100644 index 11b4559c..00000000 --- a/swarms/utils/device_checker_cuda.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -from typing import Any, List, Union - -import torch -from torch.cuda import memory_allocated, memory_reserved - - -def check_device( - log_level: Any = logging.INFO, - memory_threshold: float = 0.8, - capability_threshold: float = 3.5, - return_type: str = "list", -) -> Union[torch.device, List[torch.device]]: - """ - Checks for the availability of CUDA and returns the appropriate device(s). - If CUDA is not available, returns a CPU device. - If CUDA is available, returns a list of all available GPU devices. - """ - logging.basicConfig(level=log_level) - - # Check for CUDA availability - try: - if not torch.cuda.is_available(): - logging.info("CUDA is not available. Using CPU...") - return torch.device("cpu") - except Exception as e: - logging.error("Error checking for CUDA availability: ", e) - return torch.device("cpu") - - logging.info("CUDA is available.") - - # Check for multiple GPUs - num_gpus = torch.cuda.device_count() - devices = [] - if num_gpus > 1: - logging.info(f"Multiple GPUs available: {num_gpus}") - devices = [torch.device(f"cuda:{i}") for i in range(num_gpus)] - else: - logging.info("Only one GPU is available.") - devices = [torch.device("cuda")] - - # Check additional properties for each device - for device in devices: - try: - torch.cuda.set_device(device) - capability = torch.cuda.get_device_capability(device) - total_memory = torch.cuda.get_device_properties( - device - ).total_memory - allocated_memory = memory_allocated(device) - reserved_memory = memory_reserved(device) - device_name = torch.cuda.get_device_name(device) - - logging.info( - f"Device: {device}, Name: {device_name}, Compute" - f" Capability: {capability}, Total Memory:" - f" {total_memory}, Allocated Memory:" - f" {allocated_memory}, Reserved Memory:" - f" {reserved_memory}" - ) - except Exception as e: - logging.error( - f"Error retrieving properties for device {device}: ", - e, - ) - - return devices - - -# devices = check_device() -# logging.info(f"Using device(s): {devices}") diff --git a/swarms/utils/download_weights_from_url.py b/swarms/utils/download_weights_from_url.py deleted file mode 100644 index b5fa1633..00000000 --- a/swarms/utils/download_weights_from_url.py +++ /dev/null @@ -1,22 +0,0 @@ -import requests - - -def download_weights_from_url( - url: str, save_path: str = "models/weights.pth" -): - """ - Downloads model weights from the given URL and saves them to the specified path. - - Args: - url (str): The URL from which to download the model weights. - save_path (str, optional): The path where the downloaded weights should be saved. - Defaults to "models/weights.pth". - """ - response = requests.get(url, stream=True) - response.raise_for_status() - - with open(save_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - - print(f"Model weights downloaded and saved to {save_path}") diff --git a/swarms/utils/file_processing.py b/swarms/utils/file_processing.py index 835e8734..dddaea13 100644 --- a/swarms/utils/file_processing.py +++ b/swarms/utils/file_processing.py @@ -15,9 +15,7 @@ def zip_workspace(workspace_path: str, output_filename: str): base_output_path = os.path.join( temp_dir, output_filename.replace(".zip", "") ) - zip_path = shutil.make_archive( - base_output_path, "zip", workspace_path - ) + zip_path = shutil.make_archive(base_output_path, "zip", workspace_path) return zip_path # make_archive already appends .zip @@ -62,9 +60,7 @@ def create_file( return file_path -def create_file_in_folder( - folder_path: str, file_name: str, content: str -): +def create_file_in_folder(folder_path: str, file_name: str, content: str): """ Creates a file in the specified folder with the given file name and content. diff --git a/swarms/utils/find_img_path.py b/swarms/utils/find_img_path.py index 2ca5d082..cecd12dc 100644 --- a/swarms/utils/find_img_path.py +++ b/swarms/utils/find_img_path.py @@ -18,7 +18,5 @@ def find_image_path(text): if match.group() ] matches += [match.replace("\\", "") for match in matches if match] - existing_paths = [ - match for match in matches if os.path.exists(match) - ] + existing_paths = [match for match in matches if os.path.exists(match)] return max(existing_paths, key=len) if existing_paths else None diff --git a/swarms/utils/hash.py b/swarms/utils/hash.py deleted file mode 100644 index 0e82766b..00000000 --- a/swarms/utils/hash.py +++ /dev/null @@ -1,17 +0,0 @@ -import hashlib - -import pandas as pd - - -def dataframe_to_hash(dataframe: pd.DataFrame) -> str: - return hashlib.sha256( - pd.util.hash_pandas_object(dataframe, index=True).values - ).hexdigest() - - -def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str: - m = hashlib.new(hash_algorithm) - - m.update(text.encode()) - - return m.hexdigest() diff --git a/swarms/utils/inference_convert_utils.py b/swarms/utils/inference_convert_utils.py deleted file mode 100644 index 596d222b..00000000 --- a/swarms/utils/inference_convert_utils.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch - - -def continuous_tensor( - inputs: torch.Tensor, seq_length: torch.LongTensor -): - """Convert batched tensor to continuous tensor. - - Args: - inputs (Tensor): batched tensor. - seq_length (Tensor): length of each sequence. - - Return: - Tensor: continuoused tensor. - """ - assert inputs.dim() > 1 - if inputs.size(1) == 1: - return inputs.reshape(1, -1) - - inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)] - - inputs = torch.cat(inputs).unsqueeze(0) - return inputs - - -def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor): - """Convert continuoused tensor to batched tensor. - - Args: - inputs (Tensor): continuoused tensor. - seq_length (Tensor): length of each sequence. - - Return: - Tensor: batched tensor. - """ - from torch.nn.utils.rnn import pad_sequence - - end_loc = seq_length.cumsum(0) - start_loc = end_loc - seq_length - - inputs = [ - inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc) - ] - inputs = pad_sequence(inputs, batch_first=True) - return inputs - - -def page_cache( - paged_cache: torch.Tensor, - batched_cache: torch.Tensor, - cache_length: torch.Tensor, - block_offsets: torch.Tensor, - permute_head: bool = True, -): - """Convert batched cache to paged cache. - - Args: - paged_cache (Tensor): Output paged cache. - batched_cache (Tensor): Input batched cache. - cache_length (Tensor): length of the cache. - block_offsets (Tensor): Offset of each blocks. - """ - assert block_offsets.dim() == 2 - block_size = paged_cache.size(1) - batch_size = batched_cache.size(0) - if permute_head: - batched_cache = batched_cache.permute(0, 2, 1, 3) - - for b_idx in range(batch_size): - cache_len = cache_length[b_idx] - b_cache = batched_cache[b_idx] - block_off = block_offsets[b_idx] - block_off_idx = 0 - for s_start in range(0, cache_len, block_size): - s_end = min(s_start + block_size, cache_len) - s_len = s_end - s_start - b_off = block_off[block_off_idx] - paged_cache[b_off, :s_len] = b_cache[s_start:s_end] - block_off_idx += 1 diff --git a/swarms/utils/json_output_parser.py b/swarms/utils/json_output_parser.py index 4f76c3a5..9f0c9133 100644 --- a/swarms/utils/json_output_parser.py +++ b/swarms/utils/json_output_parser.py @@ -76,9 +76,7 @@ class JsonOutputParser: """ schema = self.pydantic_object.schema() reduced_schema = { - k: v - for k, v in schema.items() - if k not in ["title", "type"] + k: v for k, v in schema.items() if k not in ["title", "type"] } schema_str = json.dumps(reduced_schema, indent=4) diff --git a/swarms/utils/jsonl_utils.py b/swarms/utils/jsonl_utils.py index 95a0d9d6..a4aa06da 100644 --- a/swarms/utils/jsonl_utils.py +++ b/swarms/utils/jsonl_utils.py @@ -34,9 +34,7 @@ def stream_jsonl(filename: str) -> Iterable[Dict]: yield json.loads(line) -def write_jsonl( - filename: str, data: Iterable[Dict], append: bool = False -): +def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): """ Write a list of dictionaries to a JSONL file. diff --git a/swarms/utils/logger.py b/swarms/utils/logger.py index 804a4fb1..af383216 100644 --- a/swarms/utils/logger.py +++ b/swarms/utils/logger.py @@ -31,9 +31,7 @@ def log_wrapper(func): ) try: result = func(*args, **kwargs) - logger.debug( - f"Function {func.__name__} returned {result}" - ) + logger.debug(f"Function {func.__name__} returned {result}") return result except Exception as e: logger.error( @@ -73,9 +71,7 @@ class Logger: task (str): The task associated with the message. message (str): The message to be logged. """ - timestamp = datetime.datetime.now().strftime( - "%d/%m/%y %H:%M:%S" - ) + timestamp = datetime.datetime.now().strftime("%d/%m/%y %H:%M:%S") formatted_message = ( f"[{timestamp}] {level:<8} {task}\n{' ' * 29}{message}" ) diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index 7ec3fcd2..cff8d1ac 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -16,9 +16,7 @@ from swarms.utils.apa import Action, ToolCallStatus # from autogpt.speech import say_text class JsonFileHandler(logging.FileHandler): - def __init__( - self, filename, mode="a", encoding=None, delay=False - ): + def __init__(self, filename, mode="a", encoding=None, delay=False): """ Initializes a new instance of the class with the specified file name, mode, encoding, and delay settings. @@ -88,9 +86,7 @@ class Logger: log_file = "activity.log" error_file = "error.log" - console_formatter = AutoGptFormatter( - "%(title_color)s %(message)s" - ) + console_formatter = AutoGptFormatter("%(title_color)s %(message)s") # Create a handler for console which simulate typing self.typing_console_handler = TypingConsoleHandler() @@ -381,9 +377,7 @@ class TypingConsoleHandler(logging.StreamHandler): " ", transfer_space ) words = msg_transfered.split() - words = [ - word.replace(transfer_enter, "\n") for word in words - ] + words = [word.replace(transfer_enter, "\n") for word in words] words = [ word.replace(transfer_space, " ") for word in words ] @@ -488,12 +482,8 @@ def print_action_base(action: Action): None """ if action.content != "": - logger.typewriter_log( - "content:", Fore.YELLOW, f"{action.content}" - ) - logger.typewriter_log( - "Thought:", Fore.YELLOW, f"{action.thought}" - ) + logger.typewriter_log("content:", Fore.YELLOW, f"{action.content}") + logger.typewriter_log("Thought:", Fore.YELLOW, f"{action.thought}") if len(action.plan) > 0: logger.typewriter_log( "Plan:", @@ -502,9 +492,7 @@ def print_action_base(action: Action): for line in action.plan: line = line.lstrip("- ") logger.typewriter_log("- ", Fore.GREEN, line.strip()) - logger.typewriter_log( - "Criticism:", Fore.YELLOW, f"{action.criticism}" - ) + logger.typewriter_log("Criticism:", Fore.YELLOW, f"{action.criticism}") def print_action_tool(action: Action): @@ -518,21 +506,15 @@ def print_action_tool(action: Action): None """ logger.typewriter_log("Tool:", Fore.BLUE, f"{action.tool_name}") - logger.typewriter_log( - "Tool Input:", Fore.BLUE, f"{action.tool_input}" - ) + logger.typewriter_log("Tool Input:", Fore.BLUE, f"{action.tool_input}") - output = ( - action.tool_output if action.tool_output != "" else "None" - ) + output = action.tool_output if action.tool_output != "" else "None" logger.typewriter_log("Tool Output:", Fore.BLUE, f"{output}") color = Fore.RED if action.tool_output_status == ToolCallStatus.ToolCallSuccess: color = Fore.GREEN - elif ( - action.tool_output_status == ToolCallStatus.InputCannotParsed - ): + elif action.tool_output_status == ToolCallStatus.InputCannotParsed: color = Fore.YELLOW logger.typewriter_log( diff --git a/swarms/utils/main.py b/swarms/utils/main.py deleted file mode 100644 index ffb496d1..00000000 --- a/swarms/utils/main.py +++ /dev/null @@ -1,227 +0,0 @@ -import os -import random -import shutil -import uuid -from abc import ABC, abstractmethod, abstractstaticmethod -from enum import Enum -from pathlib import Path -from typing import Dict - -import numpy as np -import requests - - -def seed_everything(seed): - random.seed(seed) - np.random.seed(seed) - try: - import torch - - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - except BaseException: - pass - return seed - - -def cut_dialogue_history(history_memory, keep_last_n_words=500): - tokens = history_memory.split() - n_tokens = len(tokens) - print(f"history_memory:{history_memory}, n_tokens: {n_tokens}") - if n_tokens < keep_last_n_words: - return history_memory - else: - paragraphs = history_memory.split("\n") - last_n_tokens = n_tokens - while last_n_tokens >= keep_last_n_words: - last_n_tokens = last_n_tokens - len( - paragraphs[0].split(" ") - ) - paragraphs = paragraphs[1:] - return "\n" + "\n".join(paragraphs) - - -def get_new_image_name(org_img_name, func_name="update"): - head_tail = os.path.split(org_img_name) - head = head_tail[0] - tail = head_tail[1] - name_split = tail.split(".")[0].split("_") - this_new_uuid = str(uuid.uuid4())[0:4] - if len(name_split) == 1: - most_org_file_name = name_split[0] - recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, - func_name, - recent_prev_file_name, - most_org_file_name, - ) - else: - assert len(name_split) == 4 - most_org_file_name = name_split[3] - recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.png".format( - this_new_uuid, - func_name, - recent_prev_file_name, - most_org_file_name, - ) - return os.path.join(head, new_file_name) - - -def get_new_dataframe_name(org_img_name, func_name="update"): - head_tail = os.path.split(org_img_name) - head = head_tail[0] - tail = head_tail[1] - name_split = tail.split(".")[0].split("_") - this_new_uuid = str(uuid.uuid4())[0:4] - if len(name_split) == 1: - most_org_file_name = name_split[0] - recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, - func_name, - recent_prev_file_name, - most_org_file_name, - ) - else: - assert len(name_split) == 4 - most_org_file_name = name_split[3] - recent_prev_file_name = name_split[0] - new_file_name = "{}_{}_{}_{}.csv".format( - this_new_uuid, - func_name, - recent_prev_file_name, - most_org_file_name, - ) - return os.path.join(head, new_file_name) - - -STATIC_DIR = "static" - - -class AbstractUploader(ABC): - @abstractmethod - def upload(self, filepath: str) -> str: - pass - - @abstractstaticmethod - def from_settings() -> "AbstractUploader": - pass - - -class FileType(Enum): - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - DATAFRAME = "dataframe" - UNKNOWN = "unknown" - - @staticmethod - def from_filename(url: str) -> "FileType": - filename = url.split("?")[0] - - if filename.endswith(".png") or filename.endswith(".jpg"): - return FileType.IMAGE - elif filename.endswith(".mp3") or filename.endswith(".wav"): - return FileType.AUDIO - elif filename.endswith(".mp4") or filename.endswith(".avi"): - return FileType.VIDEO - elif filename.endswith(".csv"): - return FileType.DATAFRAME - else: - return FileType.UNKNOWN - - @staticmethod - def from_url(url: str) -> "FileType": - return FileType.from_filename(url.split("?")[0]) - - def to_extension(self) -> str: - if self == FileType.IMAGE: - return ".png" - elif self == FileType.AUDIO: - return ".mp3" - elif self == FileType.VIDEO: - return ".mp4" - elif self == FileType.DATAFRAME: - return ".csv" - else: - return ".unknown" - - -class BaseHandler: - def handle(self, filename: str) -> str: - raise NotImplementedError - - -class FileHandler: - def __init__( - self, handlers: Dict[FileType, BaseHandler], path: Path - ): - self.handlers = handlers - self.path = path - - def register( - self, filetype: FileType, handler: BaseHandler - ) -> "FileHandler": - self.handlers[filetype] = handler - return self - - def download(self, url: str) -> str: - filetype = FileType.from_url(url) - data = requests.get(url).content - local_filename = os.path.join( - "file", str(uuid.uuid4())[0:8] + filetype.to_extension() - ) - os.makedirs(os.path.dirname(local_filename), exist_ok=True) - with open(local_filename, "wb") as f: - size = f.write(data) - print(f"Inputs: {url} ({size//1000}MB) => {local_filename}") - return local_filename - - def handle(self, url: str) -> str: - try: - if url.startswith( - os.environ.get("SERVER", "http://localhost:8000") - ): - local_filepath = url[ - len( - os.environ.get( - "SERVER", "http://localhost:8000" - ) - ) - + 1 : - ] - local_filename = ( - Path("file") / local_filepath.split("/")[-1] - ) - src = self.path / local_filepath - dst = ( - self.path - / os.environ.get("PLAYGROUND_DIR", "./playground") - / local_filename - ) - os.makedirs(os.path.dirname(dst), exist_ok=True) - shutil.copy(src, dst) - else: - local_filename = self.download(url) - handler = self.handlers.get(FileType.from_url(url)) - if handler is None: - if FileType.from_url(url) == FileType.IMAGE: - raise Exception( - f"No handler for {FileType.from_url(url)}." - " Please set USE_GPU to True in" - " env/settings.py" - ) - else: - raise Exception( - f"No handler for {FileType.from_url(url)}" - ) - return handler.handle(local_filename) - except Exception as e: - raise e - - -# => base end - -# ===========================> diff --git a/swarms/utils/pdf_to_text.py b/swarms/utils/pdf_to_text.py index 4877f3b1..36dfebcf 100644 --- a/swarms/utils/pdf_to_text.py +++ b/swarms/utils/pdf_to_text.py @@ -36,9 +36,7 @@ def pdf_to_text(pdf_path): return text except FileNotFoundError: - raise FileNotFoundError( - f"The file at {pdf_path} was not found." - ) + raise FileNotFoundError(f"The file at {pdf_path} was not found.") except Exception as e: raise Exception( f"An error occurred while reading the PDF file: {e}" diff --git a/swarms/utils/prep_torch_model_inference.py b/swarms/utils/prep_torch_model_inference.py deleted file mode 100644 index 1b88cab5..00000000 --- a/swarms/utils/prep_torch_model_inference.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from swarms.utils.load_model_torch import load_model_torch - - -def prep_torch_inference( - model_path: str = None, - device: torch.device = None, - *args, - **kwargs, -): - """ - Prepare a Torch model for inference. - - Args: - model_path (str): Path to the model file. - device (torch.device): Device to run the model on. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - torch.nn.Module: The prepared model. - """ - try: - model = load_model_torch(model_path, device) - model.eval() - return model - except Exception as e: - # Add error handling code here - print(f"Error occurred while preparing Torch model: {e}") - return None diff --git a/swarms/utils/save_logs.py b/swarms/utils/save_logs.py index dd8810b1..83357112 100644 --- a/swarms/utils/save_logs.py +++ b/swarms/utils/save_logs.py @@ -19,9 +19,7 @@ def parse_log_file(filename: str): """ # Check if the file exists if not os.path.exists(filename): - raise FileNotFoundError( - f"The file {filename} does not exist." - ) + raise FileNotFoundError(f"The file {filename} does not exist.") log_entries = [] diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index cb0fc791..3692a2bb 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -105,9 +105,7 @@ class Serializable(BaseModel, ABC): # include all secrets, even if not specified in kwargs # as these secrets may be passed as an environment variable instead for key in secrets.keys(): - secret_value = getattr(self, key, None) or lc_kwargs.get( - key - ) + secret_value = getattr(self, key, None) or lc_kwargs.get(key) if secret_value is not None: lc_kwargs.update({key: secret_value}) diff --git a/swarms/utils/supervision_visualizer.py b/swarms/utils/supervision_visualizer.py deleted file mode 100644 index 1515b709..00000000 --- a/swarms/utils/supervision_visualizer.py +++ /dev/null @@ -1,85 +0,0 @@ -import numpy as np -import supervision as sv - - -class MarkVisualizer: - """ - A class for visualizing different marks including bounding boxes, masks, polygons, - and labels. - - Parameters: - line_thickness (int): The thickness of the lines for boxes and polygons. - mask_opacity (float): The opacity level for masks. - text_scale (float): The scale of the text for labels. - """ - - def __init__( - self, - line_thickness: int = 2, - mask_opacity: float = 0.1, - text_scale: float = 0.6, - ) -> None: - self.box_annotator = sv.BoundingBoxAnnotator( - color_lookup=sv.ColorLookup.INDEX, - thickness=line_thickness, - ) - self.mask_annotator = sv.MaskAnnotator( - color_lookup=sv.ColorLookup.INDEX, opacity=mask_opacity - ) - self.polygon_annotator = sv.PolygonAnnotator( - color_lookup=sv.ColorLookup.INDEX, - thickness=line_thickness, - ) - self.label_annotator = sv.LabelAnnotator( - color=sv.Color.black(), - text_color=sv.Color.white(), - color_lookup=sv.ColorLookup.INDEX, - text_position=sv.Position.CENTER_OF_MASS, - text_scale=text_scale, - ) - - def visualize( - self, - image: np.ndarray, - marks: sv.Detections, - with_box: bool = False, - with_mask: bool = False, - with_polygon: bool = True, - with_label: bool = True, - ) -> np.ndarray: - """ - Visualizes annotations on an image. - - This method takes an image and an instance of sv.Detections, and overlays - the specified types of marks (boxes, masks, polygons, labels) on the image. - - Parameters: - image (np.ndarray): The image on which to overlay annotations. - marks (sv.Detections): The detection results containing the annotations. - with_box (bool): Whether to draw bounding boxes. Defaults to False. - with_mask (bool): Whether to overlay masks. Defaults to False. - with_polygon (bool): Whether to draw polygons. Defaults to True. - with_label (bool): Whether to add labels. Defaults to True. - - Returns: - np.ndarray: The annotated image. - """ - annotated_image = image.copy() - if with_box: - annotated_image = self.box_annotator.annotate( - scene=annotated_image, detections=marks - ) - if with_mask: - annotated_image = self.mask_annotator.annotate( - scene=annotated_image, detections=marks - ) - if with_polygon: - annotated_image = self.polygon_annotator.annotate( - scene=annotated_image, detections=marks - ) - if with_label: - labels = list(map(str, range(len(marks)))) - annotated_image = self.label_annotator.annotate( - scene=annotated_image, detections=marks, labels=labels - ) - return annotated_image diff --git a/swarms/utils/torch_utils.py b/swarms/utils/torch_utils.py deleted file mode 100644 index 41d2eb3f..00000000 --- a/swarms/utils/torch_utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - - -def autodetect_device(): - """ - Autodetects the device to use for inference. - - Returns - ------- - str - The device to use for inference. - """ - return "cuda" if torch.cuda.is_available() else "cpu" diff --git a/swarms/utils/yaml_output_parser.py b/swarms/utils/yaml_output_parser.py index 5832bf16..2768a0e7 100644 --- a/swarms/utils/yaml_output_parser.py +++ b/swarms/utils/yaml_output_parser.py @@ -78,9 +78,7 @@ class YamlOutputParser: """ schema = self.pydantic_object.schema() reduced_schema = { - k: v - for k, v in schema.items() - if k not in ["title", "type"] + k: v for k, v in schema.items() if k not in ["title", "type"] } schema_str = json.dumps(reduced_schema, indent=4) diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 691489c0..6e374466 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -20,9 +20,7 @@ def test_tool_agent_init(): name = "Test Agent" description = "This is a test agent" - agent = ToolAgent( - name, description, model, tokenizer, json_schema - ) + agent = ToolAgent(name, description, model, tokenizer, json_schema) assert agent.name == name assert agent.description == description @@ -47,13 +45,10 @@ def test_tool_agent_run(mock_run): name = "Test Agent" description = "This is a test agent" task = ( - "Generate a person's information based on the following" - " schema:" + "Generate a person's information based on the following" " schema:" ) - agent = ToolAgent( - name, description, model, tokenizer, json_schema - ) + agent = ToolAgent(name, description, model, tokenizer, json_schema) agent.run(task) mock_run.assert_called_once_with(task) @@ -96,6 +91,5 @@ def test_tool_agent_init_with_kwargs(): assert agent.max_number_tokens == kwargs["max_number_tokens"] assert agent.temperature == kwargs["temperature"] assert ( - agent.max_string_token_length - == kwargs["max_string_token_length"] + agent.max_string_token_length == kwargs["max_string_token_length"] ) diff --git a/tests/memory/test_dictsharedmemory.py b/tests/memory/test_dictsharedmemory.py index a41ccd8f..9aa6381e 100644 --- a/tests/memory/test_dictsharedmemory.py +++ b/tests/memory/test_dictsharedmemory.py @@ -63,9 +63,7 @@ def test_parametrized_get_top_n( memory_instance, scores, agent_ids, expected_top_score ): for score, agent_id in zip(scores, agent_ids): - memory_instance.add( - score, agent_id, 1, f"Entry by {agent_id}" - ) + memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") top_1 = memory_instance.get_top_n(1) top_score = next(iter(top_1.values()))["score"] assert ( @@ -78,9 +76,7 @@ def test_parametrized_get_top_n( def test_add_entry_invalid_input(memory_instance): with pytest.raises(ValueError): - memory_instance.add( - "invalid_score", "agent123", 1, "Test Entry" - ) + memory_instance.add("invalid_score", "agent123", 1, "Test Entry") # Mocks and monkey-patching diff --git a/tests/memory/test_langchainchromavectormemory.py b/tests/memory/test_langchainchromavectormemory.py index ee882c6c..11c4a5cf 100644 --- a/tests/memory/test_langchainchromavectormemory.py +++ b/tests/memory/test_langchainchromavectormemory.py @@ -42,9 +42,7 @@ def test_initialization_default_settings(vector_memory): def test_add_entry(vector_memory, embeddings_mock): - with patch.object( - vector_memory.db, "add_texts" - ) as add_texts_mock: + with patch.object(vector_memory.db, "add_texts") as add_texts_mock: vector_memory.add("Example text") add_texts_mock.assert_called() @@ -90,7 +88,5 @@ def test_search_memory_different_params( "similarity_search_with_score", return_value=expected, ): - result = vector_memory.search_memory( - query, k=k, type=type - ) + result = vector_memory.search_memory(query, k=k, type=type) assert len(result) == (k if k > 0 else 0) diff --git a/tests/memory/test_qdrant.py b/tests/memory/test_qdrant.py index 5f82814c..b9f4f142 100644 --- a/tests/memory/test_qdrant.py +++ b/tests/memory/test_qdrant.py @@ -29,9 +29,7 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client): assert qdrant_client.client is not None -def test_load_embedding_model( - qdrant_client, mock_sentence_transformer -): +def test_load_embedding_model(qdrant_client, mock_sentence_transformer): qdrant_client._load_embedding_model("model_name") mock_sentence_transformer.assert_called_once_with("model_name") diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 132da5f6..ec7cd0eb 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -71,9 +71,7 @@ def test_search_memory(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.search_memory("Hello") == { - "short_term": [ - (0, {"role": "user", "message": "Hello, world!"}) - ], + "short_term": [(0, {"role": "user", "message": "Hello, world!"})], "medium_term": [], } @@ -114,9 +112,7 @@ def test_thread_safety(): for _ in range(1000): memory.add("user", "Hello, world!") - threads = [ - threading.Thread(target=add_messages) for _ in range(10) - ] + threads = [threading.Thread(target=add_messages) for _ in range(10)] for thread in threads: thread.start() for thread in threads: diff --git a/tests/memory/test_sqlite.py b/tests/memory/test_sqlite.py index 49d61ef7..78900199 100644 --- a/tests/memory/test_sqlite.py +++ b/tests/memory/test_sqlite.py @@ -8,9 +8,7 @@ from swarms.memory.sqlite import SQLiteDB @pytest.fixture def db(): conn = sqlite3.connect(":memory:") - conn.execute( - "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" - ) + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") conn.commit() return SQLiteDB(":memory:") @@ -30,9 +28,7 @@ def test_delete(db): def test_update(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - db.update( - "UPDATE test SET name = ? WHERE name = ?", ("new", "test") - ) + db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) result = db.query("SELECT * FROM test") assert result == [(1, "new")] @@ -101,6 +97,4 @@ def test_query_with_wrong_query(db): def test_execute_query_with_wrong_query(db): with pytest.raises(sqlite3.OperationalError): - db.execute_query( - "SELECT * FROM wrong WHERE name = ?", ("test",) - ) + db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) diff --git a/tests/memory/test_weaviate.py b/tests/memory/test_weaviate.py index d1a69da0..dff41620 100644 --- a/tests/memory/test_weaviate.py +++ b/tests/memory/test_weaviate.py @@ -16,9 +16,7 @@ def weaviate_client_mock(): grpc_port="mock_grpc_port", grpc_secure=False, auth_client_secret="mock_api_key", - additional_headers={ - "X-OpenAI-Api-Key": "mock_openai_api_key" - }, + additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, additional_config=Mock(), ) @@ -74,9 +72,7 @@ def test_update_object(weaviate_client_mock): # Test updating an object object_id = "12345" properties = {"name": "Jane"} - weaviate_client_mock.update( - "test_collection", object_id, properties - ) + weaviate_client_mock.update("test_collection", object_id, properties) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) @@ -143,9 +139,7 @@ def test_create_collection_failure(weaviate_client_mock): "weaviate_client.weaviate.collections.create", side_effect=Exception("Create error"), ): - with pytest.raises( - Exception, match="Error creating collection" - ): + with pytest.raises(Exception, match="Error creating collection"): weaviate_client_mock.create_collection( "test_collection", [{"name": "property"}] ) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index cc48479a..816d47d2 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -11,9 +11,7 @@ class MockAnthropicClient: def __init__(self, *args, **kwargs): pass - def completions_create( - self, prompt, stop_sequences, stream, **kwargs - ): + def completions_create(self, prompt, stop_sequences, stream, **kwargs): return MockAnthropicResponse() @@ -199,9 +197,7 @@ def test_anthropic_wrap_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance): prompt = "What is the meaning of life?" converted_prompt = anthropic_instance.convert_prompt(prompt) - assert converted_prompt.startswith( - anthropic_instance.HUMAN_PROMPT - ) + assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) diff --git a/tests/models/test_biogpt.py b/tests/models/test_biogpt.py index e6093729..d96dbfc0 100644 --- a/tests/models/test_biogpt.py +++ b/tests/models/test_biogpt.py @@ -83,9 +83,7 @@ def test_bioinformatics_response(biogpt_instance): # 44. Test for a neuroscience question def test_neuroscience_response(biogpt_instance): - question = ( - "Explain the function of synapses in the nervous system." - ) + question = "Explain the function of synapses in the nervous system." response = biogpt_instance(question) assert response assert isinstance(response, str) @@ -167,9 +165,7 @@ def test_get_config_return_type(biogpt_instance): # 28. Test saving model functionality by checking if files are created @patch.object(BioGptForCausalLM, "save_pretrained") @patch.object(BioGptTokenizer, "save_pretrained") -def test_save_model( - mock_save_model, mock_save_tokenizer, biogpt_instance -): +def test_save_model(mock_save_model, mock_save_tokenizer, biogpt_instance): path = "test_path" biogpt_instance.save_model(path) mock_save_model.assert_called_once_with(path) @@ -198,9 +194,7 @@ def test_print_model_metadata(biogpt_instance): # 31. Test that beam_search_decoding uses the correct number of beams @patch.object(BioGptForCausalLM, "generate") -def test_beam_search_decoding_num_beams( - mock_generate, biogpt_instance -): +def test_beam_search_decoding_num_beams(mock_generate, biogpt_instance): biogpt_instance.beam_search_decoding("test_sentence", num_beams=7) _, kwargs = mock_generate.call_args assert kwargs["num_beams"] == 7 diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index 8a1147d3..354f047d 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -42,9 +42,7 @@ def test_cohere_async_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - cohere_instance.async_call( - "Error handling with invalid API key." - ) + cohere_instance.async_call("Error handling with invalid API key.") def test_cohere_stream_api_error_handling(cohere_instance): @@ -174,12 +172,8 @@ def test_base_cohere_validate_environment_without_cohere(): # Test cases for benchmarking generations with various models def test_cohere_generate_with_command_light(cohere_instance): cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text with Command Light model." - ) - assert response.startswith( - "Generated text with Command Light model" - ) + response = cohere_instance("Generate text with Command Light model.") + assert response.startswith("Generated text with Command Light model") def test_cohere_generate_with_command(cohere_instance): @@ -329,9 +323,7 @@ def test_cohere_call_with_long_prompt(cohere_instance): def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) + prompt = "This is a test prompt that will exceed the max tokens limit." with pytest.raises(ValueError): cohere_instance(prompt) @@ -512,9 +504,7 @@ def test_cohere_representation_model_max_tokens_limit_exceeded( # Test handling max tokens limit exceeded error cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) + prompt = "This is a test prompt that will exceed the max tokens limit." with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -527,9 +517,7 @@ def test_cohere_representation_model_multilingual_embedding( ): # Test using the Representation model for multilingual text embedding cohere_instance.model = "embed-multilingual-v3.0" - embedding = cohere_instance.embed( - "Generate multilingual embeddings." - ) + embedding = cohere_instance.embed("Generate multilingual embeddings.") assert isinstance(embedding, list) assert len(embedding) > 0 @@ -625,18 +613,14 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede def test_cohere_command_light_model(cohere_instance): # Test using the Command Light model for text generation cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text using Command Light model." - ) + response = cohere_instance("Generate text using Command Light model.") assert isinstance(response, str) def test_cohere_base_light_model(cohere_instance): # Test using the Base Light model for text generation cohere_instance.model = "base-light" - response = cohere_instance( - "Generate text using Base Light model." - ) + response = cohere_instance("Generate text using Base Light model.") assert isinstance(response, str) @@ -662,9 +646,7 @@ def test_cohere_representation_model_english_classification( ): # Test using the Representation model for English text classification cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify( - "Classify English text." - ) + classification = cohere_instance.classify("Classify English text.") assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -700,9 +682,7 @@ def test_cohere_representation_model_english_light_embedding( ): # Test using the Representation model for English light text embedding cohere_instance.model = "embed-english-light-v3.0" - embedding = cohere_instance.embed( - "Generate English light embeddings." - ) + embedding = cohere_instance.embed("Generate English light embeddings.") assert isinstance(embedding, list) assert len(embedding) > 0 @@ -748,9 +728,7 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( def test_cohere_command_model(cohere_instance): # Test using the Command model for text generation cohere_instance.model = "command" - response = cohere_instance( - "Generate text using the Command model." - ) + response = cohere_instance("Generate text using the Command model.") assert isinstance(response, str) diff --git a/tests/models/test_elevenlab.py b/tests/models/test_elevenlab.py index da41ca53..c209fd9d 100644 --- a/tests/models/test_elevenlab.py +++ b/tests/models/test_elevenlab.py @@ -30,16 +30,12 @@ def test_run_text_to_speech(eleven_labs_tool): def test_play_speech(eleven_labs_tool): - with patch( - "builtins.open", mock_open(read_data="fake_audio_data") - ): + with patch("builtins.open", mock_open(read_data="fake_audio_data")): eleven_labs_tool.play(EXPECTED_SPEECH_FILE) def test_stream_speech(eleven_labs_tool): - with patch( - "tempfile.NamedTemporaryFile", mock_open() - ) as mock_file: + with patch("tempfile.NamedTemporaryFile", mock_open()) as mock_file: eleven_labs_tool.stream_speech(SAMPLE_TEXT) mock_file.assert_called_with( mode="bx", suffix=".wav", delete=False @@ -52,9 +48,7 @@ def test_api_key_validation(eleven_labs_tool): "langchain.utils.get_from_dict_or_env", return_value=API_KEY ): values = {"eleven_api_key": None} - validated_values = eleven_labs_tool.validate_environment( - values - ) + validated_values = eleven_labs_tool.validate_environment(values) assert "eleven_api_key" in validated_values @@ -66,9 +60,7 @@ def test_run_text_to_speech_with_mock(eleven_labs_tool): "your_module._import_elevenlabs" ) as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.return_value = ( - b"fake_audio_data" - ) + mock_elevenlabs_instance.generate.return_value = b"fake_audio_data" eleven_labs_tool.run(SAMPLE_TEXT) assert mock_file.call_args[1]["suffix"] == ".wav" assert mock_file.call_args[1]["delete"] is False @@ -97,9 +89,7 @@ def test_run_text_to_speech_error_handling(eleven_labs_tool): "model", [ElevenLabsModel.MULTI_LINGUAL, ElevenLabsModel.MONO_LINGUAL], ) -def test_run_text_to_speech_with_different_models( - eleven_labs_tool, model -): +def test_run_text_to_speech_with_different_models(eleven_labs_tool, model): eleven_labs_tool.model = model speech_file = eleven_labs_tool.run(SAMPLE_TEXT) assert isinstance(speech_file, str) diff --git a/tests/models/test_fire_function_caller.py b/tests/models/test_fire_function_caller.py index 082d954d..5e859272 100644 --- a/tests/models/test_fire_function_caller.py +++ b/tests/models/test_fire_function_caller.py @@ -39,6 +39,4 @@ def test_fire_function_caller_run(mocker): tokenizer.batch_decode.assert_called_once_with(generated_ids) # Assert the decoded output is printed - assert decoded_output in mocker.patch.object( - print, "call_args_list" - ) + assert decoded_output in mocker.patch.object(print, "call_args_list") diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py index e76e11bb..0f880342 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/test_fuyu.py @@ -38,9 +38,7 @@ def fuyu_instance(): # Test using the fixture. def test_fuyu_processor_initialization(fuyu_instance): assert isinstance(fuyu_instance.processor, FuyuProcessor) - assert isinstance( - fuyu_instance.image_processor, FuyuImageProcessor - ) + assert isinstance(fuyu_instance.image_processor, FuyuImageProcessor) # Test exception when providing an invalid image path. @@ -83,9 +81,7 @@ def test_processor_has_image_processor_and_tokenizer(fuyu_instance): fuyu_instance.processor.image_processor == fuyu_instance.image_processor ) - assert ( - fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer - ) + assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer def test_model_device_map(fuyu_instance): @@ -186,9 +182,7 @@ def test_run_invalid_image_path(fuyu_instance): with patch.object(fuyu_instance, "get_img") as mock_get_img: mock_get_img.side_effect = FileNotFoundError with pytest.raises(FileNotFoundError): - fuyu_instance.run( - "Hello, world!", "invalid/path/to/image.png" - ) + fuyu_instance.run("Hello, world!", "invalid/path/to/image.png") # Test `__init__` method with default parameters diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a61d1676..db5c58f9 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -24,9 +24,7 @@ def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model): assert model.model is mock_genai_model -def test_gemini_init_custom_params( - mock_gemini_api_key, mock_genai_model -): +def test_gemini_init_custom_params(mock_gemini_api_key, mock_genai_model): model = Gemini( model_name="custom-model", gemini_api_key="custom-api-key" ) @@ -101,9 +99,7 @@ def test_gemini_process_img(mock_gemini_api_key, mock_genai_model): processed_img = model.process_img(img) - assert processed_img == [ - {"mime_type": "image/png", "data": img_data} - ] + assert processed_img == [{"mime_type": "image/png", "data": img_data}] open_mock.assert_called_with(img, "rb") @@ -117,9 +113,7 @@ def test_gemini_init_missing_api_key(): # Test Gemini initialization with missing model name def test_gemini_init_missing_model_name(): - with pytest.raises( - ValueError, match="Please provide a model name" - ): + with pytest.raises(ValueError, match="Please provide a model name"): Gemini(model_name=None) @@ -158,9 +152,7 @@ def test_gemini_process_img_missing_image_type( ): model = Gemini() img = "cat.png" - with pytest.raises( - ValueError, match="Please provide the image type" - ): + with pytest.raises(ValueError, match="Please provide the image type"): model.process_img(img=img, type=None) diff --git a/tests/models/test_gigabind.py b/tests/models/test_gigabind.py index 3aae0739..cf54604e 100644 --- a/tests/models/test_gigabind.py +++ b/tests/models/test_gigabind.py @@ -11,9 +11,7 @@ except ImportError: @pytest.fixture def api(): - return Gigabind( - host="localhost", port=8000, endpoint="embeddings" - ) + return Gigabind(host="localhost", port=8000, endpoint="embeddings") @pytest.fixture @@ -93,9 +91,7 @@ def test_proxy_url(api): def test_invalid_response(api, requests_mock): - requests_mock.post( - "http://localhost:8000/embeddings", text="not json" - ) + requests_mock.post("http://localhost:8000/embeddings", text="not json") response = api.run(text="Hello, world!") assert response is None @@ -110,9 +106,7 @@ def test_connection_error(api, requests_mock): def test_http_error(api, requests_mock): - requests_mock.post( - "http://localhost:8000/embeddings", status_code=500 - ) + requests_mock.post("http://localhost:8000/embeddings", status_code=500) response = api.run(text="Hello, world!") assert response is None diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index ac797280..99a61621 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -95,9 +95,7 @@ def test_initialization_with_custom_key(): def test_run_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): + with patch("requests.post", side_effect=Exception("Test Exception")): with pytest.raises(Exception): gpt_api.run(task, img_url) @@ -105,14 +103,10 @@ def test_run_with_exception(gpt_api): def test_call_method_successful_response(gpt_api): task = "What is in the image?" img_url = img - response_json = { - "choices": [{"text": "Answer from GPT-4 Vision"}] - } + response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]} mock_response = Mock() mock_response.json.return_value = response_json - with patch( - "requests.post", return_value=mock_response - ) as mock_post: + with patch("requests.post", return_value=mock_response) as mock_post: result = gpt_api(task, img_url) mock_post.assert_called_once() assert result == response_json @@ -121,9 +115,7 @@ def test_call_method_successful_response(gpt_api): def test_call_method_with_exception(gpt_api): task = "What is in the image?" img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): + with patch("requests.post", side_effect=Exception("Test Exception")): with pytest.raises(Exception): gpt_api(task, img_url) @@ -193,9 +185,7 @@ async def test_arun_json_decode_error(vision_api): with patch( "aiohttp.ClientSession.post", new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(side_effect=ValueError) - ), + return_value=AsyncMock(json=AsyncMock(side_effect=ValueError)), ): with pytest.raises(ValueError): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index cbbba940..169dff1a 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -133,9 +133,7 @@ def test_llm_set_repitition_penalty(llm_instance): def test_llm_set_no_repeat_ngram_size(llm_instance): new_no_repeat_ngram_size = 6 llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) - assert ( - llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size - ) + assert llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size # Test for setting temperature @@ -185,9 +183,7 @@ def test_llm_set_model_id(llm_instance): # Test for setting model -@patch( - "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" -) +@patch("swarms.models.huggingface.AutoModelForCausalLM.from_pretrained") def test_llm_set_model(mock_model, llm_instance): mock_model.return_value = "mocked model" llm_instance.set_model(mock_model) diff --git a/tests/models/test_hf_pipeline.py b/tests/models/test_hf_pipeline.py index 8580dd56..17d244f7 100644 --- a/tests/models/test_hf_pipeline.py +++ b/tests/models/test_hf_pipeline.py @@ -22,11 +22,7 @@ def pipeline(mock_pipeline): def test_init(pipeline, mock_pipeline): assert pipeline.task_type == "text-generation" assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" - assert ( - pipeline.use_fp8 is True - if torch.cuda.is_available() - else False - ) + assert pipeline.use_fp8 is True if torch.cuda.is_available() else False mock_pipeline.assert_called_once_with( "text-generation", "meta-llama/Llama-2-13b-chat-hf", diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 7e19a056..9a892a4e 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -19,8 +19,7 @@ def llm_instance(): # Test for instantiation and attributes def test_llm_initialization(llm_instance): assert ( - llm_instance.model_id - == "NousResearch/Nous-Hermes-2-Vision-Alpha" + llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha" ) assert llm_instance.max_length == 500 # ... add more assertions for all default attributes @@ -88,9 +87,7 @@ def test_llm_memory_consumption(llm_instance): ) def test_llm_initialization_params(model_id, max_length): if max_length: - instance = HuggingfaceLLM( - model_id=model_id, max_length=max_length - ) + instance = HuggingfaceLLM(model_id=model_id, max_length=max_length) assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) @@ -197,9 +194,7 @@ def test_llm_run_model_exception(mock_generate, llm_instance): # Test the behavior when GPU is forced but not available @patch("torch.cuda.is_available", return_value=False) -def test_llm_force_gpu_when_unavailable( - mock_is_available, llm_instance -): +def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): with pytest.raises(EnvironmentError): llm_instance.set_device( "cuda" diff --git a/tests/models/test_idefics.py b/tests/models/test_idefics.py index 3bfee679..e94e397a 100644 --- a/tests/models/test_idefics.py +++ b/tests/models/test_idefics.py @@ -85,9 +85,7 @@ def test_call(idefics_instance): def test_chat(idefics_instance): user_input = "User: Hello" response = "Model: Hi there!" - with patch.object( - idefics_instance, "run", return_value=[response] - ): + with patch.object(idefics_instance, "run", return_value=[response]): result = idefics_instance.chat(user_input) assert result == response @@ -163,9 +161,7 @@ def test_run_batched_mode_false(idefics_instance): # Test `run` method with an exception def test_run_with_exception(idefics_instance): task = "User: Test" - with patch.object( - idefics_instance, "processor" - ) as mock_processor: + with patch.object(idefics_instance, "processor") as mock_processor: mock_processor.side_effect = Exception("Test exception") with pytest.raises(Exception): idefics_instance.run(task) diff --git a/tests/models/test_kosmos.py b/tests/models/test_kosmos.py index 1219f895..44cad400 100644 --- a/tests/models/test_kosmos.py +++ b/tests/models/test_kosmos.py @@ -16,9 +16,7 @@ def mock_image_request(): img_data = open(TEST_IMAGE_URL, "rb").read() mock_resp = Mock() mock_resp.raw = img_data - with patch.object( - requests, "get", return_value=mock_resp - ) as _fixture: + with patch.object(requests, "get", return_value=mock_resp) as _fixture: yield _fixture @@ -132,9 +130,7 @@ def test_referring_expression_comprehension(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_generation(kosmos): - kosmos.referring_expression_generation( - "It is on the table.", IMG_URL3 - ) + kosmos.referring_expression_generation("It is on the table.", IMG_URL3) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/test_llama_function_caller.py b/tests/models/test_llama_function_caller.py index 1e9df654..b6086c7a 100644 --- a/tests/models/test_llama_function_caller.py +++ b/tests/models/test_llama_function_caller.py @@ -86,9 +86,7 @@ def test_llama_custom_function_invalid_arguments(llama_caller): ) with pytest.raises(TypeError): - llama_caller.call_function( - "sample_function", arg1="arg1_value" - ) + llama_caller.call_function("sample_function", arg1="arg1_value") # Test streaming with custom runtime diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py index a68a9026..3a47c87c 100644 --- a/tests/models/test_mixtral.py +++ b/tests/models/test_mixtral.py @@ -21,9 +21,7 @@ def test_mixtral_run(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = ( - mock_tokenizer_instance - ) + mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.return_value = [101, 102, 103] @@ -45,9 +43,7 @@ def test_mixtral_run_error(mock_model, mock_tokenizer): mixtral = Mixtral() mock_tokenizer_instance = MagicMock() mock_model_instance = MagicMock() - mock_tokenizer.from_pretrained.return_value = ( - mock_tokenizer_instance - ) + mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance mock_model.from_pretrained.return_value = mock_model_instance mock_tokenizer_instance.return_tensors = "pt" mock_model_instance.generate.side_effect = Exception("Test error") diff --git a/tests/models/test_mpt7b.py b/tests/models/test_mpt7b.py index 92b6c254..0db9784a 100644 --- a/tests/models/test_mpt7b.py +++ b/tests/models/test_mpt7b.py @@ -30,9 +30,7 @@ def test_mpt7b_run(): ) assert isinstance(output, str) - assert output.startswith( - "Once upon a time in a land far, far away..." - ) + assert output.startswith("Once upon a time in a land far, far away...") def test_mpt7b_run_invalid_task(): @@ -55,14 +53,10 @@ def test_mpt7b_generate(): "EleutherAI/gpt-neox-20b", max_tokens=150, ) - output = mpt.generate( - "Once upon a time in a land far, far away..." - ) + output = mpt.generate("Once upon a time in a land far, far away...") assert isinstance(output, str) - assert output.startswith( - "Once upon a time in a land far, far away..." - ) + assert output.startswith("Once upon a time in a land far, far away...") def test_mpt7b_batch_generate(): diff --git a/tests/models/test_nougat.py b/tests/models/test_nougat.py index 858845a6..3c520510 100644 --- a/tests/models/test_nougat.py +++ b/tests/models/test_nougat.py @@ -74,9 +74,7 @@ def test_get_image_invalid_path(setup_nougat): (10, 50), ], ) -def test_model_call_with_diff_params( - setup_nougat, min_len, max_tokens -): +def test_model_call_with_diff_params(setup_nougat, min_len, max_tokens): setup_nougat.min_length = min_len setup_nougat.max_new_tokens = max_tokens diff --git a/tests/models/test_speech_t5.py b/tests/models/test_speech_t5.py index d32c21db..d9ed2a03 100644 --- a/tests/models/test_speech_t5.py +++ b/tests/models/test_speech_t5.py @@ -20,9 +20,7 @@ def test_speecht5_init(speecht5_model): speecht5_model.processor, SpeechT5.processor.__class__ ) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) - assert isinstance( - speecht5_model.vocoder, SpeechT5.vocoder.__class__ - ) + assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) assert isinstance( speecht5_model.embeddings_dataset, torch.utils.data.Dataset ) @@ -49,10 +47,7 @@ def test_speecht5_set_model(speecht5_model): speecht5_model.set_model(new_model_name) assert speecht5_model.model_name == new_model_name assert speecht5_model.processor.model_name == new_model_name - assert ( - speecht5_model.model.config.model_name_or_path - == new_model_name - ) + assert speecht5_model.model.config.model_name_or_path == new_model_name speecht5_model.set_model(old_model_name) # Restore original model diff --git a/tests/models/test_timm.py b/tests/models/test_timm.py index 4af689e5..fb37d5ca 100644 --- a/tests/models/test_timm.py +++ b/tests/models/test_timm.py @@ -19,9 +19,7 @@ def test_timm_model_init(): def test_timm_model_call(): - with patch( - "swarms.models.timm.create_model" - ) as mock_create_model: + with patch("swarms.models.timm.create_model") as mock_create_model: model_name = "resnet18" pretrained = True in_chans = 3 diff --git a/tests/models/test_timm_model.py b/tests/models/test_timm_model.py index b2f8f6c9..c4c42e0f 100644 --- a/tests/models/test_timm_model.py +++ b/tests/models/test_timm_model.py @@ -22,9 +22,7 @@ def test_create_model(sample_model_info): def test_call(sample_model_info): model_handler = TimmModel() input_tensor = torch.randn(1, 3, 224, 224) - output_shape = model_handler.__call__( - sample_model_info, input_tensor - ) + output_shape = model_handler.__call__(sample_model_info, input_tensor) assert isinstance(output_shape, torch.Size) diff --git a/tests/models/test_vilt.py b/tests/models/test_vilt.py index d849f98e..a8b2d092 100644 --- a/tests/models/test_vilt.py +++ b/tests/models/test_vilt.py @@ -29,9 +29,7 @@ def test_vilt_prediction( mock_requests_get.return_value.raw = Mock() # It's a mock response, so no real answer expected - with pytest.raises( - Exception - ): # Ensure exception is more specific + with pytest.raises(Exception): # Ensure exception is more specific vilt_instance( "What is this image", "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", diff --git a/tests/models/test_yi_200k.py b/tests/models/test_yi_200k.py index b31daa3e..155e0319 100644 --- a/tests/models/test_yi_200k.py +++ b/tests/models/test_yi_200k.py @@ -32,9 +32,7 @@ def test_yi34b_generate_text_with_length(yi34b_model, max_length): @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_yi34b_generate_text_with_temperature( - yi34b_model, temperature -): +def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, temperature=temperature) assert isinstance(generated_text, str) diff --git a/tests/models/test_zeroscope.py b/tests/models/test_zeroscope.py index 25a4c597..372c4bd2 100644 --- a/tests/models/test_zeroscope.py +++ b/tests/models/test_zeroscope.py @@ -25,9 +25,7 @@ def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) + mock_pipeline.from_pretrained.return_value = mock_pipeline_instance mock_pipeline_instance.return_value = MagicMock( frames="Generated frames" ) @@ -51,9 +49,7 @@ def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) + mock_pipeline.from_pretrained.return_value = mock_pipeline_instance mock_pipeline_instance.return_value = MagicMock( frames="Generated frames" ) @@ -67,9 +63,7 @@ def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) + mock_pipeline.from_pretrained.return_value = mock_pipeline_instance mock_pipeline_instance.return_value = MagicMock( frames="Generated frames" ) @@ -89,9 +83,7 @@ def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) + mock_pipeline.from_pretrained.return_value = mock_pipeline_instance mock_pipeline_instance.return_value = MagicMock( frames="Generated frames" ) @@ -105,9 +97,7 @@ def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline): zeroscope = ZeroscopeTTV() mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) + mock_pipeline.from_pretrained.return_value = mock_pipeline_instance mock_pipeline_instance.return_value = MagicMock( frames="Generated frames" ) diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index 5be7f31a..f29c40fd 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -71,9 +71,7 @@ def test_run_without_stopping_condition(mocked_sleep, basic_flow): @patch("time.sleep", return_value=None) # to speed up tests -def test_run_with_stopping_condition( - mocked_sleep, flow_with_condition -): +def test_run_with_stopping_condition(mocked_sleep, flow_with_condition): response = flow_with_condition.run("Stop") assert response == "Stop" @@ -252,9 +250,7 @@ def test_different_retry_intervals(mocked_sleep, basic_flow): # Test invoking the agent with additional kwargs @patch("time.sleep", return_value=None) def test_flow_call_with_kwargs(mocked_sleep, basic_flow): - response = basic_flow( - "Test call", param1="value1", param2="value2" - ) + response = basic_flow("Test call", param1="value1", param2="value2") assert response == "Test call" @@ -402,9 +398,7 @@ def test_flow_response_length(flow_instance): "Generate a 10,000 word long blog on mental clarity and the" " benefits of meditation." ) - assert ( - len(response) > flow_instance.get_response_length_threshold() - ) + assert len(response) > flow_instance.get_response_length_threshold() def test_flow_set_response_length_threshold(flow_instance): @@ -493,9 +487,7 @@ def test_flow_get_conversation_log(flow_instance): flow_instance.run("Message 1") flow_instance.run("Message 2") conversation_log = flow_instance.get_conversation_log() - assert ( - len(conversation_log) == 4 - ) # Including system and user messages + assert len(conversation_log) == 4 # Including system and user messages def test_flow_clear_conversation_log(flow_instance): @@ -579,20 +571,14 @@ def test_flow_rollback(flow_instance): flow_instance.change_prompt("New prompt") flow_instance.get_state() flow_instance.rollback_to_state(state1) - assert ( - flow_instance.get_current_prompt() == state1["current_prompt"] - ) + assert flow_instance.get_current_prompt() == state1["current_prompt"] assert flow_instance.get_instructions() == state1["instructions"] + assert flow_instance.get_user_messages() == state1["user_messages"] assert ( - flow_instance.get_user_messages() == state1["user_messages"] - ) - assert ( - flow_instance.get_response_history() - == state1["response_history"] + flow_instance.get_response_history() == state1["response_history"] ) assert ( - flow_instance.get_conversation_log() - == state1["conversation_log"] + flow_instance.get_conversation_log() == state1["conversation_log"] ) assert ( flow_instance.is_dynamic_pacing_enabled() @@ -603,13 +589,10 @@ def test_flow_rollback(flow_instance): == state1["response_length_threshold"] ) assert ( - flow_instance.get_response_filters() - == state1["response_filters"] + flow_instance.get_response_filters() == state1["response_filters"] ) assert flow_instance.get_max_loops() == state1["max_loops"] - assert ( - flow_instance.get_autosave_path() == state1["autosave_path"] - ) + assert flow_instance.get_autosave_path() == state1["autosave_path"] assert flow_instance.get_state() == state1 @@ -627,13 +610,9 @@ def test_flow_contextual_intent(flow_instance): def test_flow_contextual_intent_override(flow_instance): # Test contextual intent override flow_instance.add_context("location", "New York") - response1 = flow_instance.run( - "What's the weather like in {location}?" - ) + response1 = flow_instance.run("What's the weather like in {location}?") flow_instance.add_context("location", "Los Angeles") - response2 = flow_instance.run( - "What's the weather like in {location}?" - ) + response2 = flow_instance.run("What's the weather like in {location}?") assert "New York" in response1 assert "Los Angeles" in response2 @@ -641,13 +620,9 @@ def test_flow_contextual_intent_override(flow_instance): def test_flow_contextual_intent_reset(flow_instance): # Test resetting contextual intent flow_instance.add_context("location", "New York") - response1 = flow_instance.run( - "What's the weather like in {location}?" - ) + response1 = flow_instance.run("What's the weather like in {location}?") flow_instance.reset_context() - response2 = flow_instance.run( - "What's the weather like in {location}?" - ) + response2 = flow_instance.run("What's the weather like in {location}?") assert "New York" in response1 assert "New York" in response2 @@ -672,9 +647,7 @@ def test_flow_non_interruptible(flow_instance): def test_flow_timeout(flow_instance): # Test conversation timeout flow_instance.timeout = 60 # Set a timeout of 60 seconds - response = flow_instance.run( - "This should take some time to respond." - ) + response = flow_instance.run("This should take some time to respond.") assert "Timed out" in response assert flow_instance.is_timed_out() is True @@ -723,20 +696,14 @@ def test_flow_save_and_load_conversation(flow_instance): def test_flow_inject_custom_system_message(flow_instance): # Test injecting a custom system message into the conversation - flow_instance.inject_custom_system_message( - "Custom system message" - ) - assert ( - "Custom system message" in flow_instance.get_message_history() - ) + flow_instance.inject_custom_system_message("Custom system message") + assert "Custom system message" in flow_instance.get_message_history() def test_flow_inject_custom_user_message(flow_instance): # Test injecting a custom user message into the conversation flow_instance.inject_custom_user_message("Custom user message") - assert ( - "Custom user message" in flow_instance.get_message_history() - ) + assert "Custom user message" in flow_instance.get_message_history() def test_flow_inject_custom_response(flow_instance): @@ -747,23 +714,15 @@ def test_flow_inject_custom_response(flow_instance): def test_flow_clear_injected_messages(flow_instance): # Test clearing injected messages from the conversation - flow_instance.inject_custom_system_message( - "Custom system message" - ) + flow_instance.inject_custom_system_message("Custom system message") flow_instance.inject_custom_user_message("Custom user message") flow_instance.inject_custom_response("Custom response") flow_instance.clear_injected_messages() assert ( - "Custom system message" - not in flow_instance.get_message_history() - ) - assert ( - "Custom user message" - not in flow_instance.get_message_history() - ) - assert ( - "Custom response" not in flow_instance.get_message_history() + "Custom system message" not in flow_instance.get_message_history() ) + assert "Custom user message" not in flow_instance.get_message_history() + assert "Custom response" not in flow_instance.get_message_history() def test_flow_disable_message_history(flow_instance): @@ -772,9 +731,7 @@ def test_flow_disable_message_history(flow_instance): response = flow_instance.run( "This message should not be recorded in history." ) - assert ( - "This message should not be recorded in history." in response - ) + assert "This message should not be recorded in history." in response assert ( len(flow_instance.get_message_history()) == 0 ) # History is empty @@ -1066,13 +1023,9 @@ def test_flow_custom_response(flow_instance): assert flow_instance.run("Hello") == "Hi there!" assert ( - flow_instance.run("How are you?") - == "I'm doing well, thank you." - ) - assert ( - flow_instance.run("What's your name?") - == "I don't understand." + flow_instance.run("How are you?") == "I'm doing well, thank you." ) + assert flow_instance.run("What's your name?") == "I don't understand." def test_flow_message_validation(flow_instance): @@ -1113,15 +1066,10 @@ def test_flow_complex_use_case(flow_instance): flow_instance.add_context("user_id", "12345") flow_instance.run("Hello") flow_instance.run("How can I help you?") - assert ( - flow_instance.get_response() == "Please provide more details." - ) + assert flow_instance.get_response() == "Please provide more details." flow_instance.update_context("user_id", "54321") flow_instance.run("I need help with my order") - assert ( - flow_instance.get_response() - == "Sure, I can assist with that." - ) + assert flow_instance.get_response() == "Sure, I can assist with that." flow_instance.reset_conversation() assert len(flow_instance.get_message_history()) == 0 assert flow_instance.get_context("user_id") is None @@ -1160,9 +1108,7 @@ def test_flow_concurrent_requests(flow_instance): def test_flow_custom_timeout(flow_instance): # Test custom timeout handling - flow_instance.set_timeout( - 10 - ) # Set a custom timeout of 10 seconds + flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds assert flow_instance.get_timeout() == 10 import time @@ -1217,12 +1163,9 @@ def test_flow_agent_history_prompt(flow_instance): ) assert ( - "SYSTEM_PROMPT: This is the system prompt." - in agent_history_prompt - ) - assert ( - "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt + "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt ) + assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt async def test_flow_run_concurrent(flow_instance): @@ -1254,9 +1197,7 @@ def test_flow_from_llm_and_template(): llm_instance = mocked_llm # Replace with your LLM class template = "This is a template for testing." - flow_instance = Agent.from_llm_and_template( - llm_instance, template - ) + flow_instance = Agent.from_llm_and_template(llm_instance, template) assert isinstance(flow_instance, Agent) @@ -1264,9 +1205,7 @@ def test_flow_from_llm_and_template(): def test_flow_from_llm_and_template_file(): # Test creating Agent instance from an LLM and a template file llm_instance = mocked_llm # Replace with your LLM class - template_file = ( # Create a template file for testing - "template.txt" - ) + template_file = "template.txt" # Create a template file for testing flow_instance = Agent.from_llm_and_template_file( llm_instance, template_file diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index 2e5585bf..1db90efc 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -44,9 +44,7 @@ def test_autoscaler_run(): agent.id, "Generate a 10,000 word blog on health and wellness.", ) - assert ( - out == "Generate a 10,000 word blog on health and wellness." - ) + assert out == "Generate a 10,000 word blog on health and wellness." def test_autoscaler_add_agent(): diff --git a/tests/structs/test_base.py b/tests/structs/test_base.py index 6ff05e16..06ff6bf5 100644 --- a/tests/structs/test_base.py +++ b/tests/structs/test_base.py @@ -89,8 +89,7 @@ class TestBaseStructure: lines = file.readlines() assert len(lines) == 1 assert ( - lines[0] - == f"[{base_structure._current_timestamp()}]" + lines[0] == f"[{base_structure._current_timestamp()}]" f" [{event_type}] {event}\n" ) @@ -136,9 +135,7 @@ class TestBaseStructure: artifact = {"key": "value"} artifact_name = "test_artifact" - await base_structure.save_artifact_async( - artifact, artifact_name - ) + await base_structure.save_artifact_async(artifact, artifact_name) loaded_artifact = base_structure.load_artifact(artifact_name) assert loaded_artifact == artifact @@ -171,8 +168,7 @@ class TestBaseStructure: lines = file.readlines() assert len(lines) == 1 assert ( - lines[0] - == f"[{base_structure._current_timestamp()}]" + lines[0] == f"[{base_structure._current_timestamp()}]" f" [{event_type}] {event}\n" ) @@ -201,18 +197,14 @@ class TestBaseStructure: def test_run_in_thread(self): base_structure = BaseStructure() - result = base_structure.run_in_thread( - lambda: "Thread Test Result" - ) + result = base_structure.run_in_thread(lambda: "Thread Test Result") assert result.result() == "Thread Test Result" def test_save_and_decompress_data(self): base_structure = BaseStructure() data = {"key": "value"} compressed_data = base_structure.compress_data(data) - decompressed_data = base_structure.decompres_data( - compressed_data - ) + decompressed_data = base_structure.decompres_data(compressed_data) assert decompressed_data == data def test_run_batched(self): @@ -226,9 +218,7 @@ class TestBaseStructure: batched_data, batch_size=5, func=run_function ) - expected_result = [ - f"Processed {data}" for data in batched_data - ] + expected_result = [f"Processed {data}" for data in batched_data] assert result == expected_result def test_load_config(self, tmpdir): @@ -246,9 +236,7 @@ class TestBaseStructure: tmp_dir = tmpdir.mkdir("test_dir") base_structure = BaseStructure() data_to_backup = {"key": "value"} - base_structure.backup_data( - data_to_backup, backup_path=tmp_dir - ) + base_structure.backup_data(data_to_backup, backup_path=tmp_dir) backup_files = os.listdir(tmp_dir) assert len(backup_files) == 1 @@ -283,7 +271,5 @@ class TestBaseStructure: batched_data, batch_size=5, func=run_function ) - expected_result = [ - f"Processed {data}" for data in batched_data - ] + expected_result = [f"Processed {data}" for data in batched_data] assert result == expected_result diff --git a/tests/structs/test_base_workflow.py b/tests/structs/test_base_workflow.py index ccb7a563..eb029a87 100644 --- a/tests/structs/test_base_workflow.py +++ b/tests/structs/test_base_workflow.py @@ -30,12 +30,9 @@ def test_load_workflow_state(): workflow.load_workflow_state("workflow_state.json") assert workflow.max_loops == 1 assert len(workflow.tasks) == 2 + assert workflow.tasks[0].description == "What's the weather in miami" assert ( - workflow.tasks[0].description == "What's the weather in miami" - ) - assert ( - workflow.tasks[1].description - == "Create a report on these metrics" + workflow.tasks[1].description == "Create a report on these metrics" ) teardown_workflow() diff --git a/tests/structs/test_concurrent_workflow.py b/tests/structs/test_concurrent_workflow.py index e3fabdd5..9a3f46da 100644 --- a/tests/structs/test_concurrent_workflow.py +++ b/tests/structs/test_concurrent_workflow.py @@ -18,9 +18,7 @@ def test_run(): workflow.add(task1) workflow.add(task2) - with patch( - "concurrent.futures.ThreadPoolExecutor" - ) as mock_executor: + with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor: future1 = Future() future1.set_result(None) future2 = Future() diff --git a/tests/structs/test_json.py b/tests/structs/test_json.py index 9ba11072..519a6ee8 100644 --- a/tests/structs/test_json.py +++ b/tests/structs/test_json.py @@ -16,8 +16,7 @@ def valid_schema_path(tmp_path): d.mkdir() p = d / "schema.json" p.write_text( - '{"type": "object", "properties": {"name": {"type":' - ' "string"}}}' + '{"type": "object", "properties": {"name": {"type":' ' "string"}}}' ) return str(p) diff --git a/tests/structs/test_majority_voting.py b/tests/structs/test_majority_voting.py index dcd25f0b..b6f09020 100644 --- a/tests/structs/test_majority_voting.py +++ b/tests/structs/test_majority_voting.py @@ -35,15 +35,9 @@ def test_majority_voting_run_concurrent(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -83,15 +77,9 @@ def test_majority_voting_run_multithreaded(mocker): majority_vote = mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) @@ -133,15 +121,9 @@ async def test_majority_voting_run_asynchronous(mocker): majority_vote = await mv.run("What is the capital of France?") # Assert agent.run method was called with the correct task - agent1.run.assert_called_once_with( - "What is the capital of France?" - ) - agent2.run.assert_called_once_with( - "What is the capital of France?" - ) - agent3.run.assert_called_once_with( - "What is the capital of France?" - ) + agent1.run.assert_called_once_with("What is the capital of France?") + agent2.run.assert_called_once_with("What is the capital of France?") + agent3.run.assert_called_once_with("What is the capital of France?") # Assert conversation.add method was called with the correct responses conversation.add.assert_any_call(agent1.agent_name, results[0]) diff --git a/tests/structs/test_message_pool.py b/tests/structs/test_message_pool.py index cfbb4df5..7af78769 100644 --- a/tests/structs/test_message_pool.py +++ b/tests/structs/test_message_pool.py @@ -8,9 +8,7 @@ def test_message_pool_initialization(): agent2 = Agent(llm=OpenAIChat(), agent_name="agent1") moderator = Agent(llm=OpenAIChat(), agent_name="agent1") agents = [agent1, agent2] - message_pool = MessagePool( - agents=agents, moderator=moderator, turns=5 - ) + message_pool = MessagePool(agents=agents, moderator=moderator, turns=5) assert message_pool.agent == agents assert message_pool.moderator == moderator @@ -20,9 +18,7 @@ def test_message_pool_initialization(): def test_message_pool_add(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.messages == [ @@ -38,9 +34,7 @@ def test_message_pool_add(): def test_message_pool_reset(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) message_pool.reset() @@ -49,9 +43,7 @@ def test_message_pool_reset(): def test_message_pool_last_turn(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_turn() == 1 @@ -59,9 +51,7 @@ def test_message_pool_last_turn(): def test_message_pool_last_message(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.last_message == { @@ -75,9 +65,7 @@ def test_message_pool_last_message(): def test_message_pool_get_all_messages(): agent1 = Agent(llm=OpenAIChat(), agent_name="agent1") - message_pool = MessagePool( - agents=[agent1], moderator=agent1, turns=5 - ) + message_pool = MessagePool(agents=[agent1], moderator=agent1, turns=5) message_pool.add(agent=agent1, content="Hello, world!", turn=1) assert message_pool.get_all_messages() == [ @@ -104,9 +92,7 @@ def test_message_pool_get_visible_messages(): visible_to=[agent2.agent_name], ) - assert message_pool.get_visible_messages( - agent=agent2, turn=2 - ) == [ + assert message_pool.get_visible_messages(agent=agent2, turn=2) == [ { "agent": agent1, "content": "Hello, agent2!", diff --git a/tests/structs/test_multi_agent_collab.py b/tests/structs/test_multi_agent_collab.py index 555771e7..29c58dcf 100644 --- a/tests/structs/test_multi_agent_collab.py +++ b/tests/structs/test_multi_agent_collab.py @@ -73,12 +73,8 @@ def test_run(collaboration): def test_format_results(collaboration): - collaboration.results = [ - {"agent": "Agent1", "response": "Response1"} - ] - formatted_results = collaboration.format_results( - collaboration.results - ) + collaboration.results = [{"agent": "Agent1", "response": "Response1"}] + formatted_results = collaboration.format_results(collaboration.results) assert "Agent1 responded: Response1" in formatted_results diff --git a/tests/structs/test_recursive_workflow.py b/tests/structs/test_recursive_workflow.py index 5b24f921..618f955a 100644 --- a/tests/structs/test_recursive_workflow.py +++ b/tests/structs/test_recursive_workflow.py @@ -53,9 +53,7 @@ def test_run_stop_token_not_in_result(): try: workflow.run() except RecursionError: - pytest.fail( - "RecursiveWorkflow.run caused a RecursionError" - ) + pytest.fail("RecursiveWorkflow.run caused a RecursionError") assert agent.execute.call_count == max_iterations diff --git a/tests/structs/test_sequential_workflow.py b/tests/structs/test_sequential_workflow.py index 0d12991a..14faffe5 100644 --- a/tests/structs/test_sequential_workflow.py +++ b/tests/structs/test_sequential_workflow.py @@ -70,8 +70,7 @@ def test_sequential_workflow_initialization(): assert workflow.max_loops == 1 assert workflow.autosave is False assert ( - workflow.saved_state_filepath - == "sequential_workflow_state.json" + workflow.saved_state_filepath == "sequential_workflow_state.json" ) assert workflow.restore_state_filepath is None assert workflow.dashboard is False diff --git a/tests/structs/test_swarmnetwork.py b/tests/structs/test_swarmnetwork.py index 9dc6d903..146e27fb 100644 --- a/tests/structs/test_swarmnetwork.py +++ b/tests/structs/test_swarmnetwork.py @@ -20,9 +20,7 @@ def test_swarm_network_init(swarm_network): @patch("swarms.structs.swarm_net.SwarmNetwork.logger") def test_run(mock_logger, swarm_network): swarm_network.run() - assert ( - mock_logger.info.call_count == 10 - ) # 2 log messages per agent + assert mock_logger.info.call_count == 10 # 2 log messages per agent def test_run_with_mocked_agents(mocker, swarm_network): diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index de0352af..b1d0600f 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -213,9 +213,7 @@ def test_task_execute_with_action(mocker): mock_agent = mocker.Mock(spec=Agent) mock_agent.run.return_value = "result" action = mocker.Mock() - task = Task( - description="Test task", agent=mock_agent, action=action - ) + task = Task(description="Test task", agent=mock_agent, action=action) task.execute() assert task.result == "result" assert task.history == ["result"] diff --git a/tests/structs/test_tests_graph_workflow.py b/tests/structs/test_tests_graph_workflow.py index cb5b17a7..d3ed88db 100644 --- a/tests/structs/test_tests_graph_workflow.py +++ b/tests/structs/test_tests_graph_workflow.py @@ -27,9 +27,7 @@ def test_set_entry_point(graph_workflow): def test_set_entry_point_nonexistent_node(graph_workflow): - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.set_entry_point("nonexistent") @@ -42,9 +40,7 @@ def test_add_edge(graph_workflow): def test_add_edge_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.add_edge("node1", "nonexistent") @@ -59,9 +55,7 @@ def test_add_conditional_edges(graph_workflow): def test_add_conditional_edges_nonexistent_node(graph_workflow): graph_workflow.add("node1", "value1") - with pytest.raises( - ValueError, match="Node does not exist in graph" - ): + with pytest.raises(ValueError, match="Node does not exist in graph"): graph_workflow.add_conditional_edges( "node1", "condition1", {"condition_value1": "nonexistent"} ) diff --git a/tests/telemetry/test_user_utils.py b/tests/telemetry/test_user_utils.py index c7b5962c..a0936fbb 100644 --- a/tests/telemetry/test_user_utils.py +++ b/tests/telemetry/test_user_utils.py @@ -46,9 +46,7 @@ def test_generate_unique_identifier(): # Generate unique identifiers and ensure they are valid UUID strings unique_id = generate_unique_identifier() assert isinstance(unique_id, str) - assert uuid.UUID( - unique_id, version=5, namespace=uuid.NAMESPACE_DNS - ) + assert uuid.UUID(unique_id, version=5, namespace=uuid.NAMESPACE_DNS) def test_generate_user_id_edge_case(): diff --git a/tests/test_upload_tests_to_issues.py b/tests/test_upload_tests_to_issues.py index 0857c58a..eaa01ca9 100644 --- a/tests/test_upload_tests_to_issues.py +++ b/tests/test_upload_tests_to_issues.py @@ -10,7 +10,9 @@ load_dotenv GITHUB_USERNAME = os.getenv("GITHUB_USERNAME") REPO_NAME = os.getenv("GITHUB_REPO_NAME") GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") -ISSUES_URL = f"https://api.github.com/repos/{GITHUB_USERNAME}/{REPO_NAME}/issues" +ISSUES_URL = ( + f"https://api.github.com/repos/{GITHUB_USERNAME}/{REPO_NAME}/issues" +) # Headers for authentication headers = { @@ -20,9 +22,7 @@ headers = { def run_pytest(): - result = subprocess.run( - ["pytest"], capture_output=True, text=True - ) + result = subprocess.run(["pytest"], capture_output=True, text=True) return result.stdout + result.stderr @@ -56,9 +56,7 @@ def main(): errors = parse_pytest_output(pytest_output) for error in errors: - issue_response = create_github_issue( - error["title"], error["body"] - ) + issue_response = create_github_issue(error["title"], error["body"]) print(f"Issue created: {issue_response.get('html_url')}") diff --git a/tests/tools/test_tools_base.py b/tests/tools/test_tools_base.py index 9060f53f..ffbb8c6b 100644 --- a/tests/tools/test_tools_base.py +++ b/tests/tools/test_tools_base.py @@ -182,9 +182,7 @@ def test_tool_ainvoke_exception(): def test_tool_ainvoke_with_coroutine_exception(): - tool = Tool( - name="test_tool", coroutine=None, description="Test tool" - ) + tool = Tool(name="test_tool", coroutine=None, description="Test tool") with pytest.raises(NotImplementedError): tool.ainvoke("input_data") @@ -369,9 +367,7 @@ def test_structured_tool_ainvoke_with_new_argument(): func=sample_function, args_schema=SampleArgsSchema, ) - result = tool.ainvoke( - {"tool_input": "input_data"}, callbacks=None - ) + result = tool.ainvoke({"tool_input": "input_data"}, callbacks=None) assert result == "input_data" @@ -461,9 +457,7 @@ def test_tool_with_runnable(mock_runnable): def test_tool_with_invalid_argument(): # Test passing an invalid argument type with pytest.raises(ValueError): - tool( - 123 - ) # Using an integer instead of a string/callable/Runnable + tool(123) # Using an integer instead of a string/callable/Runnable def test_tool_with_multiple_arguments(mock_func): @@ -525,9 +519,7 @@ class MockSchema(BaseModel): # Test suite starts here class TestTool: # Basic Functionality Tests - def test_tool_with_valid_callable_creates_base_tool( - self, mock_func - ): + def test_tool_with_valid_callable_creates_base_tool(self, mock_func): result = tool(mock_func) assert isinstance(result, BaseTool) diff --git a/tests/utils/test_check_device.py b/tests/utils/test_check_device.py index 503a3774..cf8d6ce2 100644 --- a/tests/utils/test_check_device.py +++ b/tests/utils/test_check_device.py @@ -33,12 +33,8 @@ def test_check_device_one_cuda(monkeypatch): # Mock torch.cuda.device_count to return 1 monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr( - torch.cuda, "memory_allocated", lambda device: 0 - ) - monkeypatch.setattr( - torch.cuda, "memory_reserved", lambda device: 0 - ) + monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) result = check_device(log_level=logging.DEBUG) assert len(result) == 1 @@ -52,12 +48,8 @@ def test_check_device_multiple_cuda(monkeypatch): # Mock torch.cuda.device_count to return 4 monkeypatch.setattr(torch.cuda, "device_count", lambda: 4) # Mock torch.cuda.memory_allocated and torch.cuda.memory_reserved to return 0 - monkeypatch.setattr( - torch.cuda, "memory_allocated", lambda device: 0 - ) - monkeypatch.setattr( - torch.cuda, "memory_reserved", lambda device: 0 - ) + monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device: 0) + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device: 0) result = check_device(log_level=logging.DEBUG) assert len(result) == 4 diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py index 99d38b2c..884377fb 100644 --- a/tests/utils/test_class_args_wrapper.py +++ b/tests/utils/test_class_args_wrapper.py @@ -33,9 +33,7 @@ def test_print_class_parameters_error(): def get_parameters(class_name: str): classes = {"Agent": Agent} if class_name in classes: - return print_class_parameters( - classes[class_name], api_format=True - ) + return print_class_parameters(classes[class_name], api_format=True) else: return {"error": "Class not found"} diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index 9be83be4..1fe98c4d 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -32,18 +32,14 @@ def test_multiple_gpus_available(mocker): def test_device_properties(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch( - "torch.cuda.get_device_capability", return_value=(7, 5) - ) + mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") devices = check_device() assert len(devices) == 1 assert str(devices[0]) == "cuda" @@ -52,9 +48,7 @@ def test_device_properties(mocker): def test_memory_threshold(mocker): mocker.patch("torch.cuda.is_available", return_value=True) mocker.patch("torch.cuda.device_count", return_value=1) - mocker.patch( - "torch.cuda.get_device_capability", return_value=(7, 5) - ) + mocker.patch("torch.cuda.get_device_capability", return_value=(7, 5)) mocker.patch( "torch.cuda.get_device_properties", return_value=MagicMock(total_memory=1000), @@ -63,9 +57,7 @@ def test_memory_threshold(mocker): "torch.cuda.memory_allocated", return_value=900 ) # 90% of total memory mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") with pytest.warns( UserWarning, match=r"Memory usage for device cuda exceeds threshold", @@ -89,14 +81,10 @@ def test_compute_capability_threshold(mocker): ) mocker.patch("torch.cuda.memory_allocated", return_value=200) mocker.patch("torch.cuda.memory_reserved", return_value=300) - mocker.patch( - "torch.cuda.get_device_name", return_value="Tesla K80" - ) + mocker.patch("torch.cuda.get_device_name", return_value="Tesla K80") with pytest.warns( UserWarning, - match=( - r"Compute capability for device cuda is below threshold" - ), + match=(r"Compute capability for device cuda is below threshold"), ): devices = check_device( capability_threshold=3.5 diff --git a/tests/utils/test_extract_code_from_markdown.py b/tests/utils/test_extract_code_from_markdown.py index eb1a3e5d..24fc6109 100644 --- a/tests/utils/test_extract_code_from_markdown.py +++ b/tests/utils/test_extract_code_from_markdown.py @@ -25,9 +25,7 @@ def markdown_content_without_code(): def test_extract_code_from_markdown_with_code( markdown_content_with_code, ): - extracted_code = extract_code_from_markdown( - markdown_content_with_code - ) + extracted_code = extract_code_from_markdown(markdown_content_with_code) assert "def my_func():" in extracted_code assert 'print("This is my function.")' in extracted_code assert "class MyClass:" in extracted_code diff --git a/tests/utils/test_find_image_path.py b/tests/utils/test_find_image_path.py index 29b1c627..ae15343f 100644 --- a/tests/utils/test_find_image_path.py +++ b/tests/utils/test_find_image_path.py @@ -9,9 +9,7 @@ from swarms.utils import find_image_path def test_find_image_path_no_images(): assert ( - find_image_path( - "This is a test string without any image paths." - ) + find_image_path("This is a test string without any image paths.") is None ) diff --git a/tests/utils/test_limit_tokens_from_string.py b/tests/utils/test_limit_tokens_from_string.py index 4d68dccb..7cca8f0b 100644 --- a/tests/utils/test_limit_tokens_from_string.py +++ b/tests/utils/test_limit_tokens_from_string.py @@ -21,9 +21,7 @@ def test_limit_zero_tokens(): def test_negative_token_limit(): - sentence = ( - "This test will raise an exception when limit is negative." - ) + sentence = "This test will raise an exception when limit is negative." with pytest.raises(Exception): limit_tokens_from_string(sentence, limit=-1) diff --git a/tests/utils/test_metrics_decorator.py b/tests/utils/test_metrics_decorator.py index 719d50a7..8c3a8af9 100644 --- a/tests/utils/test_metrics_decorator.py +++ b/tests/utils/test_metrics_decorator.py @@ -55,11 +55,14 @@ def test_metrics_decorator_with_mocked_time(mocker): return ["tok_1", "tok_2"] metrics = decorated_func() - assert metrics == """ + assert ( + metrics + == """ Time to First Token: 5 Generation Latency: 20 Throughput: 0.1 """ + ) mocked_time.assert_any_call() diff --git a/tests/utils/test_prep_torch_inference.py b/tests/utils/test_prep_torch_inference.py index 6af4a9a7..da3a511a 100644 --- a/tests/utils/test_prep_torch_inference.py +++ b/tests/utils/test_prep_torch_inference.py @@ -9,9 +9,7 @@ from swarms.utils import prep_torch_inference def test_prep_torch_inference(): model_path = "model_path" - device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_mock = Mock() model_mock.eval = Mock() diff --git a/tests/utils/test_print_class_parameters.py b/tests/utils/test_print_class_parameters.py index 9a133ae4..5ff8eb0b 100644 --- a/tests/utils/test_print_class_parameters.py +++ b/tests/utils/test_print_class_parameters.py @@ -19,9 +19,7 @@ def test_class_with_complex_parameters(): pass output = {"value1": "", "value2": ""} - assert ( - print_class_parameters(ComplexArgs, api_format=True) == output - ) + assert print_class_parameters(ComplexArgs, api_format=True) == output def test_empty_class(): @@ -41,10 +39,7 @@ def test_class_with_no_annotations(): "value1": "", "value2": "", } - assert ( - print_class_parameters(NoAnnotations, api_format=True) - == output - ) + assert print_class_parameters(NoAnnotations, api_format=True) == output def test_class_with_partial_annotations(): diff --git a/tests/utils/test_subprocess_code_interpreter.py b/tests/utils/test_subprocess_code_interpreter.py index 3bb800f5..eb8da09f 100644 --- a/tests/utils/test_subprocess_code_interpreter.py +++ b/tests/utils/test_subprocess_code_interpreter.py @@ -4,7 +4,7 @@ import threading import pytest -from swarms.utils.code_interpreter import ( # Adjust the import according to your project structure +from swarms.tools.code_interpreter import ( # Adjust the import according to your project structure SubprocessCodeInterpreter, )