From fd8919dde5f1b38b823fd4862046f051154abb63 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 6 Nov 2023 16:23:49 -0500 Subject: [PATCH] GPT4Vision + Dalle3 -> modules + tests + documentation --- dalle3.py | 6 + docs/swarms/models/gpt4v.md | 251 ++++++++++++++++++ gpt4vision_example.py | 7 + sequential_workflow_example.py | 4 +- swarms/models/dalle3.py | 175 +++++++++++++ swarms/models/gpt4v.py | 288 +++++++++++++++++++++ swarms/structs/base.py | 2 +- swarms/utils/code_interpreter.py | 2 +- tests/models/dalle3.py | 374 +++++++++++++++++++++++++++ tests/models/gpt4v.py | 321 +++++++++++++++++++++++ tests/structs/sequential_workflow.py | 35 ++- 11 files changed, 1456 insertions(+), 9 deletions(-) create mode 100644 dalle3.py create mode 100644 docs/swarms/models/gpt4v.md create mode 100644 gpt4vision_example.py create mode 100644 swarms/models/dalle3.py create mode 100644 swarms/models/gpt4v.py create mode 100644 tests/models/dalle3.py create mode 100644 tests/models/gpt4v.py diff --git a/dalle3.py b/dalle3.py new file mode 100644 index 00000000..ac9ba760 --- /dev/null +++ b/dalle3.py @@ -0,0 +1,6 @@ +from swarms.models.dalle3 import Dalle3 + +model = Dalle3() + +task = "A painting of a dog" +img = model(task) diff --git a/docs/swarms/models/gpt4v.md b/docs/swarms/models/gpt4v.md new file mode 100644 index 00000000..2af4348b --- /dev/null +++ b/docs/swarms/models/gpt4v.md @@ -0,0 +1,251 @@ +# GPT4Vision Documentation + +## Table of Contents +- [Overview](#overview) +- [Installation](#installation) +- [Initialization](#initialization) +- [Methods](#methods) + - [process_img](#process_img) + - [__call__](#__call__) + - [run](#run) + - [arun](#arun) +- [Configuration Options](#configuration-options) +- [Usage Examples](#usage-examples) +- [Additional Tips](#additional-tips) +- [References and Resources](#references-and-resources) + +--- + +## Overview + +The GPT4Vision Model API is designed to provide an easy-to-use interface for interacting with the OpenAI GPT-4 Vision model. This model can generate textual descriptions for images and answer questions related to visual content. Whether you want to describe images or perform other vision-related tasks, GPT4Vision makes it simple and efficient. + +The library offers a straightforward way to send images and tasks to the GPT-4 Vision model and retrieve the generated responses. It handles API communication, authentication, and retries, making it a powerful tool for developers working with computer vision and natural language processing tasks. + +## Installation + +To use the GPT4Vision Model API, you need to install the required dependencies and configure your environment. Follow these steps to get started: + +1. Install the required Python package: + + ```bash + pip3 install --upgrade swarms + ``` + +2. Make sure you have an OpenAI API key. You can obtain one by signing up on the [OpenAI platform](https://beta.openai.com/signup/). + +3. Set your OpenAI API key as an environment variable. You can do this in your code or your environment configuration. Alternatively, you can provide the API key directly when initializing the `GPT4Vision` class. + +## Initialization + +To start using the GPT4Vision Model API, you need to create an instance of the `GPT4Vision` class. You can customize its behavior by providing various configuration options, but it also comes with sensible defaults. + +Here's how you can initialize the `GPT4Vision` class: + +```python +from swarms.models.gpt4v import GPT4Vision + +gpt4vision = GPT4Vision( + api_key="Your Key" +) +``` + +The above code initializes the `GPT4Vision` class with default settings. You can adjust these settings as needed. + +## Methods + +### `process_img` + +The `process_img` method is used to preprocess an image before sending it to the GPT-4 Vision model. It takes the image path as input and returns the processed image in a format suitable for API requests. + +```python +processed_img = gpt4vision.process_img(img_path) +``` + +- `img_path` (str): The file path or URL of the image to be processed. + +### `__call__` + +The `__call__` method is the main method for interacting with the GPT-4 Vision model. It sends the image and tasks to the model and returns the generated response. + +```python +response = gpt4vision(img, tasks) +``` + +- `img` (Union[str, List[str]]): Either a single image URL or a list of image URLs to be used for the API request. +- `tasks` (List[str]): A list of tasks or questions related to the image(s). + +This method returns a `GPT4VisionResponse` object, which contains the generated answer. + +### `run` + +The `run` method is an alternative way to interact with the GPT-4 Vision model. It takes a single task and image URL as input and returns the generated response. + +```python +response = gpt4vision.run(task, img) +``` + +- `task` (str): The task or question related to the image. +- `img` (str): The image URL to be used for the API request. + +This method simplifies interactions when dealing with a single task and image. + +### `arun` + +The `arun` method is an asynchronous version of the `run` method. It allows for asynchronous processing of API requests, which can be useful in certain scenarios. + +```python +import asyncio + +async def main(): + response = await gpt4vision.arun(task, img) + print(response) + +loop = asyncio.get_event_loop() +loop.run_until_complete(main()) +``` + +- `task` (str): The task or question related to the image. +- `img` (str): The image URL to be used for the API request. + +## Configuration Options + +The `GPT4Vision` class provides several configuration options that allow you to customize its behavior: + +- `max_retries` (int): The maximum number of retries to make to the API. Default: 3 +- `backoff_factor` (float): The backoff factor to use for exponential backoff. Default: 2.0 +- `timeout_seconds` (int): The timeout in seconds for the API request. Default: 10 +- `api_key` (str): The API key to use for the API request. Default: None (set via environment variable) +- `quality` (str): The quality of the image to generate. Options: 'low' or 'high'. Default: 'low' +- `max_tokens` (int): The maximum number of tokens to use for the API request. Default: 200 + +## Usage Examples + +### Example 1: Generating Image Descriptions + +```python +gpt4vision = GPT4Vision() +img = "https://example.com/image.jpg" +tasks = ["Describe this image."] +response = gpt4vision(img, tasks) +print(response.answer) +``` + +In this example, we create an instance of `GPT4Vision`, provide an image URL, and ask the model to describe the image. The response contains the generated description. + +### Example 2: Custom Configuration + +```python +custom_config = { + "max_retries": 5, + "timeout_seconds": 20, + "quality": "high", + "max_tokens": 300, +} +gpt4vision = GPT4Vision(**custom_config) +img = "https://example.com/another_image.jpg" +tasks = ["What objects can you identify in this image?"] +response = gpt4vision(img, tasks) +print(response.answer) +``` + +In this example, we create an instance of `GPT4Vision` with custom configuration options. We set a higher timeout, request high-quality images, and allow more tokens in the response. + +### Example 3: Using the `run` Method + +```python +gpt4vision = GPT4Vision() +img = "https://example.com/image.jpg" +task = "Describe this image in detail." +response = gpt4vision.run(task, img) +print(response) +``` + +In this example, we use the `run` method to simplify the interaction by providing a single task and image URL. + +# Model Usage and Image Understanding + +The GPT-4 Vision model processes images in a unique way, allowing it to answer questions about both or each of the images independently. Here's an overview: + +| Purpose | Description | +| --------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | +| Image Understanding | The model is shown two copies of the same image and can answer questions about both or each of the images independently. | + +# Image Detail Control + +You have control over how the model processes the image and generates textual understanding by using the `detail` parameter, which has two options: `low` and `high`. + +| Detail | Description | +| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| low | Disables the "high-res" model. The model receives a low-res 512 x 512 version of the image and represents the image with a budget of 65 tokens. Ideal for use cases not requiring high detail. | +| high | Enables "high-res" mode. The model first sees the low-res image and then creates detailed crops of input images as 512px squares based on the input image size. Uses a total of 129 tokens. | + +# Managing Images + +To use the Chat Completions API effectively, you must manage the images you pass to the model. Here are some key considerations: + +| Management Aspect | Description | +| ------------------------- | ------------------------------------------------------------------------------------------------- | +| Image Reuse | To pass the same image multiple times, include the image with each API request. | +| Image Size Optimization | Improve latency by downsizing images to meet the expected size requirements. | +| Image Deletion | After processing, images are deleted from OpenAI servers and not retained. No data is used for training. | + +# Limitations + +While GPT-4 with Vision is powerful, it has some limitations: + +| Limitation | Description | +| -------------------------------------------- | --------------------------------------------------------------------------------------------------- | +| Medical Images | Not suitable for interpreting specialized medical images like CT scans. | +| Non-English Text | May not perform optimally when handling non-Latin alphabets, such as Japanese or Korean. | +| Large Text in Images | Enlarge text within images for readability, but avoid cropping important details. | +| Rotated or Upside-Down Text/Images | May misinterpret rotated or upside-down text or images. | +| Complex Visual Elements | May struggle to understand complex graphs or text with varying colors or styles. | +| Spatial Reasoning | Struggles with tasks requiring precise spatial localization, such as identifying chess positions. | +| Accuracy | May generate incorrect descriptions or captions in certain scenarios. | +| Panoramic and Fisheye Images | Struggles with panoramic and fisheye images. | + +# Calculating Costs + +Image inputs are metered and charged in tokens. The token cost depends on the image size and detail option. + +| Example | Token Cost | +| --------------------------------------------- | ----------- | +| 1024 x 1024 square image in detail: high mode | 765 tokens | +| 2048 x 4096 image in detail: high mode | 1105 tokens | +| 4096 x 8192 image in detail: low mode | 85 tokens | + +# FAQ + +Here are some frequently asked questions about GPT-4 with Vision: + +| Question | Answer | +| -------------------------------------------- | -------------------------------------------------------------------------------------------------- | +| Fine-Tuning Image Capabilities | No, fine-tuning the image capabilities of GPT-4 is not supported at this time. | +| Generating Images | GPT-4 is used for understanding images, not generating them. | +| Supported Image File Types | Supported image file types include PNG (.png), JPEG (.jpeg and .jpg), WEBP (.webp), and non-animated GIF (.gif). | +| Image Size Limitations | Image uploads are restricted to 20MB per image. | +| Image Deletion | Uploaded images are automatically deleted after processing by the model. | +| Learning More | For more details about GPT-4 with Vision, refer to the GPT-4 with Vision system card. | +| CAPTCHA Submission | CAPTCHAs are blocked for safety reasons. | +| Rate Limits | Image processing counts toward your tokens per minute (TPM) limit. Refer to the calculating costs section for details. | +| Image Metadata | The model does not receive image metadata. | +| Handling Unclear Images | If an image is unclear, the model will do its best to interpret it, but results may be less accurate. | + + + +## Additional Tips + +- Make sure to handle potential exceptions and errors when making API requests. The library includes retries and error handling, but it's essential to handle exceptions gracefully in your code. +- Experiment with different configuration options to optimize the trade-off between response quality and response time based on your specific requirements. + +## References and Resources + +- [OpenAI Platform](https://beta.openai.com/signup/): Sign up for an OpenAI API key. +- [OpenAI API Documentation](https://platform.openai.com/docs/api-reference/chat/create): Official API documentation for the GPT-4 Vision model. + +Now you have a comprehensive understanding of the GPT4Vision Model API, its configuration options, and how to use it for various computer vision and natural language processing tasks. Start experimenting and integrating it into your projects to leverage the power of GPT-4 Vision for image-related tasks. + +# Conclusion + +With GPT-4 Vision, you have a powerful tool for understanding and generating textual descriptions for images. By considering its capabilities, limitations, and cost calculations, you can effectively leverage this model for various image-related tasks. \ No newline at end of file diff --git a/gpt4vision_example.py b/gpt4vision_example.py new file mode 100644 index 00000000..7306fc56 --- /dev/null +++ b/gpt4vision_example.py @@ -0,0 +1,7 @@ +from swarms.models.gpt4v import GPT4Vision + +gpt4vision = GPT4Vision(api_key="") +task = "What is the following image about?" +img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + +answer = gpt4vision.run(task, img) diff --git a/sequential_workflow_example.py b/sequential_workflow_example.py index b9ab8196..feb6c748 100644 --- a/sequential_workflow_example.py +++ b/sequential_workflow_example.py @@ -3,9 +3,7 @@ from swarms.structs import Flow from swarms.structs.sequential_workflow import SequentialWorkflow # Example usage -api_key = ( - "" # Your actual API key here -) +api_key = "" # Your actual API key here # Initialize the language flow llm = OpenAIChat( diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py new file mode 100644 index 00000000..f22b11e0 --- /dev/null +++ b/swarms/models/dalle3.py @@ -0,0 +1,175 @@ +import openai +import logging +import os +from dataclasses import dataclass +from functools import lru_cache +from termcolor import colored +from openai import OpenAI +from dotenv import load_dotenv +from pydantic import BaseModel, validator +from PIL import Image +from io import BytesIO + + +load_dotenv() + +api_key = os.getenv("OPENAI_API_KEY") + +# Configure Logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class Dalle3: + """ + Dalle3 model class + + Attributes: + ----------- + image_url: str + The image url generated by the Dalle3 API + + Methods: + -------- + __call__(self, task: str) -> Dalle3: + Makes a call to the Dalle3 API and returns the image url + + Example: + -------- + >>> dalle3 = Dalle3() + >>> task = "A painting of a dog" + >>> image_url = dalle3(task) + >>> print(image_url) + https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png + + """ + + model: str = "dall-e-3" + img: str = None + size: str = "1024x1024" + max_retries: int = 3 + quality: str = "standard" + n: int = 4 + client = OpenAI( + api_key=api_key, + max_retries=max_retries, + ) + + class Config: + """Config class for the Dalle3 model""" + + arbitrary_types_allowed = True + + @validator("max_retries", "time_seconds") + def must_be_positive(cls, value): + if value <= 0: + raise ValueError("Must be positive") + return value + + def read_img(self, img: str): + """Read the image using pil""" + img = Image.open(img) + return img + + def set_width_height(self, img: str, width: int, height: int): + """Set the width and height of the image""" + img = self.read_img(img) + img = img.resize((width, height)) + return img + + def convert_to_bytesio(self, img: str, format: str = "PNG"): + """Convert the image to an bytes io object""" + byte_stream = BytesIO() + img.save(byte_stream, format=format) + byte_array = byte_stream.getvalue() + return byte_array + + # @lru_cache(maxsize=32) + def __call__(self, task: str): + """ + Text to image conversion using the Dalle3 API + + Parameters: + ----------- + task: str + The task to be converted to an image + + Returns: + -------- + Dalle3: + An instance of the Dalle3 class with the image url generated by the Dalle3 API + + Example: + -------- + >>> dalle3 = Dalle3() + >>> task = "A painting of a dog" + >>> image_url = dalle3(task) + >>> print(image_url) + https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png + """ + try: + # Making a call to the the Dalle3 API + response = self.client.images.generate( + # model=self.model, + prompt=task, + # size=self.size, + # quality=self.quality, + n=self.n, + ) + # Extracting the image url from the response + img = response.data[0].url + return img + except openai.OpenAIError as error: + # Handling exceptions and printing the errors details + print( + colored( + f"Error running Dalle3: {error} try optimizing your api key and or try again", + "red", + ) + ) + raise error + + def create_variations(self, img: str): + """ + Create variations of an image using the Dalle3 API + + Parameters: + ----------- + img: str + The image to be used for the API request + + Returns: + -------- + img: str + The image url generated by the Dalle3 API + + Example: + -------- + >>> dalle3 = Dalle3() + >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + >>> img = dalle3.create_variations(img) + >>> print(img) + + + """ + try: + + response = self.client.images.create_variation( + img = open(img, "rb"), + n=self.n, + size=self.size + ) + img = response.data[0].url + + return img + except (Exception, openai.OpenAIError) as error: + print( + colored( + f"Error running Dalle3: {error} try optimizing your api key and or try again", + "red", + ) + ) + print(colored(f"Error running Dalle3: {error.http_status}", "red")) + print(colored(f"Error running Dalle3: {error.error}", "red")) + raise error \ No newline at end of file diff --git a/swarms/models/gpt4v.py b/swarms/models/gpt4v.py new file mode 100644 index 00000000..a7f8f1c1 --- /dev/null +++ b/swarms/models/gpt4v.py @@ -0,0 +1,288 @@ +import base64 +import logging +import os +import time +from dataclasses import dataclass +from typing import List, Optional, Union + +import requests +from dotenv import load_dotenv +from openai import OpenAI +from termcolor import colored + +# ENV +load_dotenv() + + +def logging_config(): + """Configures logging""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + + return logger + + +@dataclass +class GPT4VisionResponse: + """A response structure for GPT-4""" + + answer: str + + +@dataclass +class GPT4Vision: + """ + GPT4Vision model class + + Attributes: + ----------- + max_retries: int + The maximum number of retries to make to the API + backoff_factor: float + The backoff factor to use for exponential backoff + timeout_seconds: int + The timeout in seconds for the API request + api_key: str + The API key to use for the API request + quality: str + The quality of the image to generate + max_tokens: int + The maximum number of tokens to use for the API request + + Methods: + -------- + process_img(self, img_path: str) -> str: + Processes the image to be used for the API request + __call__(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse: + Makes a call to the GPT-4 Vision API and returns the image url + + Example: + >>> gpt4vision = GPT4Vision() + >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + >>> tasks = ["A painting of a dog"] + >>> answer = gpt4vision(img, tasks) + >>> print(answer) + + + """ + + max_retries: int = 3 + model: str = "gpt-4-vision-preview" + backoff_factor: float = 2.0 + timeout_seconds: int = 10 + api_key: Optional[str] = None or os.getenv("OPENAI_API_KEY") + # 'Low' or 'High' for respesctively fast or high quality, but high more token usage + quality: str = "low" + # Max tokens to use for the API request, the maximum might be 3,000 but we don't know + max_tokens: int = 200 + client = OpenAI( + api_key=api_key, + max_retries=max_retries, + ) + logger = logging_config() + + class Config: + """Config class for the GPT4Vision model""" + + arbitary_types_allowed = True + + def process_img(self, img: str) -> str: + """Processes the image to be used for the API request""" + with open(img, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def __call__( + self, + img: Union[str, List[str]], + tasks: List[str], + ) -> GPT4VisionResponse: + """ + Calls the GPT-4 Vision API and returns the image url + + Parameters: + ----------- + img: Union[str, List[str]] + The image to be used for the API request + tasks: List[str] + The tasks to be used for the API request + + Returns: + -------- + answer: GPT4VisionResponse + The response from the API request + + Example: + -------- + >>> gpt4vision = GPT4Vision() + >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + >>> tasks = ["A painting of a dog"] + >>> answer = gpt4vision(img, tasks) + >>> print(answer) + + + """ + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + # Image content + image_content = [ + {"type": "imavge_url", "image_url": img} + if img.startswith("http") + else {"type": "image", "data": img} + for img in img + ] + + messages = [ + { + "role": "user", + "content": image_content + [{"type": "text", "text": q} for q in tasks], + } + ] + + payload = { + "model": "gpt-4-vision-preview", + "messages": messages, + "max_tokens": self.max_tokens, + "detail": self.quality, + } + + for attempt in range(self.max_retries): + try: + response = requests.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + timeout=self.timeout_seconds, + ) + response.raise_for_status() + answer = response.json()["choices"][0]["message"]["content"]["text"] + return GPT4VisionResponse(answer=answer) + except requests.exceptions.HTTPError as error: + self.logger.error( + f"HTTP error: {error.response.status_code}, {error.response.text}" + ) + if error.response.status_code in [429, 500, 503]: + # Exponential backoff = 429(too many requesys) + # And 503 = (Service unavailable) errors + time.sleep(self.backoff_factor**attempt) + else: + break + + except requests.exceptions.RequestException as error: + self.logger.error(f"Request error: {error}") + time.sleep(self.backoff_factor**attempt) + except Exception as error: + self.logger.error( + f"Unexpected Error: {error} try optimizing your api key and try again" + ) + raise error from None + + raise TimeoutError("API Request timed out after multiple retries") + + def run(self, task: str, img: str) -> str: + """ + Runs the GPT-4 Vision API + + Parameters: + ----------- + task: str + The task to be used for the API request + img: str + The image to be used for the API request + + Returns: + -------- + out: str + The response from the API request + + Example: + -------- + >>> gpt4vision = GPT4Vision() + >>> task = "A painting of a dog" + >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + >>> answer = gpt4vision.run(task, img) + >>> print(answer) + """ + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": f"{task}"}, + { + "type": "image_url", + "image_url": f"{img}", + }, + ], + } + ], + max_tokens=self.max_tokens, + ) + + out = response.choices[0].text + return out + except Exception as error: + print( + colored( + f"Error when calling GPT4Vision, Error: {error} Try optimizing your key, and try again", + "red", + ) + ) + + async def arun(self, task: str, img: str) -> str: + """ + Asynchronous run method for GPT-4 Vision + + Parameters: + ----------- + task: str + The task to be used for the API request + img: str + The image to be used for the API request + + Returns: + -------- + out: str + The response from the API request + + Example: + -------- + >>> gpt4vision = GPT4Vision() + >>> task = "A painting of a dog" + >>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + >>> answer = await gpt4vision.arun(task, img) + >>> print(answer) + """ + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": f"{task}"}, + { + "type": "image_url", + "image_url": f"{img}", + }, + ], + } + ], + max_tokens=self.max_tokens, + ) + out = response.choices[0].text + return out + except Exception as error: + print( + colored( + f"Error when calling GPT4Vision, Error: {error} Try optimizing your key, and try again", + "red", + ) + ) diff --git a/swarms/structs/base.py b/swarms/structs/base.py index f33a204e..4208ba39 100644 --- a/swarms/structs/base.py +++ b/swarms/structs/base.py @@ -2,4 +2,4 @@ Base Structure for all Swarm Structures -""" \ No newline at end of file +""" diff --git a/swarms/utils/code_interpreter.py b/swarms/utils/code_interpreter.py index af6eb327..2448edc7 100644 --- a/swarms/utils/code_interpreter.py +++ b/swarms/utils/code_interpreter.py @@ -24,7 +24,7 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter): """ SubprocessCodeinterpreter is a base class for code interpreters that run code in a subprocess. - + Attributes: start_cmd (str): The command to start the subprocess. Should be a string that can be split by spaces. process (subprocess.Popen): The subprocess that is running the code. diff --git a/tests/models/dalle3.py b/tests/models/dalle3.py new file mode 100644 index 00000000..ff1489ea --- /dev/null +++ b/tests/models/dalle3.py @@ -0,0 +1,374 @@ +import os +from unittest.mock import Mock + +import pytest +from openai import OpenAIError +from PIL import Image +from termcolor import colored + +from dalle3 import Dalle3 + + +# Mocking the OpenAI client to avoid making actual API calls during testing +@pytest.fixture +def mock_openai_client(): + return Mock() + + +@pytest.fixture +def dalle3(mock_openai_client): + return Dalle3(client=mock_openai_client) + + +def test_dalle3_call_success(dalle3, mock_openai_client): + # Arrange + task = "A painting of a dog" + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) + + # Act + img_url = dalle3(task) + + # Assert + assert img_url == expected_img_url + mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4) + + +def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): + # Arrange + task = "Invalid task" + expected_error_message = "Error running Dalle3: API Error" + + # Mocking OpenAIError + mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + + # Act and assert + with pytest.raises(OpenAIError) as excinfo: + dalle3(task) + + assert str(excinfo.value) == expected_error_message + mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4) + + # Ensure the error message is printed in red + captured = capsys.readouterr() + assert colored(expected_error_message, "red") in captured.out + + +def test_dalle3_create_variations_success(dalle3, mock_openai_client): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + mock_openai_client.images.create_variation.return_value = Mock(data=[Mock(url=expected_variation_url)]) + + # Act + variation_img_url = dalle3.create_variations(img_url) + + # Assert + assert variation_img_url == expected_variation_url + mock_openai_client.images.create_variation.assert_called_once() + _, kwargs = mock_openai_client.images.create_variation.call_args + assert kwargs["img"] is not None + assert kwargs["n"] == 4 + assert kwargs["size"] == "1024x1024" + + +def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_error_message = "Error running Dalle3: API Error" + + # Mocking OpenAIError + mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + + # Act and assert + with pytest.raises(OpenAIError) as excinfo: + dalle3.create_variations(img_url) + + assert str(excinfo.value) == expected_error_message + mock_openai_client.images.create_variation.assert_called_once() + + # Ensure the error message is printed in red + captured = capsys.readouterr() + assert colored(expected_error_message, "red") in captured.out + + +def test_dalle3_read_img(): + # Arrange + img_path = "test_image.png" + img = Image.new("RGB", (512, 512)) + + # Save the image temporarily + img.save(img_path) + + # Act + dalle3 = Dalle3() + img_loaded = dalle3.read_img(img_path) + + # Assert + assert isinstance(img_loaded, Image.Image) + + # Clean up + os.remove(img_path) + + +def test_dalle3_set_width_height(): + # Arrange + img = Image.new("RGB", (512, 512)) + width = 256 + height = 256 + + # Act + dalle3 = Dalle3() + img_resized = dalle3.set_width_height(img, width, height) + + # Assert + assert img_resized.size == (width, height) + + +def test_dalle3_convert_to_bytesio(): + # Arrange + img = Image.new("RGB", (512, 512)) + expected_format = "PNG" + + # Act + dalle3 = Dalle3() + img_bytes = dalle3.convert_to_bytesio(img, format=expected_format) + + # Assert + assert isinstance(img_bytes, bytes) + assert img_bytes.startswith(b"\x89PNG") + + +def test_dalle3_call_multiple_times(dalle3, mock_openai_client): + # Arrange + task = "A painting of a dog" + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) + + # Act + img_url1 = dalle3(task) + img_url2 = dalle3(task) + + # Assert + assert img_url1 == expected_img_url + assert img_url2 == expected_img_url + assert mock_openai_client.images.generate.call_count == 2 + + +def test_dalle3_call_with_large_input(dalle3, mock_openai_client): + # Arrange + task = "A" * 2048 # Input longer than API's limit + expected_error_message = "Error running Dalle3: API Error" + mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + + # Act and assert + with pytest.raises(OpenAIError) as excinfo: + dalle3(task) + + assert str(excinfo.value) == expected_error_message + + +def test_dalle3_create_variations_with_invalid_image_url(dalle3, mock_openai_client): + # Arrange + img_url = "https://invalid-image-url.com" + expected_error_message = "Error running Dalle3: Invalid image URL" + + # Act and assert + with pytest.raises(ValueError) as excinfo: + dalle3.create_variations(img_url) + + assert str(excinfo.value) == expected_error_message + + +def test_dalle3_set_width_height_invalid_dimensions(dalle3): + # Arrange + img = dalle3.read_img("test_image.png") + width = 0 + height = -1 + + # Act and assert + with pytest.raises(ValueError): + dalle3.set_width_height(img, width, height) + + +def test_dalle3_convert_to_bytesio_invalid_format(dalle3): + # Arrange + img = dalle3.read_img("test_image.png") + invalid_format = "invalid_format" + + # Act and assert + with pytest.raises(ValueError): + dalle3.convert_to_bytesio(img, format=invalid_format) + + +def test_dalle3_call_with_retry(dalle3, mock_openai_client): + # Arrange + task = "A painting of a dog" + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + + # Simulate a retry scenario + mock_openai_client.images.generate.side_effect = [ + OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + Mock(data=[Mock(url=expected_img_url)]), + ] + + # Act + img_url = dalle3(task) + + # Assert + assert img_url == expected_img_url + assert mock_openai_client.images.generate.call_count == 2 + + +def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + + # Simulate a retry scenario + mock_openai_client.images.create_variation.side_effect = [ + OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + Mock(data=[Mock(url=expected_variation_url)]), + ] + + # Act + variation_img_url = dalle3.create_variations(img_url) + + # Assert + assert variation_img_url == expected_variation_url + assert mock_openai_client.images.create_variation.call_count == 2 + + +def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys): + # Arrange + task = "A painting of a dog" + expected_error_message = "Error running Dalle3: API Error" + + # Mocking OpenAIError + mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + + # Act + with pytest.raises(OpenAIError): + dalle3(task) + + # Assert that the error message is logged + captured = capsys.readouterr() + assert expected_error_message in captured.err + + +def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, capsys): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_error_message = "Error running Dalle3: API Error" + + # Mocking OpenAIError + mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + + # Act + with pytest.raises(OpenAIError): + dalle3.create_variations(img_url) + + # Assert that the error message is logged + captured = capsys.readouterr() + assert expected_error_message in captured.err + + +def test_dalle3_read_img_invalid_path(dalle3): + # Arrange + invalid_img_path = "invalid_image_path.png" + + # Act and assert + with pytest.raises(FileNotFoundError): + dalle3.read_img(invalid_img_path) + + +def test_dalle3_call_no_api_key(): + # Arrange + task = "A painting of a dog" + dalle3 = Dalle3(api_key=None) + expected_error_message = "Error running Dalle3: API Key is missing" + + # Act and assert + with pytest.raises(ValueError) as excinfo: + dalle3(task) + + assert str(excinfo.value) == expected_error_message + + +def test_dalle3_create_variations_no_api_key(): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + dalle3 = Dalle3(api_key=None) + expected_error_message = "Error running Dalle3: API Key is missing" + + # Act and assert + with pytest.raises(ValueError) as excinfo: + dalle3.create_variations(img_url) + + assert str(excinfo.value) == expected_error_message + + +def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client): + # Arrange + task = "A painting of a dog" + + # Simulate max retries exceeded + mock_openai_client.images.generate.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") + + # Act and assert + with pytest.raises(OpenAIError) as excinfo: + dalle3(task) + + assert "Retry limit exceeded" in str(excinfo.value) + + +def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_openai_client): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + + # Simulate max retries exceeded + mock_openai_client.images.create_variation.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") + + # Act and assert + with pytest.raises(OpenAIError) as excinfo: + dalle3.create_variations(img_url) + + assert "Retry limit exceeded" in str(excinfo.value) + + +def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): + # Arrange + task = "A painting of a dog" + expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + + # Simulate success after a retry + mock_openai_client.images.generate.side_effect = [ + OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + Mock(data=[Mock(url=expected_img_url)]), + ] + + # Act + img_url = dalle3(task) + + # Assert + assert img_url == expected_img_url + assert mock_openai_client.images.generate.call_count == 2 + + +def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client): + # Arrange + img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + + # Simulate success after a retry + mock_openai_client.images.create_variation.side_effect = [ + OpenAIError("Temporary error", http_status=500, error="Internal Server Error"), + Mock(data=[Mock(url=expected_variation_url)]), + ] + + # Act + variation_img_url = dalle3.create_variations(img_url) + + # Assert + assert variation_img_url == expected_variation_url + assert mock_openai_client.images.create_variation.call_count == 2 diff --git a/tests/models/gpt4v.py b/tests/models/gpt4v.py new file mode 100644 index 00000000..40ccc7f5 --- /dev/null +++ b/tests/models/gpt4v.py @@ -0,0 +1,321 @@ +import logging +import os +from unittest.mock import Mock + +import pytest +from dotenv import load_dotenv +from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout + +from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse + +load_dotenv + +api_key = os.getenv("OPENAI_API_KEY") + +# Mock the OpenAI client +@pytest.fixture +def mock_openai_client(): + return Mock() + +@pytest.fixture +def gpt4vision(mock_openai_client): + return GPT4Vision(client=mock_openai_client) + +def test_gpt4vision_default_values(): + # Arrange and Act + gpt4vision = GPT4Vision() + + # Assert + assert gpt4vision.max_retries == 3 + assert gpt4vision.model == "gpt-4-vision-preview" + assert gpt4vision.backoff_factor == 2.0 + assert gpt4vision.timeout_seconds == 10 + assert gpt4vision.api_key is None + assert gpt4vision.quality == "low" + assert gpt4vision.max_tokens == 200 + +def test_gpt4vision_api_key_from_env_variable(): + # Arrange + api_key = os.environ["OPENAI_API_KEY"] + + # Act + gpt4vision = GPT4Vision() + + # Assert + assert gpt4vision.api_key == api_key + +def test_gpt4vision_set_api_key(): + # Arrange + gpt4vision = GPT4Vision(api_key=api_key) + + # Assert + assert gpt4vision.api_key == api_key + +def test_gpt4vision_invalid_max_retries(): + # Arrange and Act + with pytest.raises(ValueError): + GPT4Vision(max_retries=-1) + +def test_gpt4vision_invalid_backoff_factor(): + # Arrange and Act + with pytest.raises(ValueError): + GPT4Vision(backoff_factor=-1) + +def test_gpt4vision_invalid_timeout_seconds(): + # Arrange and Act + with pytest.raises(ValueError): + GPT4Vision(timeout_seconds=-1) + +def test_gpt4vision_invalid_max_tokens(): + # Arrange and Act + with pytest.raises(ValueError): + GPT4Vision(max_tokens=-1) + +def test_gpt4vision_logger_initialized(): + # Arrange + gpt4vision = GPT4Vision() + + # Assert + assert isinstance(gpt4vision.logger, logging.Logger) + +def test_gpt4vision_process_img_nonexistent_file(): + # Arrange + gpt4vision = GPT4Vision() + img_path = "nonexistent_image.jpg" + + # Act and Assert + with pytest.raises(FileNotFoundError): + gpt4vision.process_img(img_path) + +def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + # Act and Assert + with pytest.raises(AttributeError): + gpt4vision(img_url, [task]) + +def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + mock_openai_client.chat.completions.create.return_value.choices = [] + + # Act + response = gpt4vision(img_url, [task]) + + # Assert + assert response.answer == "" + mock_openai_client.chat.completions.create.assert_called_once() + +def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + tasks = ["Describe this image.", "What's in this picture?"] + + mock_openai_client.chat.completions.create.return_value.choices = [] + + # Act + responses = gpt4vision(img_url, tasks) + + # Assert + assert all(response.answer == "" for response in responses) + assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + +def test_gpt4vision_call_single_task_single_image_timeout(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + mock_openai_client.chat.completions.create.side_effect = Timeout("Request timed out") + + # Act and Assert + with pytest.raises(Timeout): + gpt4vision(img_url, [task]) + +def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + # Simulate success after a timeout and retry + mock_openai_client.chat.completions.create.side_effect = [ + Timeout("Request timed out"), + {"choices": [{"message": {"content": {"text": "A description of the image."}}}],} + ] + + # Act + response = gpt4vision(img_url, [task]) + + # Assert + assert response.answer == "A description of the image." + assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice + + +def test_gpt4vision_process_img(): + # Arrange + img_path = "test_image.jpg" + gpt4vision = GPT4Vision() + + # Act + img_data = gpt4vision.process_img(img_path) + + # Assert + assert img_data.startswith("/9j/") # Base64-encoded image data + + +def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + expected_response = GPT4VisionResponse(answer="A description of the image.") + + mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer + + # Act + response = gpt4vision(img_url, [task]) + + # Assert + assert response == expected_response + mock_openai_client.chat.completions.create.assert_called_once() + + +def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_client): + # Arrange + img_urls = ["https://example.com/image1.jpg", "https://example.com/image2.jpg"] + task = "Describe these images." + + expected_response = GPT4VisionResponse(answer="Descriptions of the images.") + + mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer + + # Act + response = gpt4vision(img_urls, [task]) + + # Assert + assert response == expected_response + mock_openai_client.chat.completions.create.assert_called_once() + + +def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + tasks = ["Describe this image.", "What's in this picture?"] + + expected_responses = [ + GPT4VisionResponse(answer="A description of the image."), + GPT4VisionResponse(answer="It contains various objects."), + ] + + def create_mock_response(response): + return {"choices": [{"message": {"content": {"text": response.answer}}}]} + + mock_openai_client.chat.completions.create.side_effect = [create_mock_response(response) for response in expected_responses] + + # Act + responses = gpt4vision(img_url, tasks) + + # Assert + assert responses == expected_responses + assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + tasks = ["Describe this image.", "What's in this picture?"] + + expected_responses = [ + GPT4VisionResponse(answer="A description of the image."), + GPT4VisionResponse(answer="It contains various objects."), + ] + + mock_openai_client.chat.completions.create.side_effect = [ + {"choices": [{"message": {"content": {"text": expected_responses[i].answer}}}] } for i in range(len(expected_responses)) + ] + + # Act + responses = gpt4vision(img_url, tasks) + + # Assert + assert responses == expected_responses + assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + + +def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client): + # Arrange + img_urls = ["https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"] + tasks = ["Describe these images.", "What's in these pictures?"] + + expected_responses = [ + GPT4VisionResponse(answer="Descriptions of the images."), + GPT4VisionResponse(answer="They contain various objects.") + ] + + mock_openai_client.chat.completions.create.side_effect = [ + {"choices": [{"message": {"content": {"text": response.answer}}}] } for response in expected_responses + ] + + # Act + responses = gpt4vision(img_urls, tasks) + + + # Assert + assert responses == expected_responses + assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + + +def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + mock_openai_client.chat.completions.create.side_effect = HTTPError("HTTP Error") + + # Act and Assert + with pytest.raises(HTTPError): + gpt4vision(img_url, [task]) + + +def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + mock_openai_client.chat.completions.create.side_effect = RequestException("Request Error") + + # Act and Assert + with pytest.raises(RequestException): + gpt4vision(img_url, [task]) + + +def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + mock_openai_client.chat.completions.create.side_effect = ConnectionError("Connection Error") + + # Act and Assert + with pytest.raises(ConnectionError): + gpt4vision(img_url, [task]) + + +def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + task = "Describe this image." + + # Simulate success after a retry + mock_openai_client.chat.completions.create.side_effect = [ + RequestException("Temporary error"), + {"choices": [{"text": "A description of the image."}]} # fixed dictionary syntax + ] + + # Act + response = gpt4vision(img_url, [task]) + + # Assert + assert response.answer == "A description of the image." + assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice diff --git a/tests/structs/sequential_workflow.py b/tests/structs/sequential_workflow.py index 64b51f28..7bd3e4a4 100644 --- a/tests/structs/sequential_workflow.py +++ b/tests/structs/sequential_workflow.py @@ -12,7 +12,6 @@ from swarms.structs.sequential_workflow import SequentialWorkflow, Task os.environ["OPENAI_API_KEY"] = "mocked_api_key" - # Mock OpenAIChat class for testing class MockOpenAIChat: def __init__(self, *args, **kwargs): @@ -21,6 +20,7 @@ class MockOpenAIChat: def run(self, *args, **kwargs): return "Mocked result" + # Mock Flow class for testing class MockFlow: def __init__(self, *args, **kwargs): @@ -29,6 +29,7 @@ class MockFlow: def run(self, *args, **kwargs): return "Mocked result" + # Mock SequentialWorkflow class for testing class MockSequentialWorkflow: def __init__(self, *args, **kwargs): @@ -40,6 +41,7 @@ class MockSequentialWorkflow: def run(self): pass + # Test Task class def test_task_initialization(): description = "Sample Task" @@ -48,6 +50,7 @@ def test_task_initialization(): assert task.description == description assert task.flow == flow + def test_task_execute(): description = "Sample Task" flow = MockOpenAIChat() @@ -55,6 +58,7 @@ def test_task_execute(): task.execute() assert task.result == "Mocked result" + # Test SequentialWorkflow class def test_sequential_workflow_initialization(): workflow = SequentialWorkflow() @@ -66,6 +70,7 @@ def test_sequential_workflow_initialization(): assert workflow.restore_state_filepath == None assert workflow.dashboard == False + def test_sequential_workflow_add_task(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -75,6 +80,7 @@ def test_sequential_workflow_add_task(): assert workflow.tasks[0].description == task_description assert workflow.tasks[0].flow == task_flow + def test_sequential_workflow_reset_workflow(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -83,6 +89,7 @@ def test_sequential_workflow_reset_workflow(): workflow.reset_workflow() assert workflow.tasks[0].result == None + def test_sequential_workflow_get_task_results(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -94,6 +101,7 @@ def test_sequential_workflow_get_task_results(): assert task_description in results assert results[task_description] == "Mocked result" + def test_sequential_workflow_remove_task(): workflow = SequentialWorkflow() task1_description = "Task 1" @@ -106,6 +114,7 @@ def test_sequential_workflow_remove_task(): assert len(workflow.tasks) == 1 assert workflow.tasks[0].description == task2_description + def test_sequential_workflow_update_task(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -114,6 +123,7 @@ def test_sequential_workflow_update_task(): workflow.update_task(task_description, max_tokens=1000) assert workflow.tasks[0].kwargs["max_tokens"] == 1000 + def test_sequential_workflow_save_workflow_state(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -123,6 +133,7 @@ def test_sequential_workflow_save_workflow_state(): assert os.path.exists("test_state.json") os.remove("test_state.json") + def test_sequential_workflow_load_workflow_state(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -134,6 +145,7 @@ def test_sequential_workflow_load_workflow_state(): assert workflow.tasks[0].description == task_description os.remove("test_state.json") + def test_sequential_workflow_run(): workflow = SequentialWorkflow() task_description = "Sample Task" @@ -142,18 +154,21 @@ def test_sequential_workflow_run(): workflow.run() assert workflow.tasks[0].result == "Mocked result" + def test_sequential_workflow_workflow_bootup(capfd): workflow = SequentialWorkflow() workflow.workflow_bootup() out, _ = capfd.readouterr() assert "Sequential Workflow Initializing..." in out + def test_sequential_workflow_workflow_dashboard(capfd): workflow = SequentialWorkflow() workflow.workflow_dashboard() out, _ = capfd.readouterr() assert "Sequential Workflow Dashboard" in out + # Mock Flow class for async testing class MockAsyncFlow: def __init__(self, *args, **kwargs): @@ -162,6 +177,7 @@ class MockAsyncFlow: async def arun(self, *args, **kwargs): return "Mocked result" + # Test async execution in SequentialWorkflow @pytest.mark.asyncio async def test_sequential_workflow_arun(): @@ -173,23 +189,24 @@ async def test_sequential_workflow_arun(): assert workflow.tasks[0].result == "Mocked result" - - def test_real_world_usage_with_openai_key(): # Initialize the language model llm = OpenAIChat() assert isinstance(llm, OpenAIChat) + def test_real_world_usage_with_flow_and_openai_key(): # Initialize a flow with the language model flow = Flow(llm=OpenAIChat()) assert isinstance(flow, Flow) + def test_real_world_usage_with_sequential_workflow(): # Initialize a sequential workflow workflow = SequentialWorkflow() assert isinstance(workflow, SequentialWorkflow) + def test_real_world_usage_add_tasks(): # Create a sequential workflow and add tasks workflow = SequentialWorkflow() @@ -203,6 +220,7 @@ def test_real_world_usage_add_tasks(): assert workflow.tasks[0].description == task1_description assert workflow.tasks[1].description == task2_description + def test_real_world_usage_run_workflow(): # Create a sequential workflow, add a task, and run the workflow workflow = SequentialWorkflow() @@ -212,6 +230,7 @@ def test_real_world_usage_run_workflow(): workflow.run() assert workflow.tasks[0].result is not None + def test_real_world_usage_dashboard_display(): # Create a sequential workflow, add tasks, and display the dashboard workflow = SequentialWorkflow() @@ -225,6 +244,7 @@ def test_real_world_usage_dashboard_display(): workflow.workflow_dashboard() mock_print.assert_called() + def test_real_world_usage_async_execution(): # Create a sequential workflow, add an async task, and run the workflow asynchronously workflow = SequentialWorkflow() @@ -238,6 +258,7 @@ def test_real_world_usage_async_execution(): asyncio.run(async_run_workflow()) assert workflow.tasks[0].result is not None + def test_real_world_usage_multiple_loops(): # Create a sequential workflow with multiple loops, add a task, and run the workflow workflow = SequentialWorkflow(max_loops=3) @@ -247,6 +268,7 @@ def test_real_world_usage_multiple_loops(): workflow.run() assert workflow.tasks[0].result is not None + def test_real_world_usage_autosave_state(): # Create a sequential workflow with autosave, add a task, run the workflow, and check if state is saved workflow = SequentialWorkflow(autosave=True) @@ -258,6 +280,7 @@ def test_real_world_usage_autosave_state(): assert os.path.exists("sequential_workflow_state.json") os.remove("sequential_workflow_state.json") + def test_real_world_usage_load_state(): # Create a sequential workflow, add a task, save state, load state, and run the workflow workflow = SequentialWorkflow() @@ -271,6 +294,7 @@ def test_real_world_usage_load_state(): assert workflow.tasks[0].result is not None os.remove("test_state.json") + def test_real_world_usage_update_task_args(): # Create a sequential workflow, add a task, and update task arguments workflow = SequentialWorkflow() @@ -280,6 +304,7 @@ def test_real_world_usage_update_task_args(): workflow.update_task(task_description, max_tokens=1000) assert workflow.tasks[0].kwargs["max_tokens"] == 1000 + def test_real_world_usage_remove_task(): # Create a sequential workflow, add tasks, remove a task, and run the workflow workflow = SequentialWorkflow() @@ -294,13 +319,15 @@ def test_real_world_usage_remove_task(): assert len(workflow.tasks) == 1 assert workflow.tasks[0].description == task2_description + def test_real_world_usage_with_environment_variables(): # Ensure that the OpenAI API key is set using environment variables assert "OPENAI_API_KEY" in os.environ assert os.environ["OPENAI_API_KEY"] == "mocked_api_key" del os.environ["OPENAI_API_KEY"] # Clean up after the test + def test_real_world_usage_no_openai_key(): # Ensure that an exception is raised when the OpenAI API key is not set with pytest.raises(ValueError): - llm = OpenAIChat() # API key not provided, should raise an exception \ No newline at end of file + llm = OpenAIChat() # API key not provided, should raise an exception