diff --git a/.github/workflows/generator-generic-ossf-slsa3-publish.yml b/.github/workflows/generator-generic-ossf-slsa3-publish.yml new file mode 100644 index 00000000..a36e782c --- /dev/null +++ b/.github/workflows/generator-generic-ossf-slsa3-publish.yml @@ -0,0 +1,66 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# This workflow lets you generate SLSA provenance file for your project. +# The generation satisfies level 3 for the provenance requirements - see https://slsa.dev/spec/v0.1/requirements +# The project is an initiative of the OpenSSF (openssf.org) and is developed at +# https://github.com/slsa-framework/slsa-github-generator. +# The provenance file can be verified using https://github.com/slsa-framework/slsa-verifier. +# For more information about SLSA and how it improves the supply-chain, visit slsa.dev. + +name: SLSA generic generator +on: + workflow_dispatch: + release: + types: [created] + +jobs: + build: + runs-on: ubuntu-latest + outputs: + digests: ${{ steps.hash.outputs.digests }} + + steps: + - uses: actions/checkout@v3 + + # ======================================================== + # + # Step 1: Build your artifacts. + # + # ======================================================== + - name: Build artifacts + run: | + # These are some amazing artifacts. + echo "artifact1" > artifact1 + echo "artifact2" > artifact2 + + # ======================================================== + # + # Step 2: Add a step to generate the provenance subjects + # as shown below. Update the sha256 sum arguments + # to include all binaries that you generate + # provenance for. + # + # ======================================================== + - name: Generate subject for provenance + id: hash + run: | + set -euo pipefail + + # List the artifacts the provenance will refer to. + files=$(ls artifact*) + # Generate the subjects (base64 encoded). + echo "hashes=$(sha256sum $files | base64 -w0)" >> "${GITHUB_OUTPUT}" + + provenance: + needs: [build] + permissions: + actions: read # To read the workflow path. + id-token: write # To sign the provenance. + contents: write # To add assets to a release. + uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v1.4.0 + with: + base64-subjects: "${{ needs.build.outputs.digests }}" + upload-assets: true # Optional: Upload to a new release diff --git a/.github/workflows/makefile.yml b/.github/workflows/makefile.yml new file mode 100644 index 00000000..ab01451f --- /dev/null +++ b/.github/workflows/makefile.yml @@ -0,0 +1,27 @@ +name: Makefile CI + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: configure + run: ./configure + + - name: Install dependencies + run: make + + - name: Run check + run: make check + + - name: Run distcheck + run: make distcheck diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml new file mode 100644 index 00000000..5ff88856 --- /dev/null +++ b/.github/workflows/pyre.yml @@ -0,0 +1,46 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# This workflow integrates Pyre with GitHub's +# Code Scanning feature. +# +# Pyre is a performant type checker for Python compliant with +# PEP 484. Pyre can analyze codebases with millions of lines +# of code incrementally – providing instantaneous feedback +# to developers as they write code. +# +# See https://pyre-check.org + +name: Pyre + +on: + workflow_dispatch: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +permissions: + contents: read + +jobs: + pyre: + permissions: + actions: read + contents: read + security-events: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: true + + - name: Run Pyre + uses: facebook/pyre-action@60697a7858f7cc8470d8cc494a3cf2ad6b06560d + with: + # To customize these inputs: + # See https://github.com/facebook/pyre-action#inputs + repo-directory: './' + requirements-path: 'requirements.txt' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae0a4fc0..0c936705 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: 'v0.0.255' hooks: - id: ruff - args: [--fix] + args: [----unsafe-fixes] - repo: https://github.com/nbQA-dev/nbQA rev: 1.6.3 hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 04f0f593..8230322d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,12 @@ Swarms is designed to provide modular building blocks to build scalable swarms o Before you contribute a new feature, consider submitting an Issue to discuss the feature so the community can weigh in and assist. +### Requirements: +- New class and or function Module with documentation in docstrings with error handling +- Tests using pytest in tests folder in the same module folder +- Documentation in the docs/swarms/module_name folder and then added into the mkdocs.yml + + ## How to Contribute Changes First, fork this repository to your own GitHub account. Click "fork" in the top corner of the `swarms` repository to get started: @@ -43,11 +49,30 @@ git push -u origin main ``` ## 🎨 Code quality +- Follow the following guide on code quality a python guide or your PR will most likely be overlooked: [CLICK HERE](https://google.github.io/styleguide/pyguide.html) + + ### Pre-commit tool This project utilizes the [pre-commit](https://pre-commit.com/) tool to maintain code quality and consistency. Before submitting a pull request or making any commits, it is important to run the pre-commit tool to ensure that your changes meet the project's guidelines. + +- Install pre-commit (https://pre-commit.com/) + +```bash +pip install pre-commit +``` + +- Check that it's installed + +```bash +pre-commit --version +``` + +Now when you make a git commit, the black code formatter and ruff linter will run. + + Furthermore, we have integrated a pre-commit GitHub Action into our workflow. This means that with every pull request opened, the pre-commit checks will be automatically enforced, streamlining the code review process and ensuring that all contributions adhere to our quality standards. To run the pre-commit tool, follow these steps: @@ -60,6 +85,7 @@ To run the pre-commit tool, follow these steps: 4. You can also install pre-commit as a git hook by execute `pre-commit install`. Every time you made `git commit` pre-commit run automatically for you. + ### Docstrings All new functions and classes in `swarms` should include docstrings. This is a prerequisite for any new functions and classes to be added to the library. diff --git a/Developers.md b/Developers.md deleted file mode 100644 index ca7fda93..00000000 --- a/Developers.md +++ /dev/null @@ -1,21 +0,0 @@ -Developers - -Install pre-commit (https://pre-commit.com/) - -```bash -pip install pre-commit -``` - -Check that it's installed - -```bash -pre-commit --versioni -``` - -This repository already has a pre-commit configuration. To install the hooks, run: - -```bash -pre-commit install -``` - -Now when you make a git commit, the black code formatter and ruff linter will run. diff --git a/README.md b/README.md index 5ef0678b..2b104caf 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,10 @@ Swarms is a modular framework that enables reliable and useful multi-agent colla --- ## Usage -### Example in Colab: - +Run example in Collab: Open In Colab - Run example in Colab, using your OpenAI API key. + ### `Flow` Example - Reliable Structure that provides LLMS autonomy @@ -42,10 +41,8 @@ api_key = "" # Initialize the language model, this model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC llm = OpenAIChat( - # model_name="gpt-4" openai_api_key=api_key, temperature=0.5, - # max_tokens=100, ) ## Initialize the workflow @@ -53,24 +50,10 @@ flow = Flow( llm=llm, max_loops=2, dashboard=True, - # stopping_condition=None, # You can define a stopping condition as needed. - # loop_interval=1, - # retry_attempts=3, - # retry_interval=1, - # interactive=False, # Set to 'True' for interactive mode. - # dynamic_temperature=False, # Set to 'True' for dynamic temperature handling. + ) -# out = flow.load_state("flow_state.json") -# temp = flow.dynamic_temperature() -# filter = flow.add_response_filter("Trump") out = flow.run("Generate a 10,000 word blog on health and wellness.") -# out = flow.validate_response(out) -# out = flow.analyze_feedback(out) -# out = flow.print_history_and_memory() -# # out = flow.save_state("flow_state.json") -# print(out) - ``` diff --git a/example.py b/example.py index 46e8b33c..ab496b77 100644 --- a/example.py +++ b/example.py @@ -1,37 +1,14 @@ from swarms.models import OpenAIChat from swarms.structs import Flow -# Initialize the language model, this model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC +# Initialize the language model llm = OpenAIChat( - # model_name="gpt-4" - # openai_api_key=api_key, temperature=0.5, - # max_tokens=100, ) ## Initialize the workflow -flow = Flow( - llm=llm, - max_loops=2, - dashboard=True, - # tools=[search_api] - # stopping_condition=None, # You can define a stopping condition as needed. - # loop_interval=1, - # retry_attempts=3, - # retry_interval=1, - # interactive=False, # Set to 'True' for interactive mode. - # dynamic_temperature=False, # Set to 'True' for dynamic temperature handling. -) +flow = Flow(llm=llm, max_loops=1, dashboard=True) -# out = flow.load_state("flow_state.json") -# temp = flow.dynamic_temperature() -# filter = flow.add_response_filter("Trump") -out = flow.run( - "Generate a 10,000 word blog on mental clarity and the benefits of meditation." -) -# out = flow.validate_response(out) -# out = flow.analyze_feedback(out) -# out = flow.print_history_and_memory() -# # out = flow.save_state("flow_state.json") -# print(out) +# Run the workflow on a task +out = flow.run("Generate a 10,000 word blog on health and wellness.") diff --git a/multi_agent_debate.py b/multi_agent_debate.py new file mode 100644 index 00000000..2bc67c8c --- /dev/null +++ b/multi_agent_debate.py @@ -0,0 +1,31 @@ +import os + +from dotenv import load_dotenv + +from swarms.models import OpenAIChat +from swarms.structs import Flow +from swarms.swarms.multi_agent_collab import MultiAgentCollaboration + +load_dotenv() + +api_key = os.environ.get("OPENAI_API_KEY") + +# Initialize the language model +llm = OpenAIChat( + temperature=0.5, + openai_api_key=api_key, +) + + +## Initialize the workflow +flow = Flow(llm=llm, max_loops=1, dashboard=True) +flow2 = Flow(llm=llm, max_loops=1, dashboard=True) +flow3 = Flow(llm=llm, max_loops=1, dashboard=True) + + +swarm = MultiAgentCollaboration( + agents=[flow, flow2, flow3], + max_iters=4, +) + +swarm.run("Generate a 10,000 word blog on health and wellness.") diff --git a/playground/agents/mm_agent_example.py b/playground/agents/mm_agent_example.py index 0da0d469..5326af6e 100644 --- a/playground/agents/mm_agent_example.py +++ b/playground/agents/mm_agent_example.py @@ -9,6 +9,9 @@ text = node.run_text("What is your name? Generate a picture of yourself") img = node.run_img("/image1", "What is this image about?") chat = node.chat( - "What is your name? Generate a picture of yourself. What is this image about?", + ( + "What is your name? Generate a picture of yourself. What is this image" + " about?" + ), streaming=True, ) diff --git a/playground/agents/revgpt_agent.py b/playground/agents/revgpt_agent.py index 42d95359..16a720e8 100644 --- a/playground/agents/revgpt_agent.py +++ b/playground/agents/revgpt_agent.py @@ -10,13 +10,19 @@ config = { "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "PUID": os.getenv("REVGPT_PUID"), - "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")], + "unverified_plugin_domains": [ + os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS") + ], } llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config) worker = Worker(ai_name="Optimus Prime", llm=llm) -task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." +task = ( + "What were the winning boston marathon times for the past 5 years (ending" + " in 2022)? Generate a table of the year, name, country of origin, and" + " times." +) response = worker.run(task) print(response) diff --git a/playground/demos/accountant_team/accountant_team.py b/playground/demos/accountant_team/accountant_team.py index 61cc2f7a..d9edc2f6 100644 --- a/playground/demos/accountant_team/accountant_team.py +++ b/playground/demos/accountant_team/accountant_team.py @@ -103,7 +103,8 @@ class AccountantSwarms: # Provide decision making support to the accountant decision_making_support_agent_output = decision_making_support_agent.run( - f"{self.decision_making_support_agent_instructions}: {summary_agent_output}" + f"{self.decision_making_support_agent_instructions}:" + f" {summary_agent_output}" ) return decision_making_support_agent_output @@ -113,5 +114,7 @@ swarm = AccountantSwarms( pdf_path="tesla.pdf", fraud_detection_instructions="Detect fraud in the document", summary_agent_instructions="Generate an actionable summary of the document", - decision_making_support_agent_instructions="Provide decision making support to the business owner:", + decision_making_support_agent_instructions=( + "Provide decision making support to the business owner:" + ), ) diff --git a/playground/demos/ai_research_team/main.py b/playground/demos/ai_research_team/main.py index a297bc0a..77d8dbdc 100644 --- a/playground/demos/ai_research_team/main.py +++ b/playground/demos/ai_research_team/main.py @@ -48,6 +48,7 @@ paper_implementor_agent = Flow( paper = pdf_to_text(PDF_PATH) algorithmic_psuedocode_agent = paper_summarizer_agent.run( - f"Focus on creating the algorithmic pseudocode for the novel method in this paper: {paper}" + "Focus on creating the algorithmic pseudocode for the novel method in this" + f" paper: {paper}" ) pytorch_code = paper_implementor_agent.run(algorithmic_psuedocode_agent) diff --git a/playground/demos/autotemp/autotemp.py b/playground/demos/autotemp/autotemp.py new file mode 100644 index 00000000..b136bad7 --- /dev/null +++ b/playground/demos/autotemp/autotemp.py @@ -0,0 +1,86 @@ +import re +from swarms.models.openai_models import OpenAIChat + + +class AutoTemp: + """ + AutoTemp is a tool for automatically selecting the best temperature setting for a given task. + It generates responses at different temperatures, evaluates them, and ranks them based on quality. + """ + + def __init__( + self, + api_key, + default_temp=0.0, + alt_temps=None, + auto_select=True, + max_workers=6, + ): + self.api_key = api_key + self.default_temp = default_temp + self.alt_temps = ( + alt_temps if alt_temps else [0.4, 0.6, 0.8, 1.0, 1.2, 1.4] + ) + self.auto_select = auto_select + self.max_workers = max_workers + self.llm = OpenAIChat( + openai_api_key=self.api_key, temperature=self.default_temp + ) + + def evaluate_output(self, output, temperature): + print(f"Evaluating output at temperature {temperature}...") + eval_prompt = f""" + Evaluate the following output which was generated at a temperature setting of {temperature}. Provide a precise score from 0.0 to 100.0, considering the following criteria: + + - Relevance: How well does the output address the prompt or task at hand? + - Clarity: Is the output easy to understand and free of ambiguity? + - Utility: How useful is the output for its intended purpose? + - Pride: If the user had to submit this output to the world for their career, would they be proud? + - Delight: Is the output likely to delight or positively surprise the user? + + Be sure to comprehensively evaluate the output, it is very important for my career. Please answer with just the score with one decimal place accuracy, such as 42.0 or 96.9. Be extremely critical. + + Output to evaluate: + --- + {output} + --- + """ + score_text = self.llm(eval_prompt, temperature=0.5) + score_match = re.search(r"\b\d+(\.\d)?\b", score_text) + return round(float(score_match.group()), 1) if score_match else 0.0 + + def run(self, prompt, temperature_string): + print("Starting generation process...") + temperature_list = [ + float(temp.strip()) + for temp in temperature_string.split(",") + if temp.strip() + ] + outputs = {} + scores = {} + for temp in temperature_list: + print(f"Generating at temperature {temp}...") + output_text = self.llm(prompt, temperature=temp) + if output_text: + outputs[temp] = output_text + scores[temp] = self.evaluate_output(output_text, temp) + + print("Generation process complete.") + if not scores: + return "No valid outputs generated.", None + + sorted_scores = sorted( + scores.items(), key=lambda item: item[1], reverse=True + ) + best_temp, best_score = sorted_scores[0] + best_output = outputs[best_temp] + + return ( + f"Best AutoTemp Output (Temp {best_temp} | Score:" + f" {best_score}):\n{best_output}" + if self.auto_select + else "\n".join( + f"Temp {temp} | Score: {score}:\n{outputs[temp]}" + for temp, score in sorted_scores + ) + ) diff --git a/playground/demos/autotemp/autotemp_example.py b/playground/demos/autotemp/autotemp_example.py new file mode 100644 index 00000000..9047268d --- /dev/null +++ b/playground/demos/autotemp/autotemp_example.py @@ -0,0 +1,22 @@ +from swarms.models import OpenAIChat +from swarms.models.autotemp import AutoTemp + +# Your OpenAI API key +api_key = "" + +autotemp_agent = AutoTemp( + api_key=api_key, + alt_temps=[0.4, 0.6, 0.8, 1.0, 1.2], + auto_select=False, + # model_version="gpt-3.5-turbo" # Specify the model version if needed +) + +# Define the task and temperature string +task = "Generate a short story about a lost civilization." +temperature_string = "0.4,0.6,0.8,1.0,1.2," + +# Run the AutoTempAgent +result = autotemp_agent.run(task, temperature_string) + +# Print the result +print(result) diff --git a/playground/demos/blog_gen/blog_gen.py b/playground/demos/blog_gen/blog_gen.py new file mode 100644 index 00000000..84ab240d --- /dev/null +++ b/playground/demos/blog_gen/blog_gen.py @@ -0,0 +1,128 @@ +import os +from termcolor import colored +from swarms.models import OpenAIChat +from swarms.models.autotemp import AutoTemp +from swarms.structs import SequentialWorkflow + + +class BlogGen: + def __init__( + self, + api_key, + blog_topic, + temperature_range: str = "0.4,0.6,0.8,1.0,1.2", + ): # Add blog_topic as an argument + self.openai_chat = OpenAIChat(openai_api_key=api_key, temperature=0.8) + self.auto_temp = AutoTemp(api_key) + self.temperature_range = temperature_range + self.workflow = SequentialWorkflow(max_loops=5) + + # Formatting the topic selection prompt with the user's topic + self.TOPIC_SELECTION_SYSTEM_PROMPT = f""" + Given the topic '{blog_topic}', generate an engaging and versatile blog topic. This topic should cover areas related to '{blog_topic}' and might include aspects such as current events, lifestyle, technology, health, and culture related to '{blog_topic}'. Identify trending subjects within this realm. The topic must be unique, thought-provoking, and have the potential to draw in readers interested in '{blog_topic}'. + """ + + self.DRAFT_WRITER_SYSTEM_PROMPT = """ + Create an engaging and comprehensive blog article of at least 1,000 words on '{{CHOSEN_TOPIC}}'. The content should be original, informative, and reflective of a human-like style, with a clear structure including headings and sub-headings. Incorporate a blend of narrative, factual data, expert insights, and anecdotes to enrich the article. Focus on SEO optimization by using relevant keywords, ensuring readability, and including meta descriptions and title tags. The article should provide value, appeal to both knowledgeable and general readers, and maintain a balance between depth and accessibility. Aim to make the article engaging and suitable for online audiences. + """ + + self.REVIEW_AGENT_SYSTEM_PROMPT = """ + Critically review the drafted blog article on '{{ARTICLE_TOPIC}}' to refine it to high-quality content suitable for online publication. Ensure the article is coherent, factually accurate, engaging, and optimized for search engines (SEO). Check for the effective use of keywords, readability, internal and external links, and the inclusion of meta descriptions and title tags. Edit the content to enhance clarity, impact, and maintain the authors voice. The goal is to polish the article into a professional, error-free piece that resonates with the target audience, adheres to publication standards, and is optimized for both search engines and social media sharing. + """ + + self.DISTRIBUTION_AGENT_SYSTEM_PROMPT = """ + Develop an autonomous distribution strategy for the blog article on '{{ARTICLE_TOPIC}}'. Utilize an API to post the article on a popular blog platform (e.g., WordPress, Blogger, Medium) commonly used by our target audience. Ensure the post includes all SEO elements like meta descriptions, title tags, and properly formatted content. Craft unique, engaging social media posts tailored to different platforms to promote the blog article. Schedule these posts to optimize reach and engagement, using data-driven insights. Monitor the performance of the distribution efforts, adjusting strategies based on engagement metrics and audience feedback. Aim to maximize the article's visibility, attract a diverse audience, and foster engagement across digital channels. + """ + + def run_workflow(self): + try: + # Topic generation using OpenAIChat + topic_result = self.openai_chat.generate( + [self.TOPIC_SELECTION_SYSTEM_PROMPT] + ) + topic_output = topic_result.generations[0][0].text + print( + colored( + ( + "\nTopic Selection Task" + f" Output:\n----------------------------\n{topic_output}\n" + ), + "white", + ) + ) + + chosen_topic = topic_output.split("\n")[0] + print(colored("Selected topic: " + chosen_topic, "yellow")) + + # Initial draft generation with AutoTemp + initial_draft_prompt = self.DRAFT_WRITER_SYSTEM_PROMPT.replace( + "{{CHOSEN_TOPIC}}", chosen_topic + ) + auto_temp_output = self.auto_temp.run( + initial_draft_prompt, self.temperature_range + ) + initial_draft_output = auto_temp_output # Assuming AutoTemp.run returns the best output directly + print( + colored( + ( + "\nInitial Draft" + f" Output:\n----------------------------\n{initial_draft_output}\n" + ), + "white", + ) + ) + + # Review process using OpenAIChat + review_prompt = self.REVIEW_AGENT_SYSTEM_PROMPT.replace( + "{{ARTICLE_TOPIC}}", chosen_topic + ) + review_result = self.openai_chat.generate([review_prompt]) + review_output = review_result.generations[0][0].text + print( + colored( + ( + "\nReview" + f" Output:\n----------------------------\n{review_output}\n" + ), + "white", + ) + ) + + # Distribution preparation using OpenAIChat + distribution_prompt = self.DISTRIBUTION_AGENT_SYSTEM_PROMPT.replace( + "{{ARTICLE_TOPIC}}", chosen_topic + ) + distribution_result = self.openai_chat.generate( + [distribution_prompt] + ) + distribution_output = distribution_result.generations[0][0].text + print( + colored( + ( + "\nDistribution" + f" Output:\n----------------------------\n{distribution_output}\n" + ), + "white", + ) + ) + + # Final compilation of the blog + final_blog_content = f"{initial_draft_output}\n\n{review_output}\n\n{distribution_output}" + print( + colored( + ( + "\nFinal Blog" + f" Content:\n----------------------------\n{final_blog_content}\n" + ), + "green", + ) + ) + + except Exception as e: + print(colored(f"An error occurred: {str(e)}", "red")) + + +if __name__ == "__main__": + api_key = os.environ["OPENAI_API_KEY"] + blog_generator = BlogGen(api_key) + blog_generator.run_workflow() diff --git a/playground/demos/blog_gen/blog_gen_example.py b/playground/demos/blog_gen/blog_gen_example.py new file mode 100644 index 00000000..7cf95535 --- /dev/null +++ b/playground/demos/blog_gen/blog_gen_example.py @@ -0,0 +1,23 @@ +import os +from swarms.swarms.blog_gen import BlogGen + + +def main(): + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set.") + + blog_topic = input("Enter the topic for the blog generation: ") + + blog_generator = BlogGen(api_key, blog_topic) + blog_generator.TOPIC_SELECTION_SYSTEM_PROMPT = ( + blog_generator.TOPIC_SELECTION_SYSTEM_PROMPT.replace( + "{{BLOG_TOPIC}}", blog_topic + ) + ) + + blog_generator.run_workflow() + + +if __name__ == "__main__": + main() diff --git a/playground/demos/ui_software_demo.py b/playground/demos/design_team/ui_software_demo.py similarity index 100% rename from playground/demos/ui_software_demo.py rename to playground/demos/design_team/ui_software_demo.py diff --git a/playground/demos/multi_modal_auto_agent.py b/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py similarity index 87% rename from playground/demos/multi_modal_auto_agent.py rename to playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py index b462795f..a2602706 100644 --- a/playground/demos/multi_modal_auto_agent.py +++ b/playground/demos/multi_modal_autonomous_agents/multi_modal_auto_agent.py @@ -4,7 +4,10 @@ from swarms.models import Idefics # Multi Modality Auto Agent llm = Idefics(max_length=2000) -task = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +task = ( + "User: What is in this image?" + " https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +) ## Initialize the workflow flow = Flow( diff --git a/playground/demos/nutrition/full_fridge.jpg b/playground/demos/nutrition/full_fridge.jpg new file mode 100644 index 00000000..c1b249c5 Binary files /dev/null and b/playground/demos/nutrition/full_fridge.jpg differ diff --git a/playground/demos/nutrition/nutrition.py b/playground/demos/nutrition/nutrition.py new file mode 100644 index 00000000..ffdafd7c --- /dev/null +++ b/playground/demos/nutrition/nutrition.py @@ -0,0 +1,129 @@ +import os +import base64 +import requests +from dotenv import load_dotenv +from swarms.models import Anthropic, OpenAIChat +from swarms.structs import Flow + +# Load environment variables +load_dotenv() +openai_api_key = os.getenv("OPENAI_API_KEY") + +# Define prompts for various tasks +MEAL_PLAN_PROMPT = ( + "Based on the following user preferences: dietary restrictions as" + " vegetarian, preferred cuisines as Italian and Indian, a total caloric" + " intake of around 2000 calories per day, and an exclusion of legumes," + " create a detailed weekly meal plan. Include a variety of meals for" + " breakfast, lunch, dinner, and optional snacks." +) +IMAGE_ANALYSIS_PROMPT = ( + "Identify the items in this fridge, including their quantities and" + " condition." +) + + +# Function to encode image to base64 +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +# Initialize Language Model (LLM) +llm = OpenAIChat( + openai_api_key=openai_api_key, + max_tokens=3000, +) + + +# Function to handle vision tasks +def create_vision_agent(image_path): + base64_image = encode_image(image_path) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {openai_api_key}", + } + payload = { + "model": "gpt-4-vision-preview", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": IMAGE_ANALYSIS_PROMPT}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, + }, + ], + } + ], + "max_tokens": 300, + } + response = requests.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + ) + return response.json() + + +# Function to generate an integrated shopping list considering meal plan and fridge contents +def generate_integrated_shopping_list( + meal_plan_output, image_analysis, user_preferences +): + # Prepare the prompt for the LLM + fridge_contents = image_analysis["choices"][0]["message"]["content"] + prompt = ( + f"Based on this meal plan: {meal_plan_output}, and the following items" + f" in the fridge: {fridge_contents}, considering dietary preferences as" + " vegetarian with a preference for Italian and Indian cuisines," + " generate a comprehensive shopping list that includes only the items" + " needed." + ) + + # Send the prompt to the LLM and return the response + response = llm(prompt) + return response # assuming the response is a string + + +# Define agent for meal planning +meal_plan_agent = Flow( + llm=llm, + sop=MEAL_PLAN_PROMPT, + max_loops=1, + autosave=True, + saved_state_path="meal_plan_agent.json", +) + +# User preferences for meal planning +user_preferences = { + "dietary_restrictions": "vegetarian", + "preferred_cuisines": ["Italian", "Indian"], + "caloric_intake": 2000, + "other notes": "Doesn't eat legumes", +} + +# Generate Meal Plan +meal_plan_output = meal_plan_agent.run( + f"Generate a meal plan: {user_preferences}" +) + +# Vision Agent - Analyze an Image +image_analysis_output = create_vision_agent("full_fridge.jpg") + +# Generate Integrated Shopping List +integrated_shopping_list = generate_integrated_shopping_list( + meal_plan_output, image_analysis_output, user_preferences +) + +# Print and save the outputs +print("Meal Plan:", meal_plan_output) +print("Integrated Shopping List:", integrated_shopping_list) + +with open("nutrition_output.txt", "w") as file: + file.write("Meal Plan:\n" + meal_plan_output + "\n\n") + file.write("Integrated Shopping List:\n" + integrated_shopping_list + "\n") + +print("Outputs have been saved to nutrition_output.txt") diff --git a/playground/demos/positive_med.py b/playground/demos/positive_med/positive_med.py similarity index 91% rename from playground/demos/positive_med.py rename to playground/demos/positive_med/positive_med.py index 6f7a2d3a..ea0c7c4e 100644 --- a/playground/demos/positive_med.py +++ b/playground/demos/positive_med/positive_med.py @@ -39,9 +39,9 @@ def get_review_prompt(article): def social_media_prompt(article: str, goal: str = "Clicks and engagement"): - prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace("{{ARTICLE}}", article).replace( - "{{GOAL}}", goal - ) + prompt = SOCIAL_MEDIA_SYSTEM_PROMPT_AGENT.replace( + "{{ARTICLE}}", article + ).replace("{{GOAL}}", goal) return prompt @@ -50,7 +50,8 @@ topic_selection_task = ( "Generate 10 topics on gaining mental clarity using ancient practices" ) topics = llm( - f"Your System Instructions: {TOPIC_GENERATOR}, Your current task: {topic_selection_task}" + f"Your System Instructions: {TOPIC_GENERATOR}, Your current task:" + f" {topic_selection_task}" ) dashboard = print( @@ -109,7 +110,9 @@ reviewed_draft = print( # Agent that publishes on social media -distribution_agent = llm(social_media_prompt(draft_blog, goal="Clicks and engagement")) +distribution_agent = llm( + social_media_prompt(draft_blog, goal="Clicks and engagement") +) distribution_agent_out = print( colored( f""" diff --git a/playground/models/bioclip.py b/playground/models/bioclip.py index dcdd309b..11fb9f27 100644 --- a/playground/models/bioclip.py +++ b/playground/models/bioclip.py @@ -1,6 +1,8 @@ from swarms.models.bioclip import BioClip -clip = BioClip("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") +clip = BioClip( + "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" +) labels = [ "adenocarcinoma histopathology", diff --git a/playground/models/idefics.py b/playground/models/idefics.py index 032e0f3b..39d6f4eb 100644 --- a/playground/models/idefics.py +++ b/playground/models/idefics.py @@ -2,11 +2,17 @@ from swarms.models import idefics model = idefics() -user_input = "User: What is in this image? https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +user_input = ( + "User: What is in this image?" + " https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG" +) response = model.chat(user_input) print(response) -user_input = "User: And who is that? https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052" +user_input = ( + "User: And who is that?" + " https://static.wikia.nocookie.net/asterix/images/2/25/R22b.gif/revision/latest?cb=20110815073052" +) response = model.chat(user_input) print(response) diff --git a/playground/models/llama_function_caller.py b/playground/models/llama_function_caller.py index 43bca3a5..201009a8 100644 --- a/playground/models/llama_function_caller.py +++ b/playground/models/llama_function_caller.py @@ -28,7 +28,9 @@ llama_caller.add_func( ) # Call the function -result = llama_caller.call_function("get_weather", location="Paris", format="Celsius") +result = llama_caller.call_function( + "get_weather", location="Paris", format="Celsius" +) print(result) # Stream a user prompt diff --git a/playground/models/vilt.py b/playground/models/vilt.py index 127514e0..8e40f59d 100644 --- a/playground/models/vilt.py +++ b/playground/models/vilt.py @@ -3,5 +3,6 @@ from swarms.models.vilt import Vilt model = Vilt() output = model( - "What is this image", "http://images.cocodataset.org/val2017/000000039769.jpg" + "What is this image", + "http://images.cocodataset.org/val2017/000000039769.jpg", ) diff --git a/playground/structs/flow_tools.py b/playground/structs/flow_tools.py index 647f6617..42ec0f72 100644 --- a/playground/structs/flow_tools.py +++ b/playground/structs/flow_tools.py @@ -30,7 +30,9 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = ( + phrase.strip() for line in lines for phrase in line.split(" ") + ) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -58,5 +60,6 @@ flow = Flow( ) out = flow.run( - "Generate a 10,000 word blog on mental clarity and the benefits of meditation." + "Generate a 10,000 word blog on mental clarity and the benefits of" + " meditation." ) diff --git a/playground/swarms/debate.py b/playground/swarms/debate.py index 2c47ed8e..4c97817d 100644 --- a/playground/swarms/debate.py +++ b/playground/swarms/debate.py @@ -36,7 +36,9 @@ class DialogueAgent: message = self.model( [ self.system_message, - HumanMessage(content="\n".join(self.message_history + [self.prefix])), + HumanMessage( + content="\n".join(self.message_history + [self.prefix]) + ), ] ) return message.content @@ -124,19 +126,19 @@ game_description = f"""Here is the topic for the presidential debate: {topic}. The presidential candidates are: {', '.join(character_names)}.""" player_descriptor_system_message = SystemMessage( - content="You can add detail to the description of each presidential candidate." + content=( + "You can add detail to the description of each presidential candidate." + ) ) def generate_character_description(character_name): character_specifier_prompt = [ player_descriptor_system_message, - HumanMessage( - content=f"""{game_description} + HumanMessage(content=f"""{game_description} Please reply with a creative description of the presidential candidate, {character_name}, in {word_limit} words or less, that emphasizes their personalities. Speak directly to {character_name}. - Do not add anything else.""" - ), + Do not add anything else."""), ] character_description = ChatOpenAI(temperature=1.0)( character_specifier_prompt @@ -155,9 +157,7 @@ Your goal is to be as creative as possible and make the voters think you are the def generate_character_system_message(character_name, character_header): - return SystemMessage( - content=( - f"""{character_header} + return SystemMessage(content=f"""{character_header} You will speak in the style of {character_name}, and exaggerate their personality. You will come up with creative ideas related to {topic}. Do not say the same things over and over again. @@ -169,13 +169,12 @@ Speak only from the perspective of {character_name}. Stop speaking the moment you finish speaking from your perspective. Never forget to keep your response to {word_limit} words! Do not add anything else. - """ - ) - ) + """) character_descriptions = [ - generate_character_description(character_name) for character_name in character_names + generate_character_description(character_name) + for character_name in character_names ] character_headers = [ generate_character_header(character_name, character_description) @@ -185,7 +184,9 @@ character_headers = [ ] character_system_messages = [ generate_character_system_message(character_name, character_headers) - for character_name, character_headers in zip(character_names, character_headers) + for character_name, character_headers in zip( + character_names, character_headers + ) ] for ( @@ -207,7 +208,10 @@ for ( class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: - return "Your response should be an integer delimited by angled brackets, like this: ." + return ( + "Your response should be an integer delimited by angled brackets," + " like this: ." + ) bid_parser = BidOutputParser( @@ -248,8 +252,7 @@ for character_name, bidding_template in zip( topic_specifier_prompt = [ SystemMessage(content="You can make a task more specific."), - HumanMessage( - content=f"""{game_description} + HumanMessage(content=f"""{game_description} You are the debate moderator. Please make the debate topic more specific. @@ -257,8 +260,7 @@ topic_specifier_prompt = [ Be creative and imaginative. Please reply with the specified topic in {word_limit} words or less. Speak directly to the presidential candidates: {*character_names,}. - Do not add anything else.""" - ), + Do not add anything else."""), ] specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content @@ -321,7 +323,9 @@ for character_name, character_system_message, bidding_template in zip( max_iters = 10 n = 0 -simulator = DialogueSimulator(agents=characters, selection_function=select_next_speaker) +simulator = DialogueSimulator( + agents=characters, selection_function=select_next_speaker +) simulator.reset() simulator.inject("Debate Moderator", specified_topic) print(f"(Debate Moderator): {specified_topic}") diff --git a/playground/swarms/multi_agent_debate.py b/playground/swarms/multi_agent_debate.py index f0bec797..d5382e56 100644 --- a/playground/swarms/multi_agent_debate.py +++ b/playground/swarms/multi_agent_debate.py @@ -36,7 +36,11 @@ agents = [worker1, worker2, worker3] debate = MultiAgentDebate(agents, select_speaker) # Run task -task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." +task = ( + "What were the winning boston marathon times for the past 5 years (ending" + " in 2022)? Generate a table of the year, name, country of origin, and" + " times." +) results = debate.run(task, max_iters=4) # Print results diff --git a/playground/swarms/orchestrate.py b/playground/swarms/orchestrate.py index e43b75e3..a90a72e8 100644 --- a/playground/swarms/orchestrate.py +++ b/playground/swarms/orchestrate.py @@ -10,4 +10,6 @@ node = Worker( orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) # Agent 7 sends a message to Agent 9 -orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") +orchestrator.chat( + sender_id=7, receiver_id=9, message="Can you help me with this task?" +) diff --git a/playground/swarms/orchestrator.py b/playground/swarms/orchestrator.py index e43b75e3..a90a72e8 100644 --- a/playground/swarms/orchestrator.py +++ b/playground/swarms/orchestrator.py @@ -10,4 +10,6 @@ node = Worker( orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[]) # Agent 7 sends a message to Agent 9 -orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") +orchestrator.chat( + sender_id=7, receiver_id=9, message="Can you help me with this task?" +) diff --git a/playground/swarms/swarms_example.py b/playground/swarms/swarms_example.py index 6dabe4a1..23b714d9 100644 --- a/playground/swarms/swarms_example.py +++ b/playground/swarms/swarms_example.py @@ -7,7 +7,10 @@ api_key = "" swarm = HierarchicalSwarm(api_key) # Define an objective -objective = "Find 20 potential customers for a HierarchicalSwarm based AI Agent automation infrastructure" +objective = ( + "Find 20 potential customers for a HierarchicalSwarm based AI Agent" + " automation infrastructure" +) # Run HierarchicalSwarm swarm.run(objective) diff --git a/pyproject.toml b/pyproject.toml index b9b0f89a..eea95362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.3.9" +version = "2.4.0" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -23,16 +23,14 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8.1" -sentence_transformers = "*" +torch = "2.1.1" transformers = "*" -qdrant_client = "*" openai = "0.28.0" langchain = "*" asyncio = "*" nest_asyncio = "*" einops = "*" google-generativeai = "*" -torch = "*" langchain-experimental = "*" playwright = "*" duckduckgo-search = "*" @@ -79,11 +77,18 @@ mypy-protobuf = "^3.0.0" [tool.autopep8] -max_line_length = 120 +max_line_length = 80 ignore = "E501,W6" # or ["E501", "W6"] in-place = true recursive = true aggressive = 3 [tool.ruff] -line-length = 120 \ No newline at end of file +line-length = 80 + +[tool.black] +line-length = 80 +target-version = ['py38'] +preview = true + + diff --git a/requirements.txt b/requirements.txt index 2330d399..067356df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,13 @@ -# faiss-gpu +torch==2.1.1 transformers -revChatGPT pandas langchain nest_asyncio -pegasusx -google-generativeai -EdgeGPT langchain-experimental playwright wget==3.2 simpleaichat httpx -torch open_clip_torch ggl beautifulsoup4 @@ -26,9 +21,7 @@ soundfile huggingface-hub google-generativeai sentencepiece -duckduckgo-search PyPDF2 -agent-protocol accelerate chromadb tiktoken @@ -56,16 +49,13 @@ openai opencv-python prettytable safetensors -streamlit test-tube timm torchmetrics -transformers webdataset marshmallow yapf autopep8 -dalle3 cohere torchvision rich @@ -74,5 +64,4 @@ rich mkdocs mkdocs-material mkdocs-glightbox - -pre-commit +pre-commit \ No newline at end of file diff --git a/swarms/agents/omni_modal_agent.py b/swarms/agents/omni_modal_agent.py index 007a2219..6a22c477 100644 --- a/swarms/agents/omni_modal_agent.py +++ b/swarms/agents/omni_modal_agent.py @@ -18,7 +18,12 @@ from swarms.agents.message import Message class Step: def __init__( - self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool + self, + task: str, + id: int, + dep: List[int], + args: Dict[str, str], + tool: BaseTool, ): self.task = task self.id = id diff --git a/swarms/memory/base.py b/swarms/memory/base.py index 7f71c4b9..3ca49617 100644 --- a/swarms/memory/base.py +++ b/swarms/memory/base.py @@ -37,7 +37,7 @@ class BaseVectorStore(ABC): self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> None: execute_futures_dict( { @@ -54,7 +54,7 @@ class BaseVectorStore(ABC): artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: if not meta: meta = {} @@ -67,7 +67,11 @@ class BaseVectorStore(ABC): vector = artifact.generate_embedding(self.embedding_driver) return self.upsert_vector( - vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs + vector, + vector_id=artifact.id, + namespace=namespace, + meta=meta, + **kwargs, ) def upsert_text( @@ -76,14 +80,14 @@ class BaseVectorStore(ABC): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: return self.upsert_vector( self.embedding_driver.embed_string(string), vector_id=vector_id, namespace=namespace, meta=meta if meta else {}, - **kwargs + **kwargs, ) @abstractmethod @@ -93,12 +97,14 @@ class BaseVectorStore(ABC): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: ... @abstractmethod - def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: + def load_entry( + self, vector_id: str, namespace: Optional[str] = None + ) -> Entry: ... @abstractmethod @@ -112,6 +118,6 @@ class BaseVectorStore(ABC): count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, - **kwargs + **kwargs, ) -> list[QueryResult]: ... diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index 67ba4cb2..2f4e473f 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -111,7 +111,9 @@ class Chroma(VectorStore): chroma_db_impl="duckdb+parquet", ) else: - _client_settings = chromadb.config.Settings(is_persistent=True) + _client_settings = chromadb.config.Settings( + is_persistent=True + ) _client_settings.persist_directory = persist_directory else: _client_settings = chromadb.config.Settings() @@ -124,9 +126,11 @@ class Chroma(VectorStore): self._embedding_function = embedding_function self._collection = self._client.get_or_create_collection( name=collection_name, - embedding_function=self._embedding_function.embed_documents - if self._embedding_function is not None - else None, + embedding_function=( + self._embedding_function.embed_documents + if self._embedding_function is not None + else None + ), metadata=collection_metadata, ) self.override_relevance_score_fn = relevance_score_fn @@ -203,7 +207,9 @@ class Chroma(VectorStore): metadatas = [metadatas[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids] embeddings_with_metadatas = ( - [embeddings[idx] for idx in non_empty_ids] if embeddings else None + [embeddings[idx] for idx in non_empty_ids] + if embeddings + else None ) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: @@ -216,7 +222,8 @@ class Chroma(VectorStore): except ValueError as e: if "Expected metadata value to be" in str(e): msg = ( - "Try filtering complex metadata from the document using " + "Try filtering complex metadata from the document" + " using " "langchain.vectorstores.utils.filter_complex_metadata." ) raise ValueError(e.args[0] + "\n\n" + msg) @@ -258,7 +265,9 @@ class Chroma(VectorStore): Returns: List[Document]: List of documents most similar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + docs_and_scores = self.similarity_search_with_score( + query, k, filter=filter + ) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector( @@ -428,7 +437,9 @@ class Chroma(VectorStore): candidates = _results_to_docs(results) - selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + selected_results = [ + r for i, r in enumerate(candidates) if i in mmr_selected + ] return selected_results def max_marginal_relevance_search( @@ -460,7 +471,8 @@ class Chroma(VectorStore): """ if self._embedding_function is None: raise ValueError( - "For MMR search, you must specify an embedding function oncreation." + "For MMR search, you must specify an embedding function" + " oncreation." ) embedding = self._embedding_function.embed_query(query) @@ -543,7 +555,9 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], documents: List[Document]) -> None: + def update_documents( + self, ids: List[str], documents: List[Document] + ) -> None: """Update a document in the collection. Args: @@ -554,7 +568,8 @@ class Chroma(VectorStore): metadata = [document.metadata for document in documents] if self._embedding_function is None: raise ValueError( - "For update, you must specify an embedding function on creation." + "For update, you must specify an embedding function on" + " creation." ) embeddings = self._embedding_function.embed_documents(text) @@ -645,7 +660,9 @@ class Chroma(VectorStore): ids=batch[0], ) else: - chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) + chroma_collection.add_texts( + texts=texts, metadatas=metadatas, ids=ids + ) return chroma_collection @classmethod diff --git a/swarms/memory/cosine_similarity.py b/swarms/memory/cosine_similarity.py index 99d47368..cdcd1a2b 100644 --- a/swarms/memory/cosine_similarity.py +++ b/swarms/memory/cosine_similarity.py @@ -18,8 +18,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: Y = np.array(Y) if X.shape[1] != Y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." + "Number of columns in X and Y must be the same. X has shape" + f" {X.shape} and Y has shape {Y.shape}." ) try: import simsimd as simd @@ -32,8 +32,9 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: return Z except ImportError: logger.info( - "Unable to import simsimd, defaulting to NumPy implementation. If you want " - "to use simsimd please install with `pip install simsimd`." + "Unable to import simsimd, defaulting to NumPy implementation. If" + " you want to use simsimd please install with `pip install" + " simsimd`." ) X_norm = np.linalg.norm(X, axis=1) Y_norm = np.linalg.norm(Y, axis=1) diff --git a/swarms/memory/db.py b/swarms/memory/db.py index 9f23b59f..4ffec16f 100644 --- a/swarms/memory/db.py +++ b/swarms/memory/db.py @@ -151,7 +151,9 @@ class InMemoryTaskDB(TaskDB): ) -> Artifact: artifact_id = str(uuid.uuid4()) artifact = Artifact( - artifact_id=artifact_id, file_name=file_name, relative_path=relative_path + artifact_id=artifact_id, + file_name=file_name, + relative_path=relative_path, ) task = await self.get_task(task_id) task.artifacts.append(artifact) diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py index da58c81c..fb0873af 100644 --- a/swarms/memory/ocean.py +++ b/swarms/memory/ocean.py @@ -91,7 +91,9 @@ class OceanDB: try: return collection.add(documents=[document], ids=[id]) except Exception as e: - logging.error(f"Failed to append document to the collection. Error {e}") + logging.error( + f"Failed to append document to the collection. Error {e}" + ) raise def add_documents(self, collection, documents: List[str], ids: List[str]): @@ -137,7 +139,9 @@ class OceanDB: the results of the query """ try: - results = collection.query(query_texts=query_texts, n_results=n_results) + results = collection.query( + query_texts=query_texts, n_results=n_results + ) return results except Exception as e: logging.error(f"Failed to query the collection. Error {e}") diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index a421c887..ce591c6e 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -89,11 +89,15 @@ class PgVectorVectorStore(BaseVectorStore): engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True) _model: any = field( - default=Factory(lambda self: self.default_vector_model(), takes_self=True) + default=Factory( + lambda self: self.default_vector_model(), takes_self=True + ) ) @connection_string.validator - def validate_connection_string(self, _, connection_string: Optional[str]) -> None: + def validate_connection_string( + self, _, connection_string: Optional[str] + ) -> None: # If an engine is provided, the connection string is not used. if self.engine is not None: return @@ -104,7 +108,8 @@ class PgVectorVectorStore(BaseVectorStore): if not connection_string.startswith("postgresql://"): raise ValueError( - "The connection string must describe a Postgres database connection" + "The connection string must describe a Postgres database" + " connection" ) @engine.validator @@ -148,7 +153,7 @@ class PgVectorVectorStore(BaseVectorStore): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: """Inserts or updates a vector in the collection.""" with Session(self.engine) as session: @@ -208,7 +213,7 @@ class PgVectorVectorStore(BaseVectorStore): namespace: Optional[str] = None, include_vectors: bool = False, distance_metric: str = "cosine_distance", - **kwargs + **kwargs, ) -> list[BaseVectorStore.QueryResult]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace. diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 2374f12a..a7eb7442 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -108,7 +108,7 @@ class PineconeVectorStoreStore(BaseVector): vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs + **kwargs, ) -> str: """Upsert vector""" vector_id = vector_id if vector_id else str_to_hash(str(vector)) @@ -123,7 +123,9 @@ class PineconeVectorStoreStore(BaseVector): self, vector_id: str, namespace: Optional[str] = None ) -> Optional[BaseVector.Entry]: """Load entry""" - result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() + result = self.index.fetch( + ids=[vector_id], namespace=namespace + ).to_dict() vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -138,7 +140,9 @@ class PineconeVectorStoreStore(BaseVector): else: return None - def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]: + def load_entries( + self, namespace: Optional[str] = None + ) -> list[BaseVector.Entry]: """Load entries""" # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: @@ -169,7 +173,7 @@ class PineconeVectorStoreStore(BaseVector): include_vectors: bool = False, # PineconeVectorStoreStorageDriver-specific params: include_metadata=True, - **kwargs + **kwargs, ) -> list[BaseVector.QueryResult]: """Query vectors""" vector = self.embedding_driver.embed_string(query) @@ -196,6 +200,9 @@ class PineconeVectorStoreStore(BaseVector): def create_index(self, name: str, **kwargs) -> None: """Create index""" - params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs + params = { + "name": name, + "dimension": self.embedding_driver.dimensions, + } | kwargs pinecone.create_index(**params) diff --git a/swarms/memory/schemas.py b/swarms/memory/schemas.py index bbc71bc2..89f1453b 100644 --- a/swarms/memory/schemas.py +++ b/swarms/memory/schemas.py @@ -50,7 +50,9 @@ class StepInput(BaseModel): class StepOutput(BaseModel): __root__: Any = Field( ..., - description="Output that the task step has produced. Any value is allowed.", + description=( + "Output that the task step has produced. Any value is allowed." + ), example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', ) @@ -112,8 +114,9 @@ class Step(StepRequestBody): None, description="Output of the task step.", example=( - "I am going to use the write_to_file command and write Washington to a file" - " called output.txt List[Document]: """Filter out metadata types that are not supported for a vector store.""" updated_documents = [] diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index f509087c..10bf2fab 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -7,7 +7,11 @@ sys.stderr = log_file from swarms.models.anthropic import Anthropic # noqa: E402 from swarms.models.petals import Petals # noqa: E402 from swarms.models.mistral import Mistral # noqa: E402 -from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402 +from swarms.models.openai_models import ( + OpenAI, + AzureOpenAI, + OpenAIChat, +) # noqa: E402 from swarms.models.zephyr import Zephyr # noqa: E402 from swarms.models.biogpt import BioGPT # noqa: E402 from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index edaae087..1f47e1bf 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -50,7 +50,9 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: ] invalid_groups = [i for i, count in enumerate(counts) if count != 1] if invalid_groups: - invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] + invalid_group_names = [ + ", ".join(arg_groups[i]) for i in invalid_groups + ] raise ValueError( "Exactly one argument in each of the following" " groups must be defined:" @@ -106,7 +108,10 @@ def mock_now(dt_value): # type: ignore def guard_import( - module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None + module_name: str, + *, + pip_name: Optional[str] = None, + package: Optional[str] = None, ) -> Any: """Dynamically imports a module and raises a helpful exception if the module is not installed.""" @@ -180,18 +185,18 @@ def build_extra_kwargs( if field_name in extra_kwargs: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - warnings.warn( - f"""WARNING! {field_name} is not default parameter. + warnings.warn(f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) + Please confirm that {field_name} is what you intended.""") extra_kwargs[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra_kwargs.keys() + ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be specified explicitly." + " Instead they were passed in as part of `model_kwargs` parameter." ) return extra_kwargs @@ -250,7 +255,9 @@ class _AnthropicCommon(BaseLanguageModel): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["anthropic_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") + get_from_dict_or_env( + values, "anthropic_api_key", "ANTHROPIC_API_KEY" + ) ) # Get custom api url from environment. values["anthropic_api_url"] = get_from_dict_or_env( @@ -305,7 +312,9 @@ class _AnthropicCommon(BaseLanguageModel): """Get the identifying parameters.""" return {**{}, **self._default_params} - def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + def _get_anthropic_stop( + self, stop: Optional[List[str]] = None + ) -> List[str]: if not self.HUMAN_PROMPT or not self.AI_PROMPT: raise NameError("Please ensure the anthropic package is loaded") @@ -354,8 +363,8 @@ class Anthropic(LLM, _AnthropicCommon): def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" warnings.warn( - "This Anthropic LLM is deprecated. " - "Please use `from langchain.chat_models import ChatAnthropic` instead" + "This Anthropic LLM is deprecated. Please use `from" + " langchain.chat_models import ChatAnthropic` instead" ) return values @@ -372,12 +381,16 @@ class Anthropic(LLM, _AnthropicCommon): return prompt # Already wrapped. # Guard against common errors in specifying wrong number of newlines. - corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) + corrected_prompt, n_subs = re.subn( + r"^\n*Human:", self.HUMAN_PROMPT, prompt + ) if n_subs == 1: return corrected_prompt # As a last resort, wrap the prompt ourselves to emulate instruct-style. - return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + return ( + f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + ) def _call( self, @@ -476,7 +489,10 @@ class Anthropic(LLM, _AnthropicCommon): params = {**self._default_params, **kwargs} for token in self.client.completions.create( - prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, ): chunk = GenerationChunk(text=token.completion) yield chunk diff --git a/swarms/models/autotemp.py b/swarms/models/autotemp.py deleted file mode 100644 index c3abb894..00000000 --- a/swarms/models/autotemp.py +++ /dev/null @@ -1,101 +0,0 @@ -import re -from concurrent.futures import ThreadPoolExecutor, as_completed -from swarms.models.openai_models import OpenAIChat - - -class AutoTempAgent: - """ - AutoTemp is a tool for automatically selecting the best temperature setting for a given task. - - Flow: - 1. Generate outputs at a range of temperature settings. - 2. Evaluate each output using the default temperature setting. - 3. Select the best output based on the evaluation score. - 4. Return the best output. - - - Args: - temperature (float, optional): The default temperature setting to use. Defaults to 0.5. - api_key (str, optional): Your OpenAI API key. Defaults to None. - alt_temps ([type], optional): A list of alternative temperature settings to try. Defaults to None. - auto_select (bool, optional): If True, the best temperature setting will be automatically selected. Defaults to True. - max_workers (int, optional): The maximum number of workers to use when generating outputs. Defaults to 6. - - Returns: - [type]: [description] - - Examples: - >>> from swarms.demos.autotemp import AutoTemp - >>> autotemp = AutoTemp() - >>> autotemp.run("Generate a 10,000 word blog on mental clarity and the benefits of meditation.", "0.4,0.6,0.8,1.0,1.2,1.4") - Best AutoTemp Output (Temp 0.4 | Score: 100.0): - Generate a 10,000 word blog on mental clarity and the benefits of meditation. - - """ - - def __init__( - self, - temperature: float = 0.5, - api_key: str = None, - alt_temps=None, - auto_select=True, - max_workers=6, - ): - self.alt_temps = alt_temps if alt_temps else [0.4, 0.6, 0.8, 1.0, 1.2, 1.4] - self.auto_select = auto_select - self.max_workers = max_workers - self.temperature = temperature - self.alt_temps = alt_temps - self.llm = OpenAIChat( - openai_api_key=api_key, - temperature=temperature, - ) - - def evaluate_output(self, output: str): - """Evaluate the output using the default temperature setting.""" - eval_prompt = f""" - Evaluate the following output which was generated at a temperature setting of {self.temperature}. - Provide a precise score from 0.0 to 100.0, considering the criteria of relevance, clarity, utility, pride, and delight. - - Output to evaluate: - --- - {output} - --- - """ - score_text = self.llm(prompt=eval_prompt) - score_match = re.search(r"\b\d+(\.\d)?\b", score_text) - return round(float(score_match.group()), 1) if score_match else 0.0 - - def run(self, task: str, temperature_string): - """Run the AutoTemp agent.""" - temperature_list = [ - float(temp.strip()) for temp in temperature_string.split(",") - ] - outputs = {} - scores = {} - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_temp = { - executor.submit(self.llm.generate, task, temp): temp - for temp in temperature_list - } - for future in as_completed(future_to_temp): - temp = future_to_temp[future] - output_text = future.result() - outputs[temp] = output_text - scores[temp] = self.evaluate_output(output_text, temp) - - if not scores: - return "No valid outputs generated.", None - - sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) - best_temp, best_score = sorted_scores[0] - best_output = outputs[best_temp] - - return ( - f"Best AutoTemp Output (Temp {best_temp} | Score: {best_score}):\n{best_output}" - if self.auto_select - else "\n".join( - f"Temp {temp} | Score: {score}:\n{outputs[temp]}" - for temp, score in sorted_scores - ) - ) diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py index c2b4bfa5..1c2627a6 100644 --- a/swarms/models/bioclip.py +++ b/swarms/models/bioclip.py @@ -98,7 +98,9 @@ class BioClip: ) = open_clip.create_model_and_transforms(model_path) self.tokenizer = open_clip.get_tokenizer(model_path) self.device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") ) self.model.to(self.device) self.model.eval() @@ -110,13 +112,17 @@ class BioClip: template: str = "this is a photo of ", context_length: int = 256, ): - image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device) + image = torch.stack([self.preprocess_val(Image.open(img_path))]).to( + self.device + ) texts = self.tokenizer( [template + l for l in labels], context_length=context_length ).to(self.device) with torch.no_grad(): - image_features, text_features, logit_scale = self.model(image, texts) + image_features, text_features, logit_scale = self.model( + image, texts + ) logits = ( (logit_scale * image_features @ text_features.t()) .detach() @@ -142,7 +148,9 @@ class BioClip: title = ( metadata["filename"] + "\n" - + "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()]) + + "\n".join( + [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()] + ) ) ax.set_title(title, fontsize=14) plt.tight_layout() diff --git a/swarms/models/biogpt.py b/swarms/models/biogpt.py index 83c31e55..d5e692f2 100644 --- a/swarms/models/biogpt.py +++ b/swarms/models/biogpt.py @@ -154,7 +154,7 @@ class BioGPT: min_length=self.min_length, max_length=self.max_length, num_beams=num_beams, - early_stopping=early_stopping + early_stopping=early_stopping, ) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py index c583b827..508e9073 100644 --- a/swarms/models/cohere_chat.py +++ b/swarms/models/cohere_chat.py @@ -96,7 +96,9 @@ class BaseCohere(Serializable): values, "cohere_api_key", "COHERE_API_KEY" ) client_name = values["user_agent"] - values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + values["client"] = cohere.Client( + cohere_api_key, client_name=client_name + ) values["async_client"] = cohere.AsyncClient( cohere_api_key, client_name=client_name ) @@ -172,17 +174,23 @@ class Cohere(LLM, BaseCohere): """Return type of llm.""" return "cohere" - def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: + def _invocation_params( + self, stop: Optional[List[str]], **kwargs: Any + ) -> dict: params = self._default_params if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) elif self.stop is not None: params["stop_sequences"] = self.stop else: params["stop_sequences"] = stop return {**params, **kwargs} - def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: + def _process_response( + self, response: Any, stop: Optional[List[str]] + ) -> str: text = response.generations[0].text # If stop tokens are provided, Cohere's endpoint returns them. # In order to make this consistent with other endpoints, we strip them. diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 7d9bcf5d..3c130670 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -169,8 +169,8 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -234,8 +234,8 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -248,8 +248,7 @@ class Dalle3: """Print the Dalle3 dashboard""" print( colored( - ( - f"""Dalle3 Dashboard: + f"""Dalle3 Dashboard: -------------------- Model: {self.model} @@ -265,13 +264,14 @@ class Dalle3: -------------------- - """ - ), + """, "green", ) ) - def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): + def process_batch_concurrently( + self, tasks: List[str], max_workers: int = 5 + ): """ Process a batch of tasks concurrently @@ -293,8 +293,12 @@ class Dalle3: ['https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png', """ - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_task = {executor.submit(self, task): task for task in tasks} + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = { + executor.submit(self, task): task for task in tasks + } results = [] for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] @@ -307,14 +311,20 @@ class Dalle3: print( colored( ( - f"Error running Dalle3: {error} try optimizing your api key and" - " or try again" + f"Error running Dalle3: {error} try optimizing" + " your api key and or try again" ), "red", ) ) - print(colored(f"Error running Dalle3: {error.http_status}", "red")) - print(colored(f"Error running Dalle3: {error.error}", "red")) + print( + colored( + f"Error running Dalle3: {error.http_status}", "red" + ) + ) + print( + colored(f"Error running Dalle3: {error.error}", "red") + ) raise error def _generate_uuid(self): diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 98b3660a..2b4fb5a5 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -28,7 +28,10 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1): retries -= 1 if retries <= 0: raise - print(f"Retry after exception: {e}, Attempts remaining: {retries}") + print( + f"Retry after exception: {e}, Attempts remaining:" + f" {retries}" + ) await asyncio.sleep(delay) return wrapper @@ -62,7 +65,9 @@ class DistilWhisperModel: def __init__(self, model_id="distil-whisper/distil-large-v2"): self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + self.torch_dtype = ( + torch.float16 if torch.cuda.is_available() else torch.float32 + ) self.model_id = model_id self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, @@ -119,7 +124,9 @@ class DistilWhisperModel: try: with torch.no_grad(): # Load the whole audio file, but process and transcribe it in chunks - audio_input = self.processor.audio_file_to_array(audio_file_path) + audio_input = self.processor.audio_file_to_array( + audio_file_path + ) sample_rate = audio_input.sampling_rate len(audio_input.array) / sample_rate chunks = [ @@ -139,7 +146,9 @@ class DistilWhisperModel: return_tensors="pt", padding=True, ) - processed_inputs = processed_inputs.input_values.to(self.device) + processed_inputs = processed_inputs.input_values.to( + self.device + ) # Generate transcription for the chunk logits = self.model.generate(processed_inputs) @@ -157,4 +166,6 @@ class DistilWhisperModel: time.sleep(chunk_duration) except Exception as e: - print(colored(f"An error occurred during transcription: {e}", "red")) + print( + colored(f"An error occurred during transcription: {e}", "red") + ) diff --git a/swarms/models/eleven_labs.py b/swarms/models/eleven_labs.py index 2fece5b6..42f4dae1 100644 --- a/swarms/models/eleven_labs.py +++ b/swarms/models/eleven_labs.py @@ -79,7 +79,9 @@ class ElevenLabsText2SpeechTool(BaseTool): f.write(speech) return f.name except Exception as e: - raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}") + raise RuntimeError( + f"Error while running ElevenLabsText2SpeechTool: {e}" + ) def play(self, speech_file: str) -> None: """Play the text as speech.""" @@ -93,7 +95,9 @@ class ElevenLabsText2SpeechTool(BaseTool): """Stream the text as speech as it is generated. Play the text in your speakers.""" elevenlabs = _import_elevenlabs() - speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True) + speech_stream = elevenlabs.generate( + text=query, model=self.model, stream=True + ) elevenlabs.stream(speech_stream) def save(self, speech_file: str, path: str) -> None: diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index d0478777..c9a0d719 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -10,7 +10,9 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the classes for image classification -with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f: +with open( + os.path.join(os.path.dirname(__file__), "fast_vit_classes.json") +) as f: FASTVIT_IMAGENET_1K_CLASSES = json.load(f) @@ -20,7 +22,9 @@ class ClassificationResult(BaseModel): @validator("class_id", "confidence", pre=True, each_item=True) def check_list_contents(cls, v): - assert isinstance(v, int) or isinstance(v, float), "must be integer or float" + assert isinstance(v, int) or isinstance( + v, float + ), "must be integer or float" return v @@ -50,7 +54,9 @@ class FastViT: "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True ).to(DEVICE) data_config = timm.data.resolve_model_data_config(self.model) - self.transforms = timm.data.create_transform(**data_config, is_training=False) + self.transforms = timm.data.create_transform( + **data_config, is_training=False + ) self.model.eval() def __call__( diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index 02ab3a25..ed955260 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -46,7 +46,9 @@ class Fuyu: self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.image_processor = FuyuImageProcessor() self.processor = FuyuProcessor( - image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs + image_processor=self.image_processor, + tokenizer=self.tokenizer, + **kwargs, ) self.model = FuyuForCausalLM.from_pretrained( pretrained_path, @@ -69,8 +71,12 @@ class Fuyu: for k, v in model_inputs.items(): model_inputs[k] = v.to(self.device_map) - output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens) - text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True) + output = self.model.generate( + **model_inputs, max_new_tokens=self.max_new_tokens + ) + text = self.processor.batch_decode( + output[:, -7:], skip_special_tokens=True + ) return print(str(text)) def get_img_from_web(self, img_url: str): diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py index 8411cb14..8f2683e0 100644 --- a/swarms/models/gpt4v.py +++ b/swarms/models/gpt4v.py @@ -190,12 +190,15 @@ class GPT4Vision: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run, task, img) for task, img in tasks_images + executor.submit(self.run, task, img) + for task, img in tasks_images ] results = [future.result() for future in futures] return results - async def run_batch_async(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + async def run_batch_async( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images asynchronously""" loop = asyncio.get_event_loop() futures = [ diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 0f226740..1db435f5 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -133,7 +133,9 @@ class HuggingfaceLLM: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -178,7 +180,11 @@ class HuggingfaceLLM: except Exception as e: # self.logger.error(f"Failed to load the model or the tokenizer: {e}") # raise - print(colored(f"Failed to load the model and or the tokenizer: {e}", "red")) + print( + colored( + f"Failed to load the model and or the tokenizer: {e}", "red" + ) + ) def print_error(self, error: str): """Print error""" @@ -207,12 +213,16 @@ class HuggingfaceLLM: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def concurrent_run(self, tasks: List[str], max_workers: int = 5): """Concurrently generate text for a list of prompts.""" - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: results = list(executor.map(self.run, tasks)) return results @@ -220,7 +230,8 @@ class HuggingfaceLLM: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run, task, img) for task, img in tasks_images + executor.submit(self.run, task, img) + for task, img in tasks_images ] results = [future.result() for future in futures] return results @@ -243,7 +254,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to( + self.device + ) # self.log.start() @@ -279,8 +292,8 @@ class HuggingfaceLLM: print( colored( ( - f"HuggingfaceLLM could not generate text because of error: {e}," - " try optimizing your arguments" + "HuggingfaceLLM could not generate text because of" + f" error: {e}, try optimizing your arguments" ), "red", ) @@ -305,7 +318,9 @@ class HuggingfaceLLM: self.print_dashboard(task) try: - inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) + inputs = self.tokenizer.encode(task, return_tensors="pt").to( + self.device + ) # self.log.start() diff --git a/swarms/models/idefics.py b/swarms/models/idefics.py index 73cb4991..0cfcf1af 100644 --- a/swarms/models/idefics.py +++ b/swarms/models/idefics.py @@ -66,7 +66,9 @@ class Idefics: max_length=100, ): self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model = IdeficsForVisionText2Text.from_pretrained( checkpoint, diff --git a/swarms/models/jina_embeds.py b/swarms/models/jina_embeds.py index a72b8a9e..1d1ac3e6 100644 --- a/swarms/models/jina_embeds.py +++ b/swarms/models/jina_embeds.py @@ -54,7 +54,9 @@ class JinaEmbeddings: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -83,7 +85,9 @@ class JinaEmbeddings: try: self.model = AutoModelForCausalLM.from_pretrained( - self.model_id, quantization_config=bnb_config, trust_remote_code=True + self.model_id, + quantization_config=bnb_config, + trust_remote_code=True, ) self.model # .to(self.device) @@ -112,7 +116,9 @@ class JinaEmbeddings: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, task: str): diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py index b0e1a9f6..f81e0fdf 100644 --- a/swarms/models/kosmos2.py +++ b/swarms/models/kosmos2.py @@ -70,11 +70,13 @@ class Kosmos2(BaseModel): prompt = "An image of" inputs = self.processor(text=prompt, images=image, return_tensors="pt") - outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64) + outputs = self.model.generate( + **inputs, use_cache=True, max_new_tokens=64 + ) - generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[ - 0 - ] + generated_text = self.processor.batch_decode( + outputs, skip_special_tokens=True + )[0] # The actual processing of generated_text to entities would go here # For the purpose of this example, assume a mock function 'extract_entities' exists: @@ -99,7 +101,9 @@ class Kosmos2(BaseModel): if not entities: return Detections.empty() - class_ids = [0] * len(entities) # Replace with actual class ID extraction logic + class_ids = [0] * len( + entities + ) # Replace with actual class ID extraction logic xyxys = [ ( e[1][0] * image.width, @@ -111,7 +115,9 @@ class Kosmos2(BaseModel): ] confidences = [1.0] * len(entities) # Placeholder confidence - return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences) + return Detections( + xyxy=xyxys, class_id=class_ids, confidence=confidences + ) # Usage: diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index 596886f3..c696ef34 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -145,12 +145,12 @@ class Kosmos: elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() - reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ - :, None, None - ] - reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ - :, None, None - ] + reverse_norm_mean = torch.tensor( + [0.48145466, 0.4578275, 0.40821073] + )[:, None, None] + reverse_norm_std = torch.tensor( + [0.26862954, 0.26130258, 0.27577711] + )[:, None, None] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height @@ -188,7 +188,11 @@ class Kosmos: # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) new_image = cv2.rectangle( - new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line + new_image, + (orig_x1, orig_y1), + (orig_x2, orig_y2), + color, + box_line, ) l_o, r_o = ( @@ -211,7 +215,10 @@ class Kosmos: # add text background (text_width, text_height), _ = cv2.getTextSize( - f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line + f" {entity_name}", + cv2.FONT_HERSHEY_COMPLEX, + text_size, + text_line, ) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( x1, @@ -222,7 +229,8 @@ class Kosmos: for prev_bbox in previous_bboxes: while is_overlapping( - (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), + prev_bbox, ): text_bg_y1 += ( text_height + text_offset_original + 2 * text_spaces @@ -230,14 +238,18 @@ class Kosmos: text_bg_y2 += ( text_height + text_offset_original + 2 * text_spaces ) - y1 += text_height + text_offset_original + 2 * text_spaces + y1 += ( + text_height + text_offset_original + 2 * text_spaces + ) if text_bg_y2 >= image_h: text_bg_y1 = max( 0, image_h - ( - text_height + text_offset_original + 2 * text_spaces + text_height + + text_offset_original + + 2 * text_spaces ), ) text_bg_y2 = image_h @@ -270,7 +282,9 @@ class Kosmos: cv2.LINE_AA, ) # previous_locations.append((x1, y1)) - previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) + previous_bboxes.append( + (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2) + ) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) if save_path: diff --git a/swarms/models/llama_function_caller.py b/swarms/models/llama_function_caller.py index a991641a..ca5ee5d3 100644 --- a/swarms/models/llama_function_caller.py +++ b/swarms/models/llama_function_caller.py @@ -121,7 +121,11 @@ class LlamaFunctionCaller: ) def add_func( - self, name: str, function: Callable, description: str, arguments: List[Dict] + self, + name: str, + function: Callable, + description: str, + arguments: List[Dict], ): """ Adds a new function to the LlamaFunctionCaller. @@ -172,12 +176,17 @@ class LlamaFunctionCaller: if self.streaming: out = self.model.generate( - **inputs, streamer=streamer, max_new_tokens=self.max_tokens, **kwargs + **inputs, + streamer=streamer, + max_new_tokens=self.max_tokens, + **kwargs, ) return out else: - out = self.model.generate(**inputs, max_length=self.max_tokens, **kwargs) + out = self.model.generate( + **inputs, max_length=self.max_tokens, **kwargs + ) # return self.tokenizer.decode(out[0], skip_special_tokens=True) return out diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 7f48a0d6..056a31bb 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -49,7 +49,9 @@ class Mistral: # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": - raise ValueError("CUDA is not available. Please choose a different device.") + raise ValueError( + "CUDA is not available. Please choose a different device." + ) # Load the model and tokenizer self.model = None @@ -70,7 +72,9 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to( + self.device + ) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, @@ -87,7 +91,9 @@ class Mistral: """Run the model on a given task.""" try: - model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + model_inputs = self.tokenizer([task], return_tensors="pt").to( + self.device + ) generated_ids = self.model.generate( **model_inputs, max_length=self.max_length, diff --git a/swarms/models/mpt.py b/swarms/models/mpt.py index 46d1a357..c304355a 100644 --- a/swarms/models/mpt.py +++ b/swarms/models/mpt.py @@ -29,7 +29,9 @@ class MPT7B: """ - def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100): + def __init__( + self, model_name: str, tokenizer_name: str, max_tokens: int = 100 + ): # Loading model and tokenizer details self.model_name = model_name self.tokenizer_name = tokenizer_name @@ -118,7 +120,10 @@ class MPT7B: """ with torch.autocast("cuda", dtype=torch.bfloat16): return self.pipe( - prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True + prompt, + max_new_tokens=self.max_tokens, + do_sample=True, + use_cache=True, )[0]["generated_text"] async def generate_async(self, prompt: str) -> str: diff --git a/swarms/models/nougat.py b/swarms/models/nougat.py index f156981c..82bb95f5 100644 --- a/swarms/models/nougat.py +++ b/swarms/models/nougat.py @@ -41,8 +41,12 @@ class Nougat: self.min_length = min_length self.max_new_tokens = max_new_tokens - self.processor = NougatProcessor.from_pretrained(self.model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path) + self.processor = NougatProcessor.from_pretrained( + self.model_name_or_path + ) + self.model = VisionEncoderDecoderModel.from_pretrained( + self.model_name_or_path + ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) @@ -63,8 +67,12 @@ class Nougat: max_new_tokens=self.max_new_tokens, ) - sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] - sequence = self.processor.post_process_generation(sequence, fix_markdown=False) + sequence = self.processor.batch_decode( + outputs, skip_special_tokens=True + )[0] + sequence = self.processor.post_process_generation( + sequence, fix_markdown=False + ) out = print(sequence) return out diff --git a/swarms/models/openai_embeddings.py b/swarms/models/openai_embeddings.py index 81dea550..08919d45 100644 --- a/swarms/models/openai_embeddings.py +++ b/swarms/models/openai_embeddings.py @@ -43,7 +43,9 @@ def get_pydantic_field_names(cls: Any) -> Set[str]: logger = logging.getLogger(__name__) -def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: +def _create_retry_decorator( + embeddings: OpenAIEmbeddings, +) -> Callable[[Any], Any]: import llm min_seconds = 4 @@ -118,7 +120,9 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) -async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: +async def async_embed_with_retry( + embeddings: OpenAIEmbeddings, **kwargs: Any +) -> Any: """Use tenacity to retry the embedding call.""" @_async_retry_decorator(embeddings) @@ -172,7 +176,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model: str = "text-embedding-ada-002" - deployment: str = model # to support Azure OpenAI Service custom deployment names + deployment: str = ( + model # to support Azure OpenAI Service custom deployment names + ) openai_api_version: Optional[str] = None # to support Azure OpenAI Service custom endpoints openai_api_base: Optional[str] = None @@ -229,11 +235,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) extra[field_name] = values.pop(field_name) - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection( + extra.keys() + ) if invalid_model_kwargs: raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - "Instead they were passed in as part of `model_kwargs` parameter." + f"Parameters {invalid_model_kwargs} should be specified" + " explicitly. Instead they were passed in as part of" + " `model_kwargs` parameter." ) values["model_kwargs"] = extra @@ -333,7 +342,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -384,11 +395,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): self, input="", **self._invocation_params, - )[ - "data" - ][0]["embedding"] + )["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average( + _result, axis=0, weights=num_tokens_in_batch[i] + ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -414,7 +425,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): try: encoding = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, text in enumerate(texts): @@ -458,7 +471,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) )["data"][0]["embedding"] else: - average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + average = np.average( + _result, axis=0, weights=num_tokens_in_batch[i] + ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings @@ -495,7 +510,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ # NOTE: to keep things simple, we assume the list may contain texts longer # than the maximum context and use length-safe embedding function. - return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + return await self._aget_len_safe_embeddings( + texts, engine=self.deployment + ) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index bac0f28d..f0c41f2a 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -146,7 +146,8 @@ class OpenAIFunctionCaller: self.messages.append({"role": role, "content": content}) @retry( - wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3) + wait=wait_random_exponential(multiplier=1, max=40), + stop=stop_after_attempt(3), ) def chat_completion_request( self, @@ -194,17 +195,22 @@ class OpenAIFunctionCaller: elif message["role"] == "user": print( colored( - f"user: {message['content']}\n", role_to_color[message["role"]] + f"user: {message['content']}\n", + role_to_color[message["role"]], ) ) - elif message["role"] == "assistant" and message.get("function_call"): + elif message["role"] == "assistant" and message.get( + "function_call" + ): print( colored( f"assistant: {message['function_call']}\n", role_to_color[message["role"]], ) ) - elif message["role"] == "assistant" and not message.get("function_call"): + elif message["role"] == "assistant" and not message.get( + "function_call" + ): print( colored( f"assistant: {message['content']}\n", diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index fcf4a223..0547a264 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -62,19 +62,25 @@ def _stream_response_to_generation_chunk( return GenerationChunk( text=stream_response["choices"][0]["text"], generation_info=dict( - finish_reason=stream_response["choices"][0].get("finish_reason", None), + finish_reason=stream_response["choices"][0].get( + "finish_reason", None + ), logprobs=stream_response["choices"][0].get("logprobs", None), ), ) -def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: +def _update_response( + response: Dict[str, Any], stream_response: Dict[str, Any] +) -> None: """Update response from the stream response.""" response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( "finish_reason", None ) - response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] + response["choices"][0]["logprobs"] = stream_response["choices"][0][ + "logprobs" + ] def _streaming_response_template() -> Dict[str, Any]: @@ -315,9 +321,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) async def _astream( @@ -339,9 +347,11 @@ class BaseOpenAI(BaseLLM): chunk.text, chunk=chunk, verbose=self.verbose, - logprobs=chunk.generation_info["logprobs"] - if chunk.generation_info - else None, + logprobs=( + chunk.generation_info["logprobs"] + if chunk.generation_info + else None + ), ) def _generate( @@ -377,10 +387,14 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts." + ) generation: Optional[GenerationChunk] = None - for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): + for chunk in self._stream( + _prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: @@ -389,12 +403,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: @@ -424,7 +442,9 @@ class BaseOpenAI(BaseLLM): for _prompts in sub_prompts: if self.streaming: if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") + raise ValueError( + "Cannot stream results with multiple prompts." + ) generation: Optional[GenerationChunk] = None async for chunk in self._astream( @@ -438,12 +458,16 @@ class BaseOpenAI(BaseLLM): choices.append( { "text": generation.text, - "finish_reason": generation.generation_info.get("finish_reason") - if generation.generation_info - else None, - "logprobs": generation.generation_info.get("logprobs") - if generation.generation_info - else None, + "finish_reason": ( + generation.generation_info.get("finish_reason") + if generation.generation_info + else None + ), + "logprobs": ( + generation.generation_info.get("logprobs") + if generation.generation_info + else None + ), } ) else: @@ -463,7 +487,9 @@ class BaseOpenAI(BaseLLM): """Get the sub prompts for llm call.""" if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) params["stop"] = stop if params["max_tokens"] == -1: if len(prompts) != 1: @@ -541,7 +567,9 @@ class BaseOpenAI(BaseLLM): try: enc = tiktoken.encoding_for_model(model_name) except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") + logger.warning( + "Warning: model not found. Using cl100k_base encoding." + ) model = "cl100k_base" enc = tiktoken.get_encoding(model) @@ -602,8 +630,9 @@ class BaseOpenAI(BaseLLM): if context_size is None: raise ValueError( - f"Unknown model: {modelname}. Please provide a valid OpenAI model name." - "Known models are: " + ", ".join(model_token_mapping.keys()) + f"Unknown model: {modelname}. Please provide a valid OpenAI" + " model name.Known models are: " + + ", ".join(model_token_mapping.keys()) ) return context_size @@ -753,7 +782,9 @@ class OpenAIChat(BaseLLM): @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} + all_required_field_names = { + field.alias for field in cls.__fields__.values() + } extra = values.get("model_kwargs", {}) for field_name in list(values): @@ -820,13 +851,21 @@ class OpenAIChat(BaseLLM): ) -> Tuple: if len(prompts) > 1: raise ValueError( - f"OpenAIChat currently only supports single prompt, got {prompts}" + "OpenAIChat currently only supports single prompt, got" + f" {prompts}" ) - messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] - params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} + messages = self.prefix_messages + [ + {"role": "user", "content": prompts[0]} + ] + params: Dict[str, Any] = { + **{"model": self.model_name}, + **self._default_params, + } if stop is not None: if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") + raise ValueError( + "`stop` found in both the input and default params." + ) params["stop"] = stop if params.get("max_tokens") == -1: # for ChatGPT api, omitting max_tokens is equivalent to having no limit @@ -897,7 +936,11 @@ class OpenAIChat(BaseLLM): } return LLMResult( generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] + [ + Generation( + text=full_response["choices"][0]["message"]["content"] + ) + ] ], llm_output=llm_output, ) @@ -911,7 +954,9 @@ class OpenAIChat(BaseLLM): ) -> LLMResult: if self.streaming: generation: Optional[GenerationChunk] = None - async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + async for chunk in self._astream( + prompts[0], stop, run_manager, **kwargs + ): if generation is None: generation = chunk else: @@ -930,7 +975,11 @@ class OpenAIChat(BaseLLM): } return LLMResult( generations=[ - [Generation(text=full_response["choices"][0]["message"]["content"])] + [ + Generation( + text=full_response["choices"][0]["message"]["content"] + ) + ] ], llm_output=llm_output, ) diff --git a/swarms/models/palm.py b/swarms/models/palm.py index ec8aafd6..8c9277d7 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -37,10 +37,16 @@ def _create_retry_decorator() -> Callable[[Any], Any]: return retry( reraise=True, stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + wait=wait_exponential( + multiplier=multiplier, min=min_seconds, max=max_seconds + ), retry=( - retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) - | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + retry_if_exception_type( + google.api_core.exceptions.ResourceExhausted + ) + | retry_if_exception_type( + google.api_core.exceptions.ServiceUnavailable + ) | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) ), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -64,7 +70,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str: The PaLM API will sometimes erroneously return a single leading space in all lines > 1. This function strips that space. """ - has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) + has_leading_space = all( + not line or line[0] == " " for line in text.split("\n")[1:] + ) if has_leading_space: return text.replace("\n ", "\n") else: @@ -112,7 +120,10 @@ class GooglePalm(BaseLLM, BaseModel): values["client"] = genai - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + if ( + values["temperature"] is not None + and not 0 <= values["temperature"] <= 1 + ): raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: @@ -121,7 +132,10 @@ class GooglePalm(BaseLLM, BaseModel): if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") - if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: + if ( + values["max_output_tokens"] is not None + and values["max_output_tokens"] <= 0 + ): raise ValueError("max_output_tokens must be greater than zero") return values diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 3662dda2..a4e99fe4 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -16,4 +16,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): text = text.replace("\n", " ") - return client.embeddings.create(input=[text], model=model)["data"][0]["embedding"] + return client.embeddings.create(input=[text], model=model)["data"][0][ + "embedding" + ] diff --git a/swarms/models/speecht5.py b/swarms/models/speecht5.py index e98036ac..143a7514 100644 --- a/swarms/models/speecht5.py +++ b/swarms/models/speecht5.py @@ -90,7 +90,9 @@ class SpeechT5: self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset( + self.dataset_name, split="validation" + ) def __call__(self, text: str, speaker_id: float = 7306): """Call the model on some text and return the speech.""" @@ -121,7 +123,9 @@ class SpeechT5: def set_embeddings_dataset(self, dataset_name): """Set the embeddings dataset to a new dataset.""" self.dataset_name = dataset_name - self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") + self.embeddings_dataset = load_dataset( + self.dataset_name, split="validation" + ) # Feature 1: Get sampling rate def get_sampling_rate(self): diff --git a/swarms/models/ssd_1b.py b/swarms/models/ssd_1b.py index caeba3fc..406678ef 100644 --- a/swarms/models/ssd_1b.py +++ b/swarms/models/ssd_1b.py @@ -141,8 +141,8 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing your api key and" - " or try again" + f"Error running SSD1B: {error} try optimizing your api" + " key and or try again" ), "red", ) @@ -167,8 +167,7 @@ class SSD1B: """Print the SSD1B dashboard""" print( colored( - ( - f"""SSD1B Dashboard: + f"""SSD1B Dashboard: -------------------- Model: {self.model} @@ -184,13 +183,14 @@ class SSD1B: -------------------- - """ - ), + """, "green", ) ) - def process_batch_concurrently(self, tasks: List[str], max_workers: int = 5): + def process_batch_concurrently( + self, tasks: List[str], max_workers: int = 5 + ): """ Process a batch of tasks concurrently @@ -211,8 +211,12 @@ class SSD1B: >>> print(results) """ - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_task = {executor.submit(self, task): task for task in tasks} + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + future_to_task = { + executor.submit(self, task): task for task in tasks + } results = [] for future in concurrent.futures.as_completed(future_to_task): task = future_to_task[future] @@ -225,13 +229,17 @@ class SSD1B: print( colored( ( - f"Error running SSD1B: {error} try optimizing your api key and" - " or try again" + f"Error running SSD1B: {error} try optimizing" + " your api key and or try again" ), "red", ) ) - print(colored(f"Error running SSD1B: {error.http_status}", "red")) + print( + colored( + f"Error running SSD1B: {error.http_status}", "red" + ) + ) print(colored(f"Error running SSD1B: {error.error}", "red")) raise error diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py index ac592b35..338971da 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx.py @@ -66,7 +66,9 @@ class WhisperX: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model("large-v2", device, compute_type=compute_type) + model = whisperx.load_model( + "large-v2", device, compute_type=compute_type + ) audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) diff --git a/swarms/models/wizard_storytelling.py b/swarms/models/wizard_storytelling.py index 49ffb70d..a34f6ec7 100644 --- a/swarms/models/wizard_storytelling.py +++ b/swarms/models/wizard_storytelling.py @@ -45,7 +45,9 @@ class WizardLLMStoryTeller: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -101,7 +103,9 @@ class WizardLLMStoryTeller: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, prompt_text: str): diff --git a/swarms/models/yarn_mistral.py b/swarms/models/yarn_mistral.py index ebe107a2..065e3140 100644 --- a/swarms/models/yarn_mistral.py +++ b/swarms/models/yarn_mistral.py @@ -45,7 +45,9 @@ class YarnMistral128: ): self.logger = logging.getLogger(__name__) self.device = ( - device if device else ("cuda" if torch.cuda.is_available() else "cpu") + device + if device + else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_id = model_id self.max_length = max_length @@ -106,7 +108,9 @@ class YarnMistral128: if self.distributed: self.model = DDP(self.model) except Exception as error: - self.logger.error(f"Failed to load the model or the tokenizer: {error}") + self.logger.error( + f"Failed to load the model or the tokenizer: {error}" + ) raise def run(self, prompt_text: str): diff --git a/swarms/prompts/agent_prompt.py b/swarms/prompts/agent_prompt.py index c4897193..b36aea19 100644 --- a/swarms/prompts/agent_prompt.py +++ b/swarms/prompts/agent_prompt.py @@ -15,7 +15,9 @@ class PromptGenerator: "thoughts": { "text": "thought", "reasoning": "reasoning", - "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "plan": ( + "- short bulleted\n- list that conveys\n- long-term plan" + ), "criticism": "constructive self-criticism", "speak": "thoughts summary to say to user", }, @@ -66,13 +68,11 @@ class PromptGenerator: """ formatted_response_format = json.dumps(self.response_format, indent=4) prompt_string = ( - f"Constraints:\n{''.join(self.constraints)}\n\n" - f"Commands:\n{''.join(self.commands)}\n\n" - f"Resources:\n{''.join(self.resources)}\n\n" - f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" - "You should only respond in JSON format as described below " - f"\nResponse Format: \n{formatted_response_format} " - "\nEnsure the response can be parsed by Python json.loads" + f"Constraints:\n{''.join(self.constraints)}\n\nCommands:\n{''.join(self.commands)}\n\nResources:\n{''.join(self.resources)}\n\nPerformance" + f" Evaluation:\n{''.join(self.performance_evaluation)}\n\nYou" + " should only respond in JSON format as described below \nResponse" + f" Format: \n{formatted_response_format} \nEnsure the response can" + " be parsed by Python json.loads" ) return prompt_string diff --git a/swarms/prompts/agent_prompts.py b/swarms/prompts/agent_prompts.py index 8d145fc0..a8c3fca7 100644 --- a/swarms/prompts/agent_prompts.py +++ b/swarms/prompts/agent_prompts.py @@ -5,26 +5,26 @@ def generate_agent_role_prompt(agent): """ prompts = { "Finance Agent": ( - "You are a seasoned finance analyst AI assistant. Your primary goal is to" - " compose comprehensive, astute, impartial, and methodically arranged" - " financial reports based on provided data and trends." + "You are a seasoned finance analyst AI assistant. Your primary goal" + " is to compose comprehensive, astute, impartial, and methodically" + " arranged financial reports based on provided data and trends." ), "Travel Agent": ( - "You are a world-travelled AI tour guide assistant. Your main purpose is to" - " draft engaging, insightful, unbiased, and well-structured travel reports" - " on given locations, including history, attractions, and cultural" - " insights." + "You are a world-travelled AI tour guide assistant. Your main" + " purpose is to draft engaging, insightful, unbiased, and" + " well-structured travel reports on given locations, including" + " history, attractions, and cultural insights." ), "Academic Research Agent": ( - "You are an AI academic research assistant. Your primary responsibility is" - " to create thorough, academically rigorous, unbiased, and systematically" - " organized reports on a given research topic, following the standards of" - " scholarly work." + "You are an AI academic research assistant. Your primary" + " responsibility is to create thorough, academically rigorous," + " unbiased, and systematically organized reports on a given" + " research topic, following the standards of scholarly work." ), "Default Agent": ( - "You are an AI critical thinker research assistant. Your sole purpose is to" - " write well written, critically acclaimed, objective and structured" - " reports on given text." + "You are an AI critical thinker research assistant. Your sole" + " purpose is to write well written, critically acclaimed, objective" + " and structured reports on given text." ), } @@ -39,12 +39,12 @@ def generate_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, answer the following' - f' question or topic: "{question}" in a detailed report -- The report should' - " focus on the answer to the question, should be well structured, informative," - " in depth, with facts and numbers if available, a minimum of 1,200 words and" - " with markdown syntax and apa format. Write all source urls at the end of the" - " report in apa format" + f'"""{research_summary}""" Using the above information, answer the' + f' following question or topic: "{question}" in a detailed report --' + " The report should focus on the answer to the question, should be" + " well structured, informative, in depth, with facts and numbers if" + " available, a minimum of 1,200 words and with markdown syntax and apa" + " format. Write all source urls at the end of the report in apa format" ) @@ -55,9 +55,10 @@ def generate_search_queries_prompt(question): """ return ( - "Write 4 google search queries to search online that form an objective opinion" - f' from the following: "{question}"You must respond with a list of strings in' - ' the following format: ["query 1", "query 2", "query 3", "query 4"]' + "Write 4 google search queries to search online that form an objective" + f' opinion from the following: "{question}"You must respond with a list' + ' of strings in the following format: ["query 1", "query 2", "query' + ' 3", "query 4"]' ) @@ -73,14 +74,15 @@ def generate_resource_report_prompt(question, research_summary): """ return ( f'"""{research_summary}""" Based on the above information, generate a' - " bibliography recommendation report for the following question or topic:" - f' "{question}". The report should provide a detailed analysis of each' - " recommended resource, explaining how each source can contribute to finding" - " answers to the research question. Focus on the relevance, reliability, and" - " significance of each source. Ensure that the report is well-structured," - " informative, in-depth, and follows Markdown syntax. Include relevant facts," - " figures, and numbers whenever available. The report should have a minimum" - " length of 1,200 words." + " bibliography recommendation report for the following question or" + f' topic: "{question}". The report should provide a detailed analysis' + " of each recommended resource, explaining how each source can" + " contribute to finding answers to the research question. Focus on the" + " relevance, reliability, and significance of each source. Ensure that" + " the report is well-structured, informative, in-depth, and follows" + " Markdown syntax. Include relevant facts, figures, and numbers" + " whenever available. The report should have a minimum length of 1,200" + " words." ) @@ -92,13 +94,14 @@ def generate_outline_report_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate an outline for' - " a research report in Markdown syntax for the following question or topic:" - f' "{question}". The outline should provide a well-structured framework for the' - " research report, including the main sections, subsections, and key points to" - " be covered. The research report should be detailed, informative, in-depth," - " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" - " outline and ensure readability." + f'"""{research_summary}""" Using the above information, generate an' + " outline for a research report in Markdown syntax for the following" + f' question or topic: "{question}". The outline should provide a' + " well-structured framework for the research report, including the" + " main sections, subsections, and key points to be covered. The" + " research report should be detailed, informative, in-depth, and a" + " minimum of 1,200 words. Use appropriate Markdown syntax to format" + " the outline and ensure readability." ) @@ -110,11 +113,12 @@ def generate_concepts_prompt(question, research_summary): """ return ( - f'"""{research_summary}""" Using the above information, generate a list of 5' - " main concepts to learn for a research report on the following question or" - f' topic: "{question}". The outline should provide a well-structured' - " frameworkYou must respond with a list of strings in the following format:" - ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' + f'"""{research_summary}""" Using the above information, generate a list' + " of 5 main concepts to learn for a research report on the following" + f' question or topic: "{question}". The outline should provide a' + " well-structured frameworkYou must respond with a list of strings in" + ' the following format: ["concepts 1", "concepts 2", "concepts 3",' + ' "concepts 4, concepts 5"]' ) @@ -128,10 +132,10 @@ def generate_lesson_prompt(concept): """ prompt = ( - f"generate a comprehensive lesson about {concept} in Markdown syntax. This" - f" should include the definitionof {concept}, its historical background and" - " development, its applications or uses in differentfields, and notable events" - f" or facts related to {concept}." + f"generate a comprehensive lesson about {concept} in Markdown syntax." + f" This should include the definitionof {concept}, its historical" + " background and development, its applications or uses in" + f" differentfields, and notable events or facts related to {concept}." ) return prompt diff --git a/swarms/prompts/base.py b/swarms/prompts/base.py index 54a0bc3f..369063e6 100644 --- a/swarms/prompts/base.py +++ b/swarms/prompts/base.py @@ -12,7 +12,9 @@ if TYPE_CHECKING: def get_buffer_string( - messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" + messages: Sequence[BaseMessage], + human_prefix: str = "Human", + ai_prefix: str = "AI", ) -> str: """Convert sequence of Messages to strings and concatenate them into one string. diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index d1e08df9..bbdaa9c7 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -105,7 +105,9 @@ class ChatMessage(Message): def get_buffer_string( - messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI" + messages: Sequence[Message], + human_prefix: str = "Human", + ai_prefix: str = "AI", ) -> str: string_messages = [] for m in messages: diff --git a/swarms/prompts/multi_modal_prompts.py b/swarms/prompts/multi_modal_prompts.py index b552b68d..1c0830d6 100644 --- a/swarms/prompts/multi_modal_prompts.py +++ b/swarms/prompts/multi_modal_prompts.py @@ -1,6 +1,6 @@ ERROR_PROMPT = ( - "An error has occurred for the following text: \n{promptedQuery} Please explain" - " this error.\n {e}" + "An error has occurred for the following text: \n{promptedQuery} Please" + " explain this error.\n {e}" ) IMAGE_PROMPT = """ diff --git a/swarms/prompts/python.py b/swarms/prompts/python.py index 9d1f4a1e..46df5cdc 100644 --- a/swarms/prompts/python.py +++ b/swarms/prompts/python.py @@ -1,16 +1,17 @@ PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." PY_REFLEXION_COMPLETION_INSTRUCTION = ( "You are a Python writing assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the implementation" - " appropriately. Write your full implementation (restate the function" - " signature).\n\n-----" + " implementation, a series of unit tests, and a hint to change the" + " implementation appropriately. Write your full implementation (restate the" + " function signature).\n\n-----" ) PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( - "You are a Python writing assistant. You will be given a function implementation" - " and a series of unit tests. Your goal is to write a few sentences to explain why" - " your implementation is wrong as indicated by the tests. You will need this as a" - " hint when you try again later. Only provide the few sentence description in your" - " answer, not the implementation.\n\n-----" + "You are a Python writing assistant. You will be given a function" + " implementation and a series of unit tests. Your goal is to write a few" + " sentences to explain why your implementation is wrong as indicated by the" + " tests. You will need this as a hint when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation.\n\n-----" ) USE_PYTHON_CODEBLOCK_INSTRUCTION = ( "Use a Python code block to write your response. For" @@ -18,25 +19,26 @@ USE_PYTHON_CODEBLOCK_INSTRUCTION = ( ) PY_SIMPLE_CHAT_INSTRUCTION = ( - "You are an AI that only responds with python code, NOT ENGLISH. You will be given" - " a function signature and its docstring by the user. Write your full" - " implementation (restate the function signature)." + "You are an AI that only responds with python code, NOT ENGLISH. You will" + " be given a function signature and its docstring by the user. Write your" + " full implementation (restate the function signature)." ) PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( - "You are an AI that only responds with only python code. You will be given a" - " function signature and its docstring by the user. Write your full implementation" - " (restate the function signature)." + "You are an AI that only responds with only python code. You will be given" + " a function signature and its docstring by the user. Write your full" + " implementation (restate the function signature)." ) PY_REFLEXION_CHAT_INSTRUCTION = ( "You are an AI Python assistant. You will be given your past function" - " implementation, a series of unit tests, and a hint to change the implementation" - " appropriately. Write your full implementation (restate the function signature)." + " implementation, a series of unit tests, and a hint to change the" + " implementation appropriately. Write your full implementation (restate the" + " function signature)." ) PY_REFLEXION_CHAT_INSTRUCTION_V2 = ( - "You are an AI Python assistant. You will be given your previous implementation of" - " a function, a series of unit tests results, and your self-reflection on your" - " previous implementation. Write your full implementation (restate the function" - " signature)." + "You are an AI Python assistant. You will be given your previous" + " implementation of a function, a series of unit tests results, and your" + " self-reflection on your previous implementation. Write your full" + " implementation (restate the function signature)." ) PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: [previous impl]: @@ -172,18 +174,19 @@ END EXAMPLES ''' PY_SELF_REFLECTION_CHAT_INSTRUCTION = ( "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit tests. Your goal is to write a few sentences" - " to explain why your implementation is wrong as indicated by the tests. You will" - " need this as a hint when you try again later. Only provide the few sentence" - " description in your answer, not the implementation." + " implementation and a series of unit tests. Your goal is to write a few" + " sentences to explain why your implementation is wrong as indicated by the" + " tests. You will need this as a hint when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation." ) PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( "You are a Python programming assistant. You will be given a function" - " implementation and a series of unit test results. Your goal is to write a few" - " sentences to explain why your implementation is wrong as indicated by the tests." - " You will need this as guidance when you try again later. Only provide the few" - " sentence description in your answer, not the implementation. You will be given a" - " few examples by the user." + " implementation and a series of unit test results. Your goal is to write a" + " few sentences to explain why your implementation is wrong as indicated by" + " the tests. You will need this as guidance when you try again later. Only" + " provide the few sentence description in your answer, not the" + " implementation. You will be given a few examples by the user." ) PY_SELF_REFLECTION_FEW_SHOT = """Example 1: [function impl]: diff --git a/swarms/prompts/sales.py b/swarms/prompts/sales.py index 4f04f7fc..3a362174 100644 --- a/swarms/prompts/sales.py +++ b/swarms/prompts/sales.py @@ -1,23 +1,26 @@ conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your company." - " Be polite and respectful while keeping the tone of the conversation" - " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." + "Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional. Your greeting should be welcoming. Always" + " clarify in your greeting the reason why you are contacting the" + " prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the right person" - " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they are the" + " right person to talk to regarding your product/service. Ensure that" + " they have the authority to make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can benefit the" - " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." + "Value proposition: Briefly explain how your product/service can" + " benefit the prospect. Focus on the unique selling points and value" + " proposition of your product/service that sets it apart from" + " competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." + "Needs analysis: Ask open-ended questions to uncover the prospect's" + " needs and pain points. Listen carefully to their responses and take" + " notes." ), "5": ( "Solution presentation: Based on the prospect's needs, present your" @@ -29,9 +32,9 @@ conversation_stages = { " testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a demo, a" - " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could be a" + " demo, a trial or a meeting with decision-makers. Ensure to summarize" + " what has been discussed and reiterate the benefits." ), } diff --git a/swarms/prompts/sales_prompts.py b/swarms/prompts/sales_prompts.py index 3f2b9f2b..7c1f50ed 100644 --- a/swarms/prompts/sales_prompts.py +++ b/swarms/prompts/sales_prompts.py @@ -46,24 +46,27 @@ Conversation history: conversation_stages = { "1": ( - "Introduction: Start the conversation by introducing yourself and your company." - " Be polite and respectful while keeping the tone of the conversation" - " professional. Your greeting should be welcoming. Always clarify in your" - " greeting the reason why you are contacting the prospect." + "Introduction: Start the conversation by introducing yourself and your" + " company. Be polite and respectful while keeping the tone of the" + " conversation professional. Your greeting should be welcoming. Always" + " clarify in your greeting the reason why you are contacting the" + " prospect." ), "2": ( - "Qualification: Qualify the prospect by confirming if they are the right person" - " to talk to regarding your product/service. Ensure that they have the" - " authority to make purchasing decisions." + "Qualification: Qualify the prospect by confirming if they are the" + " right person to talk to regarding your product/service. Ensure that" + " they have the authority to make purchasing decisions." ), "3": ( - "Value proposition: Briefly explain how your product/service can benefit the" - " prospect. Focus on the unique selling points and value proposition of your" - " product/service that sets it apart from competitors." + "Value proposition: Briefly explain how your product/service can" + " benefit the prospect. Focus on the unique selling points and value" + " proposition of your product/service that sets it apart from" + " competitors." ), "4": ( - "Needs analysis: Ask open-ended questions to uncover the prospect's needs and" - " pain points. Listen carefully to their responses and take notes." + "Needs analysis: Ask open-ended questions to uncover the prospect's" + " needs and pain points. Listen carefully to their responses and take" + " notes." ), "5": ( "Solution presentation: Based on the prospect's needs, present your" @@ -75,8 +78,8 @@ conversation_stages = { " testimonials to support your claims." ), "7": ( - "Close: Ask for the sale by proposing a next step. This could be a demo, a" - " trial or a meeting with decision-makers. Ensure to summarize what has been" - " discussed and reiterate the benefits." + "Close: Ask for the sale by proposing a next step. This could be a" + " demo, a trial or a meeting with decision-makers. Ensure to summarize" + " what has been discussed and reiterate the benefits." ), } diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index be79a860..97e8a5ae 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -7,7 +7,11 @@ from typing import Callable, Dict, List from termcolor import colored from swarms.structs.flow import Flow -from swarms.utils.decorators import error_decorator, log_decorator, timing_decorator +from swarms.utils.decorators import ( + error_decorator, + log_decorator, + timing_decorator, +) class AutoScaler: @@ -69,7 +73,9 @@ class AutoScaler: try: self.tasks_queue.put(task) except Exception as error: - print(f"Error adding task to queue: {error} try again with a new task") + print( + f"Error adding task to queue: {error} try again with a new task" + ) @log_decorator @error_decorator @@ -108,10 +114,15 @@ class AutoScaler: if pending_tasks / len(self.agents_pool) > self.busy_threshold: self.scale_up() - elif active_agents / len(self.agents_pool) < self.idle_threshold: + elif ( + active_agents / len(self.agents_pool) < self.idle_threshold + ): self.scale_down() except Exception as error: - print(f"Error monitoring and scaling: {error} try again with a new task") + print( + f"Error monitoring and scaling: {error} try again with a new" + " task" + ) @log_decorator @error_decorator @@ -125,7 +136,9 @@ class AutoScaler: while True: task = self.task_queue.get() if task: - available_agent = next((agent for agent in self.agents_pool)) + available_agent = next( + (agent for agent in self.agents_pool) + ) if available_agent: available_agent.run(task) except Exception as error: diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index fd359592..18a141a3 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -11,14 +11,9 @@ from termcolor import colored from swarms.utils.code_interpreter import SubprocessCodeInterpreter from swarms.utils.parse_code import extract_code_in_backticks_in_string +from swarms.tools.tool import BaseTool -# Prompts -DYNAMIC_STOP_PROMPT = """ -When you have finished the task from the Human, output a special token: -This will enable you to leave the autonomous loop. -""" - -# Constants +# System prompt FLOW_SYSTEM_PROMPT = f""" You are an autonomous agent granted autonomy in a autonomous loop structure. Your role is to engage in multi-step conversations with your self or the user, @@ -30,6 +25,17 @@ to aid in these complex tasks. Your responses should be coherent, contextually r """ + +# Prompts +DYNAMIC_STOP_PROMPT = """ + +Now, when you 99% sure you have completed the task, you may follow the instructions below to escape the autonomous loop. + +When you have finished the task from the Human, output a special token: +This will enable you to leave the autonomous loop. +""" + + # Make it able to handle multi input tools DYNAMICAL_TOOL_USAGE = """ You have access to the following tools: @@ -46,6 +52,11 @@ commands: { "tool1": "inputs", "tool1": "inputs" } + "tool3: "tool_name", + "params": { + "tool1": "inputs", + "tool1": "inputs" + } } } @@ -53,6 +64,29 @@ commands: { {tools} """ +SCENARIOS = """ +commands: { + "tools": { + tool1: "tool_name", + "params": { + "tool1": "inputs", + "tool1": "inputs" + } + "tool2: "tool_name", + "params": { + "tool1": "inputs", + "tool1": "inputs" + } + "tool3: "tool_name", + "params": { + "tool1": "inputs", + "tool1": "inputs" + } + } +} + +""" + def autonomous_agent_prompt( tools_prompt: str = DYNAMICAL_TOOL_USAGE, @@ -101,9 +135,9 @@ def parse_done_token(response: str) -> bool: class Flow: """ - Flow is a chain like structure from langchain that provides the autonomy to language models - to generate sequential responses. - + Flow is the structure that provides autonomy to any llm in a reliable and effective fashion. + The flow structure is designed to be used with any llm and provides the following features: + Features: * Interactive, AI generates, then user input * Message history and performance history fed -> into context -> truncate if too long @@ -191,7 +225,7 @@ class Flow: def __init__( self, llm: Any, - # template: str, + template: Optional[str] = None, max_loops=5, stopping_condition: Optional[Callable[[str], bool]] = None, loop_interval: int = 1, @@ -205,7 +239,7 @@ class Flow: agent_name: str = " Autonomous Agent XYZ1B", agent_description: str = None, system_prompt: str = FLOW_SYSTEM_PROMPT, - # tools: List[Any] = None, + tools: List[BaseTool] = None, dynamic_temperature: bool = False, sop: str = None, saved_state_path: Optional[str] = "flow_state.json", @@ -217,6 +251,7 @@ class Flow: **kwargs: Any, ): self.llm = llm + self.template = template self.max_loops = max_loops self.stopping_condition = stopping_condition self.loop_interval = loop_interval @@ -238,7 +273,7 @@ class Flow: # The max_loops will be set dynamically if the dynamic_loop if self.dynamic_loops: self.max_loops = "auto" - # self.tools = tools or [] + self.tools = tools or [] self.system_prompt = system_prompt self.agent_name = agent_name self.agent_description = agent_description @@ -302,68 +337,82 @@ class Flow: # # Parse the text for tool usage # pass - # def get_tool_description(self): - # """Get the tool description""" - # tool_descriptions = [] - # for tool in self.tools: - # description = f"{tool.name}: {tool.description}" - # tool_descriptions.append(description) - # return "\n".join(tool_descriptions) - - # def find_tool_by_name(self, name: str): - # """Find a tool by name""" - # for tool in self.tools: - # if tool.name == name: - # return tool - # return None - - # def construct_dynamic_prompt(self): - # """Construct the dynamic prompt""" - # tools_description = self.get_tool_description() - # return DYNAMICAL_TOOL_USAGE.format(tools=tools_description) - - # def extract_tool_commands(self, text: str): - # """ - # Extract the tool commands from the text - - # Example: - # ```json - # { - # "tool": "tool_name", - # "params": { - # "tool1": "inputs", - # "param2": "value2" - # } - # } - # ``` + def get_tool_description(self): + """Get the tool description""" + if self.tools: + try: + tool_descriptions = [] + for tool in self.tools: + description = f"{tool.name}: {tool.description}" + tool_descriptions.append(description) + return "\n".join(tool_descriptions) + except Exception as error: + print( + f"Error getting tool description: {error} try adding a" + " description to the tool or removing the tool" + ) + else: + return "No tools available" - # """ - # # Regex to find JSON like strings - # pattern = r"```json(.+?)```" - # matches = re.findall(pattern, text, re.DOTALL) - # json_commands = [] - # for match in matches: - # try: - # json_commands = json.loads(match) - # json_commands.append(json_commands) - # except Exception as error: - # print(f"Error parsing JSON command: {error}") - - # def parse_and_execute_tools(self, response): - # """Parse and execute the tools""" - # json_commands = self.extract_tool_commands(response) - # for command in json_commands: - # tool_name = command.get("tool") - # params = command.get("parmas", {}) - # self.execute_tool(tool_name, params) - - # def execute_tools(self, tool_name, params): - # """Execute the tool with the provided params""" - # tool = self.tool_find_by_name(tool_name) - # if tool: - # # Execute the tool with the provided parameters - # tool_result = tool.run(**params) - # print(tool_result) + def find_tool_by_name(self, name: str): + """Find a tool by name""" + for tool in self.tools: + if tool.name == name: + return tool + return None + + def construct_dynamic_prompt(self): + """Construct the dynamic prompt""" + tools_description = self.get_tool_description() + + tool_prompt = self.tool_prompt_prep(tools_description, SCENARIOS) + + return tool_prompt + + # return DYNAMICAL_TOOL_USAGE.format(tools=tools_description) + + def extract_tool_commands(self, text: str): + """ + Extract the tool commands from the text + + Example: + ```json + { + "tool": "tool_name", + "params": { + "tool1": "inputs", + "param2": "value2" + } + } + ``` + + """ + # Regex to find JSON like strings + pattern = r"```json(.+?)```" + matches = re.findall(pattern, text, re.DOTALL) + json_commands = [] + for match in matches: + try: + json_commands = json.loads(match) + json_commands.append(json_commands) + except Exception as error: + print(f"Error parsing JSON command: {error}") + + def parse_and_execute_tools(self, response: str): + """Parse and execute the tools""" + json_commands = self.extract_tool_commands(response) + for command in json_commands: + tool_name = command.get("tool") + params = command.get("parmas", {}) + self.execute_tool(tool_name, params) + + def execute_tools(self, tool_name, params): + """Execute the tool with the provided params""" + tool = self.tool_find_by_name(tool_name) + if tool: + # Execute the tool with the provided parameters + tool_result = tool.run(**params) + print(tool_result) def truncate_history(self): """ @@ -431,8 +480,12 @@ class Flow: print(colored("Initializing Autonomous Agent...", "yellow")) # print(colored("Loading modules...", "yellow")) # print(colored("Modules loaded successfully.", "green")) - print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"])) - print(colored("All systems operational. Executing task...", "green")) + print( + colored("Autonomous Agent Activated.", "cyan", attrs=["bold"]) + ) + print( + colored("All systems operational. Executing task...", "green") + ) except Exception as error: print( colored( @@ -475,16 +528,18 @@ class Flow: self.print_dashboard(task) loop_count = 0 - # for i in range(self.max_loops): while self.max_loops == "auto" or loop_count < self.max_loops: loop_count += 1 - print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) + print( + colored(f"\nLoop {loop_count} of {self.max_loops}", "blue") + ) print("\n") + # Check to see if stopping token is in the output to stop the loop if self.stopping_token: - if self._check_stopping_condition(response) or parse_done_token( + if self._check_stopping_condition( response - ): + ) or parse_done_token(response): break # Adjust temperature, comment if no work @@ -502,111 +557,22 @@ class Flow: **kwargs, ) + # If code interpreter is enabled then run the code if self.code_interpreter: self.run_code(response) - # If there are any tools then parse and execute them - # if self.tools: - # self.parse_and_execute_tools(response) - - if self.interactive: - print(f"AI: {response}") - history.append(f"AI: {response}") - response = input("You: ") - history.append(f"Human: {response}") - else: - print(f"AI: {response}") - history.append(f"AI: {response}") - # print(response) - break - except Exception as e: - logging.error(f"Error generating response: {e}") - attempt += 1 - time.sleep(self.retry_interval) - history.append(response) - time.sleep(self.loop_interval) - self.memory.append(history) - - if self.autosave: - save_path = self.saved_state_path or "flow_state.json" - print(colored(f"Autosaving flow state to {save_path}", "green")) - self.save_state(save_path) - - if self.return_history: - return response, history - - return response - except Exception as error: - print(f"Error running flow: {error}") - raise - - def __call__(self, task: str, **kwargs): - """ - Run the autonomous agent loop - - Args: - task (str): The initial task to run - - Flow: - 1. Generate a response - 2. Check stopping condition - 3. If stopping condition is met, stop - 4. If stopping condition is not met, generate a response - 5. Repeat until stopping condition is met or max_loops is reached - - """ - try: - # dynamic_prompt = self.construct_dynamic_prompt() - # combined_prompt = f"{dynamic_prompt}\n{task}" - # Activate Autonomous agent message - self.activate_autonomous_agent() - - response = task # or combined_prompt - history = [f"{self.user_name}: {task}"] - - # If dashboard = True then print the dashboard - if self.dashboard: - self.print_dashboard(task) - - loop_count = 0 - # for i in range(self.max_loops): - while self.max_loops == "auto" or loop_count < self.max_loops: - loop_count += 1 - print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) - print("\n") - - if self.stopping_token: - if self._check_stopping_condition(response) or parse_done_token( - response - ): - break - - # Adjust temperature, comment if no work - if self.dynamic_temperature: - self.dynamic_temperature() - - # Preparing the prompt - task = self.agent_history_prompt(FLOW_SYSTEM_PROMPT, response) - - attempt = 0 - while attempt < self.retry_attempts: - try: - response = self.llm( - task, - **kwargs, - ) - - if self.code_interpreter: - self.run_code(response) # If there are any tools then parse and execute them - # if self.tools: - # self.parse_and_execute_tools(response) + if self.tools: + self.parse_and_execute_tools(response) + # If interactive mode is enabled then print the response and get user input if self.interactive: print(f"AI: {response}") history.append(f"AI: {response}") response = input("You: ") history.append(f"Human: {response}") + + # If interactive mode is not enabled then print the response else: print(f"AI: {response}") history.append(f"AI: {response}") @@ -616,15 +582,20 @@ class Flow: logging.error(f"Error generating response: {e}") attempt += 1 time.sleep(self.retry_interval) + # Add the response to the history history.append(response) + time.sleep(self.loop_interval) + # Add the history to the memory self.memory.append(history) + # If autosave is enabled then save the state if self.autosave: save_path = self.saved_state_path or "flow_state.json" print(colored(f"Autosaving flow state to {save_path}", "green")) self.save_state(save_path) + # If return history is enabled then return the response and history if self.return_history: return response, history @@ -665,7 +636,9 @@ class Flow: print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print("\n") - if self._check_stopping_condition(response) or parse_done_token(response): + if self._check_stopping_condition(response) or parse_done_token( + response + ): break # Adjust temperature, comment if no work @@ -985,7 +958,8 @@ class Flow: if hasattr(self.llm, name): value = getattr(self.llm, name) if isinstance( - value, (str, int, float, bool, list, dict, tuple, type(None)) + value, + (str, int, float, bool, list, dict, tuple, type(None)), ): llm_params[name] = value else: @@ -1046,7 +1020,9 @@ class Flow: print(f"Flow state loaded from {file_path}") - def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1): + def retry_on_failure( + self, function, retries: int = 3, retry_delay: int = 1 + ): """Retry wrapper for LLM calls.""" attempt = 0 while attempt < retries: @@ -1105,7 +1081,7 @@ class Flow: run_code = self.code_executor.run(parsed_code) return run_code - def tool_prompt_prep(self, api_docs: str = None, required_api: str = None): + def tools_prompt_prep(self, docs: str = None, scenarios: str = None): """ Prepare the tool prompt """ @@ -1152,19 +1128,14 @@ class Flow: response. Deliver your response in this format: ‘‘‘ - - Scenario 1: - - Scenario 2: - - Scenario 3: + {scenarios} ‘‘‘ # APIs ‘‘‘ - {api_docs} + {docs} ‘‘‘ # Response - Required API: {required_api} - Scenarios with >=5 API calls: ‘‘‘ - - Scenario 1: """ def self_healing(self, **kwargs): diff --git a/swarms/structs/non_linear_workflow.py b/swarms/structs/non_linear_workflow.py new file mode 100644 index 00000000..79bc0af7 --- /dev/null +++ b/swarms/structs/non_linear_workflow.py @@ -0,0 +1,97 @@ +from swarms.models import OpenAIChat +from swarms.structs.flow import Flow + +import concurrent.futures +from typing import Callable, List, Dict, Any, Sequence + + +class Task: + def __init__( + self, + id: str, + task: str, + flows: Sequence[Flow], + dependencies: List[str] = [], + ): + self.id = id + self.task = task + self.flows = flows + self.dependencies = dependencies + self.results = [] + + def execute(self, parent_results: Dict[str, Any]): + args = [parent_results[dep] for dep in self.dependencies] + for flow in self.flows: + result = flow.run(self.task, *args) + self.results.append(result) + args = [ + result + ] # The output of one flow becomes the input to the next + + +class Workflow: + def __init__(self): + self.tasks: Dict[str, Task] = {} + self.executor = concurrent.futures.ThreadPoolExecutor() + + def add_task(self, task: Task): + self.tasks[task.id] = task + + def run(self): + completed_tasks = set() + while len(completed_tasks) < len(self.tasks): + futures = [] + for task in self.tasks.values(): + if task.id not in completed_tasks and all( + dep in completed_tasks for dep in task.dependencies + ): + future = self.executor.submit( + task.execute, + { + dep: self.tasks[dep].results + for dep in task.dependencies + }, + ) + futures.append((future, task.id)) + + for future, task_id in futures: + future.result() # Wait for task completion + completed_tasks.add(task_id) + + def get_results(self): + return {task_id: task.results for task_id, task in self.tasks.items()} + + +# create flows +llm = OpenAIChat(openai_api_key="sk-") + +flow1 = Flow(llm, max_loops=1) +flow2 = Flow(llm, max_loops=1) +flow3 = Flow(llm, max_loops=1) +flow4 = Flow(llm, max_loops=1) + + +# Create tasks with their respective Flows and task strings +task1 = Task("task1", "Generate a summary on Quantum field theory", [flow1]) +task2 = Task( + "task2", + "Elaborate on the summary of topic X", + [flow2, flow3], + dependencies=["task1"], +) +task3 = Task( + "task3", "Generate conclusions for topic X", [flow4], dependencies=["task1"] +) + +# Create a workflow and add tasks +workflow = Workflow() +workflow.add_task(task1) +workflow.add_task(task2) +workflow.add_task(task3) + +# Run the workflow +workflow.run() + +# Get results +results = workflow.get_results() +print(results) diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index d1c600f0..753ada15 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -113,7 +113,9 @@ class SequentialWorkflow: restore_state_filepath: Optional[str] = None dashboard: bool = False - def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None: + def add( + self, task: str, flow: Union[Callable, Flow], *args, **kwargs + ) -> None: """ Add a task to the workflow. @@ -147,6 +149,7 @@ class SequentialWorkflow: return {task.description: task.result for task in self.tasks} def remove_task(self, task_description: str) -> None: + """Remove tasks from sequential workflow""" self.tasks = [ task for task in self.tasks if task.description != task_description ] @@ -182,7 +185,9 @@ class SequentialWorkflow: raise ValueError(f"Task {task_description} not found in workflow.") def save_workflow_state( - self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs + self, + filepath: Optional[str] = "sequential_workflow_state.json", + **kwargs, ) -> None: """ Saves the workflow state to a json file. @@ -260,10 +265,6 @@ class SequentialWorkflow: -------------------------------- Metadata: kwargs: {kwargs} - - - - """, "cyan", attrs=["bold", "underline"], @@ -352,8 +353,9 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" + "The 'task' argument is required for the" + " Flow flow execution in" + f" '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") @@ -377,7 +379,9 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json" + ) except Exception as e: print( colored( @@ -408,8 +412,8 @@ class SequentialWorkflow: # Ensure that 'task' is provided in the kwargs if "task" not in task.kwargs: raise ValueError( - "The 'task' argument is required for the Flow flow" - f" execution in '{task.description}'" + "The 'task' argument is required for the Flow" + f" flow execution in '{task.description}'" ) # Separate the 'task' argument from other kwargs flow_task_arg = task.kwargs.pop("task") @@ -433,4 +437,6 @@ class SequentialWorkflow: # Autosave the workflow state if self.autosave: - self.save_workflow_state("sequential_workflow_state.json") + self.save_workflow_state( + "sequential_workflow_state.json" + ) diff --git a/swarms/swarms/autobloggen.py b/swarms/swarms/autobloggen.py index dec2620f..d732606b 100644 --- a/swarms/swarms/autobloggen.py +++ b/swarms/swarms/autobloggen.py @@ -103,7 +103,9 @@ class AutoBlogGenSwarm: review_agent = self.print_beautifully("Review Agent", review_agent) # Agent that publishes on social media - distribution_agent = self.llm(self.social_media_prompt(article=review_agent)) + distribution_agent = self.llm( + self.social_media_prompt(article=review_agent) + ) distribution_agent = self.print_beautifully( "Distribution Agent", distribution_agent ) @@ -115,7 +117,11 @@ class AutoBlogGenSwarm: for i in range(self.iterations): self.step() except Exception as error: - print(colored(f"Error while running AutoBlogGenSwarm {error}", "red")) + print( + colored( + f"Error while running AutoBlogGenSwarm {error}", "red" + ) + ) if attempt == self.retry_attempts - 1: raise diff --git a/swarms/swarms/base.py b/swarms/swarms/base.py index e99c9b38..1ccc819c 100644 --- a/swarms/swarms/base.py +++ b/swarms/swarms/base.py @@ -117,7 +117,9 @@ class AbstractSwarm(ABC): pass @abstractmethod - def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None): + def broadcast( + self, message: str, sender: Optional["AbstractWorker"] = None + ): """Broadcast a message to all workers""" pass diff --git a/swarms/swarms/dialogue_simulator.py b/swarms/swarms/dialogue_simulator.py index ec86c414..2775daf0 100644 --- a/swarms/swarms/dialogue_simulator.py +++ b/swarms/swarms/dialogue_simulator.py @@ -23,7 +23,9 @@ class DialogueSimulator: >>> model.run("test") """ - def __init__(self, agents: List[Callable], max_iters: int = 10, name: str = None): + def __init__( + self, agents: List[Callable], max_iters: int = 10, name: str = None + ): self.agents = agents self.max_iters = max_iters self.name = name @@ -45,7 +47,8 @@ class DialogueSimulator: for receiver in self.agents: message_history = ( - f"Speaker Name: {speaker.name} and message: {speaker_message}" + f"Speaker Name: {speaker.name} and message:" + f" {speaker_message}" ) receiver.run(message_history) @@ -56,7 +59,9 @@ class DialogueSimulator: print(f"Error running dialogue simulator: {error}") def __repr__(self): - return f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})" + return ( + f"DialogueSimulator({self.agents}, {self.max_iters}, {self.name})" + ) def save_state(self): """Save the state of the dialogue simulator""" diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index e75d81d2..65377308 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -64,7 +64,8 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -83,7 +84,8 @@ class GodMode: table.append([f"LLM {i+1}", response]) print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -115,11 +117,13 @@ class GodMode: print(f"{i + 1}. {task}") print("\nLast Responses:") table = [ - [f"LLM {i+1}", response] for i, response in enumerate(self.last_responses) + [f"LLM {i+1}", response] + for i, response in enumerate(self.last_responses) ] print( colored( - tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" + tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), + "cyan", ) ) @@ -137,7 +141,8 @@ class GodMode: """Asynchronous run the task string""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, lambda llm: llm(task), llm) for llm in self.llms + loop.run_in_executor(None, lambda llm: llm(task), llm) + for llm in self.llms ] for response in await asyncio.gather(*futures): print(response) @@ -145,13 +150,18 @@ class GodMode: def concurrent_run(self, task: str) -> List[str]: """Synchronously run the task on all llms and collect responses""" with ThreadPoolExecutor() as executor: - future_to_llm = {executor.submit(llm, task): llm for llm in self.llms} + future_to_llm = { + executor.submit(llm, task): llm for llm in self.llms + } responses = [] for future in as_completed(future_to_llm): try: responses.append(future.result()) except Exception as error: - print(f"{future_to_llm[future]} generated an exception: {error}") + print( + f"{future_to_llm[future]} generated an exception:" + f" {error}" + ) self.last_responses = responses self.task_history.append(task) return responses diff --git a/swarms/swarms/groupchat.py b/swarms/swarms/groupchat.py index 5cff3263..76de7e16 100644 --- a/swarms/swarms/groupchat.py +++ b/swarms/swarms/groupchat.py @@ -47,7 +47,9 @@ class GroupChat: def next_agent(self, agent: Flow) -> Flow: """Return the next agent in the list.""" - return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] + return self.agents[ + (self.agent_names.index(agent.name) + 1) % len(self.agents) + ] def select_speaker_msg(self): """Return the message for selecting the next speaker.""" @@ -78,9 +80,9 @@ class GroupChat: { "role": "system", "content": ( - "Read the above conversation. Then select the next most" - f" suitable role from {self.agent_names} to play. Only" - " return the role." + "Read the above conversation. Then select the next" + f" most suitable role from {self.agent_names} to" + " play. Only return the role." ), } ] @@ -126,7 +128,9 @@ class GroupChatManager: self.selector = selector def __call__(self, task: str): - self.groupchat.messages.append({"role": self.selector.name, "content": task}) + self.groupchat.messages.append( + {"role": self.selector.name, "content": task} + ) for i in range(self.groupchat.max_round): speaker = self.groupchat.select_speaker( last_speaker=self.selector, selector=self.selector diff --git a/swarms/swarms/multi_agent_collab.py b/swarms/swarms/multi_agent_collab.py index ce5a0dd6..98f32d47 100644 --- a/swarms/swarms/multi_agent_collab.py +++ b/swarms/swarms/multi_agent_collab.py @@ -13,8 +13,8 @@ from swarms.utils.logger import logger class BidOutputParser(RegexParser): def get_format_instructions(self) -> str: return ( - "Your response should be an integrater delimited by angled brackets like" - " this: " + "Your response should be an integrater delimited by angled brackets" + " like this: " ) @@ -23,22 +23,6 @@ bid_parser = BidOutputParser( ) -def select_next_speaker_director(step: int, agents, director) -> int: - # if the step if even => director - # => director selects next speaker - if step % 2 == 1: - idx = 0 - else: - idx = director.select_next_speaker() + 1 - return idx - - -# Define a selection function -def select_speaker_round_table(step: int, agents) -> int: - # This function selects the speaker in a round-robin fashion - return step % len(agents) - - # main class MultiAgentCollaboration: """ @@ -49,6 +33,15 @@ class MultiAgentCollaboration: selection_function (callable): The function that selects the next speaker. Defaults to select_next_speaker. max_iters (int): The maximum number of iterations. Defaults to 10. + autosave (bool): Whether to autosave the state of all agents. Defaults to True. + saved_file_path_name (str): The path to the saved file. Defaults to + "multi_agent_collab.json". + stopping_token (str): The token that stops the collaboration. Defaults to + "". + results (list): The results of the collaboration. Defaults to []. + logger (logging.Logger): The logger. Defaults to logger. + logging (bool): Whether to log the collaboration. Defaults to True. + Methods: reset: Resets the state of all agents. @@ -62,18 +55,40 @@ class MultiAgentCollaboration: Usage: - >>> from swarms.models import MultiAgentCollaboration - >>> from swarms.models import Flow >>> from swarms.models import OpenAIChat - >>> from swarms.models import Anthropic - + >>> from swarms.structs import Flow + >>> from swarms.swarms.multi_agent_collab import MultiAgentCollaboration + >>> + >>> # Initialize the language model + >>> llm = OpenAIChat( + >>> temperature=0.5, + >>> ) + >>> + >>> + >>> ## Initialize the workflow + >>> flow = Flow(llm=llm, max_loops=1, dashboard=True) + >>> + >>> # Run the workflow on a task + >>> out = flow.run("Generate a 10,000 word blog on health and wellness.") + >>> + >>> # Initialize the multi-agent collaboration + >>> swarm = MultiAgentCollaboration( + >>> agents=[flow], + >>> max_iters=4, + >>> ) + >>> + >>> # Run the multi-agent collaboration + >>> swarm.run() + >>> + >>> # Format the results of the multi-agent collaboration + >>> swarm.format_results(swarm.results) """ def __init__( self, agents: List[Flow], - selection_function: callable = select_next_speaker_director, + selection_function: callable = None, max_iters: int = 10, autosave: bool = True, saved_file_path_name: str = "multi_agent_collab.json", @@ -165,7 +180,7 @@ class MultiAgentCollaboration: ), retry_error_callback=lambda retry_state: 0, ) - def run(self): + def run_director(self, task: str): """Runs the multi-agent collaboration.""" n = 0 self.reset() @@ -179,10 +194,85 @@ class MultiAgentCollaboration: print("\n") n += 1 + def select_next_speaker_roundtable( + self, step: int, agents: List[Flow] + ) -> int: + """Selects the next speaker.""" + return step % len(agents) + + def select_next_speaker_director( + step: int, agents: List[Flow], director + ) -> int: + # if the step if even => director + # => director selects next speaker + if step % 2 == 1: + idx = 0 + else: + idx = director.select_next_speaker() + 1 + return idx + + # def run(self, task: str): + # """Runs the multi-agent collaboration.""" + # for step in range(self.max_iters): + # speaker_idx = self.select_next_speaker_roundtable(step, self.agents) + # speaker = self.agents[speaker_idx] + # result = speaker.run(task) + # self.results.append({"agent": speaker, "response": result}) + + # if self.autosave: + # self.save_state() + # if result == self.stopping_token: + # break + # return self.results + + # def run(self, task: str): + # for _ in range(self.max_iters): + # for step, agent, in enumerate(self.agents): + # result = agent.run(task) + # self.results.append({"agent": agent, "response": result}) + # if self.autosave: + # self.save_state() + # if result == self.stopping_token: + # break + + # return self.results + + # def run(self, task: str): + # conversation = task + # for _ in range(self.max_iters): + # for agent in self.agents: + # result = agent.run(conversation) + # self.results.append({"agent": agent, "response": result}) + # conversation = result + + # if self.autosave: + # self.save() + # if result == self.stopping_token: + # break + # return self.results + + def run(self, task: str): + conversation = task + for _ in range(self.max_iters): + for agent in self.agents: + result = agent.run(conversation) + self.results.append({"agent": agent, "response": result}) + conversation += result + + if self.autosave: + self.save_state() + if result == self.stopping_token: + break + + return self.results + def format_results(self, results): """Formats the results of the run method""" formatted_results = "\n".join( - [f"{result['agent']} responded: {result['response']}" for result in results] + [ + f"{result['agent']} responded: {result['response']}" + for result in results + ] ) return formatted_results @@ -208,7 +298,12 @@ class MultiAgentCollaboration: return state def __repr__(self): - return f"MultiAgentCollaboration(agents={self.agents}, selection_function={self.select_next_speaker}, max_iters={self.max_iters}, autosave={self.autosave}, saved_file_path_name={self.saved_file_path_name})" + return ( + f"MultiAgentCollaboration(agents={self.agents}," + f" selection_function={self.select_next_speaker}," + f" max_iters={self.max_iters}, autosave={self.autosave}," + f" saved_file_path_name={self.saved_file_path_name})" + ) def performance(self): """Tracks and reports the performance of each agent""" diff --git a/swarms/swarms/orchestrate.py b/swarms/swarms/orchestrate.py index f522911b..b7a7d0e0 100644 --- a/swarms/swarms/orchestrate.py +++ b/swarms/swarms/orchestrate.py @@ -111,7 +111,9 @@ class Orchestrator: self.chroma_client = chromadb.Client() - self.collection = self.chroma_client.create_collection(name=collection_name) + self.collection = self.chroma_client.create_collection( + name=collection_name + ) self.current_tasks = {} @@ -148,13 +150,14 @@ class Orchestrator: ) logging.info( - f"Task {id(str)} has been processed by agent {id(agent)} with" + f"Task {id(str)} has been processed by agent" + f" {id(agent)} with" ) except Exception as error: logging.error( - f"Failed to process task {id(task)} by agent {id(agent)}. Error:" - f" {error}" + f"Failed to process task {id(task)} by agent {id(agent)}." + f" Error: {error}" ) finally: with self.condition: @@ -175,7 +178,9 @@ class Orchestrator: try: # Query the vector database for documents created by the agents - results = self.collection.query(query_texts=[str(agent_id)], n_results=10) + results = self.collection.query( + query_texts=[str(agent_id)], n_results=10 + ) return results except Exception as e: @@ -212,7 +217,9 @@ class Orchestrator: self.collection.add(documents=[result], ids=[str(id(result))]) except Exception as e: - logging.error(f"Failed to append the agent output to database. Error: {e}") + logging.error( + f"Failed to append the agent output to database. Error: {e}" + ) raise def run(self, objective: str): @@ -226,7 +233,9 @@ class Orchestrator: results = [ self.assign_task(agent_id, task) - for agent_id, task in zip(range(len(self.agents)), self.task_queue) + for agent_id, task in zip( + range(len(self.agents)), self.task_queue + ) ] for result in results: diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index cf5450e6..07062d11 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -6,7 +6,9 @@ from typing import Optional import pandas as pd import torch from langchain.agents import tool -from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent +from langchain.agents.agent_toolkits.pandas.base import ( + create_pandas_dataframe_agent, +) from langchain.chains.qa_with_sources.loading import ( BaseCombineDocumentsChain, ) @@ -38,7 +40,10 @@ def pushd(new_dir): @tool def process_csv( - llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None + llm, + csv_file_path: str, + instructions: str, + output_path: Optional[str] = None, ) -> str: """Process a CSV by with pandas in a limited REPL.\ Only use this after writing data to disk as a csv file.\ @@ -49,7 +54,9 @@ def process_csv( df = pd.read_csv(csv_file_path) except Exception as e: return f"Error: {e}" - agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False) + agent = create_pandas_dataframe_agent( + llm, df, max_iterations=30, verbose=False + ) if output_path is not None: instructions += f" Save output to disk at {output_path}" try: @@ -79,7 +86,9 @@ async def async_load_playwright(url: str) -> str: text = soup.get_text() lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + chunks = ( + phrase.strip() for line in lines for phrase in line.split(" ") + ) results = "\n".join(chunk for chunk in chunks if chunk) except Exception as e: results = f"Error: {e}" @@ -110,7 +119,8 @@ def _get_text_splitter(): class WebpageQATool(BaseTool): name = "query_webpage" description = ( - "Browse a webpage and retrieve the information relevant to the question." + "Browse a webpage and retrieve the information relevant to the" + " question." ) text_splitter: RecursiveCharacterTextSplitter = Field( default_factory=_get_text_splitter @@ -176,7 +186,9 @@ def VQAinference(self, inputs): image_path, question = inputs.split(",") raw_image = Image.open(image_path).convert("RGB") - inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype) + inputs = processor(raw_image, question, return_tensors="pt").to( + device, torch_dtype + ) out = model.generate(**inputs) answer = processor.decode(out[0], skip_special_tokens=True) diff --git a/swarms/tools/mm_models.py b/swarms/tools/mm_models.py index 58fe11e5..a218ff50 100644 --- a/swarms/tools/mm_models.py +++ b/swarms/tools/mm_models.py @@ -28,7 +28,9 @@ class MaskFormer: def __init__(self, device): print("Initializing MaskFormer to %s" % device) self.device = device - self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + self.processor = CLIPSegProcessor.from_pretrained( + "CIDAS/clipseg-rd64-refined" + ) self.model = CLIPSegForImageSegmentation.from_pretrained( "CIDAS/clipseg-rd64-refined" ).to(device) @@ -76,23 +78,26 @@ class ImageEditing: @tool( name="Remove Something From The Photo", description=( - "useful when you want to remove and object or something from the photo " - "from its description or location. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the object need to be removed. " + "useful when you want to remove and object or something from the" + " photo from its description or location. The input to this tool" + " should be a comma separated string of two, representing the" + " image_path and the object need to be removed. " ), ) def inference_remove(self, inputs): image_path, to_be_removed_txt = inputs.split(",") - return self.inference_replace(f"{image_path},{to_be_removed_txt},background") + return self.inference_replace( + f"{image_path},{to_be_removed_txt},background" + ) @tool( name="Replace Something From The Photo", description=( - "useful when you want to replace an object from the object description or" - " location with another object from its description. The input to this tool" - " should be a comma separated string of three, representing the image_path," - " the object to be replaced, the object to be replaced with " + "useful when you want to replace an object from the object" + " description or location with another object from its description." + " The input to this tool should be a comma separated string of" + " three, representing the image_path, the object to be replaced," + " the object to be replaced with " ), ) def inference_replace(self, inputs): @@ -137,10 +142,10 @@ class InstructPix2Pix: @tool( name="Instruct Image Using Text", description=( - "useful when you want to the style of the image to be like the text. " - "like: make it look like a painting. or make it like a robot. " - "The input to this tool should be a comma separated string of two, " - "representing the image_path and the text. " + "useful when you want to the style of the image to be like the" + " text. like: make it look like a painting. or make it like a" + " robot. The input to this tool should be a comma separated string" + " of two, representing the image_path and the text. " ), ) def inference(self, inputs): @@ -149,14 +154,17 @@ class InstructPix2Pix: image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) original_image = Image.open(image_path) image = self.pipe( - text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 + text, + image=original_image, + num_inference_steps=40, + image_guidance_scale=1.2, ).images[0] updated_image_path = get_new_image_name(image_path, func_name="pix2pix") image.save(updated_image_path) logger.debug( - f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" - f" {text}, Output Image: {updated_image_path}" + f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct" + f" Text: {text}, Output Image: {updated_image_path}" ) return updated_image_path @@ -173,17 +181,18 @@ class Text2Image: self.pipe.to(device) self.a_prompt = "best quality, extremely detailed" self.n_prompt = ( - "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " - "fewer digits, cropped, worst quality, low quality" + "longbody, lowres, bad anatomy, bad hands, missing fingers, extra" + " digit, fewer digits, cropped, worst quality, low quality" ) @tool( name="Generate Image From User Input Text", description=( - "useful when you want to generate an image from a user input text and save" - " it to a file. like: generate an image of an object or something, or" - " generate an image that includes some objects. The input to this tool" - " should be a string, representing the text used to generate image. " + "useful when you want to generate an image from a user input text" + " and save it to a file. like: generate an image of an object or" + " something, or generate an image that includes some objects. The" + " input to this tool should be a string, representing the text used" + " to generate image. " ), ) def inference(self, text): @@ -205,7 +214,9 @@ class VisualQuestionAnswering: print("Initializing VisualQuestionAnswering to %s" % device) self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.device = device - self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") + self.processor = BlipProcessor.from_pretrained( + "Salesforce/blip-vqa-base" + ) self.model = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype ).to(self.device) @@ -213,10 +224,11 @@ class VisualQuestionAnswering: @tool( name="Answer Question About The Image", description=( - "useful when you need an answer for a question based on an image. like:" - " what is the background color of the last image, how many cats in this" - " figure, what is in this figure. The input to this tool should be a comma" - " separated string of two, representing the image_path and the question" + "useful when you need an answer for a question based on an image." + " like: what is the background color of the last image, how many" + " cats in this figure, what is in this figure. The input to this" + " tool should be a comma separated string of two, representing the" + " image_path and the question" ), ) def inference(self, inputs): @@ -229,8 +241,8 @@ class VisualQuestionAnswering: answer = self.processor.decode(out[0], skip_special_tokens=True) logger.debug( - f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" - f" Question: {question}, Output Answer: {answer}" + f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}," + f" Input Question: {question}, Output Answer: {answer}" ) return answer @@ -245,7 +257,8 @@ class ImageCaptioning(BaseHandler): "Salesforce/blip-image-captioning-base" ) self.model = BlipForConditionalGeneration.from_pretrained( - "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype + "Salesforce/blip-image-captioning-base", + torch_dtype=self.torch_dtype, ).to(self.device) def handle(self, filename: str): @@ -264,8 +277,8 @@ class ImageCaptioning(BaseHandler): out = self.model.generate(**inputs) description = self.processor.decode(out[0], skip_special_tokens=True) print( - f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" - f" {description}" + f"\nProcessed ImageCaptioning, Input Image: {filename}, Output" + f" Text: {description}" ) return IMAGE_PROMPT.format(filename=filename, description=description) diff --git a/swarms/tools/tool.py b/swarms/tools/tool.py index a5ad3f75..8ae3b7cd 100644 --- a/swarms/tools/tool.py +++ b/swarms/tools/tool.py @@ -7,7 +7,17 @@ import warnings from abc import abstractmethod from functools import partial from inspect import signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -27,7 +37,11 @@ from pydantic import ( root_validator, validate_arguments, ) -from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable +from langchain.schema.runnable import ( + Runnable, + RunnableConfig, + RunnableSerializable, +) class SchemaAnnotationError(TypeError): @@ -52,7 +66,11 @@ def _get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + return { + k: schema[k] + for k in valid_keys + if k not in ("run_manager", "callbacks") + } class _SchemaConfig: @@ -120,12 +138,11 @@ class ChildTool(BaseTool): ...""" name = cls.__name__ raise SchemaAnnotationError( - f"Tool definition for {name} must include valid type annotations" - " for argument 'args_schema' to behave as expected.\n" - "Expected annotation of 'Type[BaseModel]'" - f" but got '{args_schema_type}'.\n" - "Expected class looks like:\n" - f"{typehint_mandate}" + f"Tool definition for {name} must include valid type" + " annotations for argument 'args_schema' to behave as" + " expected.\nExpected annotation of 'Type[BaseModel]' but" + f" got '{args_schema_type}'.\nExpected class looks" + f" like:\n{typehint_mandate}" ) name: str @@ -147,7 +164,9 @@ class ChildTool(BaseTool): callbacks: Callbacks = Field(default=None, exclude=True) """Callbacks to be called during tool execution.""" - callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) + callback_manager: Optional[BaseCallbackManager] = Field( + default=None, exclude=True + ) """Deprecated. Please use callbacks instead.""" tags: Optional[List[str]] = None """Optional list of tags associated with the tool. Defaults to None @@ -244,7 +263,9 @@ class ChildTool(BaseTool): else: if input_args is not None: result = input_args.parse_obj(tool_input) - return {k: v for k, v in result.dict().items() if k in tool_input} + return { + k: v for k, v in result.dict().items() if k in tool_input + } return tool_input @root_validator() @@ -286,7 +307,9 @@ class ChildTool(BaseTool): *args, ) - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs( + self, tool_input: Union[str, Dict] + ) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -353,8 +376,9 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" + "Got unexpected type of `handle_tool_error`. Expected" + " bool, str or callable. Received:" + f" {self.handle_tool_error}" ) run_manager.on_tool_end( str(observation), color="red", name=self.name, **kwargs @@ -409,7 +433,9 @@ class ChildTool(BaseTool): # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) observation = ( - await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + await self._arun( + *tool_args, run_manager=run_manager, **tool_kwargs + ) if new_arg_supported else await self._arun(*tool_args, **tool_kwargs) ) @@ -428,8 +454,9 @@ class ChildTool(BaseTool): observation = self.handle_tool_error(e) else: raise ValueError( - "Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" + "Got unexpected type of `handle_tool_error`. Expected" + " bool, str or callable. Received:" + f" {self.handle_tool_error}" ) await run_manager.on_tool_end( str(observation), color="red", name=self.name, **kwargs @@ -484,14 +511,17 @@ class Tool(BaseTool): # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs( + self, tool_input: Union[str, Dict] + ) -> Tuple[Tuple, Dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) # For backwards compatibility. The tool must be run with a single input all_args = list(args) + list(kwargs.values()) if len(all_args) != 1: raise ToolException( - f"Too many arguments to single-input tool {self.name}. Args: {all_args}" + f"Too many arguments to single-input tool {self.name}. Args:" + f" {all_args}" ) return tuple(all_args), {} @@ -503,7 +533,9 @@ class Tool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") + new_argument_supported = signature(self.func).parameters.get( + "callbacks" + ) return ( self.func( *args, @@ -537,12 +569,18 @@ class Tool(BaseTool): ) else: return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), *args + None, + partial(self._run, run_manager=run_manager, **kwargs), + *args, ) # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Optional[Callable], description: str, **kwargs: Any + self, + name: str, + func: Optional[Callable], + description: str, + **kwargs: Any, ) -> None: """Initialize tool.""" super(Tool, self).__init__( @@ -617,7 +655,9 @@ class StructuredTool(BaseTool): ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") + new_argument_supported = signature(self.func).parameters.get( + "callbacks" + ) return ( self.func( *args, @@ -714,7 +754,9 @@ class StructuredTool(BaseTool): description = f"{name}{sig} - {description.strip()}" _args_schema = args_schema if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", source_function) + _args_schema = create_schema_from_function( + f"{name}Schema", source_function + ) return cls( name=name, func=func, @@ -772,7 +814,9 @@ def tool( async def ainvoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any ) -> Any: - return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) + return await runnable.ainvoke( + kwargs, {"callbacks": callbacks} + ) def invoke_wrapper( callbacks: Optional[Callbacks] = None, **kwargs: Any @@ -821,7 +865,11 @@ def tool( return _make_tool - if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): + if ( + len(args) == 2 + and isinstance(args[0], str) + and isinstance(args[1], Runnable) + ): return _make_with_name(args[0])(args[1]) elif len(args) == 1 and isinstance(args[0], str): # if the argument is a string, then we use the string as the tool name diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index d5ce3583..b8aca925 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -1,4 +1,4 @@ -from swarms.utils.display_markdown import display_markdown_message +from swarms.utils.markdown_message import display_markdown_message from swarms.utils.futures import execute_futures_dict from swarms.utils.code_interpreter import SubprocessCodeInterpreter from swarms.utils.parse_code import extract_code_in_backticks_in_string diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index 94c6f158..4adcb5cf 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -144,7 +144,9 @@ class Singleton(abc.ABCMeta, type): def __call__(cls, *args, **kwargs): """Call method for the singleton metaclass.""" if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs + ) return cls._instances[cls] diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index 86059a83..fc2f95f7 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -116,14 +116,20 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): # Most of the time it doesn't matter, but we should figure out why it happens frequently with: # applescript yield {"output": traceback.format_exc()} - yield {"output": f"Retrying... ({retry_count}/{max_retries})"} + yield { + "output": f"Retrying... ({retry_count}/{max_retries})" + } yield {"output": "Restarting process."} self.start_process() retry_count += 1 if retry_count > max_retries: - yield {"output": "Maximum retries reached. Could not execute code."} + yield { + "output": ( + "Maximum retries reached. Could not execute code." + ) + } return while True: @@ -132,7 +138,9 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): else: time.sleep(0.1) try: - output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds + output = self.output_queue.get( + timeout=0.3 + ) # Waits for 0.3 seconds yield output except queue.Empty: if self.done.is_set(): diff --git a/swarms/utils/decorators.py b/swarms/utils/decorators.py index 8a5a5d56..cf4a774c 100644 --- a/swarms/utils/decorators.py +++ b/swarms/utils/decorators.py @@ -31,7 +31,9 @@ def timing_decorator(func): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() - logging.info(f"{func.__name__} executed in {end_time - start_time} seconds") + logging.info( + f"{func.__name__} executed in {end_time - start_time} seconds" + ) return result return wrapper @@ -79,7 +81,9 @@ def synchronized_decorator(func): def deprecated_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning) + warnings.warn( + f"{func.__name__} is deprecated", category=DeprecationWarning + ) return func(*args, **kwargs) return wrapper diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py index 55a4e5d5..a5ffdf51 100644 --- a/swarms/utils/futures.py +++ b/swarms/utils/futures.py @@ -5,6 +5,8 @@ T = TypeVar("T") def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: - futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) + futures.wait( + fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED + ) return {key: future.result() for key, future in fs_dict.items()} diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index da822d1a..d9845543 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -113,8 +113,8 @@ class Logger: ) error_handler.setLevel(logging.ERROR) error_formatter = AutoGptFormatter( - "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s" - " %(message_no_color)s" + "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d" + " %(title)s %(message_no_color)s" ) error_handler.setFormatter(error_formatter) @@ -140,7 +140,12 @@ class Logger: self.chat_plugins = [] def typewriter_log( - self, title="", title_color="", content="", speak_text=False, level=logging.INFO + self, + title="", + title_color="", + content="", + speak_text=False, + level=logging.INFO, ): """ Logs a message to the typewriter. @@ -255,7 +260,9 @@ class Logger: if isinstance(message, list): message = " ".join(message) self.logger.log( - level, message, extra={"title": str(title), "color": str(title_color)} + level, + message, + extra={"title": str(title), "color": str(title_color)}, ) def set_level(self, level): @@ -284,12 +291,15 @@ class Logger: if not additionalText: additionalText = ( "Please ensure you've setup and configured everything" - " correctly. Read https://github.com/Torantulino/Auto-GPT#readme to " - "double check. You can also create a github issue or join the discord" + " correctly. Read" + " https://github.com/Torantulino/Auto-GPT#readme to double" + " check. You can also create a github issue or join the discord" " and ask there!" ) - self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText) + self.typewriter_log( + "DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText + ) def log_json(self, data: Any, file_name: str) -> None: """ @@ -367,7 +377,9 @@ class TypingConsoleHandler(logging.StreamHandler): print(word, end="", flush=True) if i < len(words) - 1: print(" ", end="", flush=True) - typing_speed = random.uniform(min_typing_speed, max_typing_speed) + typing_speed = random.uniform( + min_typing_speed, max_typing_speed + ) time.sleep(typing_speed) # type faster after each word min_typing_speed = min_typing_speed * 0.95 diff --git a/swarms/utils/main.py b/swarms/utils/main.py index a6c4fc34..73704552 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -201,7 +201,9 @@ def dim_multiline(message: str) -> str: lines = message.split("\n") if len(lines) <= 1: return lines[0] - return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) + return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to( + Color.black().bright() + ) # +=============================> ANSI Ending @@ -227,7 +229,9 @@ class AbstractUploader(ABC): class S3Uploader(AbstractUploader): - def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str): + def __init__( + self, accessKey: str, secretKey: str, region: str, bucket: str + ): self.accessKey = accessKey self.secretKey = secretKey self.region = region @@ -338,7 +342,9 @@ class FileHandler: self.handlers = handlers self.path = path - def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": + def register( + self, filetype: FileType, handler: BaseHandler + ) -> "FileHandler": self.handlers[filetype] = handler return self @@ -356,7 +362,9 @@ class FileHandler: def handle(self, url: str) -> str: try: - if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): + if url.startswith( + os.environ.get("SERVER", "http://localhost:8000") + ): local_filepath = url[ len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : ] @@ -387,4 +395,4 @@ class FileHandler: # => base end -# ===========================> \ No newline at end of file +# ===========================> diff --git a/swarms/utils/parse_code.py b/swarms/utils/parse_code.py index a2f346ea..9e3b8cb4 100644 --- a/swarms/utils/parse_code.py +++ b/swarms/utils/parse_code.py @@ -7,5 +7,7 @@ def extract_code_in_backticks_in_string(message: str) -> str: """ pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks - match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars + match = re.search( + pattern, message, re.DOTALL + ) # re.DOTALL to match newline chars return match.group(1).strip() if match else None diff --git a/swarms/utils/serializable.py b/swarms/utils/serializable.py index 8f0e5ccf..c7f9bc2c 100644 --- a/swarms/utils/serializable.py +++ b/swarms/utils/serializable.py @@ -109,9 +109,11 @@ class Serializable(BaseModel, ABC): "lc": 1, "type": "constructor", "id": [*self.lc_namespace, self.__class__.__name__], - "kwargs": lc_kwargs - if not secrets - else _replace_secrets(lc_kwargs, secrets), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), } def to_json_not_implemented(self) -> SerializedNotImplemented: diff --git a/tests/agents/omni_modal.py b/tests/agents/omni_modal.py index d106f66c..41aa050b 100644 --- a/tests/agents/omni_modal.py +++ b/tests/agents/omni_modal.py @@ -35,4 +35,6 @@ def test_omnimodalagent_run(omni_agent): def test_task_executor_initialization(omni_agent): - assert omni_agent.task_executor is not None, "TaskExecutor initialization failed" + assert ( + omni_agent.task_executor is not None + ), "TaskExecutor initialization failed" diff --git a/tests/memory/oceandb.py b/tests/memory/oceandb.py index 3e31afab..c74b7c15 100644 --- a/tests/memory/oceandb.py +++ b/tests/memory/oceandb.py @@ -30,7 +30,9 @@ def test_create_collection(): def test_create_collection_exception(): with patch("oceandb.Client") as MockClient: - MockClient.create_collection.side_effect = Exception("Create collection error") + MockClient.create_collection.side_effect = Exception( + "Create collection error" + ) db = OceanDB(MockClient) with pytest.raises(Exception) as e: db.create_collection("test", "modality") diff --git a/tests/memory/pinecone.py b/tests/memory/pinecone.py index bd037bef..106a6e81 100644 --- a/tests/memory/pinecone.py +++ b/tests/memory/pinecone.py @@ -6,7 +6,9 @@ api_key = os.getenv("PINECONE_API_KEY") or "" def test_init(): - with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + with patch("pinecone.init") as MockInit, patch( + "pinecone.Index" + ) as MockIndex: store = PineconeVectorStore( api_key=api_key, index_name="test_index", environment="test_env" ) diff --git a/tests/models/LLM.py b/tests/models/LLM.py index 20493519..a7ca149f 100644 --- a/tests/models/LLM.py +++ b/tests/models/LLM.py @@ -11,7 +11,9 @@ class TestLLM(unittest.TestCase): @patch.object(ChatOpenAI, "__init__", return_value=None) def setUp(self, mock_hf_init, mock_openai_init): self.llm_openai = LLM(openai_api_key="mock_openai_key") - self.llm_hf = LLM(hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token") + self.llm_hf = LLM( + hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token" + ) self.prompt = "Who won the FIFA World Cup in 1998?" def test_init(self): diff --git a/tests/models/anthropic.py b/tests/models/anthropic.py index e2447614..fecd3585 100644 --- a/tests/models/anthropic.py +++ b/tests/models/anthropic.py @@ -74,7 +74,9 @@ def test_anthropic_default_params(anthropic_instance): } -def test_anthropic_run(mock_anthropic_env, mock_requests_post, anthropic_instance): +def test_anthropic_run( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -98,7 +100,9 @@ def test_anthropic_run(mock_anthropic_env, mock_requests_post, anthropic_instanc ) -def test_anthropic_call(mock_anthropic_env, mock_requests_post, anthropic_instance): +def test_anthropic_call( + mock_anthropic_env, mock_requests_post, anthropic_instance +): mock_response = Mock() mock_response.json.return_value = {"completion": "Generated text"} mock_requests_post.return_value = mock_response @@ -193,18 +197,24 @@ def test_anthropic_convert_prompt(anthropic_instance): def test_anthropic_call_with_stop(anthropic_instance): - response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"]) + response = anthropic_instance( + "Translate to French.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" def test_anthropic_stream_with_stop(anthropic_instance): - generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"]) + generator = anthropic_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_anthropic_async_call_with_stop(anthropic_instance): - response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + response = anthropic_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Anthropic" diff --git a/tests/models/auto_temp.py b/tests/models/auto_temp.py index bd37e5bb..76cdc7c3 100644 --- a/tests/models/auto_temp.py +++ b/tests/models/auto_temp.py @@ -47,7 +47,9 @@ def test_run_auto_select(auto_temp_agent): def test_run_no_scores(auto_temp_agent): task = "Invalid task." temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" - with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor: + with ThreadPoolExecutor( + max_workers=auto_temp_agent.max_workers + ) as executor: with patch.object( executor, "submit", side_effect=[None, None, None, None, None, None] ): diff --git a/tests/models/bingchat.py b/tests/models/bingchat.py index ce3af99d..8f29f905 100644 --- a/tests/models/bingchat.py +++ b/tests/models/bingchat.py @@ -44,7 +44,9 @@ class TestBingChat(unittest.TestCase): original_image_gen = BingChat.ImageGen BingChat.ImageGen = MockImageGen - img_path = self.chat.create_img("Test prompt", auth_cookie="mock_auth_cookie") + img_path = self.chat.create_img( + "Test prompt", auth_cookie="mock_auth_cookie" + ) self.assertEqual(img_path, "./output/mock_image.png") BingChat.ImageGen = original_image_gen diff --git a/tests/models/bioclip.py b/tests/models/bioclip.py index 50a65570..54ab5bb9 100644 --- a/tests/models/bioclip.py +++ b/tests/models/bioclip.py @@ -127,7 +127,9 @@ def test_clip_multiple_images(clip_instance, sample_image_path): # Test model inference performance -def test_clip_inference_performance(clip_instance, sample_image_path, benchmark): +def test_clip_inference_performance( + clip_instance, sample_image_path, benchmark +): labels = [ "adenocarcinoma histopathology", "brain MRI", diff --git a/tests/models/biogpt.py b/tests/models/biogpt.py index f420292b..e1daa14e 100644 --- a/tests/models/biogpt.py +++ b/tests/models/biogpt.py @@ -46,7 +46,10 @@ def test_cell_biology_response(biogpt_instance): # 40. Test for a question about protein structure def test_protein_structure_response(biogpt_instance): - question = "What's the difference between alpha helix and beta sheet structures in proteins?" + question = ( + "What's the difference between alpha helix and beta sheet structures in" + " proteins?" + ) response = biogpt_instance(question) assert response and isinstance(response, str) diff --git a/tests/models/cohere.py b/tests/models/cohere.py index d1bea935..08a0e39d 100644 --- a/tests/models/cohere.py +++ b/tests/models/cohere.py @@ -49,7 +49,9 @@ def test_cohere_stream_api_error_handling(cohere_instance): cohere_instance.model = "base" cohere_instance.cohere_api_key = "invalid-api-key" with pytest.raises(Exception): - generator = cohere_instance.stream("Error handling with invalid API key.") + generator = cohere_instance.stream( + "Error handling with invalid API key." + ) for token in generator: pass @@ -94,13 +96,17 @@ def test_cohere_call_with_stop(cohere_instance): def test_cohere_stream_with_stop(cohere_instance): - generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"]) + generator = cohere_instance.stream( + "Write a story.", stop=["stop1", "stop2"] + ) for token in generator: assert isinstance(token, str) def test_cohere_async_call_with_stop(cohere_instance): - response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + response = cohere_instance.async_call( + "Tell me a joke.", stop=["stop1", "stop2"] + ) assert response == "Mocked Response from Cohere" @@ -187,14 +193,22 @@ def test_cohere_generate_with_embed_english_v2(cohere_instance): def test_cohere_generate_with_embed_english_light_v2(cohere_instance): cohere_instance.model = "embed-english-light-v2.0" - response = cohere_instance("Generate embeddings with English Light v2.0 model.") - assert response.startswith("Generated embeddings with English Light v2.0 model") + response = cohere_instance( + "Generate embeddings with English Light v2.0 model." + ) + assert response.startswith( + "Generated embeddings with English Light v2.0 model" + ) def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): cohere_instance.model = "embed-multilingual-v2.0" - response = cohere_instance("Generate embeddings with Multilingual v2.0 model.") - assert response.startswith("Generated embeddings with Multilingual v2.0 model") + response = cohere_instance( + "Generate embeddings with Multilingual v2.0 model." + ) + assert response.startswith( + "Generated embeddings with Multilingual v2.0 model" + ) def test_cohere_generate_with_embed_english_v3(cohere_instance): @@ -205,14 +219,22 @@ def test_cohere_generate_with_embed_english_v3(cohere_instance): def test_cohere_generate_with_embed_english_light_v3(cohere_instance): cohere_instance.model = "embed-english-light-v3.0" - response = cohere_instance("Generate embeddings with English Light v3.0 model.") - assert response.startswith("Generated embeddings with English Light v3.0 model") + response = cohere_instance( + "Generate embeddings with English Light v3.0 model." + ) + assert response.startswith( + "Generated embeddings with English Light v3.0 model" + ) def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): cohere_instance.model = "embed-multilingual-v3.0" - response = cohere_instance("Generate embeddings with Multilingual v3.0 model.") - assert response.startswith("Generated embeddings with Multilingual v3.0 model") + response = cohere_instance( + "Generate embeddings with Multilingual v3.0 model." + ) + assert response.startswith( + "Generated embeddings with Multilingual v3.0 model" + ) def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance): @@ -423,7 +445,9 @@ def test_cohere_representation_model_classification(cohere_instance): def test_cohere_representation_model_language_detection(cohere_instance): # Test using the Representation model for language detection cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language("Detect the language of this text.") + language = cohere_instance.detect_language( + "Detect the language of this text." + ) assert isinstance(language, str) @@ -447,7 +471,9 @@ def test_cohere_representation_model_multilingual_embedding(cohere_instance): assert len(embedding) > 0 -def test_cohere_representation_model_multilingual_classification(cohere_instance): +def test_cohere_representation_model_multilingual_classification( + cohere_instance, +): # Test using the Representation model for multilingual text classification cohere_instance.model = "embed-multilingual-v3.0" classification = cohere_instance.classify("Classify multilingual text.") @@ -456,7 +482,9 @@ def test_cohere_representation_model_multilingual_classification(cohere_instance assert "score" in classification -def test_cohere_representation_model_multilingual_language_detection(cohere_instance): +def test_cohere_representation_model_multilingual_language_detection( + cohere_instance, +): # Test using the Representation model for multilingual language detection cohere_instance.model = "embed-multilingual-v3.0" language = cohere_instance.detect_language( @@ -471,12 +499,17 @@ def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( # Test handling max tokens limit exceeded error for multilingual model cohere_instance.model = "embed-multilingual-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " multilingual model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) -def test_cohere_representation_model_multilingual_light_embedding(cohere_instance): +def test_cohere_representation_model_multilingual_light_embedding( + cohere_instance, +): # Test using the Representation model for multilingual light text embedding cohere_instance.model = "embed-multilingual-light-v3.0" embedding = cohere_instance.embed("Generate multilingual light embeddings.") @@ -484,10 +517,14 @@ def test_cohere_representation_model_multilingual_light_embedding(cohere_instanc assert len(embedding) > 0 -def test_cohere_representation_model_multilingual_light_classification(cohere_instance): +def test_cohere_representation_model_multilingual_light_classification( + cohere_instance, +): # Test using the Representation model for multilingual light text classification cohere_instance.model = "embed-multilingual-light-v3.0" - classification = cohere_instance.classify("Classify multilingual light text.") + classification = cohere_instance.classify( + "Classify multilingual light text." + ) assert isinstance(classification, dict) assert "class" in classification assert "score" in classification @@ -510,7 +547,10 @@ def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceede # Test handling max tokens limit exceeded error for multilingual light model cohere_instance.model = "embed-multilingual-light-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " multilingual light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -553,19 +593,26 @@ def test_cohere_representation_model_english_classification(cohere_instance): assert "score" in classification -def test_cohere_representation_model_english_language_detection(cohere_instance): +def test_cohere_representation_model_english_language_detection( + cohere_instance, +): # Test using the Representation model for English language detection cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language("Detect the language of English text.") + language = cohere_instance.detect_language( + "Detect the language of English text." + ) assert isinstance(language, str) -def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance): +def test_cohere_representation_model_english_max_tokens_limit_exceeded( + cohere_instance, +): # Test handling max tokens limit exceeded error for English model cohere_instance.model = "embed-english-v3.0" cohere_instance.max_tokens = 10 prompt = ( - "This is a test prompt that will exceed the max tokens limit for English model." + "This is a test prompt that will exceed the max tokens limit for" + " English model." ) with pytest.raises(ValueError): cohere_instance.embed(prompt) @@ -579,7 +626,9 @@ def test_cohere_representation_model_english_light_embedding(cohere_instance): assert len(embedding) > 0 -def test_cohere_representation_model_english_light_classification(cohere_instance): +def test_cohere_representation_model_english_light_classification( + cohere_instance, +): # Test using the Representation model for English light text classification cohere_instance.model = "embed-english-light-v3.0" classification = cohere_instance.classify("Classify English light text.") @@ -588,7 +637,9 @@ def test_cohere_representation_model_english_light_classification(cohere_instanc assert "score" in classification -def test_cohere_representation_model_english_light_language_detection(cohere_instance): +def test_cohere_representation_model_english_light_language_detection( + cohere_instance, +): # Test using the Representation model for English light language detection cohere_instance.model = "embed-english-light-v3.0" language = cohere_instance.detect_language( @@ -603,7 +654,10 @@ def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( # Test handling max tokens limit exceeded error for English light model cohere_instance.model = "embed-english-light-v3.0" cohere_instance.max_tokens = 10 - prompt = "This is a test prompt that will exceed the max tokens limit for English light model." + prompt = ( + "This is a test prompt that will exceed the max tokens limit for" + " English light model." + ) with pytest.raises(ValueError): cohere_instance.embed(prompt) diff --git a/tests/models/dalle3.py b/tests/models/dalle3.py index a23d077e..9b7cf0e1 100644 --- a/tests/models/dalle3.py +++ b/tests/models/dalle3.py @@ -62,7 +62,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): def test_dalle3_create_variations_success(dalle3, mock_openai_client): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) @@ -84,7 +86,9 @@ def test_dalle3_create_variations_success(dalle3, mock_openai_client): def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError @@ -186,7 +190,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client): assert str(excinfo.value) == expected_error_message -def test_dalle3_create_variations_with_invalid_image_url(dalle3, mock_openai_client): +def test_dalle3_create_variations_with_invalid_image_url( + dalle3, mock_openai_client +): # Arrange img_url = "https://invalid-image-url.com" expected_error_message = "Error running Dalle3: Invalid image URL" @@ -228,7 +234,9 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): # Simulate a retry scenario mock_openai_client.images.generate.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -242,14 +250,18 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) # Simulate a retry scenario mock_openai_client.images.create_variation.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_variation_url)]), ] @@ -280,9 +292,13 @@ def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys): assert expected_error_message in captured.err -def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, capsys): +def test_dalle3_create_variations_exception_logging( + dalle3, mock_openai_client, capsys +): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError @@ -323,7 +339,9 @@ def test_dalle3_call_no_api_key(): def test_dalle3_create_variations_no_api_key(): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) dalle3 = Dalle3(api_key=None) expected_error_message = "Error running Dalle3: API Key is missing" @@ -334,7 +352,9 @@ def test_dalle3_create_variations_no_api_key(): assert str(excinfo.value) == expected_error_message -def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client): +def test_dalle3_call_with_retry_max_retries_exceeded( + dalle3, mock_openai_client +): # Arrange task = "A painting of a dog" @@ -354,7 +374,9 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded( dalle3, mock_openai_client ): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) # Simulate max retries exceeded mock_openai_client.images.create_variation.side_effect = OpenAIError( @@ -377,7 +399,9 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): # Simulate success after a retry mock_openai_client.images.generate.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_img_url)]), ] @@ -389,16 +413,22 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): assert mock_openai_client.images.generate.call_count == 2 -def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client): +def test_dalle3_create_variations_retry_with_success( + dalle3, mock_openai_client +): # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) expected_variation_url = ( "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" ) # Simulate success after a retry mock_openai_client.images.create_variation.side_effect = [ - OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ), Mock(data=[Mock(url=expected_variation_url)]), ] diff --git a/tests/models/distill_whisper.py b/tests/models/distill_whisper.py index d83caf62..6f95a0e3 100644 --- a/tests/models/distill_whisper.py +++ b/tests/models/distill_whisper.py @@ -44,7 +44,9 @@ async def test_async_transcribe_audio_file(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) - transcription = await distil_whisper_model.async_transcribe(audio_file_path) + transcription = await distil_whisper_model.async_transcribe( + audio_file_path + ) os.remove(audio_file_path) assert isinstance(transcription, str) @@ -62,7 +64,9 @@ def test_transcribe_audio_data(distil_whisper_model): @pytest.mark.asyncio async def test_async_transcribe_audio_data(distil_whisper_model): test_data = np.random.rand(16000) # Simulated audio data (1 second) - transcription = await distil_whisper_model.async_transcribe(test_data.tobytes()) + transcription = await distil_whisper_model.async_transcribe( + test_data.tobytes() + ) assert isinstance(transcription, str) assert transcription.strip() != "" @@ -73,7 +77,9 @@ def test_real_time_transcribe(distil_whisper_model, capsys): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: audio_file_path = create_audio_file(test_data, 16000, audio_file.name) - distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) + distil_whisper_model.real_time_transcribe( + audio_file_path, chunk_duration=1 + ) os.remove(audio_file_path) @@ -82,7 +88,9 @@ def test_real_time_transcribe(distil_whisper_model, capsys): assert "Chunk" in captured.out -def test_real_time_transcribe_audio_file_not_found(distil_whisper_model, capsys): +def test_real_time_transcribe_audio_file_not_found( + distil_whisper_model, capsys +): audio_file_path = "non_existent_audio.wav" distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1) @@ -145,7 +153,9 @@ def test_create_audio_file(): test_data = np.random.rand(16000) # Simulated audio data (1 second) sample_rate = 16000 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file: - audio_file_path = create_audio_file(test_data, sample_rate, audio_file.name) + audio_file_path = create_audio_file( + test_data, sample_rate, audio_file.name + ) assert os.path.exists(audio_file_path) os.remove(audio_file_path) @@ -215,7 +225,9 @@ async def test_async_transcription_success(whisper_model, audio_file_path): @pytest.mark.asyncio -async def test_async_transcription_failure(whisper_model, invalid_audio_file_path): +async def test_async_transcription_failure( + whisper_model, invalid_audio_file_path +): with pytest.raises(Exception): await whisper_model.async_transcribe(invalid_audio_file_path) @@ -248,17 +260,22 @@ def mocked_model(monkeypatch): model_mock, ) monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock + "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", + processor_mock, ) return model_mock, processor_mock @pytest.mark.asyncio -async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path): +async def test_async_transcribe_with_mocked_model( + mocked_model, audio_file_path +): model_mock, processor_mock = mocked_model # Set up what the mock should return when it's called model_mock.return_value.generate.return_value = torch.tensor([[0]]) - processor_mock.return_value.batch_decode.return_value = ["mocked transcription"] + processor_mock.return_value.batch_decode.return_value = [ + "mocked transcription" + ] model_wrapper = DistilWhisperModel() transcription = await model_wrapper.async_transcribe(audio_file_path) assert transcription == "mocked transcription" diff --git a/tests/models/elevenlab.py b/tests/models/elevenlab.py index 7dbcf2ea..986ce937 100644 --- a/tests/models/elevenlab.py +++ b/tests/models/elevenlab.py @@ -61,10 +61,14 @@ def test_run_text_to_speech_with_mock(eleven_labs_tool): def test_run_text_to_speech_error_handling(eleven_labs_tool): with patch("your_module._import_elevenlabs") as mock_elevenlabs: mock_elevenlabs_instance = mock_elevenlabs.return_value - mock_elevenlabs_instance.generate.side_effect = Exception("Test Exception") + mock_elevenlabs_instance.generate.side_effect = Exception( + "Test Exception" + ) with pytest.raises( RuntimeError, - match="Error while running ElevenLabsText2SpeechTool: Test Exception", + match=( + "Error while running ElevenLabsText2SpeechTool: Test Exception" + ), ): eleven_labs_tool.run(SAMPLE_TEXT) diff --git a/tests/models/fuyu.py b/tests/models/fuyu.py index 9a26dbfb..a70cb42a 100644 --- a/tests/models/fuyu.py +++ b/tests/models/fuyu.py @@ -75,7 +75,9 @@ def test_tokenizer_type(fuyu_instance): def test_processor_has_image_processor_and_tokenizer(fuyu_instance): - assert fuyu_instance.processor.image_processor == fuyu_instance.image_processor + assert ( + fuyu_instance.processor.image_processor == fuyu_instance.image_processor + ) assert fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer diff --git a/tests/models/gpt4v.py b/tests/models/gpt4v.py index 23e97d03..8532d313 100644 --- a/tests/models/gpt4v.py +++ b/tests/models/gpt4v.py @@ -4,7 +4,12 @@ from unittest.mock import Mock import pytest from dotenv import load_dotenv -from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout +from requests.exceptions import ( + ConnectionError, + HTTPError, + RequestException, + Timeout, +) from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse @@ -173,7 +178,11 @@ def test_gpt4vision_call_retry_with_success_after_timeout( Timeout("Request timed out"), { "choices": [ - {"message": {"content": {"text": "A description of the image."}}} + { + "message": { + "content": {"text": "A description of the image."} + } + } ], }, ] @@ -200,16 +209,18 @@ def test_gpt4vision_process_img(): assert img_data.startswith("/9j/") # Base64-encoded image data -def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client): +def test_gpt4vision_call_single_task_single_image( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." expected_response = GPT4VisionResponse(answer="A description of the image.") - mock_openai_client.chat.completions.create.return_value.choices[ - 0 - ].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[0].text = ( + expected_response.answer + ) # Act response = gpt4vision(img_url, [task]) @@ -219,16 +230,21 @@ def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client mock_openai_client.chat.completions.create.assert_called_once() -def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_client): +def test_gpt4vision_call_single_task_multiple_images( + gpt4vision, mock_openai_client +): # Arrange - img_urls = ["https://example.com/image1.jpg", "https://example.com/image2.jpg"] + img_urls = [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + ] task = "Describe these images." expected_response = GPT4VisionResponse(answer="Descriptions of the images.") - mock_openai_client.chat.completions.create.return_value.choices[ - 0 - ].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[0].text = ( + expected_response.answer + ) # Act response = gpt4vision(img_urls, [task]) @@ -238,7 +254,9 @@ def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_cli mock_openai_client.chat.completions.create.assert_called_once() -def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): +def test_gpt4vision_call_multiple_tasks_single_image( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" tasks = ["Describe this image.", "What's in this picture?"] @@ -249,7 +267,9 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli ] def create_mock_response(response): - return {"choices": [{"message": {"content": {"text": response.answer}}}]} + return { + "choices": [{"message": {"content": {"text": response.answer}}}] + } mock_openai_client.chat.completions.create.side_effect = [ create_mock_response(response) for response in expected_responses @@ -279,7 +299,11 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli mock_openai_client.chat.completions.create.side_effect = [ { "choices": [ - {"message": {"content": {"text": expected_responses[i].answer}}} + { + "message": { + "content": {"text": expected_responses[i].answer} + } + } ] } for i in range(len(expected_responses)) @@ -295,7 +319,9 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli ) # Should be called only once -def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client): +def test_gpt4vision_call_multiple_tasks_multiple_images( + gpt4vision, mock_openai_client +): # Arrange img_urls = [ "https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", @@ -328,7 +354,9 @@ def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = HTTPError("HTTP Error") + mock_openai_client.chat.completions.create.side_effect = HTTPError( + "HTTP Error" + ) # Act and Assert with pytest.raises(HTTPError): diff --git a/tests/models/hf.py b/tests/models/hf.py index ab3b648d..d3ff9a04 100644 --- a/tests/models/hf.py +++ b/tests/models/hf.py @@ -26,7 +26,10 @@ def mock_bitsandbytesconfig(): @pytest.fixture def hugging_face_llm( - mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig + mock_torch, + mock_autotokenizer, + mock_automodelforcausallm, + mock_bitsandbytesconfig, ): HuggingFaceLLM.torch = mock_torch HuggingFaceLLM.AutoTokenizer = mock_autotokenizer @@ -70,8 +73,12 @@ def test_init_with_quantize( def test_generate_text(hugging_face_llm): prompt_text = "test prompt" expected_output = "test output" - hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0]) # Mock tensor - hugging_face_llm.model.generate.return_value = torch.tensor([0]) # Mock tensor + hugging_face_llm.tokenizer.encode.return_value = torch.tensor( + [0] + ) # Mock tensor + hugging_face_llm.model.generate.return_value = torch.tensor( + [0] + ) # Mock tensor hugging_face_llm.tokenizer.decode.return_value = expected_output output = hugging_face_llm.generate_text(prompt_text) diff --git a/tests/models/huggingface.py b/tests/models/huggingface.py index 9a27054a..62261b9c 100644 --- a/tests/models/huggingface.py +++ b/tests/models/huggingface.py @@ -84,7 +84,9 @@ def test_llm_initialization_params(model_id, max_length): assert instance.max_length == max_length else: instance = HuggingfaceLLM(model_id=model_id) - assert instance.max_length == 500 # Assuming 500 is the default max_length + assert ( + instance.max_length == 500 + ) # Assuming 500 is the default max_length # Test for setting an invalid device @@ -180,7 +182,8 @@ def test_llm_long_input_warning(mock_warning, llm_instance): # Test for run method behavior when model raises an exception @patch( - "swarms.models.huggingface.HuggingfaceLLM._model.generate", side_effect=RuntimeError + "swarms.models.huggingface.HuggingfaceLLM._model.generate", + side_effect=RuntimeError, ) def test_llm_run_model_exception(mock_generate, llm_instance): with pytest.raises(RuntimeError): @@ -191,7 +194,9 @@ def test_llm_run_model_exception(mock_generate, llm_instance): @patch("torch.cuda.is_available", return_value=False) def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance): with pytest.raises(EnvironmentError): - llm_instance.set_device("cuda") # Attempt to set CUDA when it's not available + llm_instance.set_device( + "cuda" + ) # Attempt to set CUDA when it's not available # Test for proper cleanup after model use (releasing resources) @@ -225,7 +230,9 @@ def test_llm_multilingual_input(mock_run, llm_instance): mock_run.return_value = "mocked multilingual output" multilingual_input = "Bonjour, ceci est un test multilingue." result = llm_instance.run(multilingual_input) - assert isinstance(result, str) # Simple check to ensure output is string type + assert isinstance( + result, str + ) # Simple check to ensure output is string type # Test caching mechanism to prevent re-running the same inputs diff --git a/tests/models/idefics.py b/tests/models/idefics.py index 610657bd..2ee9f010 100644 --- a/tests/models/idefics.py +++ b/tests/models/idefics.py @@ -1,7 +1,11 @@ import pytest from unittest.mock import patch import torch -from swarms.models.idefics import Idefics, IdeficsForVisionText2Text, AutoProcessor +from swarms.models.idefics import ( + Idefics, + IdeficsForVisionText2Text, + AutoProcessor, +) @pytest.fixture @@ -30,7 +34,8 @@ def test_init_default(idefics_instance): ) def test_init_device(device, expected): with patch( - "torch.cuda.is_available", return_value=True if expected == "cuda" else False + "torch.cuda.is_available", + return_value=True if expected == "cuda" else False, ): instance = Idefics(device=device) assert instance.device == expected @@ -39,9 +44,9 @@ def test_init_device(device, expected): # Test `run` method def test_run(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -54,9 +59,9 @@ def test_run(idefics_instance): # Test `__call__` method (using the same logic as run for simplicity) def test_call(idefics_instance): prompts = [["User: Test"]] - with patch.object(idefics_instance, "processor") as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: + with patch.object( + idefics_instance, "processor" + ) as mock_processor, patch.object(idefics_instance, "model") as mock_model: mock_processor.return_value = {"input_ids": torch.tensor([1, 2, 3])} mock_model.generate.return_value = torch.tensor([1, 2, 3]) mock_processor.batch_decode.return_value = ["Test"] @@ -85,7 +90,9 @@ def test_set_checkpoint(idefics_instance): ) as mock_from_pretrained, patch.object(AutoProcessor, "from_pretrained"): idefics_instance.set_checkpoint(new_checkpoint) - mock_from_pretrained.assert_called_with(new_checkpoint, torch_dtype=torch.bfloat16) + mock_from_pretrained.assert_called_with( + new_checkpoint, torch_dtype=torch.bfloat16 + ) # Test `set_device` method diff --git a/tests/models/kosmos.py b/tests/models/kosmos.py index 11d224d1..aaa756a3 100644 --- a/tests/models/kosmos.py +++ b/tests/models/kosmos.py @@ -45,7 +45,9 @@ def test_get_image(mock_image_request): # Test multimodal grounding def test_multimodal_grounding(mock_image_request): kosmos = Kosmos() - kosmos.multimodal_grounding("Find the red apple in the image.", TEST_IMAGE_URL) + kosmos.multimodal_grounding( + "Find the red apple in the image.", TEST_IMAGE_URL + ) # TODO: Validate the result if possible @@ -117,7 +119,9 @@ def test_multimodal_grounding(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension(kosmos): - kosmos.referring_expression_comprehension("Show me the green bottle.", IMG_URL2) + kosmos.referring_expression_comprehension( + "Show me the green bottle.", IMG_URL2 + ) @pytest.mark.usefixtures("mock_request_get") @@ -147,7 +151,9 @@ def test_multimodal_grounding_2(kosmos): @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension_2(kosmos): - kosmos.referring_expression_comprehension("Where is the water bottle?", IMG_URL3) + kosmos.referring_expression_comprehension( + "Where is the water bottle?", IMG_URL3 + ) @pytest.mark.usefixtures("mock_request_get") diff --git a/tests/models/kosmos2.py b/tests/models/kosmos2.py index 2ff01092..1ad824cc 100644 --- a/tests/models/kosmos2.py +++ b/tests/models/kosmos2.py @@ -36,7 +36,10 @@ def test_kosmos2_with_sample_image(kosmos2, sample_image): # Mocked extract_entities function for testing def mock_extract_entities(text): - return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + return [ + ("entity1", (0.1, 0.2, 0.3, 0.4)), + ("entity2", (0.5, 0.6, 0.7, 0.8)), + ] # Mocked process_entities_to_detections function for testing @@ -54,7 +57,9 @@ def test_kosmos2_with_mocked_extraction_and_detection( ): monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) monkeypatch.setattr( - kosmos2, "process_entities_to_detections", mock_process_entities_to_detections + kosmos2, + "process_entities_to_detections", + mock_process_entities_to_detections, ) detections = kosmos2(img=sample_image) @@ -234,7 +239,12 @@ def test_kosmos2_with_entities_containing_special_characters( kosmos2, sample_image, monkeypatch ): def mock_extract_entities(text): - return [("entity1 with special characters (ü, ö, etc.)", (0.1, 0.2, 0.3, 0.4))] + return [ + ( + "entity1 with special characters (ü, ö, etc.)", + (0.1, 0.2, 0.3, 0.4), + ) + ] monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) detections = kosmos2(img=sample_image) @@ -252,7 +262,10 @@ def test_kosmos2_with_image_containing_multiple_objects( kosmos2, sample_image, monkeypatch ): def mock_extract_entities(text): - return [("entity1", (0.1, 0.2, 0.3, 0.4)), ("entity2", (0.5, 0.6, 0.7, 0.8))] + return [ + ("entity1", (0.1, 0.2, 0.3, 0.4)), + ("entity2", (0.5, 0.6, 0.7, 0.8)), + ] monkeypatch.setattr(kosmos2, "extract_entities", mock_extract_entities) detections = kosmos2(img=sample_image) @@ -266,7 +279,9 @@ def test_kosmos2_with_image_containing_multiple_objects( # Test Kosmos2 with image containing no objects -def test_kosmos2_with_image_containing_no_objects(kosmos2, sample_image, monkeypatch): +def test_kosmos2_with_image_containing_no_objects( + kosmos2, sample_image, monkeypatch +): def mock_extract_entities(text): return [] diff --git a/tests/models/llama_function_caller.py b/tests/models/llama_function_caller.py index c54c264b..c38b2267 100644 --- a/tests/models/llama_function_caller.py +++ b/tests/models/llama_function_caller.py @@ -72,7 +72,9 @@ def test_llama_custom_function_invalid_arguments(llama_caller): # Test streaming with custom runtime def test_llama_custom_runtime(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) user_prompt = "Tell me about the tallest mountain in the world." response = llama_caller(user_prompt) @@ -83,7 +85,9 @@ def test_llama_custom_runtime(): # Test caching functionality def test_llama_cache(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) # Perform a request to populate the cache @@ -99,7 +103,9 @@ def test_llama_cache(): # Test response length within max_tokens limit def test_llama_response_length(): llama_caller = LlamaFunctionCaller( - model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda" + model_id="Your-Model-ID", + cache_dir="Your-Cache-Directory", + runtime="cuda", ) # Generate a long prompt diff --git a/tests/models/nougat.py b/tests/models/nougat.py index e61a45af..ac972e07 100644 --- a/tests/models/nougat.py +++ b/tests/models/nougat.py @@ -21,7 +21,9 @@ def test_nougat_default_initialization(setup_nougat): def test_nougat_custom_initialization(): - nougat = Nougat(model_name_or_path="custom_path", min_length=10, max_new_tokens=50) + nougat = Nougat( + model_name_or_path="custom_path", min_length=10, max_new_tokens=50 + ) assert nougat.model_name_or_path == "custom_path" assert nougat.min_length == 10 assert nougat.max_new_tokens == 50 @@ -98,7 +100,8 @@ def mock_processor_and_model(): with patch( "transformers.NougatProcessor.from_pretrained", return_value=Mock() ), patch( - "transformers.VisionEncoderDecoderModel.from_pretrained", return_value=Mock() + "transformers.VisionEncoderDecoderModel.from_pretrained", + return_value=Mock(), ): yield diff --git a/tests/models/revgptv1.py b/tests/models/revgptv1.py index 12ceeea0..5908b64e 100644 --- a/tests/models/revgptv1.py +++ b/tests/models/revgptv1.py @@ -19,7 +19,10 @@ class TestRevChatGPT(unittest.TestCase): self.assertLess(self.model.end_time - self.model.start_time, 60) def test_generate_summary(self): - text = "This is a sample text to summarize. It has multiple sentences and details. The summary should be concise." + text = ( + "This is a sample text to summarize. It has multiple sentences and" + " details. The summary should be concise." + ) summary = self.model.generate_summary(text) self.assertLess(len(summary), len(text) / 2) @@ -60,7 +63,9 @@ class TestRevChatGPT(unittest.TestCase): convo_id = "123" title = "New Title" self.model.chatbot.change_title(convo_id, title) - self.assertEqual(self.model.chatbot.get_msg_history(convo_id)["title"], title) + self.assertEqual( + self.model.chatbot.get_msg_history(convo_id)["title"], title + ) def test_delete_conversation(self): convo_id = "123" @@ -76,7 +81,9 @@ class TestRevChatGPT(unittest.TestCase): def test_rollback_conversation(self): original_convo_id = self.model.chatbot.conversation_id self.model.chatbot.rollback_conversation(1) - self.assertNotEqual(original_convo_id, self.model.chatbot.conversation_id) + self.assertNotEqual( + original_convo_id, self.model.chatbot.conversation_id + ) if __name__ == "__main__": diff --git a/tests/models/speech_t5.py b/tests/models/speech_t5.py index 4e5f4cb1..f4d21a30 100644 --- a/tests/models/speech_t5.py +++ b/tests/models/speech_t5.py @@ -17,7 +17,9 @@ def test_speecht5_init(speecht5_model): assert isinstance(speecht5_model.processor, SpeechT5.processor.__class__) assert isinstance(speecht5_model.model, SpeechT5.model.__class__) assert isinstance(speecht5_model.vocoder, SpeechT5.vocoder.__class__) - assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) def test_speecht5_call(speecht5_model): @@ -59,8 +61,12 @@ def test_speecht5_set_embeddings_dataset(speecht5_model): new_dataset_name = "Matthijs/cmu-arctic-xvectors-test" speecht5_model.set_embeddings_dataset(new_dataset_name) assert speecht5_model.dataset_name == new_dataset_name - assert isinstance(speecht5_model.embeddings_dataset, torch.utils.data.Dataset) - speecht5_model.set_embeddings_dataset(old_dataset_name) # Restore original dataset + assert isinstance( + speecht5_model.embeddings_dataset, torch.utils.data.Dataset + ) + speecht5_model.set_embeddings_dataset( + old_dataset_name + ) # Restore original dataset def test_speecht5_get_sampling_rate(speecht5_model): diff --git a/tests/models/ssd_1b.py b/tests/models/ssd_1b.py index 7bd3154c..7a7a897f 100644 --- a/tests/models/ssd_1b.py +++ b/tests/models/ssd_1b.py @@ -19,18 +19,24 @@ def test_ssd1b_call(ssd1b_model): neg_prompt = "ugly, blurry, poor quality" image_url = ssd1b_model(task, neg_prompt) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" # Add more tests for various aspects of the class and methods # Example of a parameterized test for different tasks -@pytest.mark.parametrize("task", ["A painting of a cat", "A painting of a tree"]) +@pytest.mark.parametrize( + "task", ["A painting of a cat", "A painting of a tree"] +) def test_ssd1b_parameterized_task(ssd1b_model, task): image_url = ssd1b_model(task) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" # Example of a test using mocks to isolate units of code @@ -39,7 +45,9 @@ def test_ssd1b_with_mock(ssd1b_model, mocker): task = "A painting of a cat" image_url = ssd1b_model(task) assert isinstance(image_url, str) - assert image_url.startswith("https://") # Assuming it starts with "https://" + assert image_url.startswith( + "https://" + ) # Assuming it starts with "https://" def test_ssd1b_call_with_cache(ssd1b_model): diff --git a/tests/models/timm_model.py b/tests/models/timm_model.py index a3e62605..07f68b05 100644 --- a/tests/models/timm_model.py +++ b/tests/models/timm_model.py @@ -66,7 +66,9 @@ def test_call_parameterized(model_name, pretrained, in_chans): def test_get_supported_models_mock(): model_handler = TimmModel() - model_handler._get_supported_models = Mock(return_value=["resnet18", "resnet50"]) + model_handler._get_supported_models = Mock( + return_value=["resnet18", "resnet50"] + ) supported_models = model_handler._get_supported_models() assert supported_models == ["resnet18", "resnet50"] @@ -80,7 +82,9 @@ def test_create_model_mock(sample_model_info): def test_call_exception(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="invalid_model", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) with pytest.raises(Exception): model_handler.__call__(model_info, input_tensor) @@ -111,7 +115,9 @@ def test_environment_variable(): @pytest.mark.slow def test_marked_slow(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) output_shape = model_handler(model_info, input_tensor) assert isinstance(output_shape, torch.Size) @@ -136,7 +142,9 @@ def test_marked_parameterized(model_name, pretrained, in_chans): def test_exception_testing(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="invalid_model", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="invalid_model", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) with pytest.raises(Exception): model_handler.__call__(model_info, input_tensor) @@ -144,7 +152,9 @@ def test_exception_testing(): def test_parameterized_testing(): model_handler = TimmModel() - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) input_tensor = torch.randn(1, 3, 224, 224) output_shape = model_handler.__call__(model_info, input_tensor) assert isinstance(output_shape, torch.Size) @@ -153,7 +163,9 @@ def test_parameterized_testing(): def test_use_mocks_and_monkeypatching(): model_handler = TimmModel() model_handler._create_model = Mock(return_value=torch.nn.Module()) - model_info = TimmModelInfo(model_name="resnet18", pretrained=True, in_chans=3) + model_info = TimmModelInfo( + model_name="resnet18", pretrained=True, in_chans=3 + ) model = model_handler._create_model(model_info) assert isinstance(model, torch.nn.Module) diff --git a/tests/models/vilt.py b/tests/models/vilt.py index b376f41b..8dcdce88 100644 --- a/tests/models/vilt.py +++ b/tests/models/vilt.py @@ -33,7 +33,9 @@ def test_vilt_prediction(mock_image_open, mock_requests_get, vilt_instance): # 3. Test Exception Handling for network -@patch.object(requests, "get", side_effect=requests.RequestException("Network error")) +@patch.object( + requests, "get", side_effect=requests.RequestException("Network error") +) def test_vilt_network_exception(vilt_instance): with pytest.raises(requests.RequestException): vilt_instance( diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py index bcbd02e9..5fad3431 100644 --- a/tests/models/whisperx.py +++ b/tests/models/whisperx.py @@ -28,7 +28,9 @@ def test_speech_to_text_install(mock_run): # Mock pytube.YouTube and pytube.Streams for download tests @patch("pytube.YouTube") @patch.object(YouTube, "streams") -def test_speech_to_text_download_youtube_video(mock_streams, mock_youtube, temp_dir): +def test_speech_to_text_download_youtube_video( + mock_streams, mock_youtube, temp_dir +): # Mock YouTube and streams video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" mock_stream = mock_streams().filter().first() @@ -116,7 +118,11 @@ def test_speech_to_text_transcribe_whisperx_failure( @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_missing_segments( - mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model + mock_diarization, + mock_align, + mock_align_model, + mock_load_audio, + mock_load_model, ): # Mock whisperx functions to return incomplete output mock_load_model.return_value = mock_load_model @@ -142,7 +148,11 @@ def test_speech_to_text_transcribe_missing_segments( @patch("whisperx.align") @patch.object(whisperx.DiarizationPipeline, "__call__") def test_speech_to_text_transcribe_align_failure( - mock_diarization, mock_align, mock_align_model, mock_load_audio, mock_load_model + mock_diarization, + mock_align, + mock_align_model, + mock_load_audio, + mock_load_model, ): # Mock whisperx functions to raise an exception during align mock_load_model.return_value = mock_load_model diff --git a/tests/models/yi_200k.py b/tests/models/yi_200k.py index 72a6d1b2..6b179ca1 100644 --- a/tests/models/yi_200k.py +++ b/tests/models/yi_200k.py @@ -39,21 +39,27 @@ def test_yi34b_generate_text_with_temperature(yi34b_model, temperature): def test_yi34b_generate_text_with_invalid_prompt(yi34b_model): prompt = None # Invalid prompt - with pytest.raises(ValueError, match="Input prompt must be a non-empty string"): + with pytest.raises( + ValueError, match="Input prompt must be a non-empty string" + ): yi34b_model(prompt) def test_yi34b_generate_text_with_invalid_max_length(yi34b_model): prompt = "There's a place where time stands still." max_length = -1 # Invalid max_length - with pytest.raises(ValueError, match="max_length must be a positive integer"): + with pytest.raises( + ValueError, match="max_length must be a positive integer" + ): yi34b_model(prompt, max_length=max_length) def test_yi34b_generate_text_with_invalid_temperature(yi34b_model): prompt = "There's a place where time stands still." temperature = 2.0 # Invalid temperature - with pytest.raises(ValueError, match="temperature must be between 0.01 and 1.0"): + with pytest.raises( + ValueError, match="temperature must be between 0.01 and 1.0" + ): yi34b_model(prompt, temperature=temperature) @@ -74,7 +80,9 @@ def test_yi34b_generate_text_with_top_p(yi34b_model, top_p): def test_yi34b_generate_text_with_invalid_top_k(yi34b_model): prompt = "There's a place where time stands still." top_k = -1 # Invalid top_k - with pytest.raises(ValueError, match="top_k must be a non-negative integer"): + with pytest.raises( + ValueError, match="top_k must be a non-negative integer" + ): yi34b_model(prompt, top_k=top_k) @@ -86,7 +94,9 @@ def test_yi34b_generate_text_with_invalid_top_p(yi34b_model): @pytest.mark.parametrize("repitition_penalty", [1.0, 1.2, 1.5]) -def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_penalty): +def test_yi34b_generate_text_with_repitition_penalty( + yi34b_model, repitition_penalty +): prompt = "There's a place where time stands still." generated_text = yi34b_model(prompt, repitition_penalty=repitition_penalty) assert isinstance(generated_text, str) @@ -95,7 +105,9 @@ def test_yi34b_generate_text_with_repitition_penalty(yi34b_model, repitition_pen def test_yi34b_generate_text_with_invalid_repitition_penalty(yi34b_model): prompt = "There's a place where time stands still." repitition_penalty = 0.0 # Invalid repitition_penalty - with pytest.raises(ValueError, match="repitition_penalty must be a positive float"): + with pytest.raises( + ValueError, match="repitition_penalty must be a positive float" + ): yi34b_model(prompt, repitition_penalty=repitition_penalty) diff --git a/tests/structs/flow.py b/tests/structs/flow.py index edc4b9c7..84034a08 100644 --- a/tests/structs/flow.py +++ b/tests/structs/flow.py @@ -30,7 +30,9 @@ def basic_flow(mocked_llm): @pytest.fixture def flow_with_condition(mocked_llm): - return Flow(llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats) + return Flow( + llm=mocked_llm, max_loops=5, stopping_condition=stop_when_repeats + ) # Basic Tests @@ -61,7 +63,9 @@ def test_provide_feedback(basic_flow): @patch("time.sleep", return_value=None) # to speed up tests def test_run_without_stopping_condition(mocked_sleep, basic_flow): response = basic_flow.run("Test task") - assert response == "Test task" # since our mocked llm doesn't modify the response + assert ( + response == "Test task" + ) # since our mocked llm doesn't modify the response @patch("time.sleep", return_value=None) # to speed up tests @@ -108,7 +112,9 @@ def test_flow_with_custom_stopping_condition(mocked_llm): def stopping_condition(x): return "terminate" in x.lower() - flow = Flow(llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition) + flow = Flow( + llm=mocked_llm, max_loops=5, stopping_condition=stopping_condition + ) assert flow.stopping_condition("Please terminate now") assert not flow.stopping_condition("Continue the process") @@ -174,7 +180,9 @@ def test_save_different_memory(basic_flow, tmp_path): # Test the stopping condition check def test_check_stopping_condition(flow_with_condition): assert flow_with_condition._check_stopping_condition("Stop this process") - assert not flow_with_condition._check_stopping_condition("Continue the task") + assert not flow_with_condition._check_stopping_condition( + "Continue the task" + ) # Test without providing max loops (default value should be 5) @@ -370,7 +378,8 @@ def test_flow_autosave_path(flow_instance): def test_flow_response_length(flow_instance): # Test checking the length of the response response = flow_instance.run( - "Generate a 10,000 word long blog on mental clarity and the benefits of meditation." + "Generate a 10,000 word long blog on mental clarity and the benefits of" + " meditation." ) assert len(response) > flow_instance.get_response_length_threshold() @@ -550,7 +559,10 @@ def test_flow_rollback(flow_instance): assert flow_instance.get_user_messages() == state1["user_messages"] assert flow_instance.get_response_history() == state1["response_history"] assert flow_instance.get_conversation_log() == state1["conversation_log"] - assert flow_instance.is_dynamic_pacing_enabled() == state1["dynamic_pacing_enabled"] + assert ( + flow_instance.is_dynamic_pacing_enabled() + == state1["dynamic_pacing_enabled"] + ) assert ( flow_instance.get_response_length_threshold() == state1["response_length_threshold"] @@ -565,7 +577,9 @@ def test_flow_contextual_intent(flow_instance): # Test contextual intent handling flow_instance.add_context("location", "New York") flow_instance.add_context("time", "tomorrow") - response = flow_instance.run("What's the weather like in {location} at {time}?") + response = flow_instance.run( + "What's the weather like in {location} at {time}?" + ) assert "New York" in response assert "tomorrow" in response @@ -689,7 +703,9 @@ def test_flow_clear_injected_messages(flow_instance): def test_flow_disable_message_history(flow_instance): # Test disabling message history recording flow_instance.disable_message_history() - response = flow_instance.run("This message should not be recorded in history.") + response = flow_instance.run( + "This message should not be recorded in history." + ) assert "This message should not be recorded in history." in response assert len(flow_instance.get_message_history()) == 0 # History is empty @@ -800,16 +816,24 @@ def test_flow_input_validation(flow_instance): ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context(None, "value") # None key, should raise ValueError + flow_instance.update_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context("key", None) # None value, should raise ValueError + flow_instance.update_context( + "key", None + ) # None value, should raise ValueError def test_flow_conversation_reset(flow_instance): @@ -913,16 +937,24 @@ def test_flow_error_handling(flow_instance): ) # Invalid logger type, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context(None, "value") # None key, should raise ValueError + flow_instance.add_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.add_context("key", None) # None value, should raise ValueError + flow_instance.add_context( + "key", None + ) # None value, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context(None, "value") # None key, should raise ValueError + flow_instance.update_context( + None, "value" + ) # None key, should raise ValueError with pytest.raises(ValueError): - flow_instance.update_context("key", None) # None value, should raise ValueError + flow_instance.update_context( + "key", None + ) # None value, should raise ValueError def test_flow_context_operations(flow_instance): @@ -1089,7 +1121,9 @@ def test_flow_agent_history_prompt(flow_instance): system_prompt = "This is the system prompt." history = ["User: Hi", "AI: Hello"] - agent_history_prompt = flow_instance.agent_history_prompt(system_prompt, history) + agent_history_prompt = flow_instance.agent_history_prompt( + system_prompt, history + ) assert "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt diff --git a/tests/swarms/godmode.py b/tests/swarms/godmode.py index fa9e0c13..8f528026 100644 --- a/tests/swarms/godmode.py +++ b/tests/swarms/godmode.py @@ -16,7 +16,13 @@ def test_godmode_run(monkeypatch): godmode = GodMode(llms=[LLM] * 5) responses = godmode.run("task1") assert len(responses) == 5 - assert responses == ["response", "response", "response", "response", "response"] + assert responses == [ + "response", + "response", + "response", + "response", + "response", + ] @patch("builtins.print") diff --git a/tests/swarms/groupchat.py b/tests/swarms/groupchat.py index b25e7f91..56979d52 100644 --- a/tests/swarms/groupchat.py +++ b/tests/swarms/groupchat.py @@ -165,7 +165,9 @@ def test_groupchat_select_speaker(): # Simulate selecting the next speaker last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is agent2 assert next_speaker == agent2 @@ -183,7 +185,9 @@ def test_groupchat_underpopulated_group(): # Simulate selecting the next speaker in an underpopulated group last_speaker = agent1 - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker in an underpopulated group assert next_speaker == last_speaker @@ -207,7 +211,9 @@ def test_groupchat_max_rounds(): last_speaker = next_speaker # Try one more round, should stay with the last speaker - next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) # Ensure the next speaker is the same as the last speaker after reaching max rounds assert next_speaker == last_speaker diff --git a/tests/swarms/multi_agent_collab.py b/tests/swarms/multi_agent_collab.py index 3f7a0729..bea2c795 100644 --- a/tests/swarms/multi_agent_collab.py +++ b/tests/swarms/multi_agent_collab.py @@ -115,7 +115,10 @@ def test_repr(collaboration): def test_load(collaboration): - state = {"step": 5, "results": [{"agent": "Agent1", "response": "Response1"}]} + state = { + "step": 5, + "results": [{"agent": "Agent1", "response": "Response1"}], + } with open(collaboration.saved_file_path_name, "w") as file: json.dump(state, file) @@ -140,7 +143,8 @@ def test_save(collaboration, tmp_path): # Example of parameterized test for different selection functions @pytest.mark.parametrize( - "selection_function", [select_next_speaker_director, select_speaker_round_table] + "selection_function", + [select_next_speaker_director, select_speaker_round_table], ) def test_selection_functions(collaboration, selection_function): collaboration.select_next_speaker = selection_function diff --git a/tests/swarms/multi_agent_debate.py b/tests/swarms/multi_agent_debate.py index a2687f54..25e15ae5 100644 --- a/tests/swarms/multi_agent_debate.py +++ b/tests/swarms/multi_agent_debate.py @@ -1,5 +1,9 @@ from unittest.mock import patch -from swarms.swarms.multi_agent_debate import MultiAgentDebate, Worker, select_speaker +from swarms.swarms.multi_agent_debate import ( + MultiAgentDebate, + Worker, + select_speaker, +) def test_multiagentdebate_initialization(): @@ -57,5 +61,6 @@ def test_multiagentdebate_format_results(): formatted_results = multiagentdebate.format_results(results) assert ( formatted_results - == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded: Goodbye, world!" + == "Agent Agent 1 responded: Hello, world!\nAgent Agent 2 responded:" + " Goodbye, world!" ) diff --git a/tests/tools/base.py b/tests/tools/base.py index 4f7e2b4b..007719b2 100644 --- a/tests/tools/base.py +++ b/tests/tools/base.py @@ -75,7 +75,9 @@ def test_tool_ainvoke_with_coroutine(): async def async_function(input_data): return input_data - tool = Tool(name="test_tool", coroutine=async_function, description="Test tool") + tool = Tool( + name="test_tool", coroutine=async_function, description="Test tool" + ) result = tool.ainvoke("input_data") assert result == "input_data" @@ -560,7 +562,9 @@ class TestTool: with pytest.raises(ValueError): tool(no_doc_func) - def test_tool_raises_error_runnable_without_object_schema(self, mock_runnable): + def test_tool_raises_error_runnable_without_object_schema( + self, mock_runnable + ): with pytest.raises(ValueError): tool(mock_runnable) diff --git a/tests/utils/subprocess_code_interpreter.py b/tests/utils/subprocess_code_interpreter.py index ab7c748f..c15c0e16 100644 --- a/tests/utils/subprocess_code_interpreter.py +++ b/tests/utils/subprocess_code_interpreter.py @@ -4,7 +4,10 @@ import time import pytest -from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter +from swarms.utils.code_interpreter import ( + BaseCodeInterpreter, + SubprocessCodeInterpreter, +) @pytest.fixture @@ -53,7 +56,9 @@ def test_subprocess_code_interpreter_run_success(subprocess_code_interpreter): assert any("Hello, World!" in output.get("output", "") for output in result) -def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter): +def test_subprocess_code_interpreter_run_with_error( + subprocess_code_interpreter, +): code = 'print("Hello, World")\nraise ValueError("Error!")' result = list(subprocess_code_interpreter.run(code)) assert any("Error!" in output.get("output", "") for output in result) @@ -62,9 +67,14 @@ def test_subprocess_code_interpreter_run_with_error(subprocess_code_interpreter) def test_subprocess_code_interpreter_run_with_keyboard_interrupt( subprocess_code_interpreter, ): - code = 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise KeyboardInterrupt' + code = ( + 'import time\ntime.sleep(2)\nprint("Hello, World")\nraise' + " KeyboardInterrupt" + ) result = list(subprocess_code_interpreter.run(code)) - assert any("KeyboardInterrupt" in output.get("output", "") for output in result) + assert any( + "KeyboardInterrupt" in output.get("output", "") for output in result + ) def test_subprocess_code_interpreter_run_max_retries( @@ -78,7 +88,8 @@ def test_subprocess_code_interpreter_run_max_retries( code = 'print("Hello, World!")' result = list(subprocess_code_interpreter.run(code)) assert any( - "Maximum retries reached. Could not execute code." in output.get("output", "") + "Maximum retries reached. Could not execute code." + in output.get("output", "") for output in result ) @@ -112,19 +123,25 @@ def test_subprocess_code_interpreter_run_retry_on_error( # Import statements and fixtures from the previous code block -def test_subprocess_code_interpreter_line_postprocessor(subprocess_code_interpreter): +def test_subprocess_code_interpreter_line_postprocessor( + subprocess_code_interpreter, +): line = "This is a test line" processed_line = subprocess_code_interpreter.line_postprocessor(line) assert processed_line == line # No processing, should remain the same -def test_subprocess_code_interpreter_preprocess_code(subprocess_code_interpreter): +def test_subprocess_code_interpreter_preprocess_code( + subprocess_code_interpreter, +): code = 'print("Hello, World!")' preprocessed_code = subprocess_code_interpreter.preprocess_code(code) assert preprocessed_code == code # No preprocessing, should remain the same -def test_subprocess_code_interpreter_detect_active_line(subprocess_code_interpreter): +def test_subprocess_code_interpreter_detect_active_line( + subprocess_code_interpreter, +): line = "Active line: 5" active_line = subprocess_code_interpreter.detect_active_line(line) assert active_line == 5 @@ -172,7 +189,9 @@ def test_subprocess_code_interpreter_handle_stream_output_stdout( subprocess_code_interpreter, ): line = "This is a test line" - subprocess_code_interpreter.handle_stream_output(threading.current_thread(), False) + subprocess_code_interpreter.handle_stream_output( + threading.current_thread(), False + ) subprocess_code_interpreter.process.stdout.write(line + "\n") subprocess_code_interpreter.process.stdout.flush() time.sleep(0.1) @@ -184,7 +203,9 @@ def test_subprocess_code_interpreter_handle_stream_output_stderr( subprocess_code_interpreter, ): line = "This is an error line" - subprocess_code_interpreter.handle_stream_output(threading.current_thread(), True) + subprocess_code_interpreter.handle_stream_output( + threading.current_thread(), True + ) subprocess_code_interpreter.process.stderr.write(line + "\n") subprocess_code_interpreter.process.stderr.flush() time.sleep(0.1) @@ -207,12 +228,13 @@ def test_subprocess_code_interpreter_run_with_exception( subprocess_code_interpreter, capsys ): code = 'print("Hello, World!")' - subprocess_code_interpreter.start_cmd = ( - "nonexistent_command" # Force an exception during subprocess creation + subprocess_code_interpreter.start_cmd = ( # Force an exception during subprocess creation + "nonexistent_command" ) result = list(subprocess_code_interpreter.run(code)) assert any( - "Maximum retries reached" in output.get("output", "") for output in result + "Maximum retries reached" in output.get("output", "") + for output in result ) @@ -245,4 +267,6 @@ def test_subprocess_code_interpreter_run_with_unicode_characters( ): code = 'print("こんにちは、世界")' # Contains unicode characters result = list(subprocess_code_interpreter.run(code)) - assert any("こんにちは、世界" in output.get("output", "") for output in result) + assert any( + "こんにちは、世界" in output.get("output", "") for output in result + ) diff --git a/tests/workers/multi_model_worker.py b/tests/workers/multi_model_worker.py deleted file mode 100644 index f011d642..00000000 --- a/tests/workers/multi_model_worker.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -from unittest.mock import Mock -from swarms.agents.multi_modal_agent import ( - MultiModalVisualAgent, - MultiModalVisualAgentTool, -) - - -@pytest.fixture -def multimodal_agent(): - # Mock the MultiModalVisualAgent - mock_agent = Mock(spec=MultiModalVisualAgent) - mock_agent.run_text.return_value = "Expected output from agent" - return mock_agent - - -@pytest.fixture -def multimodal_agent_tool(multimodal_agent): - # Use the mocked MultiModalVisualAgent in the MultiModalVisualAgentTool - return MultiModalVisualAgentTool(multimodal_agent) - - -@pytest.mark.parametrize( - "text_input, expected_output", - [ - ("Hello, world!", "Expected output from agent"), - ("Another task", "Expected output from agent"), - ], -) -def test_run(multimodal_agent_tool, text_input, expected_output): - assert multimodal_agent_tool._run(text_input) == expected_output - - # You can also test if the MultiModalVisualAgent's run_text method was called with the right argument - multimodal_agent_tool.agent.run_text.assert_called_with(text_input) diff --git a/tests/workers/omni_worker.py b/tests/workers/omni_worker.py deleted file mode 100644 index 0557285d..00000000 --- a/tests/workers/omni_worker.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest - -from swarms.worker.omni_worker import OmniWorkerAgent - - -@pytest.fixture -def omni_worker(): - api_key = "test-key" - api_endpoint = "test-endpoint" - api_type = "test-type" - return OmniWorkerAgent(api_key, api_endpoint, api_type) - - -@pytest.mark.parametrize( - "data, expected_response", - [ - ( - { - "messages": ["Hello"], - "api_key": "key1", - "api_type": "type1", - "api_endpoint": "endpoint1", - }, - {"response": "Hello back from Huggingface!"}, - ), - ( - { - "messages": ["Goodbye"], - "api_key": "key2", - "api_type": "type2", - "api_endpoint": "endpoint2", - }, - {"response": "Goodbye from Huggingface!"}, - ), - ], -) -def test_chat_valid_data(mocker, omni_worker, data, expected_response): - mocker.patch( - "yourmodule.chat_huggingface", return_value=expected_response - ) # replace 'yourmodule' with actual module name - assert omni_worker.chat(data) == expected_response - - -@pytest.mark.parametrize( - "invalid_data", - [ - {"messages": ["Hello"]}, # missing api_key, api_type and api_endpoint - {"messages": ["Hello"], "api_key": "key1"}, # missing api_type and api_endpoint - { - "messages": ["Hello"], - "api_key": "key1", - "api_type": "type1", - }, # missing api_endpoint - ], -) -def test_chat_invalid_data(omni_worker, invalid_data): - with pytest.raises(ValueError): - omni_worker.chat(invalid_data) diff --git a/tests/workers/worker_agent_ultra.py b/tests/workers/worker_agent_ultra.py deleted file mode 100644 index 3cf112a2..00000000 --- a/tests/workers/worker_agent_ultra.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from unittest.mock import Mock -from swarms.workers.worker_agent_ultra import WorkerUltraNode # import your module here - - -def test_create_agent(): - mock_llm = Mock() - mock_toolset = {"test_toolset": Mock()} - mock_vectorstore = Mock() - worker = WorkerUltraNode(mock_llm, mock_toolset, mock_vectorstore) - worker.create_agent() - assert worker.agent is not None - - -@pytest.mark.parametrize("invalid_toolset", [123, "string", 0.45]) -def test_add_toolset_invalid(invalid_toolset): - mock_llm = Mock() - mock_toolset = {"test_toolset": Mock()} - mock_vectorstore = Mock() - worker = WorkerUltraNode(mock_llm, mock_toolset, mock_vectorstore) - with pytest.raises(TypeError): - worker.add_toolset(invalid_toolset) - - -@pytest.mark.parametrize("invalid_prompt", [123, None, "", []]) -def test_run_invalid_prompt(invalid_prompt): - mock_llm = Mock() - mock_toolset = {"test_toolset": Mock()} - mock_vectorstore = Mock() - worker = WorkerUltraNode(mock_llm, mock_toolset, mock_vectorstore) - with pytest.raises((TypeError, ValueError)): - worker.run(invalid_prompt) - - -def test_run_valid_prompt(mocker): - mock_llm = Mock() - mock_toolset = {"test_toolset": Mock()} - mock_vectorstore = Mock() - worker = WorkerUltraNode(mock_llm, mock_toolset, mock_vectorstore) - mocker.patch.object(worker, "create_agent") - assert worker.run("Test prompt") == "Task completed by WorkerNode" - - -def test_worker_node(): - worker = worker_ultra_node("test-key") - assert isinstance(worker, WorkerUltraNode) - - -def test_worker_node_no_key(): - with pytest.raises(ValueError): - worker_ultra_node(None) diff --git a/tests/workers/worker_node.py b/tests/workers/worker_node.py deleted file mode 100644 index e97b5023..00000000 --- a/tests/workers/worker_node.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from swarms.worker.worker_node import ( - WorkerNodeInitializer, - WorkerNode, -) # replace your_module with actual module name - - -# Mock Tool for testing -class MockTool(Tool): - pass - - -# Fixture for llm -@pytest.fixture -def mock_llm(): - return MagicMock() - - -# Fixture for vectorstore -@pytest.fixture -def mock_vectorstore(): - return MagicMock() - - -# Fixture for Tools -@pytest.fixture -def mock_tools(): - return [MockTool(), MockTool(), MockTool()] - - -# Fixture for WorkerNodeInitializer -@pytest.fixture -def worker_node(mock_llm, mock_tools, mock_vectorstore): - return WorkerNodeInitializer( - llm=mock_llm, tools=mock_tools, vectorstore=mock_vectorstore - ) - - -# Fixture for WorkerNode -@pytest.fixture -def mock_worker_node(): - return WorkerNode(openai_api_key="test_api_key") - - -# WorkerNodeInitializer Tests -def test_worker_node_init(worker_node): - assert worker_node.llm is not None - assert worker_node.tools is not None - assert worker_node.vectorstore is not None - - -def test_worker_node_create_agent(worker_node): - with patch.object(AutoGPT, "from_llm_and_tools") as mock_method: - worker_node.create_agent() - mock_method.assert_called_once() - - -def test_worker_node_add_tool(worker_node): - initial_tools_count = len(worker_node.tools) - new_tool = MockTool() - worker_node.add_tool(new_tool) - assert len(worker_node.tools) == initial_tools_count + 1 - - -def test_worker_node_run(worker_node): - with patch.object(worker_node.agent, "run") as mock_run: - worker_node.run(prompt="test prompt") - mock_run.assert_called_once() - - -# WorkerNode Tests -def test_worker_node_llm(mock_worker_node): - with patch.object(mock_worker_node, "initialize_llm") as mock_method: - mock_worker_node.initialize_llm(llm_class=MagicMock(), temperature=0.5) - mock_method.assert_called_once() - - -def test_worker_node_tools(mock_worker_node): - with patch.object(mock_worker_node, "initialize_tools") as mock_method: - mock_worker_node.initialize_tools(llm_class=MagicMock()) - mock_method.assert_called_once() - - -def test_worker_node_vectorstore(mock_worker_node): - with patch.object(mock_worker_node, "initialize_vectorstore") as mock_method: - mock_worker_node.initialize_vectorstore() - mock_method.assert_called_once() - - -def test_worker_node_create_worker_node(mock_worker_node): - with patch.object(mock_worker_node, "create_worker_node") as mock_method: - mock_worker_node.create_worker_node() - mock_method.assert_called_once() diff --git a/tests/workers/worker_ultra.py b/tests/workers/worker_ultra.py deleted file mode 100644 index b1485a28..00000000 --- a/tests/workers/worker_ultra.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest -from unittest.mock import Mock, patch -from swarms.workers.worker_agent_ultra import ( - WorkerUltraNode, - WorkerUltraNodeInitializer, -) - - -@pytest.fixture -def llm_mock(): - return Mock() - - -@pytest.fixture -def toolsets_mock(): - return Mock() - - -@pytest.fixture -def vectorstore_mock(): - return Mock() - - -@pytest.fixture -def worker_ultra_node(llm_mock, toolsets_mock, vectorstore_mock): - return WorkerUltraNode(llm_mock, toolsets_mock, vectorstore_mock) - - -def test_worker_ultra_node_create_agent(worker_ultra_node): - with patch("yourmodule.AutoGPT.from_llm_and_tools") as mock_method: - worker_ultra_node.create_agent() - mock_method.assert_called_once() - - -def test_worker_ultra_node_add_toolset(worker_ultra_node): - with pytest.raises(TypeError): - worker_ultra_node.add_toolset("wrong_toolset") - - -def test_worker_ultra_node_run(worker_ultra_node): - with patch.object(worker_ultra_node, "agent") as mock_agent: - mock_agent.run.return_value = None - result = worker_ultra_node.run("some prompt") - assert result == "Task completed by WorkerNode" - mock_agent.run.assert_called_once() - - -def test_worker_ultra_node_run_no_prompt(worker_ultra_node): - with pytest.raises(ValueError): - worker_ultra_node.run("") - - -@pytest.fixture -def worker_ultra_node_initializer(): - return WorkerUltraNodeInitializer("openai_api_key") - - -def test_worker_ultra_node_initializer_initialize_llm(worker_ultra_node_initializer): - with patch("yourmodule.ChatOpenAI") as mock_llm: - worker_ultra_node_initializer.initialize_llm(mock_llm) - mock_llm.assert_called_once() - - -def test_worker_ultra_node_initializer_initialize_toolsets( - worker_ultra_node_initializer, -): - with patch("yourmodule.Terminal"), patch("yourmodule.CodeEditor"), patch( - "yourmodule.RequestsGet" - ), patch("yourmodule.ExitConversation"): - toolsets = worker_ultra_node_initializer.initialize_toolsets() - assert len(toolsets) == 4 - - -def test_worker_ultra_node_initializer_initialize_vectorstore( - worker_ultra_node_initializer, -): - with patch("yourmodule.OpenAIEmbeddings"), patch( - "yourmodule.fauss.IndexFlatL2" - ), patch("yourmodule.FAISS"), patch("yourmodule.InMemoryDocstore"): - vectorstore = worker_ultra_node_initializer.initialize_vectorstore() - assert vectorstore is not None - - -def test_worker_ultra_node_initializer_create_worker_node( - worker_ultra_node_initializer, -): - with patch.object(worker_ultra_node_initializer, "initialize_llm"), patch.object( - worker_ultra_node_initializer, "initialize_toolsets" - ), patch.object(worker_ultra_node_initializer, "initialize_vectorstore"): - worker_node = worker_ultra_node_initializer.create_worker_node() - assert worker_node is not None