GPT4Vision + Dalle3 -> modules + tests + documentation

pull/100/head
Kye 1 year ago
parent fe48ec1393
commit fd8919dde5

@ -0,0 +1,6 @@
from swarms.models.dalle3 import Dalle3
model = Dalle3()
task = "A painting of a dog"
img = model(task)

@ -0,0 +1,251 @@
# GPT4Vision Documentation
## Table of Contents
- [Overview](#overview)
- [Installation](#installation)
- [Initialization](#initialization)
- [Methods](#methods)
- [process_img](#process_img)
- [__call__](#__call__)
- [run](#run)
- [arun](#arun)
- [Configuration Options](#configuration-options)
- [Usage Examples](#usage-examples)
- [Additional Tips](#additional-tips)
- [References and Resources](#references-and-resources)
---
## Overview
The GPT4Vision Model API is designed to provide an easy-to-use interface for interacting with the OpenAI GPT-4 Vision model. This model can generate textual descriptions for images and answer questions related to visual content. Whether you want to describe images or perform other vision-related tasks, GPT4Vision makes it simple and efficient.
The library offers a straightforward way to send images and tasks to the GPT-4 Vision model and retrieve the generated responses. It handles API communication, authentication, and retries, making it a powerful tool for developers working with computer vision and natural language processing tasks.
## Installation
To use the GPT4Vision Model API, you need to install the required dependencies and configure your environment. Follow these steps to get started:
1. Install the required Python package:
```bash
pip3 install --upgrade swarms
```
2. Make sure you have an OpenAI API key. You can obtain one by signing up on the [OpenAI platform](https://beta.openai.com/signup/).
3. Set your OpenAI API key as an environment variable. You can do this in your code or your environment configuration. Alternatively, you can provide the API key directly when initializing the `GPT4Vision` class.
## Initialization
To start using the GPT4Vision Model API, you need to create an instance of the `GPT4Vision` class. You can customize its behavior by providing various configuration options, but it also comes with sensible defaults.
Here's how you can initialize the `GPT4Vision` class:
```python
from swarms.models.gpt4v import GPT4Vision
gpt4vision = GPT4Vision(
api_key="Your Key"
)
```
The above code initializes the `GPT4Vision` class with default settings. You can adjust these settings as needed.
## Methods
### `process_img`
The `process_img` method is used to preprocess an image before sending it to the GPT-4 Vision model. It takes the image path as input and returns the processed image in a format suitable for API requests.
```python
processed_img = gpt4vision.process_img(img_path)
```
- `img_path` (str): The file path or URL of the image to be processed.
### `__call__`
The `__call__` method is the main method for interacting with the GPT-4 Vision model. It sends the image and tasks to the model and returns the generated response.
```python
response = gpt4vision(img, tasks)
```
- `img` (Union[str, List[str]]): Either a single image URL or a list of image URLs to be used for the API request.
- `tasks` (List[str]): A list of tasks or questions related to the image(s).
This method returns a `GPT4VisionResponse` object, which contains the generated answer.
### `run`
The `run` method is an alternative way to interact with the GPT-4 Vision model. It takes a single task and image URL as input and returns the generated response.
```python
response = gpt4vision.run(task, img)
```
- `task` (str): The task or question related to the image.
- `img` (str): The image URL to be used for the API request.
This method simplifies interactions when dealing with a single task and image.
### `arun`
The `arun` method is an asynchronous version of the `run` method. It allows for asynchronous processing of API requests, which can be useful in certain scenarios.
```python
import asyncio
async def main():
response = await gpt4vision.arun(task, img)
print(response)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
```
- `task` (str): The task or question related to the image.
- `img` (str): The image URL to be used for the API request.
## Configuration Options
The `GPT4Vision` class provides several configuration options that allow you to customize its behavior:
- `max_retries` (int): The maximum number of retries to make to the API. Default: 3
- `backoff_factor` (float): The backoff factor to use for exponential backoff. Default: 2.0
- `timeout_seconds` (int): The timeout in seconds for the API request. Default: 10
- `api_key` (str): The API key to use for the API request. Default: None (set via environment variable)
- `quality` (str): The quality of the image to generate. Options: 'low' or 'high'. Default: 'low'
- `max_tokens` (int): The maximum number of tokens to use for the API request. Default: 200
## Usage Examples
### Example 1: Generating Image Descriptions
```python
gpt4vision = GPT4Vision()
img = "https://example.com/image.jpg"
tasks = ["Describe this image."]
response = gpt4vision(img, tasks)
print(response.answer)
```
In this example, we create an instance of `GPT4Vision`, provide an image URL, and ask the model to describe the image. The response contains the generated description.
### Example 2: Custom Configuration
```python
custom_config = {
"max_retries": 5,
"timeout_seconds": 20,
"quality": "high",
"max_tokens": 300,
}
gpt4vision = GPT4Vision(**custom_config)
img = "https://example.com/another_image.jpg"
tasks = ["What objects can you identify in this image?"]
response = gpt4vision(img, tasks)
print(response.answer)
```
In this example, we create an instance of `GPT4Vision` with custom configuration options. We set a higher timeout, request high-quality images, and allow more tokens in the response.
### Example 3: Using the `run` Method
```python
gpt4vision = GPT4Vision()
img = "https://example.com/image.jpg"
task = "Describe this image in detail."
response = gpt4vision.run(task, img)
print(response)
```
In this example, we use the `run` method to simplify the interaction by providing a single task and image URL.
# Model Usage and Image Understanding
The GPT-4 Vision model processes images in a unique way, allowing it to answer questions about both or each of the images independently. Here's an overview:
| Purpose | Description |
| --------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| Image Understanding | The model is shown two copies of the same image and can answer questions about both or each of the images independently. |
# Image Detail Control
You have control over how the model processes the image and generates textual understanding by using the `detail` parameter, which has two options: `low` and `high`.
| Detail | Description |
| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| low | Disables the "high-res" model. The model receives a low-res 512 x 512 version of the image and represents the image with a budget of 65 tokens. Ideal for use cases not requiring high detail. |
| high | Enables "high-res" mode. The model first sees the low-res image and then creates detailed crops of input images as 512px squares based on the input image size. Uses a total of 129 tokens. |
# Managing Images
To use the Chat Completions API effectively, you must manage the images you pass to the model. Here are some key considerations:
| Management Aspect | Description |
| ------------------------- | ------------------------------------------------------------------------------------------------- |
| Image Reuse | To pass the same image multiple times, include the image with each API request. |
| Image Size Optimization | Improve latency by downsizing images to meet the expected size requirements. |
| Image Deletion | After processing, images are deleted from OpenAI servers and not retained. No data is used for training. |
# Limitations
While GPT-4 with Vision is powerful, it has some limitations:
| Limitation | Description |
| -------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| Medical Images | Not suitable for interpreting specialized medical images like CT scans. |
| Non-English Text | May not perform optimally when handling non-Latin alphabets, such as Japanese or Korean. |
| Large Text in Images | Enlarge text within images for readability, but avoid cropping important details. |
| Rotated or Upside-Down Text/Images | May misinterpret rotated or upside-down text or images. |
| Complex Visual Elements | May struggle to understand complex graphs or text with varying colors or styles. |
| Spatial Reasoning | Struggles with tasks requiring precise spatial localization, such as identifying chess positions. |
| Accuracy | May generate incorrect descriptions or captions in certain scenarios. |
| Panoramic and Fisheye Images | Struggles with panoramic and fisheye images. |
# Calculating Costs
Image inputs are metered and charged in tokens. The token cost depends on the image size and detail option.
| Example | Token Cost |
| --------------------------------------------- | ----------- |
| 1024 x 1024 square image in detail: high mode | 765 tokens |
| 2048 x 4096 image in detail: high mode | 1105 tokens |
| 4096 x 8192 image in detail: low mode | 85 tokens |
# FAQ
Here are some frequently asked questions about GPT-4 with Vision:
| Question | Answer |
| -------------------------------------------- | -------------------------------------------------------------------------------------------------- |
| Fine-Tuning Image Capabilities | No, fine-tuning the image capabilities of GPT-4 is not supported at this time. |
| Generating Images | GPT-4 is used for understanding images, not generating them. |
| Supported Image File Types | Supported image file types include PNG (.png), JPEG (.jpeg and .jpg), WEBP (.webp), and non-animated GIF (.gif). |
| Image Size Limitations | Image uploads are restricted to 20MB per image. |
| Image Deletion | Uploaded images are automatically deleted after processing by the model. |
| Learning More | For more details about GPT-4 with Vision, refer to the GPT-4 with Vision system card. |
| CAPTCHA Submission | CAPTCHAs are blocked for safety reasons. |
| Rate Limits | Image processing counts toward your tokens per minute (TPM) limit. Refer to the calculating costs section for details. |
| Image Metadata | The model does not receive image metadata. |
| Handling Unclear Images | If an image is unclear, the model will do its best to interpret it, but results may be less accurate. |
## Additional Tips
- Make sure to handle potential exceptions and errors when making API requests. The library includes retries and error handling, but it's essential to handle exceptions gracefully in your code.
- Experiment with different configuration options to optimize the trade-off between response quality and response time based on your specific requirements.
## References and Resources
- [OpenAI Platform](https://beta.openai.com/signup/): Sign up for an OpenAI API key.
- [OpenAI API Documentation](https://platform.openai.com/docs/api-reference/chat/create): Official API documentation for the GPT-4 Vision model.
Now you have a comprehensive understanding of the GPT4Vision Model API, its configuration options, and how to use it for various computer vision and natural language processing tasks. Start experimenting and integrating it into your projects to leverage the power of GPT-4 Vision for image-related tasks.
# Conclusion
With GPT-4 Vision, you have a powerful tool for understanding and generating textual descriptions for images. By considering its capabilities, limitations, and cost calculations, you can effectively leverage this model for various image-related tasks.

@ -0,0 +1,7 @@
from swarms.models.gpt4v import GPT4Vision
gpt4vision = GPT4Vision(api_key="")
task = "What is the following image about?"
img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
answer = gpt4vision.run(task, img)

@ -3,9 +3,7 @@ from swarms.structs import Flow
from swarms.structs.sequential_workflow import SequentialWorkflow from swarms.structs.sequential_workflow import SequentialWorkflow
# Example usage # Example usage
api_key = ( api_key = "" # Your actual API key here
"" # Your actual API key here
)
# Initialize the language flow # Initialize the language flow
llm = OpenAIChat( llm = OpenAIChat(

@ -0,0 +1,175 @@
import openai
import logging
import os
from dataclasses import dataclass
from functools import lru_cache
from termcolor import colored
from openai import OpenAI
from dotenv import load_dotenv
from pydantic import BaseModel, validator
from PIL import Image
from io import BytesIO
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class Dalle3:
"""
Dalle3 model class
Attributes:
-----------
image_url: str
The image url generated by the Dalle3 API
Methods:
--------
__call__(self, task: str) -> Dalle3:
Makes a call to the Dalle3 API and returns the image url
Example:
--------
>>> dalle3 = Dalle3()
>>> task = "A painting of a dog"
>>> image_url = dalle3(task)
>>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
"""
model: str = "dall-e-3"
img: str = None
size: str = "1024x1024"
max_retries: int = 3
quality: str = "standard"
n: int = 4
client = OpenAI(
api_key=api_key,
max_retries=max_retries,
)
class Config:
"""Config class for the Dalle3 model"""
arbitrary_types_allowed = True
@validator("max_retries", "time_seconds")
def must_be_positive(cls, value):
if value <= 0:
raise ValueError("Must be positive")
return value
def read_img(self, img: str):
"""Read the image using pil"""
img = Image.open(img)
return img
def set_width_height(self, img: str, width: int, height: int):
"""Set the width and height of the image"""
img = self.read_img(img)
img = img.resize((width, height))
return img
def convert_to_bytesio(self, img: str, format: str = "PNG"):
"""Convert the image to an bytes io object"""
byte_stream = BytesIO()
img.save(byte_stream, format=format)
byte_array = byte_stream.getvalue()
return byte_array
# @lru_cache(maxsize=32)
def __call__(self, task: str):
"""
Text to image conversion using the Dalle3 API
Parameters:
-----------
task: str
The task to be converted to an image
Returns:
--------
Dalle3:
An instance of the Dalle3 class with the image url generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> task = "A painting of a dog"
>>> image_url = dalle3(task)
>>> print(image_url)
https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png
"""
try:
# Making a call to the the Dalle3 API
response = self.client.images.generate(
# model=self.model,
prompt=task,
# size=self.size,
# quality=self.quality,
n=self.n,
)
# Extracting the image url from the response
img = response.data[0].url
return img
except openai.OpenAIError as error:
# Handling exceptions and printing the errors details
print(
colored(
f"Error running Dalle3: {error} try optimizing your api key and or try again",
"red",
)
)
raise error
def create_variations(self, img: str):
"""
Create variations of an image using the Dalle3 API
Parameters:
-----------
img: str
The image to be used for the API request
Returns:
--------
img: str
The image url generated by the Dalle3 API
Example:
--------
>>> dalle3 = Dalle3()
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> img = dalle3.create_variations(img)
>>> print(img)
"""
try:
response = self.client.images.create_variation(
img = open(img, "rb"),
n=self.n,
size=self.size
)
img = response.data[0].url
return img
except (Exception, openai.OpenAIError) as error:
print(
colored(
f"Error running Dalle3: {error} try optimizing your api key and or try again",
"red",
)
)
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error

@ -0,0 +1,288 @@
import base64
import logging
import os
import time
from dataclasses import dataclass
from typing import List, Optional, Union
import requests
from dotenv import load_dotenv
from openai import OpenAI
from termcolor import colored
# ENV
load_dotenv()
def logging_config():
"""Configures logging"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
return logger
@dataclass
class GPT4VisionResponse:
"""A response structure for GPT-4"""
answer: str
@dataclass
class GPT4Vision:
"""
GPT4Vision model class
Attributes:
-----------
max_retries: int
The maximum number of retries to make to the API
backoff_factor: float
The backoff factor to use for exponential backoff
timeout_seconds: int
The timeout in seconds for the API request
api_key: str
The API key to use for the API request
quality: str
The quality of the image to generate
max_tokens: int
The maximum number of tokens to use for the API request
Methods:
--------
process_img(self, img_path: str) -> str:
Processes the image to be used for the API request
__call__(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse:
Makes a call to the GPT-4 Vision API and returns the image url
Example:
>>> gpt4vision = GPT4Vision()
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> tasks = ["A painting of a dog"]
>>> answer = gpt4vision(img, tasks)
>>> print(answer)
"""
max_retries: int = 3
model: str = "gpt-4-vision-preview"
backoff_factor: float = 2.0
timeout_seconds: int = 10
api_key: Optional[str] = None or os.getenv("OPENAI_API_KEY")
# 'Low' or 'High' for respesctively fast or high quality, but high more token usage
quality: str = "low"
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know
max_tokens: int = 200
client = OpenAI(
api_key=api_key,
max_retries=max_retries,
)
logger = logging_config()
class Config:
"""Config class for the GPT4Vision model"""
arbitary_types_allowed = True
def process_img(self, img: str) -> str:
"""Processes the image to be used for the API request"""
with open(img, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def __call__(
self,
img: Union[str, List[str]],
tasks: List[str],
) -> GPT4VisionResponse:
"""
Calls the GPT-4 Vision API and returns the image url
Parameters:
-----------
img: Union[str, List[str]]
The image to be used for the API request
tasks: List[str]
The tasks to be used for the API request
Returns:
--------
answer: GPT4VisionResponse
The response from the API request
Example:
--------
>>> gpt4vision = GPT4Vision()
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> tasks = ["A painting of a dog"]
>>> answer = gpt4vision(img, tasks)
>>> print(answer)
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# Image content
image_content = [
{"type": "imavge_url", "image_url": img}
if img.startswith("http")
else {"type": "image", "data": img}
for img in img
]
messages = [
{
"role": "user",
"content": image_content + [{"type": "text", "text": q} for q in tasks],
}
]
payload = {
"model": "gpt-4-vision-preview",
"messages": messages,
"max_tokens": self.max_tokens,
"detail": self.quality,
}
for attempt in range(self.max_retries):
try:
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
timeout=self.timeout_seconds,
)
response.raise_for_status()
answer = response.json()["choices"][0]["message"]["content"]["text"]
return GPT4VisionResponse(answer=answer)
except requests.exceptions.HTTPError as error:
self.logger.error(
f"HTTP error: {error.response.status_code}, {error.response.text}"
)
if error.response.status_code in [429, 500, 503]:
# Exponential backoff = 429(too many requesys)
# And 503 = (Service unavailable) errors
time.sleep(self.backoff_factor**attempt)
else:
break
except requests.exceptions.RequestException as error:
self.logger.error(f"Request error: {error}")
time.sleep(self.backoff_factor**attempt)
except Exception as error:
self.logger.error(
f"Unexpected Error: {error} try optimizing your api key and try again"
)
raise error from None
raise TimeoutError("API Request timed out after multiple retries")
def run(self, task: str, img: str) -> str:
"""
Runs the GPT-4 Vision API
Parameters:
-----------
task: str
The task to be used for the API request
img: str
The image to be used for the API request
Returns:
--------
out: str
The response from the API request
Example:
--------
>>> gpt4vision = GPT4Vision()
>>> task = "A painting of a dog"
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> answer = gpt4vision.run(task, img)
>>> print(answer)
"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": f"{task}"},
{
"type": "image_url",
"image_url": f"{img}",
},
],
}
],
max_tokens=self.max_tokens,
)
out = response.choices[0].text
return out
except Exception as error:
print(
colored(
f"Error when calling GPT4Vision, Error: {error} Try optimizing your key, and try again",
"red",
)
)
async def arun(self, task: str, img: str) -> str:
"""
Asynchronous run method for GPT-4 Vision
Parameters:
-----------
task: str
The task to be used for the API request
img: str
The image to be used for the API request
Returns:
--------
out: str
The response from the API request
Example:
--------
>>> gpt4vision = GPT4Vision()
>>> task = "A painting of a dog"
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> answer = await gpt4vision.arun(task, img)
>>> print(answer)
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": f"{task}"},
{
"type": "image_url",
"image_url": f"{img}",
},
],
}
],
max_tokens=self.max_tokens,
)
out = response.choices[0].text
return out
except Exception as error:
print(
colored(
f"Error when calling GPT4Vision, Error: {error} Try optimizing your key, and try again",
"red",
)
)

@ -0,0 +1,374 @@
import os
from unittest.mock import Mock
import pytest
from openai import OpenAIError
from PIL import Image
from termcolor import colored
from dalle3 import Dalle3
# Mocking the OpenAI client to avoid making actual API calls during testing
@pytest.fixture
def mock_openai_client():
return Mock()
@pytest.fixture
def dalle3(mock_openai_client):
return Dalle3(client=mock_openai_client)
def test_dalle3_call_success(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)])
# Act
img_url = dalle3(task)
# Assert
assert img_url == expected_img_url
mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4)
def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
# Arrange
task = "Invalid task"
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
dalle3(task)
assert str(excinfo.value) == expected_error_message
mock_openai_client.images.generate.assert_called_once_with(prompt=task, n=4)
# Ensure the error message is printed in red
captured = capsys.readouterr()
assert colored(expected_error_message, "red") in captured.out
def test_dalle3_create_variations_success(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
mock_openai_client.images.create_variation.return_value = Mock(data=[Mock(url=expected_variation_url)])
# Act
variation_img_url = dalle3.create_variations(img_url)
# Assert
assert variation_img_url == expected_variation_url
mock_openai_client.images.create_variation.assert_called_once()
_, kwargs = mock_openai_client.images.create_variation.call_args
assert kwargs["img"] is not None
assert kwargs["n"] == 4
assert kwargs["size"] == "1024x1024"
def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
dalle3.create_variations(img_url)
assert str(excinfo.value) == expected_error_message
mock_openai_client.images.create_variation.assert_called_once()
# Ensure the error message is printed in red
captured = capsys.readouterr()
assert colored(expected_error_message, "red") in captured.out
def test_dalle3_read_img():
# Arrange
img_path = "test_image.png"
img = Image.new("RGB", (512, 512))
# Save the image temporarily
img.save(img_path)
# Act
dalle3 = Dalle3()
img_loaded = dalle3.read_img(img_path)
# Assert
assert isinstance(img_loaded, Image.Image)
# Clean up
os.remove(img_path)
def test_dalle3_set_width_height():
# Arrange
img = Image.new("RGB", (512, 512))
width = 256
height = 256
# Act
dalle3 = Dalle3()
img_resized = dalle3.set_width_height(img, width, height)
# Assert
assert img_resized.size == (width, height)
def test_dalle3_convert_to_bytesio():
# Arrange
img = Image.new("RGB", (512, 512))
expected_format = "PNG"
# Act
dalle3 = Dalle3()
img_bytes = dalle3.convert_to_bytesio(img, format=expected_format)
# Assert
assert isinstance(img_bytes, bytes)
assert img_bytes.startswith(b"\x89PNG")
def test_dalle3_call_multiple_times(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)])
# Act
img_url1 = dalle3(task)
img_url2 = dalle3(task)
# Assert
assert img_url1 == expected_img_url
assert img_url2 == expected_img_url
assert mock_openai_client.images.generate.call_count == 2
def test_dalle3_call_with_large_input(dalle3, mock_openai_client):
# Arrange
task = "A" * 2048 # Input longer than API's limit
expected_error_message = "Error running Dalle3: API Error"
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
dalle3(task)
assert str(excinfo.value) == expected_error_message
def test_dalle3_create_variations_with_invalid_image_url(dalle3, mock_openai_client):
# Arrange
img_url = "https://invalid-image-url.com"
expected_error_message = "Error running Dalle3: Invalid image URL"
# Act and assert
with pytest.raises(ValueError) as excinfo:
dalle3.create_variations(img_url)
assert str(excinfo.value) == expected_error_message
def test_dalle3_set_width_height_invalid_dimensions(dalle3):
# Arrange
img = dalle3.read_img("test_image.png")
width = 0
height = -1
# Act and assert
with pytest.raises(ValueError):
dalle3.set_width_height(img, width, height)
def test_dalle3_convert_to_bytesio_invalid_format(dalle3):
# Arrange
img = dalle3.read_img("test_image.png")
invalid_format = "invalid_format"
# Act and assert
with pytest.raises(ValueError):
dalle3.convert_to_bytesio(img, format=invalid_format)
def test_dalle3_call_with_retry(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
# Simulate a retry scenario
mock_openai_client.images.generate.side_effect = [
OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
Mock(data=[Mock(url=expected_img_url)]),
]
# Act
img_url = dalle3(task)
# Assert
assert img_url == expected_img_url
assert mock_openai_client.images.generate.call_count == 2
def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
# Simulate a retry scenario
mock_openai_client.images.create_variation.side_effect = [
OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
Mock(data=[Mock(url=expected_variation_url)]),
]
# Act
variation_img_url = dalle3.create_variations(img_url)
# Assert
assert variation_img_url == expected_variation_url
assert mock_openai_client.images.create_variation.call_count == 2
def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys):
# Arrange
task = "A painting of a dog"
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
# Act
with pytest.raises(OpenAIError):
dalle3(task)
# Assert that the error message is logged
captured = capsys.readouterr()
assert expected_error_message in captured.err
def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, capsys):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
# Act
with pytest.raises(OpenAIError):
dalle3.create_variations(img_url)
# Assert that the error message is logged
captured = capsys.readouterr()
assert expected_error_message in captured.err
def test_dalle3_read_img_invalid_path(dalle3):
# Arrange
invalid_img_path = "invalid_image_path.png"
# Act and assert
with pytest.raises(FileNotFoundError):
dalle3.read_img(invalid_img_path)
def test_dalle3_call_no_api_key():
# Arrange
task = "A painting of a dog"
dalle3 = Dalle3(api_key=None)
expected_error_message = "Error running Dalle3: API Key is missing"
# Act and assert
with pytest.raises(ValueError) as excinfo:
dalle3(task)
assert str(excinfo.value) == expected_error_message
def test_dalle3_create_variations_no_api_key():
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
dalle3 = Dalle3(api_key=None)
expected_error_message = "Error running Dalle3: API Key is missing"
# Act and assert
with pytest.raises(ValueError) as excinfo:
dalle3.create_variations(img_url)
assert str(excinfo.value) == expected_error_message
def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
# Simulate max retries exceeded
mock_openai_client.images.generate.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error")
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
dalle3(task)
assert "Retry limit exceeded" in str(excinfo.value)
def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
# Simulate max retries exceeded
mock_openai_client.images.create_variation.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error")
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
dalle3.create_variations(img_url)
assert "Retry limit exceeded" in str(excinfo.value)
def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
# Simulate success after a retry
mock_openai_client.images.generate.side_effect = [
OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
Mock(data=[Mock(url=expected_img_url)]),
]
# Act
img_url = dalle3(task)
# Assert
assert img_url == expected_img_url
assert mock_openai_client.images.generate.call_count == 2
def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
# Simulate success after a retry
mock_openai_client.images.create_variation.side_effect = [
OpenAIError("Temporary error", http_status=500, error="Internal Server Error"),
Mock(data=[Mock(url=expected_variation_url)]),
]
# Act
variation_img_url = dalle3.create_variations(img_url)
# Assert
assert variation_img_url == expected_variation_url
assert mock_openai_client.images.create_variation.call_count == 2

@ -0,0 +1,321 @@
import logging
import os
from unittest.mock import Mock
import pytest
from dotenv import load_dotenv
from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout
from swarms.models.gpt4v import GPT4Vision, GPT4VisionResponse
load_dotenv
api_key = os.getenv("OPENAI_API_KEY")
# Mock the OpenAI client
@pytest.fixture
def mock_openai_client():
return Mock()
@pytest.fixture
def gpt4vision(mock_openai_client):
return GPT4Vision(client=mock_openai_client)
def test_gpt4vision_default_values():
# Arrange and Act
gpt4vision = GPT4Vision()
# Assert
assert gpt4vision.max_retries == 3
assert gpt4vision.model == "gpt-4-vision-preview"
assert gpt4vision.backoff_factor == 2.0
assert gpt4vision.timeout_seconds == 10
assert gpt4vision.api_key is None
assert gpt4vision.quality == "low"
assert gpt4vision.max_tokens == 200
def test_gpt4vision_api_key_from_env_variable():
# Arrange
api_key = os.environ["OPENAI_API_KEY"]
# Act
gpt4vision = GPT4Vision()
# Assert
assert gpt4vision.api_key == api_key
def test_gpt4vision_set_api_key():
# Arrange
gpt4vision = GPT4Vision(api_key=api_key)
# Assert
assert gpt4vision.api_key == api_key
def test_gpt4vision_invalid_max_retries():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(max_retries=-1)
def test_gpt4vision_invalid_backoff_factor():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(backoff_factor=-1)
def test_gpt4vision_invalid_timeout_seconds():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(timeout_seconds=-1)
def test_gpt4vision_invalid_max_tokens():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(max_tokens=-1)
def test_gpt4vision_logger_initialized():
# Arrange
gpt4vision = GPT4Vision()
# Assert
assert isinstance(gpt4vision.logger, logging.Logger)
def test_gpt4vision_process_img_nonexistent_file():
# Arrange
gpt4vision = GPT4Vision()
img_path = "nonexistent_image.jpg"
# Act and Assert
with pytest.raises(FileNotFoundError):
gpt4vision.process_img(img_path)
def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
# Act and Assert
with pytest.raises(AttributeError):
gpt4vision(img_url, [task])
def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
mock_openai_client.chat.completions.create.return_value.choices = []
# Act
response = gpt4vision(img_url, [task])
# Assert
assert response.answer == ""
mock_openai_client.chat.completions.create.assert_called_once()
def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"]
mock_openai_client.chat.completions.create.return_value.choices = []
# Act
responses = gpt4vision(img_url, tasks)
# Assert
assert all(response.answer == "" for response in responses)
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
def test_gpt4vision_call_single_task_single_image_timeout(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = Timeout("Request timed out")
# Act and Assert
with pytest.raises(Timeout):
gpt4vision(img_url, [task])
def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
# Simulate success after a timeout and retry
mock_openai_client.chat.completions.create.side_effect = [
Timeout("Request timed out"),
{"choices": [{"message": {"content": {"text": "A description of the image."}}}],}
]
# Act
response = gpt4vision(img_url, [task])
# Assert
assert response.answer == "A description of the image."
assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice
def test_gpt4vision_process_img():
# Arrange
img_path = "test_image.jpg"
gpt4vision = GPT4Vision()
# Act
img_data = gpt4vision.process_img(img_path)
# Assert
assert img_data.startswith("/9j/") # Base64-encoded image data
def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
expected_response = GPT4VisionResponse(answer="A description of the image.")
mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer
# Act
response = gpt4vision(img_url, [task])
# Assert
assert response == expected_response
mock_openai_client.chat.completions.create.assert_called_once()
def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_client):
# Arrange
img_urls = ["https://example.com/image1.jpg", "https://example.com/image2.jpg"]
task = "Describe these images."
expected_response = GPT4VisionResponse(answer="Descriptions of the images.")
mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer
# Act
response = gpt4vision(img_urls, [task])
# Assert
assert response == expected_response
mock_openai_client.chat.completions.create.assert_called_once()
def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"]
expected_responses = [
GPT4VisionResponse(answer="A description of the image."),
GPT4VisionResponse(answer="It contains various objects."),
]
def create_mock_response(response):
return {"choices": [{"message": {"content": {"text": response.answer}}}]}
mock_openai_client.chat.completions.create.side_effect = [create_mock_response(response) for response in expected_responses]
# Act
responses = gpt4vision(img_url, tasks)
# Assert
assert responses == expected_responses
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"]
expected_responses = [
GPT4VisionResponse(answer="A description of the image."),
GPT4VisionResponse(answer="It contains various objects."),
]
mock_openai_client.chat.completions.create.side_effect = [
{"choices": [{"message": {"content": {"text": expected_responses[i].answer}}}] } for i in range(len(expected_responses))
]
# Act
responses = gpt4vision(img_url, tasks)
# Assert
assert responses == expected_responses
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client):
# Arrange
img_urls = ["https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"]
tasks = ["Describe these images.", "What's in these pictures?"]
expected_responses = [
GPT4VisionResponse(answer="Descriptions of the images."),
GPT4VisionResponse(answer="They contain various objects.")
]
mock_openai_client.chat.completions.create.side_effect = [
{"choices": [{"message": {"content": {"text": response.answer}}}] } for response in expected_responses
]
# Act
responses = gpt4vision(img_urls, tasks)
# Assert
assert responses == expected_responses
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = HTTPError("HTTP Error")
# Act and Assert
with pytest.raises(HTTPError):
gpt4vision(img_url, [task])
def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = RequestException("Request Error")
# Act and Assert
with pytest.raises(RequestException):
gpt4vision(img_url, [task])
def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = ConnectionError("Connection Error")
# Act and Assert
with pytest.raises(ConnectionError):
gpt4vision(img_url, [task])
def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
# Simulate success after a retry
mock_openai_client.chat.completions.create.side_effect = [
RequestException("Temporary error"),
{"choices": [{"text": "A description of the image."}]} # fixed dictionary syntax
]
# Act
response = gpt4vision(img_url, [task])
# Assert
assert response.answer == "A description of the image."
assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice

@ -12,7 +12,6 @@ from swarms.structs.sequential_workflow import SequentialWorkflow, Task
os.environ["OPENAI_API_KEY"] = "mocked_api_key" os.environ["OPENAI_API_KEY"] = "mocked_api_key"
# Mock OpenAIChat class for testing # Mock OpenAIChat class for testing
class MockOpenAIChat: class MockOpenAIChat:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -21,6 +20,7 @@ class MockOpenAIChat:
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
return "Mocked result" return "Mocked result"
# Mock Flow class for testing # Mock Flow class for testing
class MockFlow: class MockFlow:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -29,6 +29,7 @@ class MockFlow:
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
return "Mocked result" return "Mocked result"
# Mock SequentialWorkflow class for testing # Mock SequentialWorkflow class for testing
class MockSequentialWorkflow: class MockSequentialWorkflow:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -40,6 +41,7 @@ class MockSequentialWorkflow:
def run(self): def run(self):
pass pass
# Test Task class # Test Task class
def test_task_initialization(): def test_task_initialization():
description = "Sample Task" description = "Sample Task"
@ -48,6 +50,7 @@ def test_task_initialization():
assert task.description == description assert task.description == description
assert task.flow == flow assert task.flow == flow
def test_task_execute(): def test_task_execute():
description = "Sample Task" description = "Sample Task"
flow = MockOpenAIChat() flow = MockOpenAIChat()
@ -55,6 +58,7 @@ def test_task_execute():
task.execute() task.execute()
assert task.result == "Mocked result" assert task.result == "Mocked result"
# Test SequentialWorkflow class # Test SequentialWorkflow class
def test_sequential_workflow_initialization(): def test_sequential_workflow_initialization():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -66,6 +70,7 @@ def test_sequential_workflow_initialization():
assert workflow.restore_state_filepath == None assert workflow.restore_state_filepath == None
assert workflow.dashboard == False assert workflow.dashboard == False
def test_sequential_workflow_add_task(): def test_sequential_workflow_add_task():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -75,6 +80,7 @@ def test_sequential_workflow_add_task():
assert workflow.tasks[0].description == task_description assert workflow.tasks[0].description == task_description
assert workflow.tasks[0].flow == task_flow assert workflow.tasks[0].flow == task_flow
def test_sequential_workflow_reset_workflow(): def test_sequential_workflow_reset_workflow():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -83,6 +89,7 @@ def test_sequential_workflow_reset_workflow():
workflow.reset_workflow() workflow.reset_workflow()
assert workflow.tasks[0].result == None assert workflow.tasks[0].result == None
def test_sequential_workflow_get_task_results(): def test_sequential_workflow_get_task_results():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -94,6 +101,7 @@ def test_sequential_workflow_get_task_results():
assert task_description in results assert task_description in results
assert results[task_description] == "Mocked result" assert results[task_description] == "Mocked result"
def test_sequential_workflow_remove_task(): def test_sequential_workflow_remove_task():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task1_description = "Task 1" task1_description = "Task 1"
@ -106,6 +114,7 @@ def test_sequential_workflow_remove_task():
assert len(workflow.tasks) == 1 assert len(workflow.tasks) == 1
assert workflow.tasks[0].description == task2_description assert workflow.tasks[0].description == task2_description
def test_sequential_workflow_update_task(): def test_sequential_workflow_update_task():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -114,6 +123,7 @@ def test_sequential_workflow_update_task():
workflow.update_task(task_description, max_tokens=1000) workflow.update_task(task_description, max_tokens=1000)
assert workflow.tasks[0].kwargs["max_tokens"] == 1000 assert workflow.tasks[0].kwargs["max_tokens"] == 1000
def test_sequential_workflow_save_workflow_state(): def test_sequential_workflow_save_workflow_state():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -123,6 +133,7 @@ def test_sequential_workflow_save_workflow_state():
assert os.path.exists("test_state.json") assert os.path.exists("test_state.json")
os.remove("test_state.json") os.remove("test_state.json")
def test_sequential_workflow_load_workflow_state(): def test_sequential_workflow_load_workflow_state():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -134,6 +145,7 @@ def test_sequential_workflow_load_workflow_state():
assert workflow.tasks[0].description == task_description assert workflow.tasks[0].description == task_description
os.remove("test_state.json") os.remove("test_state.json")
def test_sequential_workflow_run(): def test_sequential_workflow_run():
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
task_description = "Sample Task" task_description = "Sample Task"
@ -142,18 +154,21 @@ def test_sequential_workflow_run():
workflow.run() workflow.run()
assert workflow.tasks[0].result == "Mocked result" assert workflow.tasks[0].result == "Mocked result"
def test_sequential_workflow_workflow_bootup(capfd): def test_sequential_workflow_workflow_bootup(capfd):
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
workflow.workflow_bootup() workflow.workflow_bootup()
out, _ = capfd.readouterr() out, _ = capfd.readouterr()
assert "Sequential Workflow Initializing..." in out assert "Sequential Workflow Initializing..." in out
def test_sequential_workflow_workflow_dashboard(capfd): def test_sequential_workflow_workflow_dashboard(capfd):
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
workflow.workflow_dashboard() workflow.workflow_dashboard()
out, _ = capfd.readouterr() out, _ = capfd.readouterr()
assert "Sequential Workflow Dashboard" in out assert "Sequential Workflow Dashboard" in out
# Mock Flow class for async testing # Mock Flow class for async testing
class MockAsyncFlow: class MockAsyncFlow:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -162,6 +177,7 @@ class MockAsyncFlow:
async def arun(self, *args, **kwargs): async def arun(self, *args, **kwargs):
return "Mocked result" return "Mocked result"
# Test async execution in SequentialWorkflow # Test async execution in SequentialWorkflow
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sequential_workflow_arun(): async def test_sequential_workflow_arun():
@ -173,23 +189,24 @@ async def test_sequential_workflow_arun():
assert workflow.tasks[0].result == "Mocked result" assert workflow.tasks[0].result == "Mocked result"
def test_real_world_usage_with_openai_key(): def test_real_world_usage_with_openai_key():
# Initialize the language model # Initialize the language model
llm = OpenAIChat() llm = OpenAIChat()
assert isinstance(llm, OpenAIChat) assert isinstance(llm, OpenAIChat)
def test_real_world_usage_with_flow_and_openai_key(): def test_real_world_usage_with_flow_and_openai_key():
# Initialize a flow with the language model # Initialize a flow with the language model
flow = Flow(llm=OpenAIChat()) flow = Flow(llm=OpenAIChat())
assert isinstance(flow, Flow) assert isinstance(flow, Flow)
def test_real_world_usage_with_sequential_workflow(): def test_real_world_usage_with_sequential_workflow():
# Initialize a sequential workflow # Initialize a sequential workflow
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
assert isinstance(workflow, SequentialWorkflow) assert isinstance(workflow, SequentialWorkflow)
def test_real_world_usage_add_tasks(): def test_real_world_usage_add_tasks():
# Create a sequential workflow and add tasks # Create a sequential workflow and add tasks
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -203,6 +220,7 @@ def test_real_world_usage_add_tasks():
assert workflow.tasks[0].description == task1_description assert workflow.tasks[0].description == task1_description
assert workflow.tasks[1].description == task2_description assert workflow.tasks[1].description == task2_description
def test_real_world_usage_run_workflow(): def test_real_world_usage_run_workflow():
# Create a sequential workflow, add a task, and run the workflow # Create a sequential workflow, add a task, and run the workflow
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -212,6 +230,7 @@ def test_real_world_usage_run_workflow():
workflow.run() workflow.run()
assert workflow.tasks[0].result is not None assert workflow.tasks[0].result is not None
def test_real_world_usage_dashboard_display(): def test_real_world_usage_dashboard_display():
# Create a sequential workflow, add tasks, and display the dashboard # Create a sequential workflow, add tasks, and display the dashboard
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -225,6 +244,7 @@ def test_real_world_usage_dashboard_display():
workflow.workflow_dashboard() workflow.workflow_dashboard()
mock_print.assert_called() mock_print.assert_called()
def test_real_world_usage_async_execution(): def test_real_world_usage_async_execution():
# Create a sequential workflow, add an async task, and run the workflow asynchronously # Create a sequential workflow, add an async task, and run the workflow asynchronously
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -238,6 +258,7 @@ def test_real_world_usage_async_execution():
asyncio.run(async_run_workflow()) asyncio.run(async_run_workflow())
assert workflow.tasks[0].result is not None assert workflow.tasks[0].result is not None
def test_real_world_usage_multiple_loops(): def test_real_world_usage_multiple_loops():
# Create a sequential workflow with multiple loops, add a task, and run the workflow # Create a sequential workflow with multiple loops, add a task, and run the workflow
workflow = SequentialWorkflow(max_loops=3) workflow = SequentialWorkflow(max_loops=3)
@ -247,6 +268,7 @@ def test_real_world_usage_multiple_loops():
workflow.run() workflow.run()
assert workflow.tasks[0].result is not None assert workflow.tasks[0].result is not None
def test_real_world_usage_autosave_state(): def test_real_world_usage_autosave_state():
# Create a sequential workflow with autosave, add a task, run the workflow, and check if state is saved # Create a sequential workflow with autosave, add a task, run the workflow, and check if state is saved
workflow = SequentialWorkflow(autosave=True) workflow = SequentialWorkflow(autosave=True)
@ -258,6 +280,7 @@ def test_real_world_usage_autosave_state():
assert os.path.exists("sequential_workflow_state.json") assert os.path.exists("sequential_workflow_state.json")
os.remove("sequential_workflow_state.json") os.remove("sequential_workflow_state.json")
def test_real_world_usage_load_state(): def test_real_world_usage_load_state():
# Create a sequential workflow, add a task, save state, load state, and run the workflow # Create a sequential workflow, add a task, save state, load state, and run the workflow
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -271,6 +294,7 @@ def test_real_world_usage_load_state():
assert workflow.tasks[0].result is not None assert workflow.tasks[0].result is not None
os.remove("test_state.json") os.remove("test_state.json")
def test_real_world_usage_update_task_args(): def test_real_world_usage_update_task_args():
# Create a sequential workflow, add a task, and update task arguments # Create a sequential workflow, add a task, and update task arguments
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -280,6 +304,7 @@ def test_real_world_usage_update_task_args():
workflow.update_task(task_description, max_tokens=1000) workflow.update_task(task_description, max_tokens=1000)
assert workflow.tasks[0].kwargs["max_tokens"] == 1000 assert workflow.tasks[0].kwargs["max_tokens"] == 1000
def test_real_world_usage_remove_task(): def test_real_world_usage_remove_task():
# Create a sequential workflow, add tasks, remove a task, and run the workflow # Create a sequential workflow, add tasks, remove a task, and run the workflow
workflow = SequentialWorkflow() workflow = SequentialWorkflow()
@ -294,12 +319,14 @@ def test_real_world_usage_remove_task():
assert len(workflow.tasks) == 1 assert len(workflow.tasks) == 1
assert workflow.tasks[0].description == task2_description assert workflow.tasks[0].description == task2_description
def test_real_world_usage_with_environment_variables(): def test_real_world_usage_with_environment_variables():
# Ensure that the OpenAI API key is set using environment variables # Ensure that the OpenAI API key is set using environment variables
assert "OPENAI_API_KEY" in os.environ assert "OPENAI_API_KEY" in os.environ
assert os.environ["OPENAI_API_KEY"] == "mocked_api_key" assert os.environ["OPENAI_API_KEY"] == "mocked_api_key"
del os.environ["OPENAI_API_KEY"] # Clean up after the test del os.environ["OPENAI_API_KEY"] # Clean up after the test
def test_real_world_usage_no_openai_key(): def test_real_world_usage_no_openai_key():
# Ensure that an exception is raised when the OpenAI API key is not set # Ensure that an exception is raised when the OpenAI API key is not set
with pytest.raises(ValueError): with pytest.raises(ValueError):

Loading…
Cancel
Save