removed open interpreter, clean uped docs, added add messages to flow + utils

pull/93/head^2
Kye 1 year ago
parent 62a413579c
commit 16176e8cad

@ -53,7 +53,7 @@ The `BaseChunker` class is the core component of the `BaseChunker` module. It is
#### Parameters: #### Parameters:
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the text into chunks. - `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the text into chunks.
- `tokenizer` (OpenAiTokenizer): Defines the tokenizer to be used for counting tokens in the text. - `tokenizer` (OpenAITokenizer): Defines the tokenizer to be used for counting tokens in the text.
- `max_tokens` (int): Sets the maximum token limit for each chunk. - `max_tokens` (int): Sets the maximum token limit for each chunk.
### 4.2. Examples <a name="examples"></a> ### 4.2. Examples <a name="examples"></a>

@ -52,7 +52,7 @@ The `PdfChunker` class is the core component of the `PdfChunker` module. It is u
#### Parameters: #### Parameters:
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks. - `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks.
- `tokenizer` (OpenAiTokenizer): Defines the tokenizer used for counting tokens in the text. - `tokenizer` (OpenAITokenizer): Defines the tokenizer used for counting tokens in the text.
- `max_tokens` (int): Sets the maximum token limit for each chunk. - `max_tokens` (int): Sets the maximum token limit for each chunk.
### 4.2. Examples <a name="examples"></a> ### 4.2. Examples <a name="examples"></a>

@ -29,7 +29,9 @@ flow = Flow(
# out = flow.load_state("flow_state.json") # out = flow.load_state("flow_state.json")
# temp = flow.dynamic_temperature() # temp = flow.dynamic_temperature()
# filter = flow.add_response_filter("Trump") # filter = flow.add_response_filter("Trump")
out = flow.run("Generate a 10,000 word blog on mental clarity and the benefits of meditation.") out = flow.run(
"Generate a 10,000 word blog on mental clarity and the benefits of meditation."
)
# out = flow.validate_response(out) # out = flow.validate_response(out)
# out = flow.analyze_feedback(out) # out = flow.analyze_feedback(out)
# out = flow.print_history_and_memory() # out = flow.print_history_and_memory()

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "2.0.1" version = "2.0.2"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]
@ -41,6 +41,7 @@ sentencepiece = "*"
wget = "*" wget = "*"
griptape = "*" griptape = "*"
httpx = "*" httpx = "*"
tiktoken = "*"
attrs = "*" attrs = "*"
ggl = "*" ggl = "*"
beautifulsoup4 = "*" beautifulsoup4 = "*"
@ -49,7 +50,6 @@ pydantic = "*"
tenacity = "*" tenacity = "*"
Pillow = "*" Pillow = "*"
chromadb = "*" chromadb = "*"
open-interpreter = "*"
tabulate = "*" tabulate = "*"
termcolor = "*" termcolor = "*"
black = "*" black = "*"

@ -29,6 +29,7 @@ sentencepiece
duckduckgo-search duckduckgo-search
agent-protocol agent-protocol
chromadb chromadb
tiktoken
open-interpreter open-interpreter
tabulate tabulate
colored colored

@ -5,6 +5,7 @@ from swarms.agents.message import Message
# from swarms.agents.stream_response import stream # from swarms.agents.stream_response import stream
from swarms.agents.base import AbstractAgent from swarms.agents.base import AbstractAgent
from swarms.agents.registry import Registry from swarms.agents.registry import Registry
# from swarms.agents.idea_to_image_agent import Idea2Image # from swarms.agents.idea_to_image_agent import Idea2Image
from swarms.agents.simple_agent import SimpleAgent from swarms.agents.simple_agent import SimpleAgent

@ -0,0 +1,4 @@
"""
Companion agents converse with the user about the agent the user wants to create then creates the agent with the desired attributes and traits and tools and configurations
"""

@ -16,7 +16,6 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarms.prompts.sales import SALES_AGENT_TOOLS_PROMPT, conversation_stages from swarms.prompts.sales import SALES_AGENT_TOOLS_PROMPT, conversation_stages
from swarms.tools.interpreter_tool import compile
# classes # classes
@ -166,12 +165,7 @@ def get_tools(product_catalog):
func=knowledge_base.run, func=knowledge_base.run,
description="useful for when you need to answer questions about product information", description="useful for when you need to answer questions about product information",
), ),
# Interpreter
Tool(
name="Code Interepeter",
func=compile,
description="Useful when you need to run code locally, such as Python, Javascript, Shell, and more.",
)
# omnimodal agent # omnimodal agent
] ]

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
from typing import Optional from typing import Optional
from attr import define, field, Factory
from attr import Factory, define, field
from griptape.artifacts import TextArtifact from griptape.artifacts import TextArtifact
from swarms.chunkers.chunk_seperators import ChunkSeparator
from griptape.tokenizers import OpenAiTokenizer from swarms.chunkers.chunk_seperator import ChunkSeparator
from swarms.models.openai_tokenizer import OpenAITokenizer
@define @define
@ -16,6 +19,24 @@ class BaseChunker(ABC):
Usage: Usage:
-------------- --------------
from swarms.chunkers.base import BaseChunker
from swarms.chunkers.chunk_seperator import ChunkSeparator
class PdfChunker(BaseChunker):
DEFAULT_SEPARATORS = [
ChunkSeparator("\n\n"),
ChunkSeparator(". "),
ChunkSeparator("! "),
ChunkSeparator("? "),
ChunkSeparator(" "),
]
# Example
pdf = "swarmdeck.pdf"
chunker = PdfChunker()
chunks = chunker.chunk(pdf)
print(chunks)
""" """
@ -26,10 +47,10 @@ class BaseChunker(ABC):
default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True), default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True),
kw_only=True, kw_only=True,
) )
tokenizer: OpenAiTokenizer = field( tokenizer: OpenAITokenizer = field(
default=Factory( default=Factory(
lambda: OpenAiTokenizer( lambda: OpenAITokenizer(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
) )
), ),
kw_only=True, kw_only=True,
@ -47,7 +68,7 @@ class BaseChunker(ABC):
def _chunk_recursively( def _chunk_recursively(
self, chunk: str, current_separator: Optional[ChunkSeparator] = None self, chunk: str, current_separator: Optional[ChunkSeparator] = None
) -> list[str]: ) -> list[str]:
token_count = self.tokenizer.token_count(chunk) token_count = self.tokenizer.count_tokens(chunk)
if token_count <= self.max_tokens: if token_count <= self.max_tokens:
return [chunk] return [chunk]

@ -15,3 +15,10 @@ class MarkdownChunker(BaseChunker):
ChunkSeparator("? "), ChunkSeparator("? "),
ChunkSeparator(" "), ChunkSeparator(" "),
] ]
# # Example using chunker to chunk a markdown file
# file = open("README.md", "r")
# text = file.read()
# chunker = MarkdownChunker()
# chunks = chunker.chunk(text)

@ -0,0 +1,124 @@
"""
Omni Chunker is a chunker that chunks all files into select chunks of size x strings
Usage:
--------------
from swarms.chunkers.omni_chunker import OmniChunker
# Example
pdf = "swarmdeck.pdf"
chunker = OmniChunker(chunk_size=1000, beautify=True)
chunks = chunker(pdf)
print(chunks)
"""
from dataclasses import dataclass
from typing import List, Optional, Callable
from termcolor import colored
import os
import sys
@dataclass
class OmniChunker:
"""
"""
chunk_size: int = 1000
beautify: bool = False
use_tokenizer: bool = False
tokenizer: Optional[Callable[[str], List[str]]] = None
def __call__(self, file_path: str) -> List[str]:
"""
Chunk the given file into parts of size `chunk_size`.
Args:
file_path (str): The path to the file to chunk.
Returns:
List[str]: A list of string chunks from the file.
"""
if not os.path.isfile(file_path):
print(colored("The file does not exist.", "red"))
return []
file_extension = os.path.splitext(file_path)[1]
try:
with open(file_path, "rb") as file:
content = file.read()
# Decode content based on MIME type or file extension
decoded_content = self.decode_content(content, file_extension)
chunks = self.chunk_content(decoded_content)
return chunks
except Exception as e:
print(colored(f"Error reading file: {e}", "red"))
return []
def decode_content(self, content: bytes, file_extension: str) -> str:
"""
Decode the content of the file based on its MIME type or file extension.
Args:
content (bytes): The content of the file.
file_extension (str): The file extension of the file.
Returns:
str: The decoded content of the file.
"""
# Add logic to handle different file types based on the extension
# For simplicity, this example assumes text files encoded in utf-8
try:
return content.decode("utf-8")
except UnicodeDecodeError as e:
print(
colored(
f"Could not decode file with extension {file_extension}: {e}",
"yellow",
)
)
return ""
def chunk_content(self, content: str) -> List[str]:
"""
Split the content into chunks of size `chunk_size`.
Args:
content (str): The content to chunk.
Returns:
List[str]: The list of chunks.
"""
return [
content[i : i + self.chunk_size]
for i in range(0, len(content), self.chunk_size)
]
def __str__(self):
return f"OmniChunker(chunk_size={self.chunk_size}, beautify={self.beautify})"
def metrics(self):
return {
"chunk_size": self.chunk_size,
"beautify": self.beautify,
}
def print_dashboard(self):
print(
colored(
f"""
Omni Chunker
------------
{self.metrics()}
""",
"cyan",
)
)

@ -10,3 +10,10 @@ class PdfChunker(BaseChunker):
ChunkSeparator("? "), ChunkSeparator("? "),
ChunkSeparator(" "), ChunkSeparator(" "),
] ]
# # Example
# pdf = "swarmdeck.pdf"
# chunker = PdfChunker()
# chunks = chunker.chunk(pdf)
# print(chunks)

@ -16,6 +16,7 @@ from swarms.models.kosmos_two import Kosmos
from swarms.models.vilt import Vilt from swarms.models.vilt import Vilt
from swarms.models.nougat import Nougat from swarms.models.nougat import Nougat
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA
# from swarms.models.gpt4v import GPT4Vision # from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3

@ -44,7 +44,7 @@ class Anthropic:
top_p=None, top_p=None,
streaming=False, streaming=False,
default_request_timeout=None, default_request_timeout=None,
api_key: str = None api_key: str = None,
): ):
self.model = model self.model = model
self.max_tokens_to_sample = max_tokens_to_sample self.max_tokens_to_sample = max_tokens_to_sample

@ -129,7 +129,7 @@ class Dalle3:
) )
) )
raise error raise error
def create_variations(self, img: str): def create_variations(self, img: str):
""" """
Create variations of an image using the Dalle3 API Create variations of an image using the Dalle3 API
@ -151,14 +151,11 @@ class Dalle3:
>>> img = dalle3.create_variations(img) >>> img = dalle3.create_variations(img)
>>> print(img) >>> print(img)
""" """
try: try:
response = self.client.images.create_variation( response = self.client.images.create_variation(
img = open(img, "rb"), img=open(img, "rb"), n=self.n, size=self.size
n=self.n,
size=self.size
) )
img = response.data[0].url img = response.data[0].url
@ -172,4 +169,4 @@ class Dalle3:
) )
print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error raise error

@ -74,7 +74,9 @@ class HuggingfaceLLM:
bnb_config = BitsAndBytesConfig(**quantization_config) bnb_config = BitsAndBytesConfig(**quantization_config)
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, *args, **kwargs) self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, *args, **kwargs
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config, *args, **kwargs self.model_id, quantization_config=bnb_config, *args, **kwargs
) )
@ -162,7 +164,12 @@ class HuggingfaceLLM:
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e: except Exception as e:
print(colored(f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments", "red")) print(
colored(
f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments",
"red",
)
)
raise raise
async def run_async(self, task: str, *args, **kwargs) -> str: async def run_async(self, task: str, *args, **kwargs) -> str:

@ -0,0 +1,74 @@
from typing import Dict, List, Optional
from dataclass import dataclass
from swarms.models import OpenAI
@dataclass
class OpenAIAssistant:
name: str = "OpenAI Assistant"
instructions: str = None
tools: List[Dict] = None
model: str = None
openai_api_key: str = None
temperature: float = 0.5
max_tokens: int = 100
stop: List[str] = None
echo: bool = False
stream: bool = False
log: bool = False
presence: bool = False
dashboard: bool = False
debug: bool = False
max_loops: int = 5
stopping_condition: Optional[str] = None
loop_interval: int = 1
retry_attempts: int = 3
retry_interval: int = 1
interactive: bool = False
dynamic_temperature: bool = False
state: Dict = None
response_filters: List = None
response_filter: Dict = None
response_filter_name: str = None
response_filter_value: str = None
response_filter_type: str = None
response_filter_action: str = None
response_filter_action_value: str = None
response_filter_action_type: str = None
response_filter_action_name: str = None
client = OpenAI()
role: str = "user"
instructions: str = None
def create_assistant(self, task: str):
assistant = self.client.create_assistant(
name=self.name,
instructions=self.instructions,
tools=self.tools,
model=self.model,
)
return assistant
def create_thread(self):
thread = self.client.beta.threads.create()
return thread
def add_message_to_thread(self, thread_id: str, message: str):
message = self.client.beta.threads.add_message(
thread_id=thread_id, role=self.user, content=message
)
return message
def run(self, task: str):
run = self.client.beta.threads.runs.create(
thread_id=self.create_thread().id,
assistant_id=self.create_assistant().id,
instructions=self.instructions,
)
out = self.client.beta.threads.runs.retrieve(
thread_id=run.thread_id, run_id=run.id
)
return out

@ -0,0 +1,150 @@
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Optional
import tiktoken
from attr import Factory, define, field
@define(frozen=True)
class BaseTokenizer(ABC):
DEFAULT_STOP_SEQUENCES = ["Observation:"]
stop_sequences: list[str] = field(
default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES),
kw_only=True,
)
@property
@abstractmethod
def max_tokens(self) -> int:
...
def count_tokens_left(self, text: str) -> int:
diff = self.max_tokens - self.count_tokens(text)
if diff > 0:
return diff
else:
return 0
@abstractmethod
def count_tokens(self, text: str) -> int:
...
@define(frozen=True)
class OpenAITokenizer(BaseTokenizer):
DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003"
DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo"
DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4"
DEFAULT_ENCODING = "cl100k_base"
DEFAULT_MAX_TOKENS = 2049
TOKEN_OFFSET = 8
MODEL_PREFIXES_TO_MAX_TOKENS = {
"gpt-4-32k": 32768,
"gpt-4": 8192,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-35-turbo": 4096,
"text-davinci-003": 4097,
"text-davinci-002": 4097,
"code-davinci-002": 8001,
"text-embedding-ada-002": 8191,
"text-embedding-ada-001": 2046,
}
EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"]
model: str = field(kw_only=True)
@property
def encoding(self) -> tiktoken.Encoding:
try:
return tiktoken.encoding_for_model(self.model)
except KeyError:
return tiktoken.get_encoding(self.DEFAULT_ENCODING)
@property
def max_tokens(self) -> int:
tokens = next(
v
for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
if self.model.startswith(k)
)
offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
def count_tokens(
self, text: str | list, model: Optional[str] = None
) -> int:
"""
Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if isinstance(text, list):
model = model if model else self.model
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logging.warning("model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
# every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logging.info(
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
)
return self.count_tokens(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logging.info(
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return self.count_tokens(text, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""token_count() is not implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for
information on how messages are converted to tokens."""
)
num_tokens = 0
for message in text:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <|start|>assistant<|message|>
num_tokens += 3
return num_tokens
else:
return len(
self.encoding.encode(
text, allowed_special=set(self.stop_sequences)
)
)

@ -116,6 +116,7 @@ class Flow:
dynamic_temperature: bool = False, dynamic_temperature: bool = False,
saved_state_path: Optional[str] = "flow_state.json", saved_state_path: Optional[str] = "flow_state.json",
autosave: bool = False, autosave: bool = False,
context_length: int = 8192,
**kwargs: Any, **kwargs: Any,
): ):
self.llm = llm self.llm = llm
@ -188,6 +189,26 @@ class Flow:
return "\n".join(params_str_list) return "\n".join(params_str_list)
def truncate_history(self):
"""
Take the history and truncate it to fit into the model context length
"""
truncated_history = self.memory[-1][-self.context_length :]
self.memory[-1] = truncated_history
def add_task_to_memory(self, task: str):
"""Add the task to the memory"""
self.memory.append([f"Human: {task}"])
def add_message_to_memory(self, message: str):
"""Add the message to the memory"""
self.memory[-1].append(message)
def add_message_to_memory_and_truncate(self, message: str):
"""Add the message to the memory and truncate"""
self.memory[-1].append(message)
self.truncate_history()
def print_dashboard(self, task: str): def print_dashboard(self, task: str):
"""Print dashboard""" """Print dashboard"""
model_config = self.get_llm_init_params() model_config = self.get_llm_init_params()

@ -1,24 +0,0 @@
import os
import interpreter
def compile(task: str):
"""
Open Interpreter lets LLMs run code (Python, Javascript, Shell, and more) locally. You can chat with Open Interpreter through a ChatGPT-like interface in your terminal by running $ interpreter after installing.
This provides a natural-language interface to your computer's general-purpose capabilities:
Create and edit photos, videos, PDFs, etc.
Control a Chrome browser to perform research
Plot, clean, and analyze large datasets
...etc.
Note: You'll be asked to approve code before it's run.
"""
task = interpreter.chat(task, return_messages=True)
interpreter.chat()
interpreter.reset(task)
os.environ["INTERPRETER_CLI_AUTO_RUN"] = True
os.environ["INTERPRETER_CLI_FAST_MODE"] = True
os.environ["INTERPRETER_CLI_DEBUG"] = True

@ -1,2 +1,2 @@
from swarms.workers.worker import Worker # from swarms.workers.worker import Worker
from swarms.workers.base import AbstractWorker from swarms.workers.base import AbstractWorker

@ -3,7 +3,7 @@ from swarms.chunkers.base import (
BaseChunker, BaseChunker,
TextArtifact, TextArtifact,
ChunkSeparator, ChunkSeparator,
OpenAiTokenizer, OpenAITokenizer,
) # adjust the import paths accordingly ) # adjust the import paths accordingly
@ -21,7 +21,7 @@ def test_default_separators():
def test_default_tokenizer(): def test_default_tokenizer():
chunker = BaseChunker() chunker = BaseChunker()
assert isinstance(chunker.tokenizer, OpenAiTokenizer) assert isinstance(chunker.tokenizer, OpenAITokenizer)
# 2. Test Basic Chunking # 2. Test Basic Chunking

@ -23,8 +23,12 @@ def dalle3(mock_openai_client):
def test_dalle3_call_success(dalle3, mock_openai_client): def test_dalle3_call_success(dalle3, mock_openai_client):
# Arrange # Arrange
task = "A painting of a dog" task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_img_url = (
mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
)
mock_openai_client.images.generate.return_value = Mock(
data=[Mock(url=expected_img_url)]
)
# Act # Act
img_url = dalle3(task) img_url = dalle3(task)
@ -40,7 +44,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error" expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError # Mocking OpenAIError
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") mock_openai_client.images.generate.side_effect = OpenAIError(
expected_error_message, http_status=500, error="Internal Server Error"
)
# Act and assert # Act and assert
with pytest.raises(OpenAIError) as excinfo: with pytest.raises(OpenAIError) as excinfo:
@ -57,8 +63,12 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
def test_dalle3_create_variations_success(dalle3, mock_openai_client): def test_dalle3_create_variations_success(dalle3, mock_openai_client):
# Arrange # Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" expected_variation_url = (
mock_openai_client.images.create_variation.return_value = Mock(data=[Mock(url=expected_variation_url)]) "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
)
mock_openai_client.images.create_variation.return_value = Mock(
data=[Mock(url=expected_variation_url)]
)
# Act # Act
variation_img_url = dalle3.create_variations(img_url) variation_img_url = dalle3.create_variations(img_url)
@ -78,7 +88,9 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error" expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError # Mocking OpenAIError
mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") mock_openai_client.images.create_variation.side_effect = OpenAIError(
expected_error_message, http_status=500, error="Internal Server Error"
)
# Act and assert # Act and assert
with pytest.raises(OpenAIError) as excinfo: with pytest.raises(OpenAIError) as excinfo:
@ -86,7 +98,7 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
assert str(excinfo.value) == expected_error_message assert str(excinfo.value) == expected_error_message
mock_openai_client.images.create_variation.assert_called_once() mock_openai_client.images.create_variation.assert_called_once()
# Ensure the error message is printed in red # Ensure the error message is printed in red
captured = capsys.readouterr() captured = capsys.readouterr()
assert colored(expected_error_message, "red") in captured.out assert colored(expected_error_message, "red") in captured.out
@ -142,8 +154,12 @@ def test_dalle3_convert_to_bytesio():
def test_dalle3_call_multiple_times(dalle3, mock_openai_client): def test_dalle3_call_multiple_times(dalle3, mock_openai_client):
# Arrange # Arrange
task = "A painting of a dog" task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_img_url = (
mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
)
mock_openai_client.images.generate.return_value = Mock(
data=[Mock(url=expected_img_url)]
)
# Act # Act
img_url1 = dalle3(task) img_url1 = dalle3(task)
@ -159,7 +175,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client):
# Arrange # Arrange
task = "A" * 2048 # Input longer than API's limit task = "A" * 2048 # Input longer than API's limit
expected_error_message = "Error running Dalle3: API Error" expected_error_message = "Error running Dalle3: API Error"
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") mock_openai_client.images.generate.side_effect = OpenAIError(
expected_error_message, http_status=500, error="Internal Server Error"
)
# Act and assert # Act and assert
with pytest.raises(OpenAIError) as excinfo: with pytest.raises(OpenAIError) as excinfo:
@ -204,7 +222,9 @@ def test_dalle3_convert_to_bytesio_invalid_format(dalle3):
def test_dalle3_call_with_retry(dalle3, mock_openai_client): def test_dalle3_call_with_retry(dalle3, mock_openai_client):
# Arrange # Arrange
task = "A painting of a dog" task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_img_url = (
"https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
)
# Simulate a retry scenario # Simulate a retry scenario
mock_openai_client.images.generate.side_effect = [ mock_openai_client.images.generate.side_effect = [
@ -223,7 +243,9 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client):
def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client):
# Arrange # Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" expected_variation_url = (
"https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
)
# Simulate a retry scenario # Simulate a retry scenario
mock_openai_client.images.create_variation.side_effect = [ mock_openai_client.images.create_variation.side_effect = [
@ -245,7 +267,9 @@ def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error" expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError # Mocking OpenAIError
mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") mock_openai_client.images.generate.side_effect = OpenAIError(
expected_error_message, http_status=500, error="Internal Server Error"
)
# Act # Act
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
@ -262,7 +286,9 @@ def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client,
expected_error_message = "Error running Dalle3: API Error" expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError # Mocking OpenAIError
mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") mock_openai_client.images.create_variation.side_effect = OpenAIError(
expected_error_message, http_status=500, error="Internal Server Error"
)
# Act # Act
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
@ -313,7 +339,9 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client)
task = "A painting of a dog" task = "A painting of a dog"
# Simulate max retries exceeded # Simulate max retries exceeded
mock_openai_client.images.generate.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") mock_openai_client.images.generate.side_effect = OpenAIError(
"Temporary error", http_status=500, error="Internal Server Error"
)
# Act and assert # Act and assert
with pytest.raises(OpenAIError) as excinfo: with pytest.raises(OpenAIError) as excinfo:
@ -322,12 +350,16 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client)
assert "Retry limit exceeded" in str(excinfo.value) assert "Retry limit exceeded" in str(excinfo.value)
def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_openai_client): def test_dalle3_create_variations_with_retry_max_retries_exceeded(
dalle3, mock_openai_client
):
# Arrange # Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
# Simulate max retries exceeded # Simulate max retries exceeded
mock_openai_client.images.create_variation.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") mock_openai_client.images.create_variation.side_effect = OpenAIError(
"Temporary error", http_status=500, error="Internal Server Error"
)
# Act and assert # Act and assert
with pytest.raises(OpenAIError) as excinfo: with pytest.raises(OpenAIError) as excinfo:
@ -339,7 +371,9 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_o
def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
# Arrange # Arrange
task = "A painting of a dog" task = "A painting of a dog"
expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" expected_img_url = (
"https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
)
# Simulate success after a retry # Simulate success after a retry
mock_openai_client.images.generate.side_effect = [ mock_openai_client.images.generate.side_effect = [
@ -358,7 +392,9 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client): def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client):
# Arrange # Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" expected_variation_url = (
"https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
)
# Simulate success after a retry # Simulate success after a retry
mock_openai_client.images.create_variation.side_effect = [ mock_openai_client.images.create_variation.side_effect = [

@ -12,19 +12,22 @@ load_dotenv
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
# Mock the OpenAI client # Mock the OpenAI client
@pytest.fixture @pytest.fixture
def mock_openai_client(): def mock_openai_client():
return Mock() return Mock()
@pytest.fixture @pytest.fixture
def gpt4vision(mock_openai_client): def gpt4vision(mock_openai_client):
return GPT4Vision(client=mock_openai_client) return GPT4Vision(client=mock_openai_client)
def test_gpt4vision_default_values(): def test_gpt4vision_default_values():
# Arrange and Act # Arrange and Act
gpt4vision = GPT4Vision() gpt4vision = GPT4Vision()
# Assert # Assert
assert gpt4vision.max_retries == 3 assert gpt4vision.max_retries == 3
assert gpt4vision.model == "gpt-4-vision-preview" assert gpt4vision.model == "gpt-4-vision-preview"
@ -34,59 +37,68 @@ def test_gpt4vision_default_values():
assert gpt4vision.quality == "low" assert gpt4vision.quality == "low"
assert gpt4vision.max_tokens == 200 assert gpt4vision.max_tokens == 200
def test_gpt4vision_api_key_from_env_variable(): def test_gpt4vision_api_key_from_env_variable():
# Arrange # Arrange
api_key = os.environ["OPENAI_API_KEY"] api_key = os.environ["OPENAI_API_KEY"]
# Act # Act
gpt4vision = GPT4Vision() gpt4vision = GPT4Vision()
# Assert # Assert
assert gpt4vision.api_key == api_key assert gpt4vision.api_key == api_key
def test_gpt4vision_set_api_key(): def test_gpt4vision_set_api_key():
# Arrange # Arrange
gpt4vision = GPT4Vision(api_key=api_key) gpt4vision = GPT4Vision(api_key=api_key)
# Assert # Assert
assert gpt4vision.api_key == api_key assert gpt4vision.api_key == api_key
def test_gpt4vision_invalid_max_retries(): def test_gpt4vision_invalid_max_retries():
# Arrange and Act # Arrange and Act
with pytest.raises(ValueError): with pytest.raises(ValueError):
GPT4Vision(max_retries=-1) GPT4Vision(max_retries=-1)
def test_gpt4vision_invalid_backoff_factor(): def test_gpt4vision_invalid_backoff_factor():
# Arrange and Act # Arrange and Act
with pytest.raises(ValueError): with pytest.raises(ValueError):
GPT4Vision(backoff_factor=-1) GPT4Vision(backoff_factor=-1)
def test_gpt4vision_invalid_timeout_seconds(): def test_gpt4vision_invalid_timeout_seconds():
# Arrange and Act # Arrange and Act
with pytest.raises(ValueError): with pytest.raises(ValueError):
GPT4Vision(timeout_seconds=-1) GPT4Vision(timeout_seconds=-1)
def test_gpt4vision_invalid_max_tokens(): def test_gpt4vision_invalid_max_tokens():
# Arrange and Act # Arrange and Act
with pytest.raises(ValueError): with pytest.raises(ValueError):
GPT4Vision(max_tokens=-1) GPT4Vision(max_tokens=-1)
def test_gpt4vision_logger_initialized(): def test_gpt4vision_logger_initialized():
# Arrange # Arrange
gpt4vision = GPT4Vision() gpt4vision = GPT4Vision()
# Assert # Assert
assert isinstance(gpt4vision.logger, logging.Logger) assert isinstance(gpt4vision.logger, logging.Logger)
def test_gpt4vision_process_img_nonexistent_file(): def test_gpt4vision_process_img_nonexistent_file():
# Arrange # Arrange
gpt4vision = GPT4Vision() gpt4vision = GPT4Vision()
img_path = "nonexistent_image.jpg" img_path = "nonexistent_image.jpg"
# Act and Assert # Act and Assert
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
gpt4vision.process_img(img_path) gpt4vision.process_img(img_path)
def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision): def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
# Arrange # Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
@ -96,7 +108,10 @@ def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
gpt4vision(img_url, [task]) gpt4vision(img_url, [task])
def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, mock_openai_client):
def test_gpt4vision_call_single_task_single_image_empty_response(
gpt4vision, mock_openai_client
):
# Arrange # Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image." task = "Describe this image."
@ -110,7 +125,10 @@ def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, moc
assert response.answer == "" assert response.answer == ""
mock_openai_client.chat.completions.create.assert_called_once() mock_openai_client.chat.completions.create.assert_called_once()
def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, mock_openai_client):
def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(
gpt4vision, mock_openai_client
):
# Arrange # Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"] tasks = ["Describe this image.", "What's in this picture?"]
@ -122,20 +140,30 @@ def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision,
# Assert # Assert
assert all(response.answer == "" for response in responses) assert all(response.answer == "" for response in responses)
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once assert (
mock_openai_client.chat.completions.create.call_count == 1
) # Should be called only once
def test_gpt4vision_call_single_task_single_image_timeout(gpt4vision, mock_openai_client): def test_gpt4vision_call_single_task_single_image_timeout(
gpt4vision, mock_openai_client
):
# Arrange # Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image." task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = Timeout("Request timed out") mock_openai_client.chat.completions.create.side_effect = Timeout(
"Request timed out"
)
# Act and Assert # Act and Assert
with pytest.raises(Timeout): with pytest.raises(Timeout):
gpt4vision(img_url, [task]) gpt4vision(img_url, [task])
def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_openai_client):
def test_gpt4vision_call_retry_with_success_after_timeout(
gpt4vision, mock_openai_client
):
# Arrange # Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image." task = "Describe this image."
@ -143,7 +171,11 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena
# Simulate success after a timeout and retry # Simulate success after a timeout and retry
mock_openai_client.chat.completions.create.side_effect = [ mock_openai_client.chat.completions.create.side_effect = [
Timeout("Request timed out"), Timeout("Request timed out"),
{"choices": [{"message": {"content": {"text": "A description of the image."}}}],} {
"choices": [
{"message": {"content": {"text": "A description of the image."}}}
],
},
] ]
# Act # Act
@ -151,7 +183,9 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena
# Assert # Assert
assert response.answer == "A description of the image." assert response.answer == "A description of the image."
assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice assert (
mock_openai_client.chat.completions.create.call_count == 2
) # Should be called twice
def test_gpt4vision_process_img(): def test_gpt4vision_process_img():
@ -173,7 +207,9 @@ def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client
expected_response = GPT4VisionResponse(answer="A description of the image.") expected_response = GPT4VisionResponse(answer="A description of the image.")
mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer mock_openai_client.chat.completions.create.return_value.choices[
0
].text = expected_response.answer
# Act # Act
response = gpt4vision(img_url, [task]) response = gpt4vision(img_url, [task])
@ -190,7 +226,9 @@ def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_cli
expected_response = GPT4VisionResponse(answer="Descriptions of the images.") expected_response = GPT4VisionResponse(answer="Descriptions of the images.")
mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer mock_openai_client.chat.completions.create.return_value.choices[
0
].text = expected_response.answer
# Act # Act
response = gpt4vision(img_urls, [task]) response = gpt4vision(img_urls, [task])
@ -213,57 +251,76 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli
def create_mock_response(response): def create_mock_response(response):
return {"choices": [{"message": {"content": {"text": response.answer}}}]} return {"choices": [{"message": {"content": {"text": response.answer}}}]}
mock_openai_client.chat.completions.create.side_effect = [create_mock_response(response) for response in expected_responses] mock_openai_client.chat.completions.create.side_effect = [
create_mock_response(response) for response in expected_responses
]
# Act # Act
responses = gpt4vision(img_url, tasks) responses = gpt4vision(img_url, tasks)
# Assert # Assert
assert responses == expected_responses assert responses == expected_responses
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once assert (
def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): mock_openai_client.chat.completions.create.call_count == 1
# Arrange ) # Should be called only once
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"] def test_gpt4vision_call_multiple_tasks_single_image(
gpt4vision, mock_openai_client
expected_responses = [ ):
GPT4VisionResponse(answer="A description of the image."), # Arrange
GPT4VisionResponse(answer="It contains various objects."), img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"]
expected_responses = [
GPT4VisionResponse(answer="A description of the image."),
GPT4VisionResponse(answer="It contains various objects."),
]
mock_openai_client.chat.completions.create.side_effect = [
{
"choices": [
{"message": {"content": {"text": expected_responses[i].answer}}}
] ]
}
for i in range(len(expected_responses))
]
mock_openai_client.chat.completions.create.side_effect = [ # Act
{"choices": [{"message": {"content": {"text": expected_responses[i].answer}}}] } for i in range(len(expected_responses)) responses = gpt4vision(img_url, tasks)
]
# Act # Assert
responses = gpt4vision(img_url, tasks) assert responses == expected_responses
assert (
# Assert mock_openai_client.chat.completions.create.call_count == 1
assert responses == expected_responses ) # Should be called only once
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client): def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client):
# Arrange # Arrange
img_urls = ["https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"] img_urls = [
"https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
"https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
]
tasks = ["Describe these images.", "What's in these pictures?"] tasks = ["Describe these images.", "What's in these pictures?"]
expected_responses = [ expected_responses = [
GPT4VisionResponse(answer="Descriptions of the images."), GPT4VisionResponse(answer="Descriptions of the images."),
GPT4VisionResponse(answer="They contain various objects.") GPT4VisionResponse(answer="They contain various objects."),
] ]
mock_openai_client.chat.completions.create.side_effect = [ mock_openai_client.chat.completions.create.side_effect = [
{"choices": [{"message": {"content": {"text": response.answer}}}] } for response in expected_responses {"choices": [{"message": {"content": {"text": response.answer}}}]}
for response in expected_responses
] ]
# Act # Act
responses = gpt4vision(img_urls, tasks) responses = gpt4vision(img_urls, tasks)
# Assert # Assert
assert responses == expected_responses assert responses == expected_responses
assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once assert (
mock_openai_client.chat.completions.create.call_count == 1
) # Should be called only once
def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client):
@ -283,7 +340,9 @@ def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client):
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image." task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = RequestException("Request Error") mock_openai_client.chat.completions.create.side_effect = RequestException(
"Request Error"
)
# Act and Assert # Act and Assert
with pytest.raises(RequestException): with pytest.raises(RequestException):
@ -295,7 +354,9 @@ def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client):
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image." task = "Describe this image."
mock_openai_client.chat.completions.create.side_effect = ConnectionError("Connection Error") mock_openai_client.chat.completions.create.side_effect = ConnectionError(
"Connection Error"
)
# Act and Assert # Act and Assert
with pytest.raises(ConnectionError): with pytest.raises(ConnectionError):
@ -310,7 +371,9 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
# Simulate success after a retry # Simulate success after a retry
mock_openai_client.chat.completions.create.side_effect = [ mock_openai_client.chat.completions.create.side_effect = [
RequestException("Temporary error"), RequestException("Temporary error"),
{"choices": [{"text": "A description of the image."}]} # fixed dictionary syntax {
"choices": [{"text": "A description of the image."}]
}, # fixed dictionary syntax
] ]
# Act # Act
@ -318,4 +381,6 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
# Assert # Assert
assert response.answer == "A description of the image." assert response.answer == "A description of the image."
assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice assert (
mock_openai_client.chat.completions.create.call_count == 2
) # Should be called twice

Loading…
Cancel
Save