diff --git a/.github/actions/init_environment/action.yml b/.github/actions/init_environment/action.yml new file mode 100644 index 00000000..f2f9016c --- /dev/null +++ b/.github/actions/init_environment/action.yml @@ -0,0 +1,37 @@ +name: "Init Environment" +description: "Initialize environment for tests" +runs: + using: "composite" + steps: + - name: Checkout actions + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction --no-root --with test --with dev --all-extras + shell: bash + + - name: Activate venv + run: | + source .venv/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + shell: bash \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8a6f374c..b38491c2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,5 +1,3 @@ ---- -# This is a basic workflow to help you get started with Actions name: Lint diff --git a/.github/workflows/pr_request_checks.yml b/.github/workflows/pr_request_checks.yml index 6c9cb0b2..cfc01afb 100644 --- a/.github/workflows/pr_request_checks.yml +++ b/.github/workflows/pr_request_checks.yml @@ -1,4 +1,3 @@ ---- name: Pull Request Checks on: @@ -22,6 +21,7 @@ jobs: - name: Install dependencies run: | pip install -r requirements.txt + pip install swarms pip install pytest - name: Run tests and checks diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index aa3edc3e..adc0c5ef 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest swarms if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ae647f57..31cdfb93 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade swarms - python -m pip install flake8 pytest + python -m pip install flake8 pytest swarms if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 6a0f06c5..889774f0 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,4 +1,3 @@ ---- name: Upload Python Package on: # yamllint disable-line rule:truthy diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d9dafc76..baaadfc8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,3 @@ ---- name: test on: @@ -10,6 +9,7 @@ on: env: POETRY_VERSION: "1.4.2" +jobs: test: runs-on: ubuntu-latest strategy: @@ -30,7 +30,7 @@ env: python-version: ${{ matrix.python-version }} poetry-version: "1.4.2" cache-key: ${{ matrix.test_type }} - install-command: | + install-command: if [ "${{ matrix.test_type }}" == "core" ]; then echo "Running core tests, installing dependencies with poetry..." poetry install diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index ccb5a6b9..c858b5fe 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.x + python-version: "3.10" - name: Install dependencies run: | @@ -24,4 +24,4 @@ jobs: pip install pytest - name: Run unit tests - run: pytest \ No newline at end of file + run: pytest diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 1e4f0f1e..a858dae4 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -24,6 +24,7 @@ jobs: run: | pip install -r requirements.txt pip install pytest + pip install swarms - name: Run Python unit tests run: pytest diff --git a/README.md b/README.md index e52b872d..4bf0fc06 100644 --- a/README.md +++ b/README.md @@ -312,6 +312,97 @@ efficiency_analysis = efficiency_agent.run( factory_image, ) ``` + +### Gemini +- Deploy Gemini from Google with utmost reliability with our visual chain of thought prompt that enables more reliable responses +```python +import os + +from dotenv import load_dotenv + +from swarms.models import Gemini +from swarms.prompts.visual_cot import VISUAL_CHAIN_OF_THOUGHT + +# Load the environment variables +load_dotenv() + +# Get the API key from the environment +api_key = os.environ.get("GEMINI_API_KEY") + +# Initialize the language model +llm = Gemini( + gemini_api_key=api_key, + temperature=0.5, + max_tokens=1000, + system_prompt=VISUAL_CHAIN_OF_THOUGHT, +) + +# Initialize the task +task = "This is an eye test. What do you see?" +img = "playground/demos/multi_modal_chain_of_thought/eyetest.jpg" + +# Run the workflow on a task +out = llm.run(task=task, img=img) +print(out) +``` + + +### `Anthropic` +```python +# Import necessary modules and classes +from swarms.models import Anthropic + +# Initialize an instance of the Anthropic class +model = Anthropic( + anthropic_api_key="" +) + +# Using the run method +completion_1 = model.run("What is the capital of France?") +print(completion_1) + +# Using the __call__ method +completion_2 = model("How far is the moon from the earth?", stop=["miles", "km"]) +print(completion_2) + +``` + + +### `HuggingFaceLLM` +```python +from swarms.models import HuggingfaceLLM + +# Initialize with custom configuration +custom_config = { + "quantize": True, + "quantization_config": {"load_in_4bit": True}, + "verbose": True +} +inference = HuggingfaceLLM(model_id="NousResearch/Nous-Hermes-2-Vision-Alpha", **custom_config) + +# Generate text based on a prompt +prompt_text = "Create a list of known biggest risks of structural collapse with references" +generated_text = inference(prompt_text) +print(generated_text) +``` + +### Mixtral +- Utilize Mixtral in a very simple API, +- Utilize 4bit quantization for a increased speed and less memory usage +- Use Flash Attention 2.0 for increased speed and less memory usage +```python +from swarms.models import Mixtral + +# Initialize the Mixtral model with 4 bit and flash attention! +mixtral = Mixtral(load_in_4bit=True, use_flash_attention_2=True) + +# Generate text for a simple task +generated_text = mixtral.run("Generate a creative story.") + +# Print the generated text +print(generated_text) +``` + --- # Features 🤖 diff --git a/docs/swarms/memory/pinecone.md b/docs/swarms/memory/pinecone.md index 11f9a018..830d10fe 100644 --- a/docs/swarms/memory/pinecone.md +++ b/docs/swarms/memory/pinecone.md @@ -1,4 +1,4 @@ -# `PineconDB` Documentation +# `PineconeDB` Documentation ## Table of Contents diff --git a/docs/swarms/memory/weaviate.md b/docs/swarms/memory/weaviate.md index 7044f1ad..b23baedf 100644 --- a/docs/swarms/memory/weaviate.md +++ b/docs/swarms/memory/weaviate.md @@ -28,7 +28,7 @@ pip install swarms ## Initialization -To use the Weaviate API Client, you need to initialize an instance of the `WeaviateClient` class. Here are the parameters you can pass to the constructor: +To use the Weaviate API Client, you need to initialize an instance of the `WeaviateDB` class. Here are the parameters you can pass to the constructor: | Parameter | Type | Description | |----------------------|----------------|----------------------------------------------------------------------------------------------------------------------------------| @@ -43,12 +43,12 @@ To use the Weaviate API Client, you need to initialize an instance of the `Weavi | `additional_config` | Optional[weaviate.AdditionalConfig] | Additional configuration for the client. (Optional) | | `connection_params` | Dict[str, Any] | Dictionary containing connection parameters. This parameter is used internally and can be ignored in most cases. | -Here's an example of how to initialize a WeaviateClient: +Here's an example of how to initialize a WeaviateDB: ```python -from swarms.memory import WeaviateClient +from swarms.memory import WeaviateDB -weaviate_client = WeaviateClient( +weaviate_client = WeaviateDB( http_host="YOUR_HTTP_HOST", http_port="YOUR_HTTP_PORT", http_secure=True, diff --git a/docs/swarms/models/mixtral.md b/docs/swarms/models/mixtral.md new file mode 100644 index 00000000..aa1b64d3 --- /dev/null +++ b/docs/swarms/models/mixtral.md @@ -0,0 +1,76 @@ +# Module Name: Mixtral + +## Introduction +The Mixtral module is a powerful language model designed for text generation tasks. It leverages the MistralAI Mixtral-8x7B pre-trained model to generate high-quality text based on user-defined tasks or prompts. In this documentation, we will provide a comprehensive overview of the Mixtral module, including its architecture, purpose, arguments, and detailed usage examples. + +## Purpose +The Mixtral module is designed to facilitate text generation tasks using state-of-the-art language models. Whether you need to generate creative content, draft text for various applications, or simply explore the capabilities of Mixtral, this module serves as a versatile and efficient solution. With its easy-to-use interface, you can quickly generate text for a wide range of applications. + +## Architecture +The Mixtral module is built on top of the MistralAI Mixtral-8x7B pre-trained model. It utilizes a deep neural network architecture with 8 layers and 7 attention heads to generate coherent and contextually relevant text. The model is capable of handling a variety of text generation tasks, from simple prompts to more complex content generation. + +## Class Definition +### `Mixtral(model_name: str = "mistralai/Mixtral-8x7B-v0.1", max_new_tokens: int = 500)` + +#### Parameters +- `model_name` (str, optional): The name or path of the pre-trained Mixtral model. Default is "mistralai/Mixtral-8x7B-v0.1". +- `max_new_tokens` (int, optional): The maximum number of new tokens to generate. Default is 500. + +## Functionality and Usage +The Mixtral module offers a straightforward interface for text generation. It accepts a task or prompt as input and returns generated text based on the provided input. + +### `run(task: Optional[str] = None, **kwargs) -> str` + +#### Parameters +- `task` (str, optional): The task or prompt for text generation. + +#### Returns +- `str`: The generated text. + +## Usage Examples +### Example 1: Basic Usage + +```python +from swarms.models import Mixtral + +# Initialize the Mixtral model +mixtral = Mixtral() + +# Generate text for a simple task +generated_text = mixtral.run("Generate a creative story.") +print(generated_text) +``` + +### Example 2: Custom Model + +You can specify a custom pre-trained model by providing the `model_name` parameter. + +```python +custom_model_name = "model_name" +mixtral_custom = Mixtral(model_name=custom_model_name) + +generated_text = mixtral_custom.run("Generate text with a custom model.") +print(generated_text) +``` + +### Example 3: Controlling Output Length + +You can control the length of the generated text by adjusting the `max_new_tokens` parameter. + +```python +mixtral_length = Mixtral(max_new_tokens=100) + +generated_text = mixtral_length.run("Generate a short text.") +print(generated_text) +``` + +## Additional Information and Tips +- It's recommended to use a descriptive task or prompt to guide the text generation process. +- Experiment with different prompt styles and lengths to achieve the desired output. +- You can fine-tune Mixtral on specific tasks if needed, although pre-trained models often work well out of the box. +- Monitor the `max_new_tokens` parameter to control the length of the generated text. + +## Conclusion +The Mixtral module is a versatile tool for text generation tasks, powered by the MistralAI Mixtral-8x7B pre-trained model. Whether you need creative writing, content generation, or assistance with text-based tasks, Mixtral can help you achieve your goals. With a simple interface and flexible parameters, it's a valuable addition to your text generation toolkit. + +If you encounter any issues or have questions about using Mixtral, please refer to the MistralAI documentation or reach out to their support team for further assistance. Happy text generation with Mixtral! \ No newline at end of file diff --git a/docs/swarms/structs/conversation.md b/docs/swarms/structs/conversation.md new file mode 100644 index 00000000..be9ceffa --- /dev/null +++ b/docs/swarms/structs/conversation.md @@ -0,0 +1,265 @@ +# Module/Class Name: Conversation + +## Introduction + +The `Conversation` class is a powerful tool for managing and structuring conversation data in a Python program. It enables you to create, manipulate, and analyze conversations easily. This documentation will provide you with a comprehensive understanding of the `Conversation` class, its attributes, methods, and how to effectively use it. + +## Table of Contents + +1. **Class Definition** + - Overview + - Attributes + +2. **Methods** + - `__init__(self, time_enabled: bool = False, *args, **kwargs)` + - `add(self, role: str, content: str, *args, **kwargs)` + - `delete(self, index: str)` + - `update(self, index: str, role, content)` + - `query(self, index: str)` + - `search(self, keyword: str)` + - `display_conversation(self, detailed: bool = False)` + - `export_conversation(self, filename: str)` + - `import_conversation(self, filename: str)` + - `count_messages_by_role(self)` + - `return_history_as_string(self)` + - `save_as_json(self, filename: str)` + - `load_from_json(self, filename: str)` + - `search_keyword_in_conversation(self, keyword: str)` + - `pretty_print_conversation(self, messages)` + +--- + +### 1. Class Definition + +#### Overview + +The `Conversation` class is designed to manage conversations by keeping track of messages and their attributes. It offers methods for adding, deleting, updating, querying, and displaying messages within the conversation. Additionally, it supports exporting and importing conversations, searching for specific keywords, and more. + +#### Attributes + +- `time_enabled (bool)`: A flag indicating whether to enable timestamp recording for messages. +- `conversation_history (list)`: A list that stores messages in the conversation. + +### 2. Methods + +#### `__init__(self, time_enabled: bool = False, *args, **kwargs)` + +- **Description**: Initializes a new Conversation object. +- **Parameters**: + - `time_enabled (bool)`: If `True`, timestamps will be recorded for each message. Default is `False`. + +#### `add(self, role: str, content: str, *args, **kwargs)` + +- **Description**: Adds a message to the conversation history. +- **Parameters**: + - `role (str)`: The role of the speaker (e.g., "user," "assistant"). + - `content (str)`: The content of the message. + +#### `delete(self, index: str)` + +- **Description**: Deletes a message from the conversation history. +- **Parameters**: + - `index (str)`: The index of the message to delete. + +#### `update(self, index: str, role, content)` + +- **Description**: Updates a message in the conversation history. +- **Parameters**: + - `index (str)`: The index of the message to update. + - `role (_type_)`: The new role of the speaker. + - `content (_type_)`: The new content of the message. + +#### `query(self, index: str)` + +- **Description**: Retrieves a message from the conversation history. +- **Parameters**: + - `index (str)`: The index of the message to query. +- **Returns**: The message as a string. + +#### `search(self, keyword: str)` + +- **Description**: Searches for messages containing a specific keyword in the conversation history. +- **Parameters**: + - `keyword (str)`: The keyword to search for. +- **Returns**: A list of messages that contain the keyword. + +#### `display_conversation(self, detailed: bool = False)` + +- **Description**: Displays the conversation history. +- **Parameters**: + - `detailed (bool, optional)`: If `True`, provides detailed information about each message. Default is `False`. + +#### `export_conversation(self, filename: str)` + +- **Description**: Exports the conversation history to a text file. +- **Parameters**: + - `filename (str)`: The name of the file to export to. + +#### `import_conversation(self, filename: str)` + +- **Description**: Imports a conversation history from a text file. +- **Parameters**: + - `filename (str)`: The name of the file to import from. + +#### `count_messages_by_role(self)` + +- **Description**: Counts the number of messages by role in the conversation. +- **Returns**: A dictionary containing the count of messages for each role. + +#### `return_history_as_string(self)` + +- **Description**: Returns the entire conversation history as a single string. +- **Returns**: The conversation history as a string. + +#### `save_as_json(self, filename: str)` + +- **Description**: Saves the conversation history as a JSON file. +- **Parameters**: + - `filename (str)`: The name of the JSON file to save. + +#### `load_from_json(self, filename: str)` + +- **Description**: Loads a conversation history from a JSON file. +- **Parameters**: + - `filename (str)`: The name of the JSON file to load. + +#### `search_keyword_in_conversation(self, keyword: str)` + +- **Description**: Searches for a keyword in the conversation history and returns matching messages. +- **Parameters**: + - `keyword (str)`: The keyword to search for. +- **Returns**: A list of messages containing the keyword. + +#### `pretty_print_conversation(self, messages)` + +- **Description**: Pretty prints a list of messages with colored role indicators. +- **Parameters**: + - `messages (list)`: A list of messages to print. + +## Examples + +Here are some usage examples of the `Conversation` class: + +### Creating a Conversation + +```python +from swarms.structs import Conversation + +conv = Conversation() +``` + +### Adding Messages + +```python +conv.add("user", "Hello, world!") +conv.add("assistant", "Hello, user!") +``` + +### Displaying the Conversation + +```python +conv.display_conversation() +``` + +### Searching for Messages + +```python +result = conv.search("Hello") +``` + +### Exporting and Importing Conversations + +```python +conv.export_conversation("conversation.txt") +conv.import_conversation("conversation.txt") +``` + +### Counting Messages by Role + +```python +counts = conv.count_messages_by_role() +``` + +### Loading and Saving as JSON + +```python +conv.save_as_json("conversation.json") +conv.load_from_json("conversation.json") +``` + +Certainly! Let's continue with more examples and additional information about the `Conversation` class. + +### Querying a Specific Message + +You can retrieve a specific message from the conversation by its index: + +```python +message = conv.query(0) # Retrieves the first message +``` + +### Updating a Message + +You can update a message's content or role within the conversation: + +```python +conv.update(0, "user", "Hi there!") # Updates the first message +``` + +### Deleting a Message + +If you want to remove a message from the conversation, you can use the `delete` method: + +```python +conv.delete(0) # Deletes the first message +``` + +### Counting Messages by Role + +You can count the number of messages by role in the conversation: + +```python +counts = conv.count_messages_by_role() +# Example result: {'user': 2, 'assistant': 2} +``` + +### Exporting and Importing as Text + +You can export the conversation to a text file and later import it: + +```python +conv.export_conversation("conversation.txt") # Export +conv.import_conversation("conversation.txt") # Import +``` + +### Exporting and Importing as JSON + +Conversations can also be saved and loaded as JSON files: + +```python +conv.save_as_json("conversation.json") # Save as JSON +conv.load_from_json("conversation.json") # Load from JSON +``` + +### Searching for a Keyword + +You can search for messages containing a specific keyword within the conversation: + +```python +results = conv.search_keyword_in_conversation("Hello") +``` + +### Pretty Printing + +The `pretty_print_conversation` method provides a visually appealing way to display messages with colored role indicators: + +```python +conv.pretty_print_conversation(conv.conversation_history) +``` + +These examples demonstrate the versatility of the `Conversation` class in managing and interacting with conversation data. Whether you're building a chatbot, conducting analysis, or simply organizing dialogues, this class offers a robust set of tools to help you accomplish your goals. + +## Conclusion + +The `Conversation` class is a valuable utility for handling conversation data in Python. With its ability to add, update, delete, search, export, and import messages, you have the flexibility to work with conversations in various ways. Feel free to explore its features and adapt them to your specific projects and applications. + +If you have any further questions or need additional assistance, please don't hesitate to ask! \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index eeb64c04..de263ac6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,7 @@ nav: - vLLM: "swarms/models/vllm.md" - MPT7B: "swarms/models/mpt.md" - Mistral: "swarms/models/mistral.md" + - Mixtral: "swarms/models/mixtral.md" - MultiModal: - BaseMultiModalModel: "swarms/models/base_multimodal_model.md" - Fuyu: "swarms/models/fuyu.md" @@ -103,9 +104,10 @@ nav: - AutoScaler: "swarms/swarms/autoscaler.md" - Agent: "swarms/structs/agent.md" - SequentialWorkflow: 'swarms/structs/sequential_workflow.md' + - Conversation: "swarms/structs/conversation.md" - swarms.memory: - Weaviate: "swarms/memory/weaviate.md" - - PineconDB: "swarms/memory/pinecone.md" + - PineconeDB: "swarms/memory/pinecone.md" - PGVectorStore: "swarms/memory/pg.md" - ShortTermMemory: "swarms/memory/short_term_memory.md" - swarms.utils: diff --git a/playground/demos/assembly/assembly.py b/playground/demos/assembly/assembly.py index b82e075c..704c80d4 100644 --- a/playground/demos/assembly/assembly.py +++ b/playground/demos/assembly/assembly.py @@ -1,8 +1,5 @@ from swarms.structs import Agent from swarms.models.gpt4_vision_api import GPT4VisionAPI -from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( - MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, -) llm = GPT4VisionAPI() diff --git a/playground/demos/autotemp/autotemp_example.py b/playground/demos/autotemp/autotemp_example.py index c5f86416..ccbd54c3 100644 --- a/playground/demos/autotemp/autotemp_example.py +++ b/playground/demos/autotemp/autotemp_example.py @@ -1,4 +1,3 @@ -from swarms.models import OpenAIChat from autotemp import AutoTemp # Your OpenAI API key diff --git a/playground/demos/gemini_benchmarking/gemini_chat.py b/playground/demos/gemini_benchmarking/gemini_chat.py index b1f12ee7..6d9dc7ae 100644 --- a/playground/demos/gemini_benchmarking/gemini_chat.py +++ b/playground/demos/gemini_benchmarking/gemini_chat.py @@ -21,5 +21,8 @@ model = Gemini( ) -out = model.chat("Create the code for a react component that displays a name", img=img) -print(out) \ No newline at end of file +out = model.chat( + "Create the code for a react component that displays a name", + img=img, +) +print(out) diff --git a/playground/demos/gemini_benchmarking/gemini_react.py b/playground/demos/gemini_benchmarking/gemini_react.py index 76caf974..022405e9 100644 --- a/playground/demos/gemini_benchmarking/gemini_react.py +++ b/playground/demos/gemini_benchmarking/gemini_react.py @@ -22,5 +22,7 @@ model = Gemini( # Run the model -out = model.run("Create the code for a react component that displays a name") +out = model.run( + "Create the code for a react component that displays a name" +) print(out) diff --git a/playground/demos/llm_with_conversation/main.py b/playground/demos/llm_with_conversation/main.py new file mode 100644 index 00000000..2bb28b4b --- /dev/null +++ b/playground/demos/llm_with_conversation/main.py @@ -0,0 +1,21 @@ +import os + +from dotenv import load_dotenv + +# Import the OpenAIChat model and the Agent struct +from swarms.models import OpenAIChat +from swarms.structs import Agent + +# Load the environment variables +load_dotenv() + +# Get the API key from the environment +api_key = os.environ.get("OPENAI_API_KEY") + +# Initialize the language model +llm = OpenAIChat( + temperature=0.5, + model_name="gpt-4", + openai_api_key=api_key, + max_tokens=1000, +) diff --git a/playground/demos/nutrition/nutrition.py b/playground/demos/nutrition/nutrition.py index aca079ba..428560e3 100644 --- a/playground/demos/nutrition/nutrition.py +++ b/playground/demos/nutrition/nutrition.py @@ -2,7 +2,7 @@ import os import base64 import requests from dotenv import load_dotenv -from swarms.models import Anthropic, OpenAIChat +from swarms.models import OpenAIChat from swarms.structs import Agent # Load environment variables diff --git a/playground/demos/optimize_llm_stack/vortex.py b/playground/demos/optimize_llm_stack/vortex.py index 438c1451..a40c29b9 100644 --- a/playground/demos/optimize_llm_stack/vortex.py +++ b/playground/demos/optimize_llm_stack/vortex.py @@ -1,5 +1,4 @@ import os -import subprocess from dotenv import load_dotenv diff --git a/playground/demos/optimize_llm_stack/weaviate.py b/playground/demos/optimize_llm_stack/weaviate.py index fa3bf96e..ad594547 100644 --- a/playground/demos/optimize_llm_stack/weaviate.py +++ b/playground/demos/optimize_llm_stack/weaviate.py @@ -1,6 +1,6 @@ -from swarms.memory import WeaviateClient +from swarms.memory import WeaviateDB -weaviate_client = WeaviateClient( +weaviate_client = WeaviateDB( http_host="YOUR_HTTP_HOST", http_port="YOUR_HTTP_PORT", http_secure=True, diff --git a/playground/demos/personal_assistant/better_communication.py b/playground/demos/personal_assistant/better_communication.py new file mode 100644 index 00000000..c6e79eb7 --- /dev/null +++ b/playground/demos/personal_assistant/better_communication.py @@ -0,0 +1,96 @@ +import time +import os + +import pygame +import speech_recognition as sr +from dotenv import load_dotenv +from playsound import playsound + +from swarms import OpenAIChat, OpenAITTS + +# Load the environment variables +load_dotenv() + +# Get the API key from the environment +openai_api_key = os.environ.get("OPENAI_API_KEY") + +# Initialize the language model +llm = OpenAIChat( + openai_api_key=openai_api_key, +) + +# Initialize the text-to-speech model +tts = OpenAITTS( + model_name="tts-1-1106", + voice="onyx", + openai_api_key=openai_api_key, + saved_filepath="runs/tts_speech.wav", +) + +# Initialize the speech recognition model +r = sr.Recognizer() + + +def play_audio(file_path): + # Check if the file exists + if not os.path.isfile(file_path): + print(f"Audio file {file_path} not found.") + return + + # Initialize the mixer module + pygame.mixer.init() + + try: + # Load the mp3 file + pygame.mixer.music.load(file_path) + + # Play the mp3 file + pygame.mixer.music.play() + + # Wait for the audio to finish playing + while pygame.mixer.music.get_busy(): + pygame.time.Clock().tick(10) + except pygame.error as e: + print(f"Couldn't play {file_path}: {e}") + finally: + # Stop the mixer module and free resources + pygame.mixer.quit() + + +while True: + # Listen for user speech + with sr.Microphone() as source: + print("Listening...") + audio = r.listen(source) + + # Convert speech to text + try: + print("Recognizing...") + task = r.recognize_google(audio) + print(f"User said: {task}") + except sr.UnknownValueError: + print("Could not understand audio") + continue + except Exception as e: + print(f"Error: {e}") + continue + + # Run the Gemini model on the task + print("Running GPT4 model...") + out = llm(task) + print(f"Gemini output: {out}") + + # Convert the Gemini output to speech + print("Running text-to-speech model...") + out = tts.run_and_save(out) + print(f"Text-to-speech output: {out}") + + # Ask the user if they want to play the audio + # play_audio = input("Do you want to play the audio? (yes/no): ") + # if play_audio.lower() == "yes": + # Initialize the mixer module + # Play the audio file + + time.sleep(5) + + playsound("runs/tts_speech.wav") diff --git a/playground/demos/swarm_of_mma_manufacturing/main.py b/playground/demos/swarm_of_mma_manufacturing/main.py index 37938608..05b0e8e5 100644 --- a/playground/demos/swarm_of_mma_manufacturing/main.py +++ b/playground/demos/swarm_of_mma_manufacturing/main.py @@ -20,7 +20,6 @@ from termcolor import colored from swarms.models import GPT4VisionAPI from swarms.structs import Agent -from swarms.utils.phoenix_handler import phoenix_trace_decorator load_dotenv() api_key = os.getenv("OPENAI_API_KEY") diff --git a/playground/models/kosmos2.py b/playground/models/kosmos2.py index ce39a710..6fc4df02 100644 --- a/playground/models/kosmos2.py +++ b/playground/models/kosmos2.py @@ -1,4 +1,4 @@ -from swarms.models.kosmos2 import Kosmos2, Detections +from swarms.models.kosmos2 import Kosmos2 from PIL import Image diff --git a/playground/tools/agent_with_tools.py b/playground/tools/agent_with_tools.py index ee4a8ef7..3bad0b1d 100644 --- a/playground/tools/agent_with_tools.py +++ b/playground/tools/agent_with_tools.py @@ -1,7 +1,6 @@ import os from swarms.models import OpenAIChat from swarms.structs import Agent -from swarms.tools.tool import tool from dotenv import load_dotenv load_dotenv() diff --git a/pyproject.toml b/pyproject.toml index e2bed3de..d29c59ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.9.8" +version = "2.4.2" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -98,3 +98,5 @@ target-version = ['py38'] preview = true +[tool.poetry.scripts] +swarms = 'swarms.cli._cli:cli' \ No newline at end of file diff --git a/scripts/delete_pycache.sh b/scripts/delete_pycache.sh new file mode 100644 index 00000000..db11f239 --- /dev/null +++ b/scripts/delete_pycache.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +# Find all __pycache__ directories and delete them +find . -type d -name "__pycache__" -exec rm -rf {} + \ No newline at end of file diff --git a/swarms/cli/_cli.py b/swarms/cli/_cli.py index b4be4e02..8dee387f 100644 --- a/swarms/cli/_cli.py +++ b/swarms/cli/_cli.py @@ -2,7 +2,7 @@ import argparse import sys -def run_file(): +def cli(): parser = argparse.ArgumentParser(description="Swarms CLI") parser.add_argument( "file_name", help="Python file containing Swarms code to run" diff --git a/swarms/cli/run_file.py b/swarms/cli/run_file.py index de035b1e..171c6c56 100644 --- a/swarms/cli/run_file.py +++ b/swarms/cli/run_file.py @@ -2,7 +2,7 @@ import sys import subprocess -def run_file(): +def run_file(filename: str): """Run a given file. Usage: swarms run file_name.py diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index 4f92880a..a63a9553 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -1,4 +1,11 @@ from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.short_term_memory import ShortTermMemory +from swarms.memory.sqlite import SQLiteDB +from swarms.memory.weaviate_db import WeaviateDB -__all__ = ["VectorDatabase", "ShortTermMemory"] +__all__ = [ + "VectorDatabase", + "ShortTermMemory", + "SQLiteDB", + "WeaviateDB", +] diff --git a/swarms/memory/base_db.py b/swarms/memory/base_db.py new file mode 100644 index 00000000..0501def7 --- /dev/null +++ b/swarms/memory/base_db.py @@ -0,0 +1,159 @@ +from abc import ABC, abstractmethod + + +class AbstractDatabase(ABC): + """ + Abstract base class for a database. + + This class defines the interface for interacting with a database. + Subclasses must implement the abstract methods to provide the + specific implementation details for connecting to a database, + executing queries, and performing CRUD operations. + + """ + + @abstractmethod + def connect(self): + """ + Connect to the database. + + This method establishes a connection to the database. + + """ + + pass + + @abstractmethod + def close(self): + """ + Close the database connection. + + This method closes the connection to the database. + + """ + + pass + + @abstractmethod + def execute_query(self, query): + """ + Execute a database query. + + This method executes the given query on the database. + + Parameters: + query (str): The query to be executed. + + """ + + pass + + @abstractmethod + def fetch_all(self): + """ + Fetch all rows from the result set. + + This method retrieves all rows from the result set of a query. + + Returns: + list: A list of dictionaries representing the rows. + + """ + + pass + + @abstractmethod + def fetch_one(self): + """ + Fetch one row from the result set. + + This method retrieves one row from the result set of a query. + + Returns: + dict: A dictionary representing the row. + + """ + + pass + + @abstractmethod + def add(self, table, data): + """ + Add a new record to the database. + + This method adds a new record to the specified table in the database. + + Parameters: + table (str): The name of the table. + data (dict): A dictionary representing the data to be added. + + """ + + pass + + @abstractmethod + def query(self, table, condition): + """ + Query the database. + + This method queries the specified table in the database based on the given condition. + + Parameters: + table (str): The name of the table. + condition (str): The condition to be applied in the query. + + Returns: + list: A list of dictionaries representing the query results. + + """ + + pass + + @abstractmethod + def get(self, table, id): + """ + Get a record from the database. + + This method retrieves a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be retrieved. + + Returns: + dict: A dictionary representing the retrieved record. + + """ + + pass + + @abstractmethod + def update(self, table, id, data): + """ + Update a record in the database. + + This method updates a record in the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be updated. + data (dict): A dictionary representing the updated data. + + """ + + pass + + @abstractmethod + def delete(self, table, id): + """ + Delete a record from the database. + + This method deletes a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be deleted. + + """ + + pass diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index 50972d98..d96b475d 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -1,302 +1,140 @@ -import subprocess import uuid -from typing import Optional -from attr import define, field, Factory -from dataclasses import dataclass -from swarms.memory.base import BaseVectorStore +from typing import Any, List, Optional +from sqlalchemy import JSON, Column, String, create_engine +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session -try: - from sqlalchemy.engine import Engine - from sqlalchemy import create_engine, Column, String, JSON - from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.dialects.postgresql import UUID - from sqlalchemy.orm import Session -except ImportError: - print( - "The PgVectorVectorStore requires sqlalchemy to be installed" - ) - print("pip install sqlalchemy") - subprocess.run(["pip", "install", "sqlalchemy"]) - -try: - from pgvector.sqlalchemy import Vector -except ImportError: - print("The PgVectorVectorStore requires pgvector to be installed") - print("pip install pgvector") - subprocess.run(["pip", "install", "pgvector"]) +class PostgresDB: + """ + A class representing a Postgres database. -@define -class PgVectorVectorStore(BaseVectorStore): - """A vector store driver to Postgres using the PGVector extension. + Args: + connection_string (str): The connection string for the Postgres database. + table_name (str): The name of the table in the database. Attributes: - connection_string: An optional string describing the target Postgres database instance. - create_engine_params: Additional configuration params passed when creating the database connection. - engine: An optional sqlalchemy Postgres engine to use. - table_name: Optionally specify the name of the table to used to store vectors. - - Methods: - upsert_vector(vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs) -> str: - Upserts a vector into the index. - load_entry(vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVector.Entry]: - Loads a single vector from the index. - load_entries(namespace: Optional[str] = None) -> list[BaseVector.Entry]: - Loads all vectors from the index. - query(query: str, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata=True, **kwargs) -> list[BaseVector.QueryResult]: - Queries the index for vectors similar to the given query string. - setup(create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True) -> None: - Provides a mechanism to initialize the database schema and extensions. - - Usage: - >>> from swarms.memory.vector_stores.pgvector import PgVectorVectorStore - >>> from swarms.utils.embeddings import USEEmbedding - >>> from swarms.utils.hash import str_to_hash - >>> from swarms.utils.dataframe import dataframe_to_hash - >>> import pandas as pd - >>> - >>> # Create a new PgVectorVectorStore instance: - >>> pv = PgVectorVectorStore( - >>> connection_string="postgresql://postgres:password@localhost:5432/postgres", - >>> table_name="your-table-name" - >>> ) - >>> # Create a new index: - >>> pv.setup() - >>> # Create a new USEEmbedding instance: - >>> use = USEEmbedding() - >>> # Create a new dataframe: - >>> df = pd.DataFrame({ - >>> "text": [ - >>> "This is a test", - >>> "This is another test", - >>> "This is a third test" - >>> ] - >>> }) - >>> # Embed the dataframe: - >>> df["embedding"] = df["text"].apply(use.embed_string) - >>> # Upsert the dataframe into the index: - >>> pv.upsert_vector( - >>> vector=df["embedding"].tolist(), - >>> vector_id=dataframe_to_hash(df), - >>> namespace="your-namespace" - >>> ) - >>> # Query the index: - >>> pv.query( - >>> query="This is a test", - >>> count=10, - >>> namespace="your-namespace" - >>> ) - >>> # Load a single entry from the index: - >>> pv.load_entry( - >>> vector_id=dataframe_to_hash(df), - >>> namespace="your-namespace" - >>> ) - >>> # Load all entries from the index: - >>> pv.load_entries( - >>> namespace="your-namespace" - >>> ) - + engine: The SQLAlchemy engine for connecting to the database. + table_name (str): The name of the table in the database. + VectorModel: The SQLAlchemy model representing the vector table. """ - connection_string: Optional[str] = field( - default=None, kw_only=True - ) - create_engine_params: dict = field(factory=dict, kw_only=True) - engine: Optional[Engine] = field(default=None, kw_only=True) - table_name: str = field(kw_only=True) - _model: any = field( - default=Factory( - lambda self: self.default_vector_model(), takes_self=True - ) - ) + def __init__( + self, connection_string: str, table_name: str, *args, **kwargs + ): + """ + Initializes a new instance of the PostgresDB class. - @connection_string.validator - def validate_connection_string( - self, _, connection_string: Optional[str] - ) -> None: - # If an engine is provided, the connection string is not used. - if self.engine is not None: - return + Args: + connection_string (str): The connection string for the Postgres database. + table_name (str): The name of the table in the database. - # If an engine is not provided, a connection string is required. - if connection_string is None: - raise ValueError( - "An engine or connection string is required" - ) - - if not connection_string.startswith("postgresql://"): - raise ValueError( - "The connection string must describe a Postgres" - " database connection" - ) + """ + self.engine = create_engine( + connection_string, *args, **kwargs + ) + self.table_name = table_name + self.VectorModel = self._create_vector_model() - @engine.validator - def validate_engine(self, _, engine: Optional[Engine]) -> None: - # If a connection string is provided, an engine does not need to be provided. - if self.connection_string is not None: - return + def _create_vector_model(self): + """ + Creates the SQLAlchemy model for the vector table. - # If a connection string is not provided, an engine is required. - if engine is None: - raise ValueError( - "An engine or connection string is required" - ) + Returns: + The SQLAlchemy model representing the vector table. - def __attrs_post_init__(self) -> None: - """If a an engine is provided, it will be used to connect to the database. - If not, a connection string is used to create a new database connection here. """ - if self.engine is None: - self.engine = create_engine( - self.connection_string, **self.create_engine_params - ) + Base = declarative_base() - def setup( - self, - create_schema: bool = True, - install_uuid_extension: bool = True, - install_vector_extension: bool = True, - ) -> None: - """Provides a mechanism to initialize the database schema and extensions.""" - if install_uuid_extension: - self.engine.execute( - 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";' - ) + class VectorModel(Base): + __tablename__ = self.table_name - if install_vector_extension: - self.engine.execute( - 'CREATE EXTENSION IF NOT EXISTS "vector";' + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + unique=True, + nullable=False, ) + vector = Column( + String + ) # Assuming vector is stored as a string + namespace = Column(String) + meta = Column(JSON) - if create_schema: - self._model.metadata.create_all(self.engine) + return VectorModel - def upsert_vector( + def add_or_update_vector( self, - vector: list[float], + vector: str, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, - **kwargs, - ) -> str: - """Inserts or updates a vector in the collection.""" - with Session(self.engine) as session: - obj = self._model( - id=vector_id, - vector=vector, - namespace=namespace, - meta=meta, - ) - - obj = session.merge(obj) - session.commit() - - return str(obj.id) - - def load_entry( - self, vector_id: str, namespace: Optional[str] = None - ) -> BaseVectorStore.Entry: - """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" - with Session(self.engine) as session: - result = session.get(self._model, vector_id) - - return BaseVectorStore.Entry( - id=result.id, - vector=result.vector, - namespace=result.namespace, - meta=result.meta, - ) - - def load_entries( - self, namespace: Optional[str] = None - ) -> list[BaseVectorStore.Entry]: - """Retrieves all vector entries from the collection, optionally filtering to only - those that match the provided namespace. + ) -> None: """ - with Session(self.engine) as session: - query = session.query(self._model) - if namespace: - query = query.filter_by(namespace=namespace) + Adds or updates a vector in the database. - results = query.all() + Args: + vector (str): The vector to be added or updated. + vector_id (str, optional): The ID of the vector. If not provided, a new ID will be generated. + namespace (str, optional): The namespace of the vector. + meta (dict, optional): Additional metadata associated with the vector. - return [ - BaseVectorStore.Entry( - id=str(result.id), - vector=result.vector, - namespace=result.namespace, - meta=result.meta, + """ + try: + with Session(self.engine) as session: + obj = self.VectorModel( + id=vector_id, + vector=vector, + namespace=namespace, + meta=meta, ) - for result in results - ] - - def query( - self, - query: str, - count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, - namespace: Optional[str] = None, - include_vectors: bool = False, - distance_metric: str = "cosine_distance", - **kwargs, - ) -> list[BaseVectorStore.QueryResult]: - """Performs a search on the collection to find vectors similar to the provided input vector, - optionally filtering to only those that match the provided namespace. + session.merge(obj) + session.commit() + except Exception as e: + print(f"Error adding or updating vector: {e}") + + def query_vectors( + self, query: Any, namespace: Optional[str] = None + ) -> List[Any]: """ - distance_metrics = { - "cosine_distance": self._model.vector.cosine_distance, - "l2_distance": self._model.vector.l2_distance, - "inner_product": self._model.vector.max_inner_product, - } - - if distance_metric not in distance_metrics: - raise ValueError("Invalid distance metric provided") - - op = distance_metrics[distance_metric] - - with Session(self.engine) as session: - vector = self.embedding_driver.embed_string(query) + Queries vectors from the database based on the given query and namespace. - # The query should return both the vector and the distance metric score. - query = session.query( - self._model, - op(vector).label("score"), - ).order_by(op(vector)) + Args: + query (Any): The query or condition to filter the vectors. + namespace (str, optional): The namespace of the vectors to be queried. - if namespace: - query = query.filter_by(namespace=namespace) + Returns: + List[Any]: A list of vectors that match the query and namespace. - results = query.limit(count).all() - - return [ - BaseVectorStore.QueryResult( - id=str(result[0].id), - vector=( - result[0].vector if include_vectors else None - ), - score=result[1], - meta=result[0].meta, - namespace=result[0].namespace, - ) - for result in results - ] - - def default_vector_model(self) -> any: - Base = declarative_base() - - @dataclass - class VectorModel(Base): - __tablename__ = self.table_name + """ + try: + with Session(self.engine) as session: + q = session.query(self.VectorModel) + if namespace: + q = q.filter_by(namespace=namespace) + # Assuming 'query' is a condition or filter + q = q.filter(query) + return q.all() + except Exception as e: + print(f"Error querying vectors: {e}") + return [] + + def delete_vector(self, vector_id): + """ + Deletes a vector from the database based on the given vector ID. - id = Column( - UUID(as_uuid=True), - primary_key=True, - default=uuid.uuid4, - unique=True, - nullable=False, - ) - vector = Column(Vector()) - namespace = Column(String) - meta = Column(JSON) + Args: + vector_id: The ID of the vector to be deleted. - return VectorModel + """ + try: + with Session(self.engine) as session: + obj = session.get(self.VectorModel, vector_id) + if obj: + session.delete(obj) + session.commit() + except Exception as e: + print(f"Error deleting vector: {e}") diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index f48bb627..164cb334 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -6,9 +6,9 @@ from swarms.utils.hash import str_to_hash @define -class PineconDB(VectorDatabase): +class PineconeDB(VectorDatabase): """ - PineconDB is a vector storage driver that uses Pinecone as the underlying storage engine. + PineconeDB is a vector storage driver that uses Pinecone as the underlying storage engine. Pinecone is a vector database that allows you to store, search, and retrieve high-dimensional vectors with blazing speed and low latency. It is a managed service that is easy to use and scales effortlessly, so you can @@ -34,14 +34,14 @@ class PineconDB(VectorDatabase): Creates a new index. Usage: - >>> from swarms.memory.vector_stores.pinecone import PineconDB + >>> from swarms.memory.vector_stores.pinecone import PineconeDB >>> from swarms.utils.embeddings import USEEmbedding >>> from swarms.utils.hash import str_to_hash >>> from swarms.utils.dataframe import dataframe_to_hash >>> import pandas as pd >>> - >>> # Create a new PineconDB instance: - >>> pv = PineconDB( + >>> # Create a new PineconeDB instance: + >>> pv = PineconeDB( >>> api_key="your-api-key", >>> index_name="your-index-name", >>> environment="us-west1-gcp", @@ -166,7 +166,7 @@ class PineconDB(VectorDatabase): count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, - # PineconDBStorageDriver-specific params: + # PineconeDBStorageDriver-specific params: include_metadata=True, **kwargs, ): diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 83ff5593..40f9979c 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -82,7 +82,7 @@ class Qdrant: f"Collection '{self.collection_name}' already" " exists." ) - except Exception as e: + except Exception: self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( diff --git a/swarms/memory/short_term_memory.py b/swarms/memory/short_term_memory.py index 53daf332..d380fba5 100644 --- a/swarms/memory/short_term_memory.py +++ b/swarms/memory/short_term_memory.py @@ -12,8 +12,8 @@ class ShortTermMemory(BaseStructure): autosave (bool, optional): _description_. Defaults to True. *args: _description_ **kwargs: _description_ - - + + Example: >>> from swarms.memory.short_term_memory import ShortTermMemory >>> stm = ShortTermMemory() @@ -22,9 +22,10 @@ class ShortTermMemory(BaseStructure): >>> stm.add(role="agent", message="I am fine.") >>> stm.add(role="agent", message="How are you?") >>> stm.add(role="agent", message="I am fine.") - - + + """ + def __init__( self, return_str: bool = True, @@ -93,7 +94,7 @@ class ShortTermMemory(BaseStructure): index (_type_): _description_ role (str): _description_ message (str): _description_ - + """ self.short_term_memory[index] = { "role": role, diff --git a/swarms/memory/sqlite.py b/swarms/memory/sqlite.py new file mode 100644 index 00000000..eed4ee2c --- /dev/null +++ b/swarms/memory/sqlite.py @@ -0,0 +1,120 @@ +from typing import List, Tuple, Any, Optional +from swarms.memory.base_vectordb import VectorDatabase + +try: + import sqlite3 +except ImportError: + raise ImportError( + "Please install sqlite3 to use the SQLiteDB class." + ) + + +class SQLiteDB(VectorDatabase): + """ + A reusable class for SQLite database operations with methods for adding, + deleting, updating, and querying data. + + Attributes: + db_path (str): The file path to the SQLite database. + """ + + def __init__(self, db_path: str): + """ + Initializes the SQLiteDB class with the given database path. + + Args: + db_path (str): The file path to the SQLite database. + """ + self.db_path = db_path + + def execute_query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Executes a SQL query and returns fetched results. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params or ()) + return cursor.fetchall() + except Exception as error: + print(f"Error executing query: {error}") + raise error + + def add(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Adds a new entry to the database. + + Args: + query (str): The SQL query for insertion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error adding new entry: {error}") + raise error + + def delete(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Deletes an entry from the database. + + Args: + query (str): The SQL query for deletion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error deleting entry: {error}") + raise error + + def update(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Updates an entry in the database. + + Args: + query (str): The SQL query for updating. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error updating entry: {error}") + raise error + + def query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Fetches data from the database based on a query. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + return self.execute_query(query, params) + except Exception as error: + print(f"Error querying database: {error}") + raise error diff --git a/swarms/memory/utils.py b/swarms/memory/utils.py index 46c7b020..42801237 100644 --- a/swarms/memory/utils.py +++ b/swarms/memory/utils.py @@ -26,7 +26,18 @@ def maximal_marginal_relevance( lambda_mult: float = 0.5, k: int = 4, ) -> List[int]: - """Calculate maximal marginal relevance.""" + """ + Calculate maximal marginal relevance. + + Args: + query_embedding (np.ndarray): The embedding of the query. + embedding_list (list): List of embeddings to select from. + lambda_mult (float, optional): The weight for query score. Defaults to 0.5. + k (int, optional): The number of embeddings to select. Defaults to 4. + + Returns: + List[int]: List of indices of selected embeddings. + """ if min(k, len(embedding_list)) <= 0: return [] if query_embedding.ndim == 1: diff --git a/swarms/memory/weaviate_db.py b/swarms/memory/weaviate_db.py new file mode 100644 index 00000000..0c0b09a2 --- /dev/null +++ b/swarms/memory/weaviate_db.py @@ -0,0 +1,182 @@ +""" +Weaviate API Client +""" + +from typing import Any, Dict, List, Optional + +from swarms.memory.base_vectordb import VectorDatabase + +try: + import weaviate +except ImportError: + print("pip install weaviate-client") + + +class WeaviateDB(VectorDatabase): + """ + + Weaviate API Client + Interface to Weaviate, a vector database with a GraphQL API. + + Args: + http_host (str): The HTTP host of the Weaviate server. + http_port (str): The HTTP port of the Weaviate server. + http_secure (bool): Whether to use HTTPS. + grpc_host (Optional[str]): The gRPC host of the Weaviate server. + grpc_port (Optional[str]): The gRPC port of the Weaviate server. + grpc_secure (Optional[bool]): Whether to use gRPC over TLS. + auth_client_secret (Optional[Any]): The authentication client secret. + additional_headers (Optional[Dict[str, str]]): Additional headers to send with requests. + additional_config (Optional[weaviate.AdditionalConfig]): Additional configuration for the client. + + Methods: + create_collection: Create a new collection in Weaviate. + add: Add an object to a specified collection. + query: Query objects from a specified collection. + update: Update an object in a specified collection. + delete: Delete an object from a specified collection. + + Examples: + >>> from swarms.memory import WeaviateDB + """ + + def __init__( + self, + http_host: str, + http_port: str, + http_secure: bool, + grpc_host: Optional[str] = None, + grpc_port: Optional[str] = None, + grpc_secure: Optional[bool] = None, + auth_client_secret: Optional[Any] = None, + additional_headers: Optional[Dict[str, str]] = None, + additional_config: Optional[Any] = None, + connection_params: Dict[str, Any] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.http_host = http_host + self.http_port = http_port + self.http_secure = http_secure + self.grpc_host = grpc_host + self.grpc_port = grpc_port + self.grpc_secure = grpc_secure + self.auth_client_secret = auth_client_secret + self.additional_headers = additional_headers + self.additional_config = additional_config + self.connection_params = connection_params + + # If connection_params are provided, use them to initialize the client. + connection_params = weaviate.ConnectionParams.from_params( + http_host=http_host, + http_port=http_port, + http_secure=http_secure, + grpc_host=grpc_host, + grpc_port=grpc_port, + grpc_secure=grpc_secure, + ) + + # If additional headers are provided, add them to the connection params. + self.client = weaviate.WeaviateDB( + connection_params=connection_params, + auth_client_secret=auth_client_secret, + additional_headers=additional_headers, + additional_config=additional_config, + ) + + def create_collection( + self, + name: str, + properties: List[Dict[str, Any]], + vectorizer_config: Any = None, + ): + """Create a new collection in Weaviate. + + Args: + name (str): _description_ + properties (List[Dict[str, Any]]): _description_ + vectorizer_config (Any, optional): _description_. Defaults to None. + """ + try: + out = self.client.collections.create( + name=name, + vectorizer_config=vectorizer_config, + properties=properties, + ) + print(out) + except Exception as error: + print(f"Error creating collection: {error}") + raise + + def add(self, collection_name: str, properties: Dict[str, Any]): + """Add an object to a specified collection. + + Args: + collection_name (str): _description_ + properties (Dict[str, Any]): _description_ + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + return collection.data.insert(properties) + except Exception as error: + print(f"Error adding object: {error}") + raise + + def query( + self, collection_name: str, query: str, limit: int = 10 + ): + """Query objects from a specified collection. + + Args: + collection_name (str): _description_ + query (str): _description_ + limit (int, optional): _description_. Defaults to 10. + + Returns: + _type_: _description_ + """ + try: + collection = self.client.collections.get(collection_name) + response = collection.query.bm25(query=query, limit=limit) + return [o.properties for o in response.objects] + except Exception as error: + print(f"Error querying objects: {error}") + raise + + def update( + self, + collection_name: str, + object_id: str, + properties: Dict[str, Any], + ): + """UPdate an object in a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + properties (Dict[str, Any]): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.update(object_id, properties) + except Exception as error: + print(f"Error updating object: {error}") + raise + + def delete(self, collection_name: str, object_id: str): + """Delete an object from a specified collection. + + Args: + collection_name (str): _description_ + object_id (str): _description_ + """ + try: + collection = self.client.collections.get(collection_name) + collection.data.delete_by_id(object_id) + except Exception as error: + print(f"Error deleting object: {error}") + raise diff --git a/swarms/models/base_llm.py b/swarms/models/base_llm.py index 0409b867..bc1f67c7 100644 --- a/swarms/models/base_llm.py +++ b/swarms/models/base_llm.py @@ -1,9 +1,11 @@ -import os +import asyncio import logging +import os import time from abc import ABC, abstractmethod -from typing import Optional, List -import asyncio +from typing import List, Optional + +from swarms.utils.llm_metrics_decorator import metrics_decorator def count_tokens(text: str) -> int: @@ -118,6 +120,7 @@ class AbstractLLM(ABC): } @abstractmethod + @metrics_decorator def run(self, task: Optional[str] = None, *args, **kwargs) -> str: """generate text using language model""" pass @@ -381,3 +384,48 @@ class AbstractLLM(ABC): TOKENS: {_num_tokens} Tokens/SEC: {_time_for_generation} """ + + def time_to_first_token(self, prompt: str) -> float: + """Time to first token + + Args: + prompt (str): _description_ + + Returns: + float: _description_ + """ + start_time = time.time() + self.track_resource_utilization( + prompt + ) # assuming `generate` is a method that generates tokens + first_token_time = time.time() + return first_token_time - start_time + + def generation_latency(self, prompt: str) -> float: + """generation latency + + Args: + prompt (str): _description_ + + Returns: + float: _description_ + """ + start_time = time.time() + self.run(prompt) + end_time = time.time() + return end_time - start_time + + def throughput(self, prompts: List[str]) -> float: + """throughput + + Args: + prompts (): _description_ + + Returns: + float: _description_ + """ + start_time = time.time() + for prompt in prompts: + self.run(prompt) + end_time = time.time() + return len(prompts) / (end_time - start_time) diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py index 2eb8c389..c4a5890a 100644 --- a/swarms/models/base_multimodal_model.py +++ b/swarms/models/base_multimodal_model.py @@ -108,7 +108,11 @@ class BaseMultiModalModel: pass def __call__( - self, task: str = None, img: str = None, *args, **kwargs + self, + task: Optional[str] = None, + img: Optional[str] = None, + *args, + **kwargs, ): """Call the model diff --git a/swarms/models/base_tts.py b/swarms/models/base_tts.py index 0faaf6ff..60896856 100644 --- a/swarms/models/base_tts.py +++ b/swarms/models/base_tts.py @@ -1,7 +1,7 @@ import wave from typing import Optional from swarms.models.base_llm import AbstractLLM -from abc import ABC, abstractmethod +from abc import abstractmethod class BaseTTSModel(AbstractLLM): diff --git a/tests/models/test_distilled_whisperx.py b/swarms/models/base_vision_model.py similarity index 100% rename from tests/models/test_distilled_whisperx.py rename to swarms/models/base_vision_model.py diff --git a/swarms/models/bioclip.py b/swarms/models/bioclip.py deleted file mode 100644 index e2d070af..00000000 --- a/swarms/models/bioclip.py +++ /dev/null @@ -1,183 +0,0 @@ -""" - - -BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 -https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 -BiomedCLIP is a biomedical vision-language foundation model that is pretrained on PMC-15M, -a dataset of 15 million figure-caption pairs extracted from biomedical research articles in PubMed Central, using contrastive learning. It uses PubMedBERT as the text encoder and Vision Transformer as the image encoder, with domain-specific adaptations. It can perform various vision-language processing (VLP) tasks such as cross-modal retrieval, image classification, and visual question answering. BiomedCLIP establishes new state of the art in a wide range of standard datasets, and substantially outperforms prior VLP approaches: - - - -Citation -@misc{https://doi.org/10.48550/arXiv.2303.00915, - doi = {10.48550/ARXIV.2303.00915}, - url = {https://arxiv.org/abs/2303.00915}, - author = {Zhang, Sheng and Xu, Yanbo and Usuyama, Naoto and Bagga, Jaspreet and Tinn, Robert and Preston, Sam and Rao, Rajesh and Wei, Mu and Valluri, Naveen and Wong, Cliff and Lungren, Matthew and Naumann, Tristan and Poon, Hoifung}, - title = {Large-Scale Domain-Specific Pretraining for Biomedical Vision-Language Processing}, - publisher = {arXiv}, - year = {2023}, -} - -Model Use -How to use -Please refer to this example notebook. - -Intended Use -This model is intended to be used solely for (I) future research on visual-language processing and (II) reproducibility of the experimental results reported in the reference paper. - -Primary Intended Use -The primary intended use is to support AI researchers building on top of this work. BiomedCLIP and its associated models should be helpful for exploring various biomedical VLP research questions, especially in the radiology domain. - -Out-of-Scope Use -Any deployed use case of the model --- commercial or otherwise --- is currently out of scope. Although we evaluated the models using a broad set of publicly-available research benchmarks, the models and evaluations are not intended for deployed use cases. Please refer to the associated paper for more details. - -Data -This model builds upon PMC-15M dataset, which is a large-scale parallel image-text dataset for biomedical vision-language processing. It contains 15 million figure-caption pairs extracted from biomedical research articles in PubMed Central. It covers a diverse range of biomedical image types, such as microscopy, radiography, histology, and more. - -Limitations -This model was developed using English corpora, and thus can be considered English-only. - -Further information -Please refer to the corresponding paper, "Large-Scale Domain-Specific Pretraining for Biomedical Vision-Language Processing" for additional details on the model training and evaluation. -""" - -import open_clip -import torch -from PIL import Image -import matplotlib.pyplot as plt - - -class BioClip: - """ - BioClip - - Args: - model_path (str): path to the model - - Attributes: - model_path (str): path to the model - model (torch.nn.Module): the model - preprocess_train (torchvision.transforms.Compose): the preprocessing pipeline for training - preprocess_val (torchvision.transforms.Compose): the preprocessing pipeline for validation - tokenizer (open_clip.Tokenizer): the tokenizer - device (torch.device): the device to run the model on - - Methods: - __call__(self, img_path: str, labels: list, template: str = 'this is a photo of ', context_length: int = 256): - returns a dictionary of labels and their probabilities - plot_image_with_metadata(img_path: str, metadata: dict): plots the image with the metadata - - Usage: - clip = BioClip('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') - - labels = [ - 'adenocarcinoma histopathology', - 'brain MRI', - 'covid line chart', - 'squamous cell carcinoma histopathology', - 'immunohistochemistry histopathology', - 'bone X-ray', - 'chest X-ray', - 'pie chart', - 'hematoxylin and eosin histopathology' - ] - - result = clip("your_image_path.jpg", labels) - metadata = {'filename': "your_image_path.jpg".split('/')[-1], 'top_probs': result} - clip.plot_image_with_metadata("your_image_path.jpg", metadata) - - - """ - - def __init__(self, model_path: str): - self.model_path = model_path - ( - self.model, - self.preprocess_train, - self.preprocess_val, - ) = open_clip.create_model_and_transforms(model_path) - self.tokenizer = open_clip.get_tokenizer(model_path) - self.device = ( - torch.device("cuda") - if torch.cuda.is_available() - else torch.device("cpu") - ) - self.model.to(self.device) - self.model.eval() - - def __call__( - self, - img_path: str, - labels: list, - template: str = "this is a photo of ", - context_length: int = 256, - ): - image = torch.stack( - [self.preprocess_val(Image.open(img_path))] - ).to(self.device) - texts = self.tokenizer( - [template + l for l in labels], - context_length=context_length, - ).to(self.device) - - with torch.no_grad(): - image_features, text_features, logit_scale = self.model( - image, texts - ) - logits = ( - (logit_scale * image_features @ text_features.t()) - .detach() - .softmax(dim=-1) - ) - sorted_indices = torch.argsort( - logits, dim=-1, descending=True - ) - logits = logits.cpu().numpy() - sorted_indices = sorted_indices.cpu().numpy() - - results = {} - for idx in sorted_indices[0]: - label = labels[idx] - prob = logits[0][idx] - results[label] = prob - return results - - @staticmethod - def plot_image_with_metadata(img_path: str, metadata: dict): - img = Image.open(img_path) - fig, ax = plt.subplots(figsize=(5, 5)) - ax.imshow(img) - ax.axis("off") - title = ( - metadata["filename"] - + "\n" - + "\n".join( - [ - f"{k}: {v*100:.1f}" - for k, v in metadata["top_probs"].items() - ] - ) - ) - ax.set_title(title, fontsize=14) - plt.tight_layout() - plt.show() - - -# Usage -# clip = BioClip('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') - -# labels = [ -# 'adenocarcinoma histopathology', -# 'brain MRI', -# 'covid line chart', -# 'squamous cell carcinoma histopathology', -# 'immunohistochemistry histopathology', -# 'bone X-ray', -# 'chest X-ray', -# 'pie chart', -# 'hematoxylin and eosin histopathology' -# ] - -# result = clip("your_image_path.jpg", labels) -# metadata = {'filename': "your_image_path.jpg".split('/')[-1], 'top_probs': result} -# clip.plot_image_with_metadata("your_image_path.jpg", metadata) diff --git a/swarms/models/fastvit.py b/swarms/models/fastvit.py index a6fc31f8..e97fb496 100644 --- a/swarms/models/fastvit.py +++ b/swarms/models/fastvit.py @@ -39,14 +39,11 @@ class FastViT: Returns: ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions - Example: >>> fastvit = FastViT() >>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5) - To use, create a json file called: fast_vit_classes.json - """ def __init__(self): @@ -62,7 +59,7 @@ class FastViT: def __call__( self, img: str, confidence_threshold: float = 0.5 ) -> ClassificationResult: - """classifies the input image and returns the top k classes and their probabilities""" + """Classifies the input image and returns the top k classes and their probabilities""" img = Image.open(img).convert("RGB") img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE) with torch.no_grad(): @@ -81,7 +78,6 @@ class FastViT: # Convert to Python lists and map class indices to labels if needed top_probs = top_probs.cpu().numpy().tolist() top_classes = top_classes.cpu().numpy().tolist() - # top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed return ClassificationResult( class_id=top_classes, confidence=top_probs diff --git a/swarms/models/gemini.py b/swarms/models/gemini.py index 8cb09ca5..d12ea7d9 100644 --- a/swarms/models/gemini.py +++ b/swarms/models/gemini.py @@ -174,7 +174,10 @@ class Gemini(BaseMultiModalModel): return response.text else: response = self.model.generate_content( - prepare_prompt, stream=self.stream, *args, **kwargs + prepare_prompt, + stream=self.stream, + *args, + **kwargs, ) return response.text except Exception as error: diff --git a/swarms/models/kosmos2.py b/swarms/models/kosmos2.py deleted file mode 100644 index 9a9a0de3..00000000 --- a/swarms/models/kosmos2.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import List, Tuple - -from PIL import Image -from pydantic import BaseModel, model_validator, validator -from transformers import AutoModelForVision2Seq, AutoProcessor - - -# Assuming the Detections class represents the output of the model prediction -class Detections(BaseModel): - xyxy: List[Tuple[float, float, float, float]] - class_id: List[int] - confidence: List[float] - - @model_validator - def check_length(cls, values): - assert ( - len(values.get("xyxy")) - == len(values.get("class_id")) - == len(values.get("confidence")) - ), "All fields must have the same length." - return values - - @validator( - "xyxy", "class_id", "confidence", pre=True, each_item=True - ) - def check_not_empty(cls, v): - if isinstance(v, list) and len(v) == 0: - raise ValueError("List must not be empty") - return v - - @classmethod - def empty(cls): - return cls(xyxy=[], class_id=[], confidence=[]) - - -class Kosmos2(BaseModel): - """ - Kosmos2 - - Args: - ------ - model: AutoModelForVision2Seq - processor: AutoProcessor - - Usage: - ------ - >>> from swarms import Kosmos2 - >>> from swarms.models.kosmos2 import Detections - >>> from PIL import Image - >>> model = Kosmos2.initialize() - >>> image = Image.open("path_to_image.jpg") - >>> detections = model(image) - >>> print(detections) - - """ - - model: AutoModelForVision2Seq - processor: AutoProcessor - - @classmethod - def initialize(cls): - model = AutoModelForVision2Seq.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True - ) - processor = AutoProcessor.from_pretrained( - "ydshieh/kosmos-2-patch14-224", trust_remote_code=True - ) - return cls(model=model, processor=processor) - - def __call__(self, img: str) -> Detections: - image = Image.open(img) - prompt = "An image of" - - inputs = self.processor( - text=prompt, images=image, return_tensors="pt" - ) - outputs = self.model.generate( - **inputs, use_cache=True, max_new_tokens=64 - ) - - generated_text = self.processor.batch_decode( - outputs, skip_special_tokens=True - )[0] - - # The actual processing of generated_text to entities would go here - # For the purpose of this example, assume a mock function 'extract_entities' exists: - entities = self.extract_entities(generated_text) - - # Convert entities to detections format - detections = self.process_entities_to_detections( - entities, image - ) - return detections - - def extract_entities( - self, text: str - ) -> List[Tuple[str, Tuple[float, float, float, float]]]: - # Placeholder function for entity extraction - # This should be replaced with the actual method of extracting entities - return [] - - def process_entities_to_detections( - self, - entities: List[Tuple[str, Tuple[float, float, float, float]]], - image: Image.Image, - ) -> Detections: - if not entities: - return Detections.empty() - - class_ids = [0] * len( - entities - ) # Replace with actual class ID extraction logic - xyxys = [ - ( - e[1][0] * image.width, - e[1][1] * image.height, - e[1][2] * image.width, - e[1][3] * image.height, - ) - for e in entities - ] - confidences = [1.0] * len(entities) # Placeholder confidence - - return Detections( - xyxy=xyxys, class_id=class_ids, confidence=confidences - ) - - -# Usage: -# kosmos2 = Kosmos2.initialize() -# detections = kosmos2(img="path_to_image.jpg") diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index 3b1d4233..a0c5a86a 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -8,6 +8,8 @@ import torchvision.transforms as T from PIL import Image from transformers import AutoModelForVision2Seq, AutoProcessor +from swarms.models.base_multimodal_model import BaseMultimodalModel + # utils def is_overlapping(rect1, rect2): @@ -16,7 +18,7 @@ def is_overlapping(rect1, rect2): return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) -class Kosmos: +class Kosmos(BaseMultimodalModel): """ Kosmos model by Yen-Chun Shieh @@ -35,9 +37,14 @@ class Kosmos: def __init__( self, model_name="ydshieh/kosmos-2-patch14-224", + max_new_tokens: int = 64, *args, **kwargs, ): + super(Kosmos, self).__init__(*args, **kwargs) + + self.max_new_tokens = max_new_tokens + self.model = AutoModelForVision2Seq.from_pretrained( model_name, trust_remote_code=True, *args, **kwargs ) @@ -45,81 +52,75 @@ class Kosmos: model_name, trust_remote_code=True, *args, **kwargs ) - def get_image(self, url): - """Image""" + def get_image(self, url: str): + """Get image from url + + Args: + url (str): url of image + + Returns: + _type_: _description_ + """ return Image.open(requests.get(url, stream=True).raw) - def run(self, prompt, image): - """Run Kosmos""" - inputs = self.processor( - text=prompt, images=image, return_tensors="pt" - ) - generated_ids = self.model.generate( - pixel_values=inputs["pixel_values"], - input_ids=inputs["input_ids"][:, :-1], - attention_mask=inputs["attention_mask"][:, :-1], - img_features=None, - img_attn_mask=inputs["img_attn_mask"][:, :-1], - use_cache=True, - max_new_tokens=64, - ) - generated_texts = self.processor.batch_decode( - generated_ids, - skip_special_tokens=True, - )[0] - processed_text, entities = ( - self.processor.post_process_generation(generated_texts) - ) + def run(self, task: str, image: str, *args, **kwargs): + """Run the model - def __call__(self, prompt, image): - """Run call""" + Args: + task (str): task to run + image (str): img url + """ inputs = self.processor( - text=prompt, images=image, return_tensors="pt" + text=task, images=image, return_tensors="pt" ) generated_ids = self.model.generate( pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"][:, :-1], attention_mask=inputs["attention_mask"][:, :-1], - img_features=None, + image_embeds=None, img_attn_mask=inputs["img_attn_mask"][:, :-1], use_cache=True, - max_new_tokens=64, + max_new_tokens=self.max_new_tokens, ) + generated_texts = self.processor.batch_decode( generated_ids, skip_special_tokens=True, )[0] + processed_text, entities = ( self.processor.post_process_generation(generated_texts) ) + return processed_text, entities + # tasks def multimodal_grounding(self, phrase, image_url): - prompt = f" {phrase} " - self.run(prompt, image_url) + task = f" {phrase} " + self.run(task, image_url) def referring_expression_comprehension(self, phrase, image_url): - prompt = f" {phrase} " - self.run(prompt, image_url) + task = f" {phrase} " + self.run(task, image_url) def referring_expression_generation(self, phrase, image_url): - prompt = ( + task = ( "" " It is" ) - self.run(prompt, image_url) + self.run(task, image_url) def grounded_vqa(self, question, image_url): - prompt = f" Question: {question} Answer:" - self.run(prompt, image_url) + task = f" Question: {question} Answer:" + self.run(task, image_url) def grounded_image_captioning(self, image_url): - prompt = " An image of" - self.run(prompt, image_url) + task = " An image of" + self.run(task, image_url) def grounded_image_captioning_detailed(self, image_url): - prompt = " Describe this image in detail" - self.run(prompt, image_url) + task = " Describe this image in detail" + self.run(task, image_url) def draw_entity_boxes_on_image( image, entities, show=False, save_path=None @@ -320,7 +321,7 @@ class Kosmos: return new_image - def generate_boxees(self, prompt, image_url): + def generate_boxees(self, task, image_url): image = self.get_image(image_url) - processed_text, entities = self.process_prompt(prompt, image) + processed_text, entities = self.process_task(task, image) self.draw_entity_boxes_on_image(image, entities, show=True) diff --git a/swarms/models/mixtral.py b/swarms/models/mixtral.py new file mode 100644 index 00000000..6f3a9c7d --- /dev/null +++ b/swarms/models/mixtral.py @@ -0,0 +1,73 @@ +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer +from swarms.models.base_llm import AbstractLLM + + +class Mixtral(AbstractLLM): + """Mixtral model. + + Args: + model_name (str): The name or path of the pre-trained Mixtral model. + max_new_tokens (int): The maximum number of new tokens to generate. + *args: Variable length argument list. + + + Examples: + >>> from swarms.models import Mixtral + >>> mixtral = Mixtral() + >>> mixtral.run("Test task") + 'Generated text' + """ + + def __init__( + self, + model_name: str = "mistralai/Mixtral-8x7B-v0.1", + max_new_tokens: int = 500, + *args, + **kwargs, + ): + """ + Initializes a Mixtral model. + + Args: + model_name (str): The name or path of the pre-trained Mixtral model. + max_new_tokens (int): The maximum number of new tokens to generate. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__(*args, **kwargs) + self.model_name = model_name + self.max_new_tokens = max_new_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, *args, **kwargs + ) + + def run(self, task: Optional[str] = None, **kwargs): + """ + Generates text based on the given task. + + Args: + task (str, optional): The task or prompt for text generation. + + Returns: + str: The generated text. + """ + try: + inputs = self.tokenizer(task, return_tensors="pt") + + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + **kwargs, + ) + + out = self.tokenizer.decode( + outputs[0], + skip_special_tokens=True, + ) + + return out + except Exception as error: + print(f"There is an error: {error} in Mixtral model.") + raise error diff --git a/swarms/models/openai_assistant.py b/swarms/models/openai_assistant.py deleted file mode 100644 index 6d0c518f..00000000 --- a/swarms/models/openai_assistant.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Dict, List, Optional -from dataclass import dataclass - -from swarms.models import OpenAI - - -@dataclass -class OpenAIAssistant: - name: str = "OpenAI Assistant" - instructions: str = None - tools: List[Dict] = None - model: str = None - openai_api_key: str = None - temperature: float = 0.5 - max_tokens: int = 100 - stop: List[str] = None - echo: bool = False - stream: bool = False - log: bool = False - presence: bool = False - dashboard: bool = False - debug: bool = False - max_loops: int = 5 - stopping_condition: Optional[str] = None - loop_interval: int = 1 - retry_attempts: int = 3 - retry_interval: int = 1 - interactive: bool = False - dynamic_temperature: bool = False - state: Dict = None - response_filters: List = None - response_filter: Dict = None - response_filter_name: str = None - response_filter_value: str = None - response_filter_type: str = None - response_filter_action: str = None - response_filter_action_value: str = None - response_filter_action_type: str = None - response_filter_action_name: str = None - client = OpenAI() - role: str = "user" - instructions: str = None - - def create_assistant(self, task: str): - assistant = self.client.create_assistant( - name=self.name, - instructions=self.instructions, - tools=self.tools, - model=self.model, - ) - return assistant - - def create_thread(self): - thread = self.client.beta.threads.create() - return thread - - def add_message_to_thread(self, thread_id: str, message: str): - message = self.client.beta.threads.add_message( - thread_id=thread_id, role=self.user, content=message - ) - return message - - def run(self, task: str): - run = self.client.beta.threads.runs.create( - thread_id=self.create_thread().id, - assistant_id=self.create_assistant().id, - instructions=self.instructions, - ) - - out = self.client.beta.threads.runs.retrieve( - thread_id=run.thread_id, run_id=run.id - ) - - return out diff --git a/swarms/models/openai_function_caller.py b/swarms/models/openai_function_caller.py index 6542e457..e6822793 100644 --- a/swarms/models/openai_function_caller.py +++ b/swarms/models/openai_function_caller.py @@ -234,10 +234,10 @@ class OpenAIFunctionCaller: ) ) - def call(self, prompt: str) -> Dict: - response = openai.Completion.create( + def call(self, task: str, *args, **kwargs) -> Dict: + return openai.Completion.create( engine=self.model, - prompt=prompt, + prompt=task, max_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, @@ -253,9 +253,10 @@ class OpenAIFunctionCaller: user=self.user, messages=self.messages, timeout_sec=self.timeout_sec, + *args, + **kwargs, ) - return response - def run(self, prompt: str) -> str: - response = self.call(prompt) + def run(self, task: str, *args, **kwargs) -> str: + response = self.call(task, *args, **kwargs) return response["choices"][0]["text"].strip() diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py deleted file mode 100644 index e9a599d0..00000000 --- a/swarms/models/simple_ada.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -from openai import OpenAI - -client = OpenAI() - - -def get_ada_embeddings( - text: str, model: str = "text-embedding-ada-002" -): - """ - Simple function to get embeddings from ada - - Usage: - >>> get_ada_embeddings("Hello World") - >>> get_ada_embeddings("Hello World", model="text-embedding-ada-001") - - """ - - text = text.replace("\n", " ") - - return client.embeddings.create(input=[text], model=model)[ - "data" - ][0]["embedding"] diff --git a/swarms/models/stable_diffusion.py b/swarms/models/stable_diffusion.py index 7b363d02..a0068531 100644 --- a/swarms/models/stable_diffusion.py +++ b/swarms/models/stable_diffusion.py @@ -140,6 +140,18 @@ class StableDiffusion: return image_paths def generate_and_move_image(self, prompt, iteration, folder_path): + """ + Generates an image based on the given prompt and moves it to the specified folder. + + Args: + prompt (str): The prompt used to generate the image. + iteration (int): The iteration number. + folder_path (str): The path to the folder where the image will be moved. + + Returns: + str: The path of the moved image. + + """ # Generate the image image_paths = self.run(prompt) if not image_paths: diff --git a/swarms/models/vllm.py b/swarms/models/vllm.py index 0ea4be4a..58745a75 100644 --- a/swarms/models/vllm.py +++ b/swarms/models/vllm.py @@ -1,12 +1,19 @@ +import torch from swarms.models.base_llm import AbstractLLM +import subprocess -try: - from vllm import LLM, SamplingParams -except ImportError as error: - print(f"[ERROR] [vLLM] {error}") - # subprocess.run(["pip", "install", "vllm"]) - # raise error - raise error +if torch.cuda.is_available() or torch.cuda.device_count() > 0: + # Download vllm with pip + try: + subprocess.run(["pip", "install", "vllm"]) + from vllm import LLM, SamplingParams + except Exception as error: + print(f"[ERROR] [vLLM] {error}") + raise error +else: + from swarms.models.huggingface import HuggingfaceLLM as LLM + + SamplingParams = None class vLLM(AbstractLLM): @@ -83,8 +90,9 @@ class vLLM(AbstractLLM): _type_: _description_ """ try: - outputs = self.llm.generate(task, self.sampling_params) - return outputs + return self.llm.generate( + task, self.sampling_params, *args, **kwargs + ) except Exception as error: print(f"[ERROR] [vLLM] [run] {error}") raise error diff --git a/swarms/models/whisperx_model.py b/swarms/models/whisperx_model.py deleted file mode 100644 index e3b76fae..00000000 --- a/swarms/models/whisperx_model.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import subprocess - -try: - import whisperx - from pydub import AudioSegment - from pytube import YouTube -except Exception as error: - print("Error importing pytube. Please install pytube manually.") - print("pip install pytube") - print("pip install pydub") - print("pip install whisperx") - print(f"Pytube error: {error}") - - -class WhisperX: - def __init__( - self, - video_url, - audio_format="mp3", - device="cuda", - batch_size=16, - compute_type="float16", - hf_api_key=None, - ): - """ - # Example usage - video_url = "url" - speech_to_text = WhisperX(video_url) - transcription = speech_to_text.transcribe_youtube_video() - print(transcription) - - """ - self.video_url = video_url - self.audio_format = audio_format - self.device = device - self.batch_size = batch_size - self.compute_type = compute_type - self.hf_api_key = hf_api_key - - def install(self): - subprocess.run(["pip", "install", "whisperx"]) - subprocess.run(["pip", "install", "pytube"]) - subprocess.run(["pip", "install", "pydub"]) - - def download_youtube_video(self): - audio_file = f"video.{self.audio_format}" - - # Download video 📥 - yt = YouTube(self.video_url) - yt_stream = yt.streams.filter(only_audio=True).first() - yt_stream.download(filename="video.mp4") - - # Convert video to audio 🎧 - video = AudioSegment.from_file("video.mp4", format="mp4") - video.export(audio_file, format=self.audio_format) - os.remove("video.mp4") - - return audio_file - - def transcribe_youtube_video(self): - audio_file = self.download_youtube_video() - - device = "cuda" - batch_size = 16 - compute_type = "float16" - - # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model( - "large-v2", device, compute_type=compute_type - ) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=batch_size) - - # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=device - ) - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=device - ) - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join( - segment["text"] for segment in segments - ) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") - - def transcribe(self, audio_file): - model = whisperx.load_model( - "large-v2", self.device, self.compute_type - ) - audio = whisperx.load_audio(audio_file) - result = model.transcribe(audio, batch_size=self.batch_size) - - # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( - language_code=result["language"], device=self.device - ) - - result = whisperx.align( - result["segments"], - model_a, - metadata, - audio, - self.device, - return_char_alignments=False, - ) - - # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_api_key, device=self.device - ) - - diarize_model(audio_file) - - try: - segments = result["segments"] - transcription = " ".join( - segment["text"] for segment in segments - ) - return transcription - except KeyError: - print("The key 'segments' is not found in the result.") diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index e389ed76..4a58ea8d 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -1,5 +1,21 @@ from swarms.structs.agent import Agent from swarms.structs.sequential_workflow import SequentialWorkflow from swarms.structs.autoscaler import AutoScaler +from swarms.structs.conversation import Conversation +from swarms.structs.schemas import ( + TaskInput, + Artifact, + ArtifactUpload, + StepInput, +) -__all__ = ["Agent", "SequentialWorkflow", "AutoScaler"] +__all__ = [ + "Agent", + "SequentialWorkflow", + "AutoScaler", + "Conversation", + "TaskInput", + "Artifact", + "ArtifactUpload", + "StepInput", +] diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 9d48791e..be5c7121 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -12,7 +12,6 @@ from termcolor import colored from swarms.memory.base_vectordb import VectorDatabase from swarms.prompts.agent_system_prompts import ( - FLOW_SYSTEM_PROMPT, AGENT_SYSTEM_PROMPT_3, agent_system_prompt_2, ) diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index 1cb31333..f26247d5 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -3,8 +3,6 @@ import queue import threading from time import sleep from typing import Callable, Dict, List, Optional -import asyncio -import concurrent.futures from termcolor import colored @@ -14,9 +12,10 @@ from swarms.utils.decorators import ( log_decorator, timing_decorator, ) +from swarms.structs.base import BaseStructure -class AutoScaler: +class AutoScaler(BaseStructure): """ AutoScaler class @@ -262,11 +261,17 @@ class AutoScaler: def balance_load(self): """Distributes tasks among agents based on their current load.""" - while not self.task_queue.empty(): - for agent in self.agents_pool: - if agent.can_accept_task(): - task = self.task_queue.get() - agent.run(task) + try: + while not self.task_queue.empty(): + for agent in self.agents_pool: + if agent.can_accept_task(): + task = self.task_queue.get() + agent.run(task) + except Exception as error: + print( + f"Error balancing load: {error} try again with a new" + " task" + ) def set_scaling_strategy( self, strategy: Callable[[int, int], int] @@ -276,17 +281,23 @@ class AutoScaler: def execute_scaling_strategy(self): """Execute the custom scaling strategy if defined.""" - if hasattr(self, "custom_scale_strategy"): - scale_amount = self.custom_scale_strategy( - self.task_queue.qsize(), len(self.agents_pool) + try: + if hasattr(self, "custom_scale_strategy"): + scale_amount = self.custom_scale_strategy( + self.task_queue.qsize(), len(self.agents_pool) + ) + if scale_amount > 0: + for _ in range(scale_amount): + self.agents_pool.append(self.agent()) + elif scale_amount < 0: + for _ in range(abs(scale_amount)): + if len(self.agents_pool) > 10: + del self.agents_pool[-1] + except Exception as error: + print( + f"Error executing scaling strategy: {error} try again" + " with a new task" ) - if scale_amount > 0: - for _ in range(scale_amount): - self.agents_pool.append(self.agent()) - elif scale_amount < 0: - for _ in range(abs(scale_amount)): - if len(self.agents_pool) > 10: - del self.agents_pool[-1] def report_agent_metrics(self) -> Dict[str, List[float]]: """Collects and reports metrics from each agent.""" diff --git a/swarms/structs/base.py b/swarms/structs/base.py index 7d365b23..adfa974d 100644 --- a/swarms/structs/base.py +++ b/swarms/structs/base.py @@ -1,6 +1,6 @@ import json import os -from abc import ABC, abstractmethod +from abc import ABC from typing import Optional, Any, Dict, List from datetime import datetime import asyncio diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py new file mode 100644 index 00000000..ccb346e6 --- /dev/null +++ b/swarms/structs/conversation.py @@ -0,0 +1,309 @@ +import datetime +import json + +from termcolor import colored + +from swarms.memory.base_db import AbstractDatabase +from swarms.structs.base import BaseStructure + + +class Conversation(BaseStructure): + """ + Conversation class + + + Attributes: + time_enabled (bool): whether to enable time + conversation_history (list): list of messages in the conversation + + + Examples: + >>> conv = Conversation() + >>> conv.add("user", "Hello, world!") + >>> conv.add("assistant", "Hello, user!") + >>> conv.display_conversation() + user: Hello, world! + + """ + + def __init__( + self, + time_enabled: bool = False, + database: AbstractDatabase = None, + autosave: bool = True, + save_filepath: str = "/runs/conversation.json", + *args, + **kwargs, + ): + super().__init__() + self.time_enabled = time_enabled + self.database = database + self.autosave = autosave + self.save_filepath = save_filepath + self.conversation_history = [] + + def add(self, role: str, content: str, *args, **kwargs): + """Add a message to the conversation history + + Args: + role (str): The role of the speaker + content (str): The content of the message + + """ + if self.time_enabled: + now = datetime.datetime.now() + timestamp = now.strftime("%Y-%m-%d %H:%M:%S") + message = { + "role": role, + "content": content, + "timestamp": timestamp, + } + else: + message = { + "role": role, + "content": content, + } + + self.conversation_history.append(message) + + if self.autosave: + self.save_as_json(self.save_filepath) + + def delete(self, index: str): + """Delete a message from the conversation history + + Args: + index (str): index of the message to delete + """ + self.conversation_history.pop(index) + + def update(self, index: str, role, content): + """Update a message in the conversation history + + Args: + index (str): index of the message to update + role (_type_): role of the speaker + content (_type_): content of the message + """ + self.conversation_history[index] = { + "role": role, + "content": content, + } + + def query(self, index: str): + """Query a message in the conversation history + + Args: + index (str): index of the message to query + + Returns: + str: the message + """ + return self.conversation_history[index] + + def search(self, keyword: str): + """Search for a message in the conversation history + + Args: + keyword (str): Keyword to search for + + Returns: + str: description + """ + return [ + msg + for msg in self.conversation_history + if keyword in msg["content"] + ] + + def display_conversation(self, detailed: bool = False): + """Display the conversation history + + Args: + detailed (bool, optional): detailed. Defaults to False. + """ + role_to_color = { + "system": "red", + "user": "green", + "assistant": "blue", + "function": "magenta", + } + for message in self.conversation_history: + print( + colored( + f"{message['role']}: {message['content']}\n\n", + role_to_color[message["role"]], + ) + ) + + def export_conversation(self, filename: str, *args, **kwargs): + """Export the conversation history to a file + + Args: + filename (str): filename to export to + """ + with open(filename, "w") as f: + for message in self.conversation_history: + f.write(f"{message['role']}: {message['content']}\n") + + def import_conversation(self, filename: str): + """Import a conversation history from a file + + Args: + filename (str): filename to import from + """ + with open(filename, "r") as f: + for line in f: + role, content = line.split(": ", 1) + self.add(role, content.strip()) + + def count_messages_by_role(self): + """Count the number of messages by role""" + counts = { + "system": 0, + "user": 0, + "assistant": 0, + "function": 0, + } + for message in self.conversation_history: + counts[message["role"]] += 1 + return counts + + def return_history_as_string(self): + """Return the conversation history as a string + + Returns: + str: the conversation history + """ + return "\n".join( + [ + f"{message['role']}: {message['content']}\n\n" + for message in self.conversation_history + ] + ) + + def save_as_json(self, filename: str): + """Save the conversation history as a JSON file + + Args: + filename (str): Save the conversation history as a JSON file + """ + # Save the conversation history as a JSON file + with open(filename, "w") as f: + json.dump(self.conversation_history, f) + + def load_from_json(self, filename: str): + """Load the conversation history from a JSON file + + Args: + filename (str): filename to load from + """ + # Load the conversation history from a JSON file + with open(filename, "r") as f: + self.conversation_history = json.load(f) + + def search_keyword_in_conversation(self, keyword: str): + """Search for a keyword in the conversation history + + Args: + keyword (str): keyword to search for + + Returns: + str: description + """ + return [ + msg + for msg in self.conversation_history + if keyword in msg["content"] + ] + + def pretty_print_conversation(self, messages): + """Pretty print the conversation history + + Args: + messages (str): messages to print + """ + role_to_color = { + "system": "red", + "user": "green", + "assistant": "blue", + "tool": "magenta", + } + + for message in messages: + if message["role"] == "system": + print( + colored( + f"system: {message['content']}\n", + role_to_color[message["role"]], + ) + ) + elif message["role"] == "user": + print( + colored( + f"user: {message['content']}\n", + role_to_color[message["role"]], + ) + ) + elif message["role"] == "assistant" and message.get( + "function_call" + ): + print( + colored( + f"assistant: {message['function_call']}\n", + role_to_color[message["role"]], + ) + ) + elif message["role"] == "assistant" and not message.get( + "function_call" + ): + print( + colored( + f"assistant: {message['content']}\n", + role_to_color[message["role"]], + ) + ) + elif message["role"] == "tool": + print( + colored( + ( + f"function ({message['name']}):" + f" {message['content']}\n" + ), + role_to_color[message["role"]], + ) + ) + + def add_to_database(self, *args, **kwargs): + """Add the conversation history to the database""" + self.database.add("conversation", self.conversation_history) + + def query_from_database(self, query, *args, **kwargs): + """Query the conversation history from the database""" + return self.database.query("conversation", query) + + def delete_from_database(self, *args, **kwargs): + """Delete the conversation history from the database""" + self.database.delete("conversation") + + def update_from_database(self, *args, **kwargs): + """Update the conversation history from the database""" + self.database.update( + "conversation", self.conversation_history + ) + + def get_from_database(self, *args, **kwargs): + """Get the conversation history from the database""" + return self.database.get("conversation") + + def execute_query_from_database(self, query, *args, **kwargs): + """Execute a query on the database""" + return self.database.execute_query(query) + + def fetch_all_from_database(self, *args, **kwargs): + """Fetch all from the database""" + return self.database.fetch_all() + + def fetch_one_from_database(self, *args, **kwargs): + """Fetch one from the database""" + return self.database.fetch_one() diff --git a/swarms/memory/schemas.py b/swarms/structs/schemas.py similarity index 93% rename from swarms/memory/schemas.py rename to swarms/structs/schemas.py index 9147a909..f7f5441e 100644 --- a/swarms/memory/schemas.py +++ b/swarms/structs/schemas.py @@ -17,6 +17,15 @@ class TaskInput(BaseModel): class Artifact(BaseModel): + """ + Represents an artifact. + + Attributes: + artifact_id (str): Id of the artifact. + file_name (str): Filename of the artifact. + relative_path (str, optional): Relative path of the artifact in the agent's workspace. + """ + artifact_id: str = Field( ..., description="Id of the artifact", diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index 9ddbd324..72fc7199 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -5,14 +5,21 @@ from swarms.utils.parse_code import ( ) from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.math_eval import math_eval - -# from swarms.utils.phoenix_handler import phoenix_trace_decorator +from swarms.utils.llm_metrics_decorator import metrics_decorator +from swarms.utils.device_checker_cuda import check_device +from swarms.utils.load_model_torch import load_model_torch +from swarms.utils.prep_torch_model_inference import ( + prep_torch_inference, +) __all__ = [ "display_markdown_message", "SubprocessCodeInterpreter", "extract_code_in_backticks_in_string", "pdf_to_text", - # "phoenix_trace_decorator", "math_eval", + "metrics_decorator", + "check_device", + "load_model_torch", + "prep_torch_inference", ] diff --git a/swarms/utils/apa.py b/swarms/utils/apa.py index f2e1bb38..fa73b7b4 100644 --- a/swarms/utils/apa.py +++ b/swarms/utils/apa.py @@ -1,7 +1,5 @@ from enum import Enum, unique, auto import abc -import hashlib -import re from typing import List, Optional import json from dataclasses import dataclass, field diff --git a/swarms/utils/class_args_wrapper.py b/swarms/utils/class_args_wrapper.py new file mode 100644 index 00000000..f24932cf --- /dev/null +++ b/swarms/utils/class_args_wrapper.py @@ -0,0 +1,36 @@ +import inspect + + +def print_class_parameters(cls, api_format: bool = False): + """ + Print the parameters of a class constructor. + + Parameters: + cls (type): The class to inspect. + + Example: + >>> print_class_parameters(Agent) + Parameter: x, Type: + Parameter: y, Type: + """ + try: + # Get the parameters of the class constructor + sig = inspect.signature(cls.__init__) + params = sig.parameters + + if api_format: + param_dict = {} + for name, param in params.items(): + if name == "self": + continue + param_dict[name] = str(param.annotation) + return param_dict + + # Print the parameters + for name, param in params.items(): + if name == "self": + continue + print(f"Parameter: {name}, Type: {param.annotation}") + + except Exception as e: + print(f"An error occurred while inspecting the class: {e}") diff --git a/swarms/utils/device_checker_cuda.py b/swarms/utils/device_checker_cuda.py new file mode 100644 index 00000000..dbf2191c --- /dev/null +++ b/swarms/utils/device_checker_cuda.py @@ -0,0 +1,70 @@ +import torch +import logging +from typing import Union, List, Any +from torch.cuda import memory_allocated, memory_reserved + + +def check_device( + log_level: Any = logging.INFO, + memory_threshold: float = 0.8, + capability_threshold: float = 3.5, + return_type: str = "list", +) -> Union[torch.device, List[torch.device]]: + """ + Checks for the availability of CUDA and returns the appropriate device(s). + If CUDA is not available, returns a CPU device. + If CUDA is available, returns a list of all available GPU devices. + """ + logging.basicConfig(level=log_level) + + # Check for CUDA availability + try: + if not torch.cuda.is_available(): + logging.info("CUDA is not available. Using CPU...") + return torch.device("cpu") + except Exception as e: + logging.error("Error checking for CUDA availability: ", e) + return torch.device("cpu") + + logging.info("CUDA is available.") + + # Check for multiple GPUs + num_gpus = torch.cuda.device_count() + devices = [] + if num_gpus > 1: + logging.info(f"Multiple GPUs available: {num_gpus}") + devices = [torch.device(f"cuda:{i}") for i in range(num_gpus)] + else: + logging.info("Only one GPU is available.") + devices = [torch.device("cuda")] + + # Check additional properties for each device + for device in devices: + try: + torch.cuda.set_device(device) + capability = torch.cuda.get_device_capability(device) + total_memory = torch.cuda.get_device_properties( + device + ).total_memory + allocated_memory = memory_allocated(device) + reserved_memory = memory_reserved(device) + device_name = torch.cuda.get_device_name(device) + + logging.info( + f"Device: {device}, Name: {device_name}, Compute" + f" Capability: {capability}, Total Memory:" + f" {total_memory}, Allocated Memory:" + f" {allocated_memory}, Reserved Memory:" + f" {reserved_memory}" + ) + except Exception as e: + logging.error( + f"Error retrieving properties for device {device}: ", + e, + ) + + return devices + + +# devices = check_device() +# logging.info(f"Using device(s): {devices}") diff --git a/swarms/utils/llm_metrics_decorator.py b/swarms/utils/llm_metrics_decorator.py new file mode 100644 index 00000000..a915623a --- /dev/null +++ b/swarms/utils/llm_metrics_decorator.py @@ -0,0 +1,39 @@ +import time +from functools import wraps +from typing import Callable + + +def metrics_decorator(func: Callable): + """Metrics decorator for LLM + + Args: + func (Callable): The function to decorate + + Example: + >>> @metrics_decorator + >>> def my_function(): + >>> return "Hello, world!" + >>> my_function() + + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # Time to First Token + start_time = time.time() + result = func(self, *args, **kwargs) + first_token_time = time.time() + + # Generation Latency + end_time = time.time() + + # Throughput (assuming the function returns a list of tokens) + throughput = len(result) / (end_time - start_time) + + return f""" + Time to First Token: {first_token_time - start_time} + Generation Latency: {end_time - start_time} + Throughput: {throughput} + """ + + return wrapper diff --git a/swarms/utils/load_model_torch.py b/swarms/utils/load_model_torch.py new file mode 100644 index 00000000..53649e93 --- /dev/null +++ b/swarms/utils/load_model_torch.py @@ -0,0 +1,57 @@ +import torch +from torch import nn + + +def load_model_torch( + model_path: str = None, + device: torch.device = None, + model: nn.Module = None, + strict: bool = True, + map_location=None, + *args, + **kwargs, +) -> nn.Module: + """ + Load a PyTorch model from a given path and move it to the specified device. + + Args: + model_path (str): Path to the saved model file. + device (torch.device): Device to move the model to. + model (nn.Module): The model architecture, if the model file only contains the state dictionary. + strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function. + map_location (callable): A function to remap the storage locations of the loaded model. + *args: Additional arguments to pass to `torch.load`. + **kwargs: Additional keyword arguments to pass to `torch.load`. + + Returns: + nn.Module: The loaded model. + + Raises: + FileNotFoundError: If the model file is not found. + RuntimeError: If there is an error while loading the model. + """ + if device is None: + device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + try: + if model is None: + model = torch.load( + model_path, map_location=map_location, *args, **kwargs + ) + else: + model.load_state_dict( + torch.load( + model_path, + map_location=map_location, + *args, + **kwargs, + ), + strict=strict, + ) + return model.to(device) + except FileNotFoundError: + raise FileNotFoundError(f"Model file not found: {model_path}") + except RuntimeError as e: + raise RuntimeError(f"Error loading model: {str(e)}") diff --git a/swarms/utils/loggers.py b/swarms/utils/loggers.py index a0dec94d..68477132 100644 --- a/swarms/utils/loggers.py +++ b/swarms/utils/loggers.py @@ -487,21 +487,21 @@ def print_action_base(action: Action): """ if action.content != "": logger.typewriter_log( - f"content:", Fore.YELLOW, f"{action.content}" + "content:", Fore.YELLOW, f"{action.content}" ) logger.typewriter_log( - f"Thought:", Fore.YELLOW, f"{action.thought}" + "Thought:", Fore.YELLOW, f"{action.thought}" ) if len(action.plan) > 0: logger.typewriter_log( - f"Plan:", + "Plan:", Fore.YELLOW, ) for line in action.plan: line = line.lstrip("- ") logger.typewriter_log("- ", Fore.GREEN, line.strip()) logger.typewriter_log( - f"Criticism:", Fore.YELLOW, f"{action.criticism}" + "Criticism:", Fore.YELLOW, f"{action.criticism}" ) @@ -515,15 +515,15 @@ def print_action_tool(action: Action): Returns: None """ - logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}") + logger.typewriter_log("Tool:", Fore.BLUE, f"{action.tool_name}") logger.typewriter_log( - f"Tool Input:", Fore.BLUE, f"{action.tool_input}" + "Tool Input:", Fore.BLUE, f"{action.tool_input}" ) output = ( action.tool_output if action.tool_output != "" else "None" ) - logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}") + logger.typewriter_log("Tool Output:", Fore.BLUE, f"{output}") color = Fore.RED if action.tool_output_status == ToolCallStatus.ToolCallSuccess: @@ -534,7 +534,7 @@ def print_action_tool(action: Action): color = Fore.YELLOW logger.typewriter_log( - f"Tool Call Status:", + "Tool Call Status:", Fore.BLUE, f"{color}{action.tool_output_status.name}{Style.RESET_ALL}", ) diff --git a/swarms/utils/main.py b/swarms/utils/main.py index c9c0f380..b94fae11 100644 --- a/swarms/utils/main.py +++ b/swarms/utils/main.py @@ -9,7 +9,6 @@ from typing import Dict import boto3 import numpy as np -import pandas as pd import requests diff --git a/swarms/utils/pdf_to_text.py b/swarms/utils/pdf_to_text.py index 35309dd3..6d589ad5 100644 --- a/swarms/utils/pdf_to_text.py +++ b/swarms/utils/pdf_to_text.py @@ -1,5 +1,4 @@ import sys -import os try: import PyPDF2 diff --git a/swarms/utils/prep_torch_model_inference.py b/swarms/utils/prep_torch_model_inference.py new file mode 100644 index 00000000..41bc07cc --- /dev/null +++ b/swarms/utils/prep_torch_model_inference.py @@ -0,0 +1,30 @@ +import torch +from swarms.utils.load_model_torch import load_model_torch + + +def prep_torch_inference( + model_path: str = None, + device: torch.device = None, + *args, + **kwargs, +): + """ + Prepare a Torch model for inference. + + Args: + model_path (str): Path to the model file. + device (torch.device): Device to run the model on. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + torch.nn.Module: The prepared model. + """ + try: + model = load_model_torch(model_path, device) + model.eval() + return model + except Exception as e: + # Add error handling code here + print(f"Error occurred while preparing Torch model: {e}") + return None diff --git a/tests/memory/test_pinecone.py b/tests/memory/test_pinecone.py index f43cd6ea..f385f058 100644 --- a/tests/memory/test_pinecone.py +++ b/tests/memory/test_pinecone.py @@ -1,6 +1,6 @@ import os from unittest.mock import patch -from swarms.memory.pinecone import PineconDB +from swarms.memory.pinecone import PineconeDB api_key = os.getenv("PINECONE_API_KEY") or "" @@ -9,7 +9,7 @@ def test_init(): with patch("pinecone.init") as MockInit, patch( "pinecone.Index" ) as MockIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", @@ -21,7 +21,7 @@ def test_init(): def test_upsert_vector(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", @@ -37,7 +37,7 @@ def test_upsert_vector(): def test_load_entry(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", @@ -48,7 +48,7 @@ def test_load_entry(): def test_load_entries(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", @@ -59,7 +59,7 @@ def test_load_entries(): def test_query(): with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", @@ -72,7 +72,7 @@ def test_create_index(): with patch("pinecone.init"), patch("pinecone.Index"), patch( "pinecone.create_index" ) as MockCreateIndex: - store = PineconDB( + store = PineconeDB( api_key=api_key, index_name="test_index", environment="test_env", diff --git a/tests/memory/test_pg.py b/tests/memory/test_pq_db.py similarity index 51% rename from tests/memory/test_pg.py rename to tests/memory/test_pq_db.py index 2bddfb27..5e44f0ba 100644 --- a/tests/memory/test_pg.py +++ b/tests/memory/test_pq_db.py @@ -1,55 +1,45 @@ -import pytest +import os from unittest.mock import patch -from swarms.memory.pg import PgVectorVectorStore + from dotenv import load_dotenv -import os -load_dotenv() +from swarms.memory.pg import PostgresDB +load_dotenv() PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING") def test_init(): with patch("sqlalchemy.create_engine") as MockEngine: - store = PgVectorVectorStore( + db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", ) MockEngine.assert_called_once() - assert store.engine == MockEngine.return_value + assert db.engine == MockEngine.return_value -def test_init_exception(): - with pytest.raises(ValueError): - PgVectorVectorStore( - connection_string=( - "mysql://root:password@localhost:3306/test" - ), - table_name="test", - ) - - -def test_setup(): - with patch("sqlalchemy.create_engine") as MockEngine: - store = PgVectorVectorStore( +def test_create_vector_model(): + with patch("sqlalchemy.create_engine"): + db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", ) - store.setup() - MockEngine.execute.assert_called() + model = db._create_vector_model() + assert model.__tablename__ == "test" -def test_upsert_vector(): +def test_add_or_update_vector(): with patch("sqlalchemy.create_engine"), patch( "sqlalchemy.orm.Session" ) as MockSession: - store = PgVectorVectorStore( + db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", ) - store.upsert_vector( - [1.0, 2.0, 3.0], + db.add_or_update_vector( + "test_vector", "test_id", "test_namespace", {"meta": "data"}, @@ -59,45 +49,32 @@ def test_upsert_vector(): MockSession.return_value.commit.assert_called() -def test_load_entry(): +def test_query_vectors(): with patch("sqlalchemy.create_engine"), patch( "sqlalchemy.orm.Session" ) as MockSession: - store = PgVectorVectorStore( + db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", ) - store.load_entry("test_id", "test_namespace") - MockSession.assert_called() - MockSession.return_value.get.assert_called() - - -def test_load_entries(): - with patch("sqlalchemy.create_engine"), patch( - "sqlalchemy.orm.Session" - ) as MockSession: - store = PgVectorVectorStore( - connection_string=PSG_CONNECTION_STRING, - table_name="test", - ) - store.load_entries("test_namespace") + db.query_vectors("test_query", "test_namespace") MockSession.assert_called() MockSession.return_value.query.assert_called() MockSession.return_value.query.return_value.filter_by.assert_called() + MockSession.return_value.query.return_value.filter.assert_called() MockSession.return_value.query.return_value.all.assert_called() -def test_query(): +def test_delete_vector(): with patch("sqlalchemy.create_engine"), patch( "sqlalchemy.orm.Session" ) as MockSession: - store = PgVectorVectorStore( + db = PostgresDB( connection_string=PSG_CONNECTION_STRING, table_name="test", ) - store.query("test_query", 10, "test_namespace") + db.delete_vector("test_id") MockSession.assert_called() - MockSession.return_value.query.assert_called() - MockSession.return_value.query.return_value.filter_by.assert_called() - MockSession.return_value.query.return_value.limit.assert_called() - MockSession.return_value.query.return_value.all.assert_called() + MockSession.return_value.get.assert_called() + MockSession.return_value.delete.assert_called() + MockSession.return_value.commit.assert_called() diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 903c3a0e..0b66b749 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -1,4 +1,3 @@ -import pytest from swarms.memory.short_term_memory import ShortTermMemory import threading diff --git a/tests/memory/test_sqlite.py b/tests/memory/test_sqlite.py new file mode 100644 index 00000000..6b4213b0 --- /dev/null +++ b/tests/memory/test_sqlite.py @@ -0,0 +1,104 @@ +import pytest +import sqlite3 +from swarms.memory.sqlite import SQLiteDB + + +@pytest.fixture +def db(): + conn = sqlite3.connect(":memory:") + conn.execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + ) + conn.commit() + return SQLiteDB(":memory:") + + +def test_add(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [(1, "test")] + + +def test_delete(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.delete("DELETE FROM test WHERE name = ?", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [] + + +def test_update(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.update( + "UPDATE test SET name = ? WHERE name = ?", ("new", "test") + ) + result = db.query("SELECT * FROM test") + assert result == [(1, "new")] + + +def test_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test WHERE name = ?", ("test",)) + assert result == [(1, "test")] + + +def test_execute_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.execute_query( + "SELECT * FROM test WHERE name = ?", ("test",) + ) + assert result == [(1, "test")] + + +def test_add_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.add("INSERT INTO test (name) VALUES (?)") + + +def test_delete_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.delete("DELETE FROM test WHERE name = ?") + + +def test_update_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.update("UPDATE test SET name = ? WHERE name = ?") + + +def test_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.query("SELECT * FROM test WHERE name = ?") + + +def test_execute_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.execute_query("SELECT * FROM test WHERE name = ?") + + +def test_add_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.add("INSERT INTO wrong (name) VALUES (?)", ("test",)) + + +def test_delete_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.delete("DELETE FROM wrong WHERE name = ?", ("test",)) + + +def test_update_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.update( + "UPDATE wrong SET name = ? WHERE name = ?", + ("new", "test"), + ) + + +def test_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.query("SELECT * FROM wrong WHERE name = ?", ("test",)) + + +def test_execute_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.execute_query( + "SELECT * FROM wrong WHERE name = ?", ("test",) + ) diff --git a/tests/memory/test_weaviate.py b/tests/memory/test_weaviate.py index 09dc6d45..f9e61c8f 100644 --- a/tests/memory/test_weaviate.py +++ b/tests/memory/test_weaviate.py @@ -1,12 +1,12 @@ import pytest from unittest.mock import Mock, patch -from swarms.memory.weaviate import WeaviateClient +from swarms.memory import WeaviateDB -# Define fixture for a WeaviateClient instance with mocked methods +# Define fixture for a WeaviateDB instance with mocked methods @pytest.fixture def weaviate_client_mock(): - client = WeaviateClient( + client = WeaviateDB( http_host="mock_host", http_port="mock_port", http_secure=False, @@ -31,7 +31,7 @@ def weaviate_client_mock(): return client -# Define tests for the WeaviateClient class +# Define tests for the WeaviateDB class def test_create_collection(weaviate_client_mock): # Test creating a collection weaviate_client_mock.create_collection( diff --git a/tests/models/test_LLM.py b/tests/models/test_LLM.py deleted file mode 100644 index 44bc29c9..00000000 --- a/tests/models/test_LLM.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -import os -from unittest.mock import patch -from langchain import HuggingFaceHub -from langchain.chat_models import ChatOpenAI - -from swarms.models.llm import LLM - - -class TestLLM(unittest.TestCase): - @patch.object(HuggingFaceHub, "__init__", return_value=None) - @patch.object(ChatOpenAI, "__init__", return_value=None) - def setUp(self, mock_hf_init, mock_openai_init): - self.llm_openai = LLM(openai_api_key="mock_openai_key") - self.llm_hf = LLM( - hf_repo_id="mock_repo_id", hf_api_token="mock_hf_token" - ) - self.prompt = "Who won the FIFA World Cup in 1998?" - - def test_init(self): - self.assertEqual( - self.llm_openai.openai_api_key, "mock_openai_key" - ) - self.assertEqual(self.llm_hf.hf_repo_id, "mock_repo_id") - self.assertEqual(self.llm_hf.hf_api_token, "mock_hf_token") - - @patch.object(HuggingFaceHub, "run", return_value="France") - @patch.object(ChatOpenAI, "run", return_value="France") - def test_run(self, mock_hf_run, mock_openai_run): - result_openai = self.llm_openai.run(self.prompt) - mock_openai_run.assert_called_once() - self.assertEqual(result_openai, "France") - - result_hf = self.llm_hf.run(self.prompt) - mock_hf_run.assert_called_once() - self.assertEqual(result_hf, "France") - - def test_error_on_no_keys(self): - with self.assertRaises(ValueError): - LLM() - - @patch.object(os, "environ", {}) - def test_error_on_missing_hf_token(self): - with self.assertRaises(ValueError): - LLM(hf_repo_id="mock_repo_id") - - @patch.dict( - os.environ, {"HUGGINGFACEHUB_API_TOKEN": "mock_hf_token"} - ) - def test_hf_token_from_env(self): - llm = LLM(hf_repo_id="mock_repo_id") - self.assertEqual(llm.hf_api_token, "mock_hf_token") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/models/test_ada.py b/tests/models/test_ada.py deleted file mode 100644 index 43895e79..00000000 --- a/tests/models/test_ada.py +++ /dev/null @@ -1,91 +0,0 @@ -# test_embeddings.py - -import pytest -import openai -from unittest.mock import patch -from swarms.models.simple_ada import ( - get_ada_embeddings, -) # Adjust this import path to your project structure -from os import getenv -from dotenv import load_dotenv - -load_dotenv() - - -# Fixture for test texts -@pytest.fixture -def test_texts(): - return [ - "Hello World", - "This is a test string with newline\ncharacters", - "A quick brown fox jumps over the lazy dog", - ] - - -# Basic Test -def test_get_ada_embeddings_basic(test_texts): - with patch("openai.resources.Embeddings.create") as mock_create: - # Mocking the OpenAI API call - mock_create.return_value = { - "data": [{"embedding": [0.1, 0.2, 0.3]}] - } - - for text in test_texts: - embedding = get_ada_embeddings(text) - assert embedding == [ - 0.1, - 0.2, - 0.3, - ], "Embedding does not match expected output" - mock_create.assert_called_with( - input=[text.replace("\n", " ")], - model="text-embedding-ada-002", - ) - - -# Parameterized Test -@pytest.mark.parametrize( - "text, model, expected_call_model", - [ - ( - "Hello World", - "text-embedding-ada-002", - "text-embedding-ada-002", - ), - ( - "Hello World", - "text-embedding-ada-001", - "text-embedding-ada-001", - ), - ], -) -def test_get_ada_embeddings_models(text, model, expected_call_model): - with patch("openai.resources.Embeddings.create") as mock_create: - mock_create.return_value = { - "data": [{"embedding": [0.1, 0.2, 0.3]}] - } - - _ = get_ada_embeddings(text, model=model) - mock_create.assert_called_with( - input=[text], model=expected_call_model - ) - - -# Exception Test -def test_get_ada_embeddings_exception(): - with patch("openai.resources.Embeddings.create") as mock_create: - mock_create.side_effect = openai.OpenAIError("Test error") - with pytest.raises(openai.OpenAIError): - get_ada_embeddings("Some text") - - -# Tests for environment variable loading -def test_env_var_loading(monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "testkey123") - with patch("openai.resources.Embeddings.create"): - assert ( - getenv("OPENAI_API_KEY") == "testkey123" - ), "Environment variable for API key is not set correctly" - - -# ... more tests to cover other aspects such as different input types, large inputs, invalid inputs, etc. diff --git a/tests/models/test_auto_temp.py b/tests/models/test_auto_temp.py deleted file mode 100644 index 7937d0dc..00000000 --- a/tests/models/test_auto_temp.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import patch - -import pytest -from dotenv import load_dotenv - -from swarms.models.autotemp import AutoTempAgent - -api_key = os.getenv("OPENAI_API_KEY") - -load_dotenv() - - -@pytest.fixture -def auto_temp_agent(): - return AutoTempAgent(api_key=api_key) - - -def test_initialization(auto_temp_agent): - assert isinstance(auto_temp_agent, AutoTempAgent) - assert auto_temp_agent.auto_select is True - assert auto_temp_agent.max_workers == 6 - assert auto_temp_agent.temperature == 0.5 - assert auto_temp_agent.alt_temps == [0.4, 0.6, 0.8, 1.0, 1.2, 1.4] - - -def test_evaluate_output(auto_temp_agent): - output = "This is a test output." - with patch("swarms.models.OpenAIChat") as MockOpenAIChat: - mock_instance = MockOpenAIChat.return_value - mock_instance.return_value = "Score: 95.5" - score = auto_temp_agent.evaluate_output(output) - assert score == 95.5 - mock_instance.assert_called_once() - - -def test_run_auto_select(auto_temp_agent): - task = "Generate a blog post." - temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" - result = auto_temp_agent.run(task, temperature_string) - assert "Best AutoTemp Output" in result - assert "Temp" in result - assert "Score" in result - - -def test_run_no_scores(auto_temp_agent): - task = "Invalid task." - temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" - with ThreadPoolExecutor( - max_workers=auto_temp_agent.max_workers - ) as executor: - with patch.object( - executor, - "submit", - side_effect=[None, None, None, None, None, None], - ): - result = auto_temp_agent.run(task, temperature_string) - assert result == "No valid outputs generated." - - -def test_run_manual_select(auto_temp_agent): - auto_temp_agent.auto_select = False - task = "Generate a blog post." - temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4" - result = auto_temp_agent.run(task, temperature_string) - assert "Best AutoTemp Output" not in result - assert "Temp" in result - assert "Score" in result - - -def test_failed_initialization(): - with pytest.raises(Exception): - AutoTempAgent() - - -def test_failed_evaluate_output(auto_temp_agent): - output = "This is a test output." - with patch("swarms.models.OpenAIChat") as MockOpenAIChat: - mock_instance = MockOpenAIChat.return_value - mock_instance.return_value = "Invalid score text" - score = auto_temp_agent.evaluate_output(output) - assert score == 0.0 diff --git a/tests/models/test_bioclip.py b/tests/models/test_bioclip.py deleted file mode 100644 index 99e1e343..00000000 --- a/tests/models/test_bioclip.py +++ /dev/null @@ -1,171 +0,0 @@ -# Import necessary modules and define fixtures if needed -import os -import pytest -import torch -from PIL import Image -from swarms.models.bioclip import BioClip - - -# Define fixtures if needed -@pytest.fixture -def sample_image_path(): - return "path_to_sample_image.jpg" - - -@pytest.fixture -def clip_instance(): - return BioClip( - "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" - ) - - -# Basic tests for the BioClip class -def test_clip_initialization(clip_instance): - assert isinstance(clip_instance.model, torch.nn.Module) - assert hasattr(clip_instance, "model_path") - assert hasattr(clip_instance, "preprocess_train") - assert hasattr(clip_instance, "preprocess_val") - assert hasattr(clip_instance, "tokenizer") - assert hasattr(clip_instance, "device") - - -def test_clip_call_method(clip_instance, sample_image_path): - labels = [ - "adenocarcinoma histopathology", - "brain MRI", - "covid line chart", - "squamous cell carcinoma histopathology", - "immunohistochemistry histopathology", - "bone X-ray", - "chest X-ray", - "pie chart", - "hematoxylin and eosin histopathology", - ] - result = clip_instance(sample_image_path, labels) - assert isinstance(result, dict) - assert len(result) == len(labels) - - -def test_clip_plot_image_with_metadata( - clip_instance, sample_image_path -): - metadata = { - "filename": "sample_image.jpg", - "top_probs": {"label1": 0.75, "label2": 0.65}, - } - clip_instance.plot_image_with_metadata( - sample_image_path, metadata - ) - - -# More test cases can be added to cover additional functionality and edge cases - - -# Parameterized tests for different image and label combinations -@pytest.mark.parametrize( - "image_path, labels", - [ - ("image1.jpg", ["label1", "label2"]), - ("image2.jpg", ["label3", "label4"]), - # Add more image and label combinations - ], -) -def test_clip_parameterized_calls(clip_instance, image_path, labels): - result = clip_instance(image_path, labels) - assert isinstance(result, dict) - assert len(result) == len(labels) - - -# Test image preprocessing -def test_clip_image_preprocessing(clip_instance, sample_image_path): - image = Image.open(sample_image_path) - processed_image = clip_instance.preprocess_val(image) - assert isinstance(processed_image, torch.Tensor) - - -# Test label tokenization -def test_clip_label_tokenization(clip_instance): - labels = ["label1", "label2"] - tokenized_labels = clip_instance.tokenizer(labels) - assert isinstance(tokenized_labels, torch.Tensor) - assert tokenized_labels.shape[0] == len(labels) - - -# More tests can be added to cover other methods and edge cases - - -# End-to-end tests with actual images and labels -def test_clip_end_to_end(clip_instance, sample_image_path): - labels = [ - "adenocarcinoma histopathology", - "brain MRI", - "covid line chart", - "squamous cell carcinoma histopathology", - "immunohistochemistry histopathology", - "bone X-ray", - "chest X-ray", - "pie chart", - "hematoxylin and eosin histopathology", - ] - result = clip_instance(sample_image_path, labels) - assert isinstance(result, dict) - assert len(result) == len(labels) - - -# Test label tokenization with long labels -def test_clip_long_labels(clip_instance): - labels = ["label" + str(i) for i in range(100)] - tokenized_labels = clip_instance.tokenizer(labels) - assert isinstance(tokenized_labels, torch.Tensor) - assert tokenized_labels.shape[0] == len(labels) - - -# Test handling of multiple image files -def test_clip_multiple_images(clip_instance, sample_image_path): - labels = ["label1", "label2"] - image_paths = [sample_image_path, "image2.jpg"] - results = clip_instance(image_paths, labels) - assert isinstance(results, list) - assert len(results) == len(image_paths) - for result in results: - assert isinstance(result, dict) - assert len(result) == len(labels) - - -# Test model inference performance -def test_clip_inference_performance( - clip_instance, sample_image_path, benchmark -): - labels = [ - "adenocarcinoma histopathology", - "brain MRI", - "covid line chart", - "squamous cell carcinoma histopathology", - "immunohistochemistry histopathology", - "bone X-ray", - "chest X-ray", - "pie chart", - "hematoxylin and eosin histopathology", - ] - result = benchmark(clip_instance, sample_image_path, labels) - assert isinstance(result, dict) - assert len(result) == len(labels) - - -# Test different preprocessing pipelines -def test_clip_preprocessing_pipelines( - clip_instance, sample_image_path -): - labels = ["label1", "label2"] - image = Image.open(sample_image_path) - - # Test preprocessing for training - processed_image_train = clip_instance.preprocess_train(image) - assert isinstance(processed_image_train, torch.Tensor) - - # Test preprocessing for validation - processed_image_val = clip_instance.preprocess_val(image) - assert isinstance(processed_image_val, torch.Tensor) - - -# ... diff --git a/tests/models/test_dalle3.py b/tests/models/test_dalle3.py deleted file mode 100644 index 00ba7bc9..00000000 --- a/tests/models/test_dalle3.py +++ /dev/null @@ -1,454 +0,0 @@ -import os -from unittest.mock import Mock - -import pytest -from openai import OpenAIError -from PIL import Image -from termcolor import colored - -from swarms.models.dalle3 import Dalle3 - - -# Mocking the OpenAI client to avoid making actual API calls during testing -@pytest.fixture -def mock_openai_client(): - return Mock() - - -@pytest.fixture -def dalle3(mock_openai_client): - return Dalle3(client=mock_openai_client) - - -def test_dalle3_call_success(dalle3, mock_openai_client): - # Arrange - task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - mock_openai_client.images.generate.return_value = Mock( - data=[Mock(url=expected_img_url)] - ) - - # Act - img_url = dalle3(task) - - # Assert - assert img_url == expected_img_url - mock_openai_client.images.generate.assert_called_once_with( - prompt=task, n=4 - ) - - -def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): - # Arrange - task = "Invalid task" - expected_error_message = "Error running Dalle3: API Error" - - # Mocking OpenAIError - mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, - http_status=500, - error="Internal Server Error", - ) - - # Act and assert - with pytest.raises(OpenAIError) as excinfo: - dalle3(task) - - assert str(excinfo.value) == expected_error_message - mock_openai_client.images.generate.assert_called_once_with( - prompt=task, n=4 - ) - - # Ensure the error message is printed in red - captured = capsys.readouterr() - assert colored(expected_error_message, "red") in captured.out - - -def test_dalle3_create_variations_success(dalle3, mock_openai_client): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - mock_openai_client.images.create_variation.return_value = Mock( - data=[Mock(url=expected_variation_url)] - ) - - # Act - variation_img_url = dalle3.create_variations(img_url) - - # Assert - assert variation_img_url == expected_variation_url - mock_openai_client.images.create_variation.assert_called_once() - _, kwargs = mock_openai_client.images.create_variation.call_args - assert kwargs["img"] is not None - assert kwargs["n"] == 4 - assert kwargs["size"] == "1024x1024" - - -def test_dalle3_create_variations_failure( - dalle3, mock_openai_client, capsys -): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_error_message = "Error running Dalle3: API Error" - - # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = ( - OpenAIError( - expected_error_message, - http_status=500, - error="Internal Server Error", - ) - ) - - # Act and assert - with pytest.raises(OpenAIError) as excinfo: - dalle3.create_variations(img_url) - - assert str(excinfo.value) == expected_error_message - mock_openai_client.images.create_variation.assert_called_once() - - # Ensure the error message is printed in red - captured = capsys.readouterr() - assert colored(expected_error_message, "red") in captured.out - - -def test_dalle3_read_img(): - # Arrange - img_path = "test_image.png" - img = Image.new("RGB", (512, 512)) - - # Save the image temporarily - img.save(img_path) - - # Act - dalle3 = Dalle3() - img_loaded = dalle3.read_img(img_path) - - # Assert - assert isinstance(img_loaded, Image.Image) - - # Clean up - os.remove(img_path) - - -def test_dalle3_set_width_height(): - # Arrange - img = Image.new("RGB", (512, 512)) - width = 256 - height = 256 - - # Act - dalle3 = Dalle3() - img_resized = dalle3.set_width_height(img, width, height) - - # Assert - assert img_resized.size == (width, height) - - -def test_dalle3_convert_to_bytesio(): - # Arrange - img = Image.new("RGB", (512, 512)) - expected_format = "PNG" - - # Act - dalle3 = Dalle3() - img_bytes = dalle3.convert_to_bytesio(img, format=expected_format) - - # Assert - assert isinstance(img_bytes, bytes) - assert img_bytes.startswith(b"\x89PNG") - - -def test_dalle3_call_multiple_times(dalle3, mock_openai_client): - # Arrange - task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - mock_openai_client.images.generate.return_value = Mock( - data=[Mock(url=expected_img_url)] - ) - - # Act - img_url1 = dalle3(task) - img_url2 = dalle3(task) - - # Assert - assert img_url1 == expected_img_url - assert img_url2 == expected_img_url - assert mock_openai_client.images.generate.call_count == 2 - - -def test_dalle3_call_with_large_input(dalle3, mock_openai_client): - # Arrange - task = "A" * 2048 # Input longer than API's limit - expected_error_message = "Error running Dalle3: API Error" - mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, - http_status=500, - error="Internal Server Error", - ) - - # Act and assert - with pytest.raises(OpenAIError) as excinfo: - dalle3(task) - - assert str(excinfo.value) == expected_error_message - - -def test_dalle3_create_variations_with_invalid_image_url( - dalle3, mock_openai_client -): - # Arrange - img_url = "https://invalid-image-url.com" - expected_error_message = "Error running Dalle3: Invalid image URL" - - # Act and assert - with pytest.raises(ValueError) as excinfo: - dalle3.create_variations(img_url) - - assert str(excinfo.value) == expected_error_message - - -def test_dalle3_set_width_height_invalid_dimensions(dalle3): - # Arrange - img = dalle3.read_img("test_image.png") - width = 0 - height = -1 - - # Act and assert - with pytest.raises(ValueError): - dalle3.set_width_height(img, width, height) - - -def test_dalle3_convert_to_bytesio_invalid_format(dalle3): - # Arrange - img = dalle3.read_img("test_image.png") - invalid_format = "invalid_format" - - # Act and assert - with pytest.raises(ValueError): - dalle3.convert_to_bytesio(img, format=invalid_format) - - -def test_dalle3_call_with_retry(dalle3, mock_openai_client): - # Arrange - task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - - # Simulate a retry scenario - mock_openai_client.images.generate.side_effect = [ - OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ), - Mock(data=[Mock(url=expected_img_url)]), - ] - - # Act - img_url = dalle3(task) - - # Assert - assert img_url == expected_img_url - assert mock_openai_client.images.generate.call_count == 2 - - -def test_dalle3_create_variations_with_retry( - dalle3, mock_openai_client -): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - - # Simulate a retry scenario - mock_openai_client.images.create_variation.side_effect = [ - OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ), - Mock(data=[Mock(url=expected_variation_url)]), - ] - - # Act - variation_img_url = dalle3.create_variations(img_url) - - # Assert - assert variation_img_url == expected_variation_url - assert mock_openai_client.images.create_variation.call_count == 2 - - -def test_dalle3_call_exception_logging( - dalle3, mock_openai_client, capsys -): - # Arrange - task = "A painting of a dog" - expected_error_message = "Error running Dalle3: API Error" - - # Mocking OpenAIError - mock_openai_client.images.generate.side_effect = OpenAIError( - expected_error_message, - http_status=500, - error="Internal Server Error", - ) - - # Act - with pytest.raises(OpenAIError): - dalle3(task) - - # Assert that the error message is logged - captured = capsys.readouterr() - assert expected_error_message in captured.err - - -def test_dalle3_create_variations_exception_logging( - dalle3, mock_openai_client, capsys -): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_error_message = "Error running Dalle3: API Error" - - # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = ( - OpenAIError( - expected_error_message, - http_status=500, - error="Internal Server Error", - ) - ) - - # Act - with pytest.raises(OpenAIError): - dalle3.create_variations(img_url) - - # Assert that the error message is logged - captured = capsys.readouterr() - assert expected_error_message in captured.err - - -def test_dalle3_read_img_invalid_path(dalle3): - # Arrange - invalid_img_path = "invalid_image_path.png" - - # Act and assert - with pytest.raises(FileNotFoundError): - dalle3.read_img(invalid_img_path) - - -def test_dalle3_call_no_api_key(): - # Arrange - task = "A painting of a dog" - dalle3 = Dalle3(api_key=None) - expected_error_message = ( - "Error running Dalle3: API Key is missing" - ) - - # Act and assert - with pytest.raises(ValueError) as excinfo: - dalle3(task) - - assert str(excinfo.value) == expected_error_message - - -def test_dalle3_create_variations_no_api_key(): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - dalle3 = Dalle3(api_key=None) - expected_error_message = ( - "Error running Dalle3: API Key is missing" - ) - - # Act and assert - with pytest.raises(ValueError) as excinfo: - dalle3.create_variations(img_url) - - assert str(excinfo.value) == expected_error_message - - -def test_dalle3_call_with_retry_max_retries_exceeded( - dalle3, mock_openai_client -): - # Arrange - task = "A painting of a dog" - - # Simulate max retries exceeded - mock_openai_client.images.generate.side_effect = OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ) - - # Act and assert - with pytest.raises(OpenAIError) as excinfo: - dalle3(task) - - assert "Retry limit exceeded" in str(excinfo.value) - - -def test_dalle3_create_variations_with_retry_max_retries_exceeded( - dalle3, mock_openai_client -): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - - # Simulate max retries exceeded - mock_openai_client.images.create_variation.side_effect = ( - OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ) - ) - - # Act and assert - with pytest.raises(OpenAIError) as excinfo: - dalle3.create_variations(img_url) - - assert "Retry limit exceeded" in str(excinfo.value) - - -def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): - # Arrange - task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - - # Simulate success after a retry - mock_openai_client.images.generate.side_effect = [ - OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ), - Mock(data=[Mock(url=expected_img_url)]), - ] - - # Act - img_url = dalle3(task) - - # Assert - assert img_url == expected_img_url - assert mock_openai_client.images.generate.call_count == 2 - - -def test_dalle3_create_variations_retry_with_success( - dalle3, mock_openai_client -): - # Arrange - img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - - # Simulate success after a retry - mock_openai_client.images.create_variation.side_effect = [ - OpenAIError( - "Temporary error", - http_status=500, - error="Internal Server Error", - ), - Mock(data=[Mock(url=expected_variation_url)]), - ] - - # Act - variation_img_url = dalle3.create_variations(img_url) - - # Assert - assert variation_img_url == expected_variation_url - assert mock_openai_client.images.create_variation.call_count == 2 diff --git a/tests/models/test_distill_whisper.py b/tests/models/test_distill_whisper.py deleted file mode 100644 index 775bb896..00000000 --- a/tests/models/test_distill_whisper.py +++ /dev/null @@ -1,336 +0,0 @@ -import os -import tempfile -from functools import wraps -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -import torch -from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor - -from swarms.models.distilled_whisperx import ( - DistilWhisperModel, - async_retry, -) - - -@pytest.fixture -def distil_whisper_model(): - return DistilWhisperModel() - - -def create_audio_file( - data: np.ndarray, sample_rate: int, file_path: str -): - data.tofile(file_path) - return file_path - - -def test_initialization(distil_whisper_model): - assert isinstance(distil_whisper_model, DistilWhisperModel) - assert isinstance(distil_whisper_model.model, torch.nn.Module) - assert isinstance(distil_whisper_model.processor, torch.nn.Module) - assert distil_whisper_model.device in ["cpu", "cuda:0"] - - -def test_transcribe_audio_file(distil_whisper_model): - test_data = np.random.rand( - 16000 - ) # Simulated audio data (1 second) - with tempfile.NamedTemporaryFile( - suffix=".wav", delete=False - ) as audio_file: - audio_file_path = create_audio_file( - test_data, 16000, audio_file.name - ) - transcription = distil_whisper_model.transcribe( - audio_file_path - ) - os.remove(audio_file_path) - - assert isinstance(transcription, str) - assert transcription.strip() != "" - - -@pytest.mark.asyncio -async def test_async_transcribe_audio_file(distil_whisper_model): - test_data = np.random.rand( - 16000 - ) # Simulated audio data (1 second) - with tempfile.NamedTemporaryFile( - suffix=".wav", delete=False - ) as audio_file: - audio_file_path = create_audio_file( - test_data, 16000, audio_file.name - ) - transcription = await distil_whisper_model.async_transcribe( - audio_file_path - ) - os.remove(audio_file_path) - - assert isinstance(transcription, str) - assert transcription.strip() != "" - - -def test_transcribe_audio_data(distil_whisper_model): - test_data = np.random.rand( - 16000 - ) # Simulated audio data (1 second) - transcription = distil_whisper_model.transcribe( - test_data.tobytes() - ) - - assert isinstance(transcription, str) - assert transcription.strip() != "" - - -@pytest.mark.asyncio -async def test_async_transcribe_audio_data(distil_whisper_model): - test_data = np.random.rand( - 16000 - ) # Simulated audio data (1 second) - transcription = await distil_whisper_model.async_transcribe( - test_data.tobytes() - ) - - assert isinstance(transcription, str) - assert transcription.strip() != "" - - -def test_real_time_transcribe(distil_whisper_model, capsys): - test_data = np.random.rand( - 16000 * 5 - ) # Simulated audio data (5 seconds) - with tempfile.NamedTemporaryFile( - suffix=".wav", delete=False - ) as audio_file: - audio_file_path = create_audio_file( - test_data, 16000, audio_file.name - ) - - distil_whisper_model.real_time_transcribe( - audio_file_path, chunk_duration=1 - ) - - os.remove(audio_file_path) - - captured = capsys.readouterr() - assert "Starting real-time transcription..." in captured.out - assert "Chunk" in captured.out - - -def test_real_time_transcribe_audio_file_not_found( - distil_whisper_model, capsys -): - audio_file_path = "non_existent_audio.wav" - distil_whisper_model.real_time_transcribe( - audio_file_path, chunk_duration=1 - ) - - captured = capsys.readouterr() - assert "The audio file was not found." in captured.out - - -@pytest.fixture -def mock_async_retry(): - def _mock_async_retry( - retries=3, exceptions=(Exception,), delay=1 - ): - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - return await func(*args, **kwargs) - - return wrapper - - return decorator - - with patch( - "distil_whisper_model.async_retry", new=_mock_async_retry() - ): - yield - - -@pytest.mark.asyncio -async def test_async_retry_decorator_success(): - async def mock_async_function(): - return "Success" - - decorated_function = async_retry()(mock_async_function) - result = await decorated_function() - assert result == "Success" - - -@pytest.mark.asyncio -async def test_async_retry_decorator_failure(): - async def mock_async_function(): - raise Exception("Error") - - decorated_function = async_retry()(mock_async_function) - with pytest.raises(Exception, match="Error"): - await decorated_function() - - -@pytest.mark.asyncio -async def test_async_retry_decorator_multiple_attempts(): - async def mock_async_function(): - if mock_async_function.attempts == 0: - mock_async_function.attempts += 1 - raise Exception("Error") - else: - return "Success" - - mock_async_function.attempts = 0 - decorated_function = async_retry(max_retries=2)( - mock_async_function - ) - result = await decorated_function() - assert result == "Success" - - -def test_create_audio_file(): - test_data = np.random.rand( - 16000 - ) # Simulated audio data (1 second) - sample_rate = 16000 - with tempfile.NamedTemporaryFile( - suffix=".wav", delete=False - ) as audio_file: - audio_file_path = create_audio_file( - test_data, sample_rate, audio_file.name - ) - - assert os.path.exists(audio_file_path) - os.remove(audio_file_path) - - -# test_distilled_whisperx.py - - -# Fixtures for setting up model, processor, and audio files -@pytest.fixture(scope="module") -def model_id(): - return "distil-whisper/distil-large-v2" - - -@pytest.fixture(scope="module") -def whisper_model(model_id): - return DistilWhisperModel(model_id) - - -@pytest.fixture(scope="session") -def audio_file_path(tmp_path_factory): - # You would create a small temporary MP3 file here for testing - # or use a public domain MP3 file's path - return "path/to/valid_audio.mp3" - - -@pytest.fixture(scope="session") -def invalid_audio_file_path(): - return "path/to/invalid_audio.mp3" - - -@pytest.fixture(scope="session") -def audio_dict(): - # This should represent a valid audio dictionary as expected by the model - return {"array": torch.randn(1, 16000), "sampling_rate": 16000} - - -# Test initialization -def test_initialization(whisper_model): - assert whisper_model.model is not None - assert whisper_model.processor is not None - - -# Test successful transcription with file path -def test_transcribe_with_file_path(whisper_model, audio_file_path): - transcription = whisper_model.transcribe(audio_file_path) - assert isinstance(transcription, str) - - -# Test successful transcription with audio dict -def test_transcribe_with_audio_dict(whisper_model, audio_dict): - transcription = whisper_model.transcribe(audio_dict) - assert isinstance(transcription, str) - - -# Test for file not found error -def test_file_not_found(whisper_model, invalid_audio_file_path): - with pytest.raises(Exception): - whisper_model.transcribe(invalid_audio_file_path) - - -# Asynchronous tests -@pytest.mark.asyncio -async def test_async_transcription_success( - whisper_model, audio_file_path -): - transcription = await whisper_model.async_transcribe( - audio_file_path - ) - assert isinstance(transcription, str) - - -@pytest.mark.asyncio -async def test_async_transcription_failure( - whisper_model, invalid_audio_file_path -): - with pytest.raises(Exception): - await whisper_model.async_transcribe(invalid_audio_file_path) - - -# Testing real-time transcription simulation -def test_real_time_transcription( - whisper_model, audio_file_path, capsys -): - whisper_model.real_time_transcribe( - audio_file_path, chunk_duration=1 - ) - captured = capsys.readouterr() - assert "Starting real-time transcription..." in captured.out - - -# Testing retry decorator for asynchronous function -@pytest.mark.asyncio -async def test_async_retry(): - @async_retry(max_retries=2, exceptions=(ValueError,), delay=0) - async def failing_func(): - raise ValueError("Test") - - with pytest.raises(ValueError): - await failing_func() - - -# Mocking the actual model to avoid GPU/CPU intensive operations during test -@pytest.fixture -def mocked_model(monkeypatch): - model_mock = AsyncMock(AutoModelForSpeechSeq2Seq) - processor_mock = MagicMock(AutoProcessor) - monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained", - model_mock, - ) - monkeypatch.setattr( - "swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", - processor_mock, - ) - return model_mock, processor_mock - - -@pytest.mark.asyncio -async def test_async_transcribe_with_mocked_model( - mocked_model, audio_file_path -): - model_mock, processor_mock = mocked_model - # Set up what the mock should return when it's called - model_mock.return_value.generate.return_value = torch.tensor( - [[0]] - ) - processor_mock.return_value.batch_decode.return_value = [ - "mocked transcription" - ] - model_wrapper = DistilWhisperModel() - transcription = await model_wrapper.async_transcribe( - audio_file_path - ) - assert transcription == "mocked transcription" diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index c6f3e023..2a1d4ad4 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -110,7 +110,7 @@ def test_gemini_init_missing_api_key(): with pytest.raises( ValueError, match="Please provide a Gemini API key" ): - model = Gemini(gemini_api_key=None) + Gemini(gemini_api_key=None) # Test Gemini initialization with missing model name @@ -118,7 +118,7 @@ def test_gemini_init_missing_model_name(): with pytest.raises( ValueError, match="Please provide a model name" ): - model = Gemini(model_name=None) + Gemini(model_name=None) # Test Gemini run method with empty task diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py index dfd03e27..c7758a36 100644 --- a/tests/models/test_gpt4_vision_api.py +++ b/tests/models/test_gpt4_vision_api.py @@ -48,7 +48,7 @@ def test_run_success(vision_api): def test_run_request_error(vision_api): with patch( "requests.post", side_effect=RequestException("Request Error") - ) as mock_post: + ): with pytest.raises(RequestException): vision_api.run("What is this?", img) @@ -58,7 +58,7 @@ def test_run_response_error(vision_api): with patch( "requests.post", return_value=Mock(json=lambda: expected_response), - ) as mock_post: + ): with pytest.raises(RuntimeError): vision_api.run("What is this?", img) @@ -153,7 +153,7 @@ async def test_arun_request_error(vision_api): "aiohttp.ClientSession.post", new_callable=AsyncMock, side_effect=Exception("Request Error"), - ) as mock_post: + ): with pytest.raises(Exception): await vision_api.arun("What is this?", img) @@ -181,7 +181,7 @@ def test_run_many_success(vision_api): def test_run_many_request_error(vision_api): with patch( "requests.post", side_effect=RequestException("Request Error") - ) as mock_post: + ): tasks = ["What is this?", "What is that?"] imgs = [img, img] with pytest.raises(RequestException): @@ -196,7 +196,7 @@ async def test_arun_json_decode_error(vision_api): return_value=AsyncMock( json=AsyncMock(side_effect=ValueError) ), - ) as mock_post: + ): with pytest.raises(ValueError): await vision_api.arun("What is this?", img) @@ -210,7 +210,7 @@ async def test_arun_api_error(vision_api): return_value=AsyncMock( json=AsyncMock(return_value=error_response) ), - ) as mock_post: + ): with pytest.raises(Exception, match="API Error"): await vision_api.arun("What is this?", img) @@ -224,7 +224,7 @@ async def test_arun_unexpected_response(vision_api): return_value=AsyncMock( json=AsyncMock(return_value=unexpected_response) ), - ) as mock_post: + ): with pytest.raises(Exception, match="Unexpected response"): await vision_api.arun("What is this?", img) @@ -247,6 +247,6 @@ async def test_arun_timeout(vision_api): "aiohttp.ClientSession.post", new_callable=AsyncMock, side_effect=asyncio.TimeoutError, - ) as mock_post: + ): with pytest.raises(asyncio.TimeoutError): await vision_api.arun("What is this?", img) diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py index dce13338..48dcd008 100644 --- a/tests/models/test_hf.py +++ b/tests/models/test_hf.py @@ -1,90 +1,237 @@ -import pytest import torch -from unittest.mock import Mock -from swarms.models.huggingface import HuggingFaceLLM +import logging +from unittest.mock import patch +import pytest -@pytest.fixture -def mock_torch(): - return Mock() +from swarms.models.huggingface import HuggingfaceLLM +# Mock some functions and objects for testing @pytest.fixture -def mock_autotokenizer(): - return Mock() +def mock_huggingface_llm(monkeypatch): + # Mock the model and tokenizer creation + def mock_init( + self, + model_id, + device="cpu", + max_length=500, + quantize=False, + quantization_config=None, + verbose=False, + distributed=False, + decoding=False, + max_workers=5, + repitition_penalty=1.3, + no_repeat_ngram_size=5, + temperature=0.7, + top_k=40, + top_p=0.8, + ): + pass + + # Mock the model loading + def mock_load_model(self): + pass + + # Mock the model generation + def mock_run(self, task): + pass + + monkeypatch.setattr(HuggingfaceLLM, "__init__", mock_init) + monkeypatch.setattr(HuggingfaceLLM, "load_model", mock_load_model) + monkeypatch.setattr(HuggingfaceLLM, "run", mock_run) + + +# Basic tests for initialization and attribute settings +def test_init_huggingface_llm(): + llm = HuggingfaceLLM( + model_id="test_model", + device="cuda", + max_length=1000, + quantize=True, + quantization_config={"config_key": "config_value"}, + verbose=True, + distributed=True, + decoding=True, + max_workers=3, + repitition_penalty=1.5, + no_repeat_ngram_size=4, + temperature=0.8, + top_k=50, + top_p=0.7, + ) + assert llm.model_id == "test_model" + assert llm.device == "cuda" + assert llm.max_length == 1000 + assert llm.quantize is True + assert llm.quantization_config == {"config_key": "config_value"} + assert llm.verbose is True + assert llm.distributed is True + assert llm.decoding is True + assert llm.max_workers == 3 + assert llm.repitition_penalty == 1.5 + assert llm.no_repeat_ngram_size == 4 + assert llm.temperature == 0.8 + assert llm.top_k == 50 + assert llm.top_p == 0.7 -@pytest.fixture -def mock_automodelforcausallm(): - return Mock() +# Test loading the model +def test_load_model(mock_huggingface_llm): + llm = HuggingfaceLLM(model_id="test_model") + llm.load_model() -@pytest.fixture -def mock_bitsandbytesconfig(): - return Mock() + # Ensure that the load_model function is called + assert True -@pytest.fixture -def hugging_face_llm( - mock_torch, - mock_autotokenizer, - mock_automodelforcausallm, - mock_bitsandbytesconfig, -): - HuggingFaceLLM.torch = mock_torch - HuggingFaceLLM.AutoTokenizer = mock_autotokenizer - HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm - HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig - - return HuggingFaceLLM(model_id="test") - - -def test_init( - hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm -): - assert hugging_face_llm.model_id == "test" - mock_autotokenizer.from_pretrained.assert_called_once_with("test") - mock_automodelforcausallm.from_pretrained.assert_called_once_with( - "test", quantization_config=None +# Test running the model +def test_run(mock_huggingface_llm): + llm = HuggingfaceLLM(model_id="test_model") + llm.run("Test prompt") + + # Ensure that the run function is called + assert True + + +# Test for setting max_length +def test_llm_set_max_length(llm_instance): + new_max_length = 1000 + llm_instance.set_max_length(new_max_length) + assert llm_instance.max_length == new_max_length + + +# Test for setting verbose +def test_llm_set_verbose(llm_instance): + llm_instance.set_verbose(True) + assert llm_instance.verbose is True + + +# Test for setting distributed +def test_llm_set_distributed(llm_instance): + llm_instance.set_distributed(True) + assert llm_instance.distributed is True + + +# Test for setting decoding +def test_llm_set_decoding(llm_instance): + llm_instance.set_decoding(True) + assert llm_instance.decoding is True + + +# Test for setting max_workers +def test_llm_set_max_workers(llm_instance): + new_max_workers = 10 + llm_instance.set_max_workers(new_max_workers) + assert llm_instance.max_workers == new_max_workers + + +# Test for setting repitition_penalty +def test_llm_set_repitition_penalty(llm_instance): + new_repitition_penalty = 1.5 + llm_instance.set_repitition_penalty(new_repitition_penalty) + assert llm_instance.repitition_penalty == new_repitition_penalty + + +# Test for setting no_repeat_ngram_size +def test_llm_set_no_repeat_ngram_size(llm_instance): + new_no_repeat_ngram_size = 6 + llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) + assert ( + llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size ) -def test_init_with_quantize( - hugging_face_llm, - mock_autotokenizer, - mock_automodelforcausallm, - mock_bitsandbytesconfig, -): - quantization_config = { - "load_in_4bit": True, - "bnb_4bit_use_double_quant": True, +# Test for setting temperature +def test_llm_set_temperature(llm_instance): + new_temperature = 0.8 + llm_instance.set_temperature(new_temperature) + assert llm_instance.temperature == new_temperature + + +# Test for setting top_k +def test_llm_set_top_k(llm_instance): + new_top_k = 50 + llm_instance.set_top_k(new_top_k) + assert llm_instance.top_k == new_top_k + + +# Test for setting top_p +def test_llm_set_top_p(llm_instance): + new_top_p = 0.9 + llm_instance.set_top_p(new_top_p) + assert llm_instance.top_p == new_top_p + + +# Test for setting quantize +def test_llm_set_quantize(llm_instance): + llm_instance.set_quantize(True) + assert llm_instance.quantize is True + + +# Test for setting quantization_config +def test_llm_set_quantization_config(llm_instance): + new_quantization_config = { + "load_in_4bit": False, + "bnb_4bit_use_double_quant": False, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, } - mock_bitsandbytesconfig.return_value = quantization_config + llm_instance.set_quantization_config(new_quantization_config) + assert llm_instance.quantization_config == new_quantization_config - HuggingFaceLLM(model_id="test", quantize=True) - mock_bitsandbytesconfig.assert_called_once_with( - **quantization_config - ) - mock_autotokenizer.from_pretrained.assert_called_once_with("test") - mock_automodelforcausallm.from_pretrained.assert_called_once_with( - "test", quantization_config=quantization_config - ) +# Test for setting model_id +def test_llm_set_model_id(llm_instance): + new_model_id = "EleutherAI/gpt-neo-2.7B" + llm_instance.set_model_id(new_model_id) + assert llm_instance.model_id == new_model_id + + +# Test for setting model +@patch( + "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" +) +def test_llm_set_model(mock_model, llm_instance): + mock_model.return_value = "mocked model" + llm_instance.set_model(mock_model) + assert llm_instance.model == "mocked model" + + +# Test for setting tokenizer +@patch("swarms.models.huggingface.AutoTokenizer.from_pretrained") +def test_llm_set_tokenizer(mock_tokenizer, llm_instance): + mock_tokenizer.return_value = "mocked tokenizer" + llm_instance.set_tokenizer(mock_tokenizer) + assert llm_instance.tokenizer == "mocked tokenizer" + + +# Test for setting logger +def test_llm_set_logger(llm_instance): + new_logger = logging.getLogger("test_logger") + llm_instance.set_logger(new_logger) + assert llm_instance.logger == new_logger + + +# Test for saving model +@patch("torch.save") +def test_llm_save_model(mock_save, llm_instance): + llm_instance.save_model("path/to/save") + mock_save.assert_called_once() -def test_generate_text(hugging_face_llm): - prompt_text = "test prompt" - expected_output = "test output" - hugging_face_llm.tokenizer.encode.return_value = torch.tensor( - [0] - ) # Mock tensor - hugging_face_llm.model.generate.return_value = torch.tensor( - [0] - ) # Mock tensor - hugging_face_llm.tokenizer.decode.return_value = expected_output +# Test for print_dashboard +@patch("builtins.print") +def test_llm_print_dashboard(mock_print, llm_instance): + llm_instance.print_dashboard("test task") + mock_print.assert_called() - output = hugging_face_llm.generate_text(prompt_text) - assert output == expected_output +# Test for __call__ method +@patch("swarms.models.huggingface.HuggingfaceLLM.run") +def test_llm_call(mock_run, llm_instance): + mock_run.return_value = "mocked output" + result = llm_instance("test task") + assert result == "mocked output" diff --git a/tests/models/test_kosmos2.py b/tests/models/test_kosmos2.py deleted file mode 100644 index 7e4f0e5f..00000000 --- a/tests/models/test_kosmos2.py +++ /dev/null @@ -1,394 +0,0 @@ -import pytest -import os -from PIL import Image -from swarms.models.kosmos2 import Kosmos2, Detections - - -# Fixture for a sample image -@pytest.fixture -def sample_image(): - image = Image.new("RGB", (224, 224)) - return image - - -# Fixture for initializing Kosmos2 -@pytest.fixture -def kosmos2(): - return Kosmos2.initialize() - - -# Test Kosmos2 initialization -def test_kosmos2_initialization(kosmos2): - assert kosmos2 is not None - - -# Test Kosmos2 with a sample image -def test_kosmos2_with_sample_image(kosmos2, sample_image): - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Mocked extract_entities function for testing -def mock_extract_entities(text): - return [ - ("entity1", (0.1, 0.2, 0.3, 0.4)), - ("entity2", (0.5, 0.6, 0.7, 0.8)), - ] - - -# Mocked process_entities_to_detections function for testing -def mock_process_entities_to_detections(entities, image): - return Detections( - xyxy=[(10, 20, 30, 40), (50, 60, 70, 80)], - class_id=[0, 0], - confidence=[1.0, 1.0], - ) - - -# Test Kosmos2 with mocked entity extraction and detection -def test_kosmos2_with_mocked_extraction_and_detection( - kosmos2, sample_image, monkeypatch -): - monkeypatch.setattr( - kosmos2, "extract_entities", mock_extract_entities - ) - monkeypatch.setattr( - kosmos2, - "process_entities_to_detections", - mock_process_entities_to_detections, - ) - - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 2 - ) - - -# Test Kosmos2 with empty entity extraction -def test_kosmos2_with_empty_extraction( - kosmos2, sample_image, monkeypatch -): - monkeypatch.setattr(kosmos2, "extract_entities", lambda x: []) - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with invalid image path -def test_kosmos2_with_invalid_image_path(kosmos2): - with pytest.raises(Exception): - kosmos2(img="invalid_image_path.jpg") - - -# Additional tests can be added for various scenarios and edge cases - - -# Test Kosmos2 with a larger image -def test_kosmos2_with_large_image(kosmos2): - large_image = Image.new("RGB", (1024, 768)) - detections = kosmos2(img=large_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with different image formats -def test_kosmos2_with_different_image_formats(kosmos2, tmp_path): - # Create a temporary directory - temp_dir = tmp_path / "images" - temp_dir.mkdir() - - # Create sample images in different formats - image_formats = ["jpeg", "png", "gif", "bmp"] - for format in image_formats: - image_path = temp_dir / f"sample_image.{format}" - Image.new("RGB", (224, 224)).save(image_path) - - # Test Kosmos2 with each image format - for format in image_formats: - image_path = temp_dir / f"sample_image.{format}" - detections = kosmos2(img=image_path) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with a non-existent model -def test_kosmos2_with_non_existent_model(kosmos2): - with pytest.raises(Exception): - kosmos2.model = None - kosmos2(img="sample_image.jpg") - - -# Test Kosmos2 with a non-existent processor -def test_kosmos2_with_non_existent_processor(kosmos2): - with pytest.raises(Exception): - kosmos2.processor = None - kosmos2(img="sample_image.jpg") - - -# Test Kosmos2 with missing image -def test_kosmos2_with_missing_image(kosmos2): - with pytest.raises(Exception): - kosmos2(img="non_existent_image.jpg") - - -# ... (previous tests) - - -# Test Kosmos2 with a non-existent model and processor -def test_kosmos2_with_non_existent_model_and_processor(kosmos2): - with pytest.raises(Exception): - kosmos2.model = None - kosmos2.processor = None - kosmos2(img="sample_image.jpg") - - -# Test Kosmos2 with a corrupted image -def test_kosmos2_with_corrupted_image(kosmos2, tmp_path): - # Create a temporary directory - temp_dir = tmp_path / "images" - temp_dir.mkdir() - - # Create a corrupted image - corrupted_image_path = temp_dir / "corrupted_image.jpg" - with open(corrupted_image_path, "wb") as f: - f.write(b"corrupted data") - - with pytest.raises(Exception): - kosmos2(img=corrupted_image_path) - - -# Test Kosmos2 with a large batch size -def test_kosmos2_with_large_batch_size(kosmos2, sample_image): - kosmos2.batch_size = 32 - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with an invalid compute type -def test_kosmos2_with_invalid_compute_type(kosmos2, sample_image): - kosmos2.compute_type = "invalid_compute_type" - with pytest.raises(Exception): - kosmos2(img=sample_image) - - -# Test Kosmos2 with a valid HF API key -def test_kosmos2_with_valid_hf_api_key(kosmos2, sample_image): - kosmos2.hf_api_key = "valid_api_key" - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 2 - ) - - -# Test Kosmos2 with an invalid HF API key -def test_kosmos2_with_invalid_hf_api_key(kosmos2, sample_image): - kosmos2.hf_api_key = "invalid_api_key" - with pytest.raises(Exception): - kosmos2(img=sample_image) - - -# Test Kosmos2 with a very long generated text -def test_kosmos2_with_long_generated_text( - kosmos2, sample_image, monkeypatch -): - def mock_generate_text(*args, **kwargs): - return "A" * 10000 - - monkeypatch.setattr(kosmos2.model, "generate", mock_generate_text) - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with entities containing special characters -def test_kosmos2_with_entities_containing_special_characters( - kosmos2, sample_image, monkeypatch -): - def mock_extract_entities(text): - return [ - ( - "entity1 with special characters (ü, ö, etc.)", - (0.1, 0.2, 0.3, 0.4), - ) - ] - - monkeypatch.setattr( - kosmos2, "extract_entities", mock_extract_entities - ) - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 1 - ) - - -# Test Kosmos2 with image containing multiple objects -def test_kosmos2_with_image_containing_multiple_objects( - kosmos2, sample_image, monkeypatch -): - def mock_extract_entities(text): - return [ - ("entity1", (0.1, 0.2, 0.3, 0.4)), - ("entity2", (0.5, 0.6, 0.7, 0.8)), - ] - - monkeypatch.setattr( - kosmos2, "extract_entities", mock_extract_entities - ) - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 2 - ) - - -# Test Kosmos2 with image containing no objects -def test_kosmos2_with_image_containing_no_objects( - kosmos2, sample_image, monkeypatch -): - def mock_extract_entities(text): - return [] - - monkeypatch.setattr( - kosmos2, "extract_entities", mock_extract_entities - ) - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 0 - ) - - -# Test Kosmos2 with a valid YouTube video URL -def test_kosmos2_with_valid_youtube_video_url(kosmos2): - youtube_video_url = "https://www.youtube.com/watch?v=VIDEO_ID" - detections = kosmos2(video_url=youtube_video_url) - assert isinstance(detections, Detections) - assert ( - len(detections.xyxy) - == len(detections.class_id) - == len(detections.confidence) - == 2 - ) - - -# Test Kosmos2 with an invalid YouTube video URL -def test_kosmos2_with_invalid_youtube_video_url(kosmos2): - invalid_youtube_video_url = ( - "https://www.youtube.com/invalid_video" - ) - with pytest.raises(Exception): - kosmos2(video_url=invalid_youtube_video_url) - - -# Test Kosmos2 with no YouTube video URL provided -def test_kosmos2_with_no_youtube_video_url(kosmos2): - with pytest.raises(Exception): - kosmos2(video_url=None) - - -# Test Kosmos2 installation -def test_kosmos2_installation(): - kosmos2 = Kosmos2() - kosmos2.install() - assert os.path.exists("video.mp4") - assert os.path.exists("video.mp3") - os.remove("video.mp4") - os.remove("video.mp3") - - -# Test Kosmos2 termination -def test_kosmos2_termination(kosmos2): - kosmos2.terminate() - assert kosmos2.process is None - - -# Test Kosmos2 start_process method -def test_kosmos2_start_process(kosmos2): - kosmos2.start_process() - assert kosmos2.process is not None - - -# Test Kosmos2 preprocess_code method -def test_kosmos2_preprocess_code(kosmos2): - code = "print('Hello, World!')" - preprocessed_code = kosmos2.preprocess_code(code) - assert isinstance(preprocessed_code, str) - assert "end_of_execution" in preprocessed_code - - -# Test Kosmos2 run method with debug mode -def test_kosmos2_run_with_debug_mode(kosmos2, sample_image): - kosmos2.debug_mode = True - detections = kosmos2(img=sample_image) - assert isinstance(detections, Detections) - - -# Test Kosmos2 handle_stream_output method -def test_kosmos2_handle_stream_output(kosmos2): - stream_output = "Sample output" - kosmos2.handle_stream_output(stream_output, is_error=False) - - -# Test Kosmos2 run method with invalid image path -def test_kosmos2_run_with_invalid_image_path(kosmos2): - with pytest.raises(Exception): - kosmos2.run(img="invalid_image_path.jpg") - - -# Test Kosmos2 run method with invalid video URL -def test_kosmos2_run_with_invalid_video_url(kosmos2): - with pytest.raises(Exception): - kosmos2.run(video_url="invalid_video_url") - - -# ... (more tests) diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py new file mode 100644 index 00000000..9eb31af0 --- /dev/null +++ b/tests/models/test_mixtral.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import patch, MagicMock +from swarms.models.mixtral import Mixtral + + +@patch("swarms.models.mixtral.AutoTokenizer") +@patch("swarms.models.mixtral.AutoModelForCausalLM") +def test_mixtral_init(mock_model, mock_tokenizer): + mixtral = Mixtral() + mock_tokenizer.from_pretrained.assert_called_once() + mock_model.from_pretrained.assert_called_once() + assert mixtral.model_name == "mistralai/Mixtral-8x7B-v0.1" + assert mixtral.max_new_tokens == 20 + + +@patch("swarms.models.mixtral.AutoTokenizer") +@patch("swarms.models.mixtral.AutoModelForCausalLM") +def test_mixtral_run(mock_model, mock_tokenizer): + mixtral = Mixtral() + mock_tokenizer_instance = MagicMock() + mock_model_instance = MagicMock() + mock_tokenizer.from_pretrained.return_value = ( + mock_tokenizer_instance + ) + mock_model.from_pretrained.return_value = mock_model_instance + mock_tokenizer_instance.return_tensors = "pt" + mock_model_instance.generate.return_value = [101, 102, 103] + mock_tokenizer_instance.decode.return_value = "Generated text" + result = mixtral.run("Test task") + assert result == "Generated text" + mock_tokenizer_instance.assert_called_once_with( + "Test task", return_tensors="pt" + ) + mock_model_instance.generate.assert_called_once() + mock_tokenizer_instance.decode.assert_called_once_with( + [101, 102, 103], skip_special_tokens=True + ) + + +@patch("swarms.models.mixtral.AutoTokenizer") +@patch("swarms.models.mixtral.AutoModelForCausalLM") +def test_mixtral_run_error(mock_model, mock_tokenizer): + mixtral = Mixtral() + mock_tokenizer_instance = MagicMock() + mock_model_instance = MagicMock() + mock_tokenizer.from_pretrained.return_value = ( + mock_tokenizer_instance + ) + mock_model.from_pretrained.return_value = mock_model_instance + mock_tokenizer_instance.return_tensors = "pt" + mock_model_instance.generate.side_effect = Exception("Test error") + with pytest.raises(Exception, match="Test error"): + mixtral.run("Test task") diff --git a/tests/models/test_multion.py b/tests/models/test_multion.py index 416e6dc3..cc91b421 100644 --- a/tests/models/test_multion.py +++ b/tests/models/test_multion.py @@ -15,7 +15,7 @@ def mock_multion(): def test_multion_import(): with pytest.raises(ImportError): - import multion + pass def test_multion_init(): diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py index 75313a45..c28e69ae 100644 --- a/tests/models/test_togther.py +++ b/tests/models/test_togther.py @@ -1,4 +1,3 @@ -import os import requests import pytest from unittest.mock import patch, Mock diff --git a/tests/models/test_vllm.py b/tests/models/test_vllm.py index d15a13b9..6eec8f27 100644 --- a/tests/models/test_vllm.py +++ b/tests/models/test_vllm.py @@ -113,7 +113,7 @@ def test_vllm_run_empty_task(vllm_instance): # Test initialization with invalid parameters def test_vllm_invalid_init(): with pytest.raises(ValueError): - vllm_instance = vLLM( + vLLM( model_name=None, tensor_parallel_size=-1, trust_remote_code="invalid", diff --git a/tests/models/test_whisperx.py b/tests/models/test_whisperx.py deleted file mode 100644 index 4b0e4120..00000000 --- a/tests/models/test_whisperx.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -import subprocess -import tempfile -from unittest.mock import patch - -import pytest -import whisperx -from pydub import AudioSegment -from pytube import YouTube -from swarms.models.whisperx_model import WhisperX - - -# Fixture to create a temporary directory for testing -@pytest.fixture -def temp_dir(): - with tempfile.TemporaryDirectory() as tempdir: - yield tempdir - - -# Mock subprocess.run to prevent actual installation during tests -@patch.object(subprocess, "run") -def test_speech_to_text_install(mock_run): - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - stt.install() - mock_run.assert_called_with(["pip", "install", "whisperx"]) - - -# Mock pytube.YouTube and pytube.Streams for download tests -@patch("pytube.YouTube") -@patch.object(YouTube, "streams") -def test_speech_to_text_download_youtube_video( - mock_streams, mock_youtube, temp_dir -): - # Mock YouTube and streams - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" - mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join( - temp_dir, "video.mp4" - ) - mock_youtube.return_value = mock_youtube - mock_youtube.streams = mock_streams - - stt = WhisperX(video_url) - audio_file = stt.download_youtube_video() - - assert os.path.exists(audio_file) - assert audio_file.endswith(".mp3") - - -# Mock whisperx.load_model and whisperx.load_audio for transcribe tests -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_youtube_video( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, - temp_dir, -): - # Mock whisperx functions - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = { - "language": "en", - "segments": [{"text": "Hello, World!"}], - } - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = { - "segments": [{"text": "Hello, World!"}] - } - - # Mock diarization pipeline - mock_diarization.return_value = None - - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM/video" - stt = WhisperX(video_url) - transcription = stt.transcribe_youtube_video() - - assert transcription == "Hello, World!" - - -# More tests for different scenarios and edge cases can be added here. - - -# Test transcribe method with provided audio file -def test_speech_to_text_transcribe_audio_file(temp_dir): - # Create a temporary audio file - audio_file = os.path.join(temp_dir, "test_audio.mp3") - AudioSegment.silent(duration=500).export(audio_file, format="mp3") - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe(audio_file) - - assert transcription == "" - - -# Test transcribe method when Whisperx fails -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -def test_speech_to_text_transcribe_whisperx_failure( - mock_load_audio, mock_load_model, temp_dir -): - # Mock whisperx functions to raise an exception - mock_load_model.side_effect = Exception("Whisperx failed") - mock_load_audio.return_value = "audio_path" - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "Whisperx failed" - - -# Test transcribe method with missing 'segments' key in Whisperx output -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_missing_segments( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, -): - # Mock whisperx functions to return incomplete output - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = {"language": "en"} - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = {} - - # Mock diarization pipeline - mock_diarization.return_value = None - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "" - - -# Test transcribe method with Whisperx align failure -@patch("whisperx.load_model") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -@patch.object(whisperx.DiarizationPipeline, "__call__") -def test_speech_to_text_transcribe_align_failure( - mock_diarization, - mock_align, - mock_align_model, - mock_load_audio, - mock_load_model, -): - # Mock whisperx functions to raise an exception during align - mock_load_model.return_value = mock_load_model - mock_load_model.transcribe.return_value = { - "language": "en", - "segments": [{"text": "Hello, World!"}], - } - - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.side_effect = Exception("Align failed") - - # Mock diarization pipeline - mock_diarization.return_value = None - - stt = WhisperX("https://www.youtube.com/watch?v=MJd6pr16LRM") - transcription = stt.transcribe("audio_path") - - assert transcription == "Align failed" - - -# Test transcribe_youtube_video when Whisperx diarization fails -@patch("pytube.YouTube") -@patch.object(YouTube, "streams") -@patch("whisperx.DiarizationPipeline") -@patch("whisperx.load_audio") -@patch("whisperx.load_align_model") -@patch("whisperx.align") -def test_speech_to_text_transcribe_diarization_failure( - mock_align, - mock_align_model, - mock_load_audio, - mock_diarization, - mock_streams, - mock_youtube, - temp_dir, -): - # Mock YouTube and streams - video_url = "https://www.youtube.com/watch?v=MJd6pr16LRM" - mock_stream = mock_streams().filter().first() - mock_stream.download.return_value = os.path.join( - temp_dir, "video.mp4" - ) - mock_youtube.return_value = mock_youtube - mock_youtube.streams = mock_streams - - # Mock whisperx functions - mock_load_audio.return_value = "audio_path" - mock_align_model.return_value = (mock_align_model, "metadata") - mock_align.return_value = { - "segments": [{"text": "Hello, World!"}] - } - - # Mock diarization pipeline to raise an exception - mock_diarization.side_effect = Exception("Diarization failed") - - stt = WhisperX(video_url) - transcription = stt.transcribe_youtube_video() - - assert transcription == "Diarization failed" - - -# Add more tests for other scenarios and edge cases as needed. diff --git a/tests/structs/test_agent.py b/tests/structs/test_agent.py index a8e1cf92..8e5b11be 100644 --- a/tests/structs/test_agent.py +++ b/tests/structs/test_agent.py @@ -347,7 +347,7 @@ def test_flow_response_filtering(flow_instance): def test_flow_undo_last(flow_instance): # Test the undo functionality response1 = flow_instance.run("Task 1") - response2 = flow_instance.run("Task 2") + flow_instance.run("Task 2") previous_state, message = flow_instance.undo_last() assert response1 == previous_state assert "Restored to" in message @@ -577,7 +577,7 @@ def test_flow_rollback(flow_instance): # Test rolling back to a previous state state1 = flow_instance.get_state() flow_instance.change_prompt("New prompt") - state2 = flow_instance.get_state() + flow_instance.get_state() flow_instance.rollback_to_state(state1) assert ( flow_instance.get_current_prompt() == state1["current_prompt"] diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index 92d013b7..62abeede 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -1,7 +1,9 @@ import os from dotenv import load_dotenv -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import pytest from swarms.models import OpenAIChat from swarms.structs import Agent @@ -23,18 +25,18 @@ def test_autoscaler_init(): assert autoscaler.scale_up_factor == 1 assert autoscaler.idle_threshold == 0.2 assert autoscaler.busy_threshold == 0.7 - assert autoscaler.autoscale == True + assert autoscaler.autoscale is True assert autoscaler.min_agents == 1 assert autoscaler.max_agents == 5 - assert autoscaler.custom_scale_strategy == None + assert autoscaler.custom_scale_strategy is None assert len(autoscaler.agents_pool) == 5 - assert autoscaler.task_queue.empty() == True + assert autoscaler.task_queue.empty() is True def test_autoscaler_add_task(): autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler.add_task("task1") - assert autoscaler.task_queue.empty() == False + assert autoscaler.task_queue.empty() is False def test_autoscaler_run(): @@ -75,7 +77,7 @@ def test_autoscaler_get_agent_by_id(): def test_autoscaler_get_agent_by_id_not_found(): autoscaler = AutoScaler(initial_agents=5, agent=agent) agent = autoscaler.get_agent_by_id("fake_id") - assert agent == None + assert agent is None @patch("swarms.swarms.Agent.is_healthy") @@ -138,3 +140,79 @@ def test_autoscaler_print_dashboard(mock_print): autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler.print_dashboard() mock_print.assert_called() + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_all_healthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for agent in autoscaler.agents_pool: + agent.is_healthy = MagicMock(return_value=True) + autoscaler.check_agent_health() + mock_logging.warning.assert_not_called() + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_some_unhealthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for i, agent in enumerate(autoscaler.agents_pool): + agent.is_healthy = MagicMock(return_value=(i % 2 == 0)) + autoscaler.check_agent_health() + assert mock_logging.warning.call_count == 2 + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_all_unhealthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for agent in autoscaler.agents_pool: + agent.is_healthy = MagicMock(return_value=False) + autoscaler.check_agent_health() + assert mock_logging.warning.call_count == 5 + + +@patch("swarms.structs.autoscaler.Agent") +def test_add_agent(mock_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + initial_count = len(autoscaler.agents_pool) + autoscaler.add_agent() + assert len(autoscaler.agents_pool) == initial_count + 1 + mock_agent.assert_called_once() + + +@patch("swarms.structs.autoscaler.Agent") +def test_remove_agent(mock_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + initial_count = len(autoscaler.agents_pool) + autoscaler.remove_agent() + assert len(autoscaler.agents_pool) == initial_count - 1 + + +@patch("swarms.structs.autoscaler.AutoScaler.add_agent") +@patch("swarms.structs.autoscaler.AutoScaler.remove_agent") +def test_scale(mock_remove_agent, mock_add_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + autoscaler.scale(10) + assert mock_add_agent.call_count == 5 + assert mock_remove_agent.call_count == 0 + + mock_add_agent.reset_mock() + mock_remove_agent.reset_mock() + + autoscaler.scale(3) + assert mock_add_agent.call_count == 0 + assert mock_remove_agent.call_count == 2 + + +def test_add_task_success(): + autoscaler = AutoScaler(initial_agents=5) + initial_queue_size = autoscaler.task_queue.qsize() + autoscaler.add_task("test_task") + assert autoscaler.task_queue.qsize() == initial_queue_size + 1 + + +@patch("swarms.structs.autoscaler.queue.Queue.put") +def test_add_task_exception(mock_put): + mock_put.side_effect = Exception("test error") + autoscaler = AutoScaler(initial_agents=5) + with pytest.raises(Exception) as e: + autoscaler.add_task("test_task") + assert str(e.value) == "test error" diff --git a/tests/swarms/test_base.py b/tests/structs/test_base.py similarity index 99% rename from tests/swarms/test_base.py rename to tests/structs/test_base.py index 9641ed7e..8b54dec0 100644 --- a/tests/swarms/test_base.py +++ b/tests/structs/test_base.py @@ -1,7 +1,7 @@ import pytest import os from datetime import datetime -from swarms.swarms.base import BaseStructure +from swarms.structs.base import BaseStructure class TestBaseStructure: diff --git a/tests/structs/test_conversation.py b/tests/structs/test_conversation.py new file mode 100644 index 00000000..84673a42 --- /dev/null +++ b/tests/structs/test_conversation.py @@ -0,0 +1,241 @@ +import pytest +from swarms.structs.conversation import Conversation + + +@pytest.fixture +def conversation(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + return conv + + +def test_add_message(): + conv = Conversation() + conv.add("user", "Hello, world!") + assert len(conv.conversation_history) == 1 + assert conv.conversation_history[0]["role"] == "user" + assert conv.conversation_history[0]["content"] == "Hello, world!" + + +def test_add_message_with_time(): + conv = Conversation(time_enabled=True) + conv.add("user", "Hello, world!") + assert len(conv.conversation_history) == 1 + assert conv.conversation_history[0]["role"] == "user" + assert conv.conversation_history[0]["content"] == "Hello, world!" + assert "timestamp" in conv.conversation_history[0] + + +def test_delete_message(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.delete(0) + assert len(conv.conversation_history) == 0 + + +def test_delete_message_out_of_bounds(): + conv = Conversation() + conv.add("user", "Hello, world!") + with pytest.raises(IndexError): + conv.delete(1) + + +def test_update_message(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.update(0, "assistant", "Hello, user!") + assert len(conv.conversation_history) == 1 + assert conv.conversation_history[0]["role"] == "assistant" + assert conv.conversation_history[0]["content"] == "Hello, user!" + + +def test_update_message_out_of_bounds(): + conv = Conversation() + conv.add("user", "Hello, world!") + with pytest.raises(IndexError): + conv.update(1, "assistant", "Hello, user!") + + +def test_return_history_as_string_with_messages(conversation): + result = conversation.return_history_as_string() + assert result is not None + + +def test_return_history_as_string_with_no_messages(): + conv = Conversation() + result = conv.return_history_as_string() + assert result == "" + + +@pytest.mark.parametrize( + "role, content", + [ + ("user", "Hello, world!"), + ("assistant", "Hello, user!"), + ("system", "System message"), + ("function", "Function message"), + ], +) +def test_return_history_as_string_with_different_roles(role, content): + conv = Conversation() + conv.add(role, content) + result = conv.return_history_as_string() + expected = f"{role}: {content}\n\n" + assert result == expected + + +@pytest.mark.parametrize("message_count", range(1, 11)) +def test_return_history_as_string_with_multiple_messages( + message_count, +): + conv = Conversation() + for i in range(message_count): + conv.add("user", f"Message {i + 1}") + result = conv.return_history_as_string() + expected = "".join( + [f"user: Message {i + 1}\n\n" for i in range(message_count)] + ) + assert result == expected + + +@pytest.mark.parametrize( + "content", + [ + "Hello, world!", + "This is a longer message with multiple words.", + "This message\nhas multiple\nlines.", + "This message has special characters: !@#$%^&*()", + "This message has unicode characters: 你好,世界!", + ], +) +def test_return_history_as_string_with_different_contents(content): + conv = Conversation() + conv.add("user", content) + result = conv.return_history_as_string() + expected = f"user: {content}\n\n" + assert result == expected + + +def test_return_history_as_string_with_large_message(conversation): + large_message = "Hello, world! " * 10000 # 10,000 repetitions + conversation.add("user", large_message) + result = conversation.return_history_as_string() + expected = ( + "user: Hello, world!\n\nassistant: Hello, user!\n\nuser:" + f" {large_message}\n\n" + ) + assert result == expected + + +def test_search_keyword_in_conversation(conversation): + result = conversation.search_keyword_in_conversation("Hello") + assert len(result) == 2 + assert result[0]["content"] == "Hello, world!" + assert result[1]["content"] == "Hello, user!" + + +def test_export_import_conversation(conversation, tmp_path): + filename = tmp_path / "conversation.txt" + conversation.export_conversation(filename) + new_conversation = Conversation() + new_conversation.import_conversation(filename) + assert ( + new_conversation.return_history_as_string() + == conversation.return_history_as_string() + ) + + +def test_count_messages_by_role(conversation): + counts = conversation.count_messages_by_role() + assert counts["user"] == 1 + assert counts["assistant"] == 1 + + +def test_display_conversation(capsys, conversation): + conversation.display_conversation() + captured = capsys.readouterr() + assert "user: Hello, world!\n\n" in captured.out + assert "assistant: Hello, user!\n\n" in captured.out + + +def test_display_conversation_detailed(capsys, conversation): + conversation.display_conversation(detailed=True) + captured = capsys.readouterr() + assert "user: Hello, world!\n\n" in captured.out + assert "assistant: Hello, user!\n\n" in captured.out + + +def test_search(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.search("Hello") + assert len(results) == 2 + assert results[0]["content"] == "Hello, world!" + assert results[1]["content"] == "Hello, user!" + + +def test_return_history_as_string(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + result = conv.return_history_as_string() + expected = "user: Hello, world!\n\nassistant: Hello, user!\n\n" + assert result == expected + + +def test_search_no_results(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.search("Goodbye") + assert len(results) == 0 + + +def test_search_case_insensitive(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.search("hello") + assert len(results) == 2 + assert results[0]["content"] == "Hello, world!" + assert results[1]["content"] == "Hello, user!" + + +def test_search_multiple_occurrences(): + conv = Conversation() + conv.add("user", "Hello, world! Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.search("Hello") + assert len(results) == 2 + assert results[0]["content"] == "Hello, world! Hello, world!" + assert results[1]["content"] == "Hello, user!" + + +def test_query_no_results(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.query("Goodbye") + assert len(results) == 0 + + +def test_query_case_insensitive(): + conv = Conversation() + conv.add("user", "Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.query("hello") + assert len(results) == 2 + assert results[0]["content"] == "Hello, world!" + assert results[1]["content"] == "Hello, user!" + + +def test_query_multiple_occurrences(): + conv = Conversation() + conv.add("user", "Hello, world! Hello, world!") + conv.add("assistant", "Hello, user!") + results = conv.query("Hello") + assert len(results) == 2 + assert results[0]["content"] == "Hello, world! Hello, world!" + assert results[1]["content"] == "Hello, user!" diff --git a/tests/structs/test_task.py b/tests/structs/test_task.py index 2c116402..fada564a 100644 --- a/tests/structs/test_task.py +++ b/tests/structs/test_task.py @@ -1,4 +1,3 @@ -import os from unittest.mock import Mock import pytest @@ -163,4 +162,4 @@ def test_execute(): agent = Agent() task = Task(id="5", task="Task5", result=None, agents=[agent]) # Assuming execute method returns True on successful execution - assert task.execute() == True + assert task.execute() is True diff --git a/tests/swarms/test_godmode.py b/tests/swarms/test_godmode.py deleted file mode 100644 index 8f528026..00000000 --- a/tests/swarms/test_godmode.py +++ /dev/null @@ -1,36 +0,0 @@ -from unittest.mock import patch -from swarms.swarms.god_mode import GodMode, LLM - - -def test_godmode_initialization(): - godmode = GodMode(llms=[LLM] * 5) - assert isinstance(godmode, GodMode) - assert len(godmode.llms) == 5 - - -def test_godmode_run(monkeypatch): - def mock_llm_run(self, task): - return "response" - - monkeypatch.setattr(LLM, "run", mock_llm_run) - godmode = GodMode(llms=[LLM] * 5) - responses = godmode.run("task1") - assert len(responses) == 5 - assert responses == [ - "response", - "response", - "response", - "response", - "response", - ] - - -@patch("builtins.print") -def test_godmode_print_responses(mock_print, monkeypatch): - def mock_llm_run(self, task): - return "response" - - monkeypatch.setattr(LLM, "run", mock_llm_run) - godmode = GodMode(llms=[LLM] * 5) - godmode.print_responses("task1") - assert mock_print.call_count == 1 diff --git a/tests/swarms/test_multi_agent_collab.py b/tests/swarms/test_multi_agent_collab.py index e30358aa..4d85a436 100644 --- a/tests/swarms/test_multi_agent_collab.py +++ b/tests/swarms/test_multi_agent_collab.py @@ -6,8 +6,6 @@ from swarms.structs import Agent from swarms.models import OpenAIChat from swarms.swarms.multi_agent_collab import ( MultiAgentCollaboration, - select_next_speaker_director, - select_speaker_round_table, ) # Sample agents for testing @@ -26,7 +24,7 @@ def test_collaboration_initialization(collaboration): assert callable(collaboration.select_next_speaker) assert collaboration.max_iters == 10 assert collaboration.results == [] - assert collaboration.logging == True + assert collaboration.logging is True def test_reset(collaboration): @@ -105,13 +103,6 @@ def test_set_interaction_rules(collaboration): assert collaboration.interaction_rules == rules -def test_set_interaction_rules(collaboration): - rules = {"rule1": "action1", "rule2": "action2"} - collaboration.set_interaction_rules(rules) - assert hasattr(collaboration, "interaction_rules") - assert collaboration.interaction_rules == rules - - def test_repr(collaboration): repr_str = repr(collaboration) assert isinstance(repr_str, str) @@ -145,16 +136,6 @@ def test_save(collaboration, tmp_path): # Add more tests here... -# Example of parameterized test for different selection functions -@pytest.mark.parametrize( - "selection_function", - [select_next_speaker_director, select_speaker_round_table], -) -def test_selection_functions(collaboration, selection_function): - collaboration.select_next_speaker = selection_function - assert callable(collaboration.select_next_speaker) - - # Add more parameterized tests for different scenarios... diff --git a/tests/upload_tests_to_issues.py b/tests/test_upload_tests_to_issues.py similarity index 98% rename from tests/upload_tests_to_issues.py rename to tests/test_upload_tests_to_issues.py index 864fee29..15de1245 100644 --- a/tests/upload_tests_to_issues.py +++ b/tests/test_upload_tests_to_issues.py @@ -1,7 +1,5 @@ import os import subprocess -import json -import re import requests from dotenv import load_dotenv diff --git a/tests/tools/test_base.py b/tests/tools/test_base.py index 9f9c700f..9060f53f 100644 --- a/tests/tools/test_base.py +++ b/tests/tools/test_base.py @@ -391,21 +391,6 @@ def test_structured_tool_ainvoke_with_exceptions(): tool.ainvoke({"tool_input": "input_data"}) -# Test additional functionality and edge cases -def test_tool_with_fixture(some_fixture): - # Test Tool with a fixture - tool = Tool() - result = tool.invoke(test_input) - assert result == expected_output - - -def test_structured_tool_with_fixture(some_fixture): - # Test StructuredTool with a fixture - tool = StructuredTool() - result = tool.invoke(test_input) - assert result == expected_output - - def test_base_tool_verbose_logging(caplog): # Test verbose logging in BaseTool tool = BaseTool(verbose=True) @@ -428,13 +413,6 @@ def test_structured_tool_async_invoke(): assert result == expected_output -def test_tool_async_invoke_with_fixture(some_fixture): - # Test asynchronous invoke with a fixture in Tool - tool = Tool() - result = tool.ainvoke(test_input) - assert result == expected_output - - # Add more tests for specific functionalities and edge cases as needed # Import necessary libraries and modules diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py new file mode 100644 index 00000000..d846f786 --- /dev/null +++ b/tests/utils/test_class_args_wrapper.py @@ -0,0 +1,81 @@ +import pytest +from io import StringIO +from contextlib import redirect_stdout +from swarms.utils.class_args_wrapper import print_class_parameters +from swarms.structs import Agent, Autoscaler +from fastapi import FastAPI +from fastapi.testclient import TestClient +from swarms.utils.class_args_wrapper import print_class_parameters +from swarms.structs import Agent, Autoscaler + +app = FastAPI() + + +def test_print_class_parameters_agent(): + f = StringIO() + with redirect_stdout(f): + print_class_parameters(Agent) + output = f.getvalue().strip() + # Replace with the expected output for Agent class + expected_output = ( + "Parameter: name, Type: \nParameter: age, Type:" + " " + ) + assert output == expected_output + + +def test_print_class_parameters_autoscaler(): + f = StringIO() + with redirect_stdout(f): + print_class_parameters(Autoscaler) + output = f.getvalue().strip() + # Replace with the expected output for Autoscaler class + expected_output = ( + "Parameter: min_agents, Type: \nParameter:" + " max_agents, Type: " + ) + assert output == expected_output + + +def test_print_class_parameters_error(): + with pytest.raises(TypeError): + print_class_parameters("Not a class") + + +@app.get("/parameters/{class_name}") +def get_parameters(class_name: str): + classes = {"Agent": Agent, "Autoscaler": Autoscaler} + if class_name in classes: + return print_class_parameters( + classes[class_name], api_format=True + ) + else: + return {"error": "Class not found"} + + +client = TestClient(app) + + +def test_get_parameters_agent(): + response = client.get("/parameters/Agent") + assert response.status_code == 200 + # Replace with the expected output for Agent class + expected_output = {"x": "", "y": ""} + assert response.json() == expected_output + + +def test_get_parameters_autoscaler(): + response = client.get("/parameters/Autoscaler") + assert response.status_code == 200 + # Replace with the expected output for Autoscaler class + expected_output = { + "min_agents": "", + "max_agents": "", + } + assert response.json() == expected_output + + +def test_get_parameters_not_found(): + response = client.get("/parameters/NonexistentClass") + assert response.status_code == 200 + assert response.json() == {"error": "Class not found"} diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py new file mode 100644 index 00000000..14399de9 --- /dev/null +++ b/tests/utils/test_device.py @@ -0,0 +1,111 @@ +import torch +from unittest.mock import MagicMock +import pytest +from swarms.utils.device_checker_cuda import check_device + + +def test_cuda_not_available(mocker): + mocker.patch("torch.cuda.is_available", return_value=False) + device = check_device() + assert str(device) == "cpu" + + +def test_single_gpu_available(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=1) + devices = check_device() + assert len(devices) == 1 + assert str(devices[0]) == "cuda" + + +def test_multiple_gpus_available(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=2) + devices = check_device() + assert len(devices) == 2 + assert str(devices[0]) == "cuda:0" + assert str(devices[1]) == "cuda:1" + + +def test_device_properties(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=1) + mocker.patch( + "torch.cuda.get_device_capability", return_value=(7, 5) + ) + mocker.patch( + "torch.cuda.get_device_properties", + return_value=MagicMock(total_memory=1000), + ) + mocker.patch("torch.cuda.memory_allocated", return_value=200) + mocker.patch("torch.cuda.memory_reserved", return_value=300) + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) + devices = check_device() + assert len(devices) == 1 + assert str(devices[0]) == "cuda" + + +def test_memory_threshold(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=1) + mocker.patch( + "torch.cuda.get_device_capability", return_value=(7, 5) + ) + mocker.patch( + "torch.cuda.get_device_properties", + return_value=MagicMock(total_memory=1000), + ) + mocker.patch( + "torch.cuda.memory_allocated", return_value=900 + ) # 90% of total memory + mocker.patch("torch.cuda.memory_reserved", return_value=300) + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) + with pytest.warns( + UserWarning, + match=r"Memory usage for device cuda exceeds threshold", + ): + devices = check_device( + memory_threshold=0.8 + ) # Set memory threshold to 80% + assert len(devices) == 1 + assert str(devices[0]) == "cuda" + + +def test_compute_capability_threshold(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=1) + mocker.patch( + "torch.cuda.get_device_capability", return_value=(3, 0) + ) # Compute capability 3.0 + mocker.patch( + "torch.cuda.get_device_properties", + return_value=MagicMock(total_memory=1000), + ) + mocker.patch("torch.cuda.memory_allocated", return_value=200) + mocker.patch("torch.cuda.memory_reserved", return_value=300) + mocker.patch( + "torch.cuda.get_device_name", return_value="Tesla K80" + ) + with pytest.warns( + UserWarning, + match=( + r"Compute capability for device cuda is below threshold" + ), + ): + devices = check_device( + capability_threshold=3.5 + ) # Set compute capability threshold to 3.5 + assert len(devices) == 1 + assert str(devices[0]) == "cuda" + + +def test_return_single_device(mocker): + mocker.patch("torch.cuda.is_available", return_value=True) + mocker.patch("torch.cuda.device_count", return_value=2) + device = check_device(return_type="single") + assert isinstance(device, torch.device) + assert str(device) == "cuda:0" diff --git a/tests/utils/test_load_models_torch.py b/tests/utils/test_load_models_torch.py new file mode 100644 index 00000000..707f1ce4 --- /dev/null +++ b/tests/utils/test_load_models_torch.py @@ -0,0 +1,54 @@ +import pytest +import torch +from unittest.mock import MagicMock +from swarms.utils.load_model_torch import load_model_torch + + +def test_load_model_torch_no_model_path(): + with pytest.raises(FileNotFoundError): + load_model_torch() + + +def test_load_model_torch_model_not_found(mocker): + mocker.patch("torch.load", side_effect=FileNotFoundError) + with pytest.raises(FileNotFoundError): + load_model_torch("non_existent_model_path") + + +def test_load_model_torch_runtime_error(mocker): + mocker.patch("torch.load", side_effect=RuntimeError) + with pytest.raises(RuntimeError): + load_model_torch("model_path") + + +def test_load_model_torch_no_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value=mock_model) + mocker.patch("torch.cuda.is_available", return_value=False) + load_model_torch("model_path") + mock_model.to.assert_called_once_with(torch.device("cpu")) + + +def test_load_model_torch_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value=mock_model) + load_model_torch("model_path", device=torch.device("cuda")) + mock_model.to.assert_called_once_with(torch.device("cuda")) + + +def test_load_model_torch_model_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value={"key": "value"}) + load_model_torch("model_path", model=mock_model) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=True + ) + + +def test_load_model_torch_model_specified_strict_false(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch("torch.load", return_value={"key": "value"}) + load_model_torch("model_path", model=mock_model, strict=False) + mock_model.load_state_dict.assert_called_once_with( + {"key": "value"}, strict=False + ) diff --git a/tests/utils/test_phoenix_handler.py b/tests/utils/test_phoenix_handler.py deleted file mode 100644 index 3b6915b9..00000000 --- a/tests/utils/test_phoenix_handler.py +++ /dev/null @@ -1,152 +0,0 @@ -# Import necessary modules and functions for testing -import functools -import subprocess -import sys -import traceback - -import pytest - -# Try importing phoenix and handle exceptions -try: - import phoenix as px -except Exception as error: - print(f"Error importing phoenix: {error}") - print("Please install phoenix: pip install phoenix") - subprocess.run( - [sys.executable, "-m", "pip", "install", "arize-mlflow"] - ) - -# Import the code to be tested -from swarms.utils.phoenix_handler import phoenix_trace_decorator - - -# Define a fixture for Phoenix session -@pytest.fixture(scope="function") -def phoenix_session(): - session = px.active_session() or px.launch_app() - yield session - session.stop() - - -# Define test cases for the phoenix_trace_decorator function -def test_phoenix_trace_decorator_documentation(): - """Test if phoenix_trace_decorator has a docstring.""" - assert phoenix_trace_decorator.__doc__ is not None - - -def test_phoenix_trace_decorator_functionality( - capsys, phoenix_session -): - """Test the functionality of phoenix_trace_decorator.""" - - # Define a function to be decorated - @phoenix_trace_decorator("This is a test function.") - def test_function(): - print("Hello, Phoenix!") - - # Execute the decorated function - test_function() - - # Capture the printed output - captured = capsys.readouterr() - assert captured.out == "Hello, Phoenix!\n" - - -def test_phoenix_trace_decorator_exception_handling(phoenix_session): - """Test if phoenix_trace_decorator handles exceptions correctly.""" - - # Define a function that raises an exception - @phoenix_trace_decorator("This function raises an exception.") - def exception_function(): - raise ValueError("An error occurred.") - - # Execute the decorated function - with pytest.raises(ValueError): - exception_function() - - # Check if the exception was traced by Phoenix - traces = phoenix_session.get_traces() - assert len(traces) == 1 - assert traces[0].get("error") is not None - assert traces[0].get("error_info") is not None - - -# Define test cases for phoenix_trace_decorator -def test_phoenix_trace_decorator_docstring(): - """Test if phoenix_trace_decorator's inner function has a docstring.""" - - @phoenix_trace_decorator("This is a test function.") - def test_function(): - """Test function docstring.""" - pass - - assert test_function.__doc__ is not None - - -def test_phoenix_trace_decorator_functionality_with_params( - capsys, phoenix_session -): - """Test the functionality of phoenix_trace_decorator with parameters.""" - - # Define a function with parameters to be decorated - @phoenix_trace_decorator("This function takes parameters.") - def param_function(a, b): - result = a + b - print(f"Result: {result}") - - # Execute the decorated function with parameters - param_function(2, 3) - - # Capture the printed output - captured = capsys.readouterr() - assert captured.out == "Result: 5\n" - - -def test_phoenix_trace_decorator_nested_calls( - capsys, phoenix_session -): - """Test nested calls of phoenix_trace_decorator.""" - - # Define a nested function with decorators - @phoenix_trace_decorator("Outer function") - def outer_function(): - print("Outer function") - - @phoenix_trace_decorator("Inner function") - def inner_function(): - print("Inner function") - - inner_function() - - # Execute the decorated functions - outer_function() - - # Capture the printed output - captured = capsys.readouterr() - assert "Outer function" in captured.out - assert "Inner function" in captured.out - - -def test_phoenix_trace_decorator_nested_exception_handling( - phoenix_session, -): - """Test exception handling with nested phoenix_trace_decorators.""" - - # Define a function with nested decorators and an exception - @phoenix_trace_decorator("Outer function") - def outer_function(): - @phoenix_trace_decorator("Inner function") - def inner_function(): - raise ValueError("Inner error") - - inner_function() - - # Execute the decorated functions - with pytest.raises(ValueError): - outer_function() - - # Check if both exceptions were traced by Phoenix - traces = phoenix_session.get_traces() - assert len(traces) == 2 - assert "Outer function" in traces[0].get("error_info") - assert "Inner function" in traces[1].get("error_info") diff --git a/tests/utils/test_prep_torch_model_inference.py b/tests/utils/test_prep_torch_model_inference.py new file mode 100644 index 00000000..4a13bee1 --- /dev/null +++ b/tests/utils/test_prep_torch_model_inference.py @@ -0,0 +1,48 @@ +import torch +from unittest.mock import MagicMock +from swarms.utils.prep_torch_model_inference import ( + prep_torch_inference, +) + + +def test_prep_torch_inference_no_model_path(): + result = prep_torch_inference() + assert result is None + + +def test_prep_torch_inference_model_not_found(mocker): + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + side_effect=FileNotFoundError, + ) + result = prep_torch_inference("non_existent_model_path") + assert result is None + + +def test_prep_torch_inference_runtime_error(mocker): + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + side_effect=RuntimeError, + ) + result = prep_torch_inference("model_path") + assert result is None + + +def test_prep_torch_inference_no_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + return_value=mock_model, + ) + prep_torch_inference("model_path") + mock_model.eval.assert_called_once() + + +def test_prep_torch_inference_device_specified(mocker): + mock_model = MagicMock(spec=torch.nn.Module) + mocker.patch( + "swarms.utils.prep_torch_model_inference.load_model_torch", + return_value=mock_model, + ) + prep_torch_inference("model_path", device=torch.device("cuda")) + mock_model.eval.assert_called_once()