diff --git a/docs/swarms/chunkers/basechunker.md b/docs/swarms/chunkers/basechunker.md index fed03277..33b03312 100644 --- a/docs/swarms/chunkers/basechunker.md +++ b/docs/swarms/chunkers/basechunker.md @@ -53,7 +53,7 @@ The `BaseChunker` class is the core component of the `BaseChunker` module. It is #### Parameters: - `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the text into chunks. -- `tokenizer` (OpenAiTokenizer): Defines the tokenizer to be used for counting tokens in the text. +- `tokenizer` (OpenAITokenizer): Defines the tokenizer to be used for counting tokens in the text. - `max_tokens` (int): Sets the maximum token limit for each chunk. ### 4.2. Examples diff --git a/docs/swarms/chunkers/pdf_chunker.md b/docs/swarms/chunkers/pdf_chunker.md index 5b97a551..8c92060d 100644 --- a/docs/swarms/chunkers/pdf_chunker.md +++ b/docs/swarms/chunkers/pdf_chunker.md @@ -52,7 +52,7 @@ The `PdfChunker` class is the core component of the `PdfChunker` module. It is u #### Parameters: - `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks. -- `tokenizer` (OpenAiTokenizer): Defines the tokenizer used for counting tokens in the text. +- `tokenizer` (OpenAITokenizer): Defines the tokenizer used for counting tokens in the text. - `max_tokens` (int): Sets the maximum token limit for each chunk. ### 4.2. Examples diff --git a/example.py b/example.py index b3740aa2..6c27bceb 100644 --- a/example.py +++ b/example.py @@ -29,7 +29,9 @@ flow = Flow( # out = flow.load_state("flow_state.json") # temp = flow.dynamic_temperature() # filter = flow.add_response_filter("Trump") -out = flow.run("Generate a 10,000 word blog on mental clarity and the benefits of meditation.") +out = flow.run( + "Generate a 10,000 word blog on mental clarity and the benefits of meditation." +) # out = flow.validate_response(out) # out = flow.analyze_feedback(out) # out = flow.print_history_and_memory() diff --git a/pyproject.toml b/pyproject.toml index a80b6389..3cb153c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.0.1" +version = "2.0.2" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -41,6 +41,7 @@ sentencepiece = "*" wget = "*" griptape = "*" httpx = "*" +tiktoken = "*" attrs = "*" ggl = "*" beautifulsoup4 = "*" @@ -49,7 +50,6 @@ pydantic = "*" tenacity = "*" Pillow = "*" chromadb = "*" -open-interpreter = "*" tabulate = "*" termcolor = "*" black = "*" diff --git a/requirements.txt b/requirements.txt index 7ff9d362..cb0c65b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ sentencepiece duckduckgo-search agent-protocol chromadb +tiktoken open-interpreter tabulate colored diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py index 34dc0f1d..355f0ad1 100644 --- a/swarms/agents/__init__.py +++ b/swarms/agents/__init__.py @@ -5,6 +5,7 @@ from swarms.agents.message import Message # from swarms.agents.stream_response import stream from swarms.agents.base import AbstractAgent from swarms.agents.registry import Registry + # from swarms.agents.idea_to_image_agent import Idea2Image from swarms.agents.simple_agent import SimpleAgent diff --git a/swarms/agents/companion.py b/swarms/agents/companion.py new file mode 100644 index 00000000..a630895e --- /dev/null +++ b/swarms/agents/companion.py @@ -0,0 +1,4 @@ +""" +Companion agents converse with the user about the agent the user wants to create then creates the agent with the desired attributes and traits and tools and configurations + +""" diff --git a/swarms/agents/profitpilot.py b/swarms/agents/profitpilot.py index 8f6927c4..ac1d0b44 100644 --- a/swarms/agents/profitpilot.py +++ b/swarms/agents/profitpilot.py @@ -16,7 +16,6 @@ from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma from pydantic import BaseModel, Field from swarms.prompts.sales import SALES_AGENT_TOOLS_PROMPT, conversation_stages -from swarms.tools.interpreter_tool import compile # classes @@ -166,12 +165,7 @@ def get_tools(product_catalog): func=knowledge_base.run, description="useful for when you need to answer questions about product information", ), - # Interpreter - Tool( - name="Code Interepeter", - func=compile, - description="Useful when you need to run code locally, such as Python, Javascript, Shell, and more.", - ) + # omnimodal agent ] diff --git a/swarms/chunkers/base.py b/swarms/chunkers/base.py index 464f51e4..0fabdcef 100644 --- a/swarms/chunkers/base.py +++ b/swarms/chunkers/base.py @@ -1,10 +1,13 @@ from __future__ import annotations + from abc import ABC from typing import Optional -from attr import define, field, Factory + +from attr import Factory, define, field from griptape.artifacts import TextArtifact -from swarms.chunkers.chunk_seperators import ChunkSeparator -from griptape.tokenizers import OpenAiTokenizer + +from swarms.chunkers.chunk_seperator import ChunkSeparator +from swarms.models.openai_tokenizer import OpenAITokenizer @define @@ -16,6 +19,24 @@ class BaseChunker(ABC): Usage: -------------- + from swarms.chunkers.base import BaseChunker + from swarms.chunkers.chunk_seperator import ChunkSeparator + + class PdfChunker(BaseChunker): + DEFAULT_SEPARATORS = [ + ChunkSeparator("\n\n"), + ChunkSeparator(". "), + ChunkSeparator("! "), + ChunkSeparator("? "), + ChunkSeparator(" "), + ] + + # Example + pdf = "swarmdeck.pdf" + chunker = PdfChunker() + chunks = chunker.chunk(pdf) + print(chunks) + """ @@ -26,10 +47,10 @@ class BaseChunker(ABC): default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True), kw_only=True, ) - tokenizer: OpenAiTokenizer = field( + tokenizer: OpenAITokenizer = field( default=Factory( - lambda: OpenAiTokenizer( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL + lambda: OpenAITokenizer( + model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL ) ), kw_only=True, @@ -47,7 +68,7 @@ class BaseChunker(ABC): def _chunk_recursively( self, chunk: str, current_separator: Optional[ChunkSeparator] = None ) -> list[str]: - token_count = self.tokenizer.token_count(chunk) + token_count = self.tokenizer.count_tokens(chunk) if token_count <= self.max_tokens: return [chunk] diff --git a/swarms/chunkers/markdown.py b/swarms/chunkers/markdown.py index 6c0e755f..7836b0a7 100644 --- a/swarms/chunkers/markdown.py +++ b/swarms/chunkers/markdown.py @@ -15,3 +15,10 @@ class MarkdownChunker(BaseChunker): ChunkSeparator("? "), ChunkSeparator(" "), ] + + +# # Example using chunker to chunk a markdown file +# file = open("README.md", "r") +# text = file.read() +# chunker = MarkdownChunker() +# chunks = chunker.chunk(text) diff --git a/swarms/chunkers/omni_chunker.py b/swarms/chunkers/omni_chunker.py new file mode 100644 index 00000000..dca569ea --- /dev/null +++ b/swarms/chunkers/omni_chunker.py @@ -0,0 +1,124 @@ +""" +Omni Chunker is a chunker that chunks all files into select chunks of size x strings + +Usage: +-------------- +from swarms.chunkers.omni_chunker import OmniChunker + +# Example +pdf = "swarmdeck.pdf" +chunker = OmniChunker(chunk_size=1000, beautify=True) +chunks = chunker(pdf) +print(chunks) + + +""" +from dataclasses import dataclass +from typing import List, Optional, Callable +from termcolor import colored +import os +import sys + + + + +@dataclass +class OmniChunker: + """ + + + """ + chunk_size: int = 1000 + beautify: bool = False + use_tokenizer: bool = False + tokenizer: Optional[Callable[[str], List[str]]] = None + + + + def __call__(self, file_path: str) -> List[str]: + """ + Chunk the given file into parts of size `chunk_size`. + + Args: + file_path (str): The path to the file to chunk. + + Returns: + List[str]: A list of string chunks from the file. + """ + if not os.path.isfile(file_path): + print(colored("The file does not exist.", "red")) + return [] + + file_extension = os.path.splitext(file_path)[1] + try: + with open(file_path, "rb") as file: + content = file.read() + # Decode content based on MIME type or file extension + decoded_content = self.decode_content(content, file_extension) + chunks = self.chunk_content(decoded_content) + return chunks + + except Exception as e: + print(colored(f"Error reading file: {e}", "red")) + return [] + + def decode_content(self, content: bytes, file_extension: str) -> str: + """ + Decode the content of the file based on its MIME type or file extension. + + Args: + content (bytes): The content of the file. + file_extension (str): The file extension of the file. + + Returns: + str: The decoded content of the file. + """ + # Add logic to handle different file types based on the extension + # For simplicity, this example assumes text files encoded in utf-8 + try: + return content.decode("utf-8") + except UnicodeDecodeError as e: + print( + colored( + f"Could not decode file with extension {file_extension}: {e}", + "yellow", + ) + ) + return "" + + def chunk_content(self, content: str) -> List[str]: + """ + Split the content into chunks of size `chunk_size`. + + Args: + content (str): The content to chunk. + + Returns: + List[str]: The list of chunks. + """ + return [ + content[i : i + self.chunk_size] + for i in range(0, len(content), self.chunk_size) + ] + + def __str__(self): + return f"OmniChunker(chunk_size={self.chunk_size}, beautify={self.beautify})" + + def metrics(self): + return { + "chunk_size": self.chunk_size, + "beautify": self.beautify, + } + + def print_dashboard(self): + print( + colored( + f""" + Omni Chunker + ------------ + {self.metrics()} + """, + "cyan", + ) + ) + diff --git a/swarms/chunkers/pdf.py b/swarms/chunkers/pdf.py index 206c74f3..710134a0 100644 --- a/swarms/chunkers/pdf.py +++ b/swarms/chunkers/pdf.py @@ -10,3 +10,10 @@ class PdfChunker(BaseChunker): ChunkSeparator("? "), ChunkSeparator(" "), ] + + +# # Example +# pdf = "swarmdeck.pdf" +# chunker = PdfChunker() +# chunks = chunker.chunk(pdf) +# print(chunks) diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index dd21ba80..26c06066 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -16,6 +16,7 @@ from swarms.models.kosmos_two import Kosmos from swarms.models.vilt import Vilt from swarms.models.nougat import Nougat from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA + # from swarms.models.gpt4v import GPT4Vision # from swarms.models.dalle3 import Dalle3 diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py index 9914fce9..cc3931bb 100644 --- a/swarms/models/anthropic.py +++ b/swarms/models/anthropic.py @@ -44,7 +44,7 @@ class Anthropic: top_p=None, streaming=False, default_request_timeout=None, - api_key: str = None + api_key: str = None, ): self.model = model self.max_tokens_to_sample = max_tokens_to_sample diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index 2ac5d403..899564fc 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -129,7 +129,7 @@ class Dalle3: ) ) raise error - + def create_variations(self, img: str): """ Create variations of an image using the Dalle3 API @@ -151,14 +151,11 @@ class Dalle3: >>> img = dalle3.create_variations(img) >>> print(img) - + """ try: - response = self.client.images.create_variation( - img = open(img, "rb"), - n=self.n, - size=self.size + img=open(img, "rb"), n=self.n, size=self.size ) img = response.data[0].url @@ -172,4 +169,4 @@ class Dalle3: ) print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red")) - raise error \ No newline at end of file + raise error diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 0c5bf2c7..f11bf3df 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -74,7 +74,9 @@ class HuggingfaceLLM: bnb_config = BitsAndBytesConfig(**quantization_config) try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, *args, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id, *args, **kwargs + ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=bnb_config, *args, **kwargs ) @@ -162,7 +164,12 @@ class HuggingfaceLLM: del inputs return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: - print(colored(f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments", "red")) + print( + colored( + f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments", + "red", + ) + ) raise async def run_async(self, task: str, *args, **kwargs) -> str: diff --git a/swarms/models/openai_assistant.py b/swarms/models/openai_assistant.py new file mode 100644 index 00000000..6d0c518f --- /dev/null +++ b/swarms/models/openai_assistant.py @@ -0,0 +1,74 @@ +from typing import Dict, List, Optional +from dataclass import dataclass + +from swarms.models import OpenAI + + +@dataclass +class OpenAIAssistant: + name: str = "OpenAI Assistant" + instructions: str = None + tools: List[Dict] = None + model: str = None + openai_api_key: str = None + temperature: float = 0.5 + max_tokens: int = 100 + stop: List[str] = None + echo: bool = False + stream: bool = False + log: bool = False + presence: bool = False + dashboard: bool = False + debug: bool = False + max_loops: int = 5 + stopping_condition: Optional[str] = None + loop_interval: int = 1 + retry_attempts: int = 3 + retry_interval: int = 1 + interactive: bool = False + dynamic_temperature: bool = False + state: Dict = None + response_filters: List = None + response_filter: Dict = None + response_filter_name: str = None + response_filter_value: str = None + response_filter_type: str = None + response_filter_action: str = None + response_filter_action_value: str = None + response_filter_action_type: str = None + response_filter_action_name: str = None + client = OpenAI() + role: str = "user" + instructions: str = None + + def create_assistant(self, task: str): + assistant = self.client.create_assistant( + name=self.name, + instructions=self.instructions, + tools=self.tools, + model=self.model, + ) + return assistant + + def create_thread(self): + thread = self.client.beta.threads.create() + return thread + + def add_message_to_thread(self, thread_id: str, message: str): + message = self.client.beta.threads.add_message( + thread_id=thread_id, role=self.user, content=message + ) + return message + + def run(self, task: str): + run = self.client.beta.threads.runs.create( + thread_id=self.create_thread().id, + assistant_id=self.create_assistant().id, + instructions=self.instructions, + ) + + out = self.client.beta.threads.runs.retrieve( + thread_id=run.thread_id, run_id=run.id + ) + + return out diff --git a/swarms/models/openai_tokenizer.py b/swarms/models/openai_tokenizer.py new file mode 100644 index 00000000..b4e375cc --- /dev/null +++ b/swarms/models/openai_tokenizer.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import tiktoken +from attr import Factory, define, field + + +@define(frozen=True) +class BaseTokenizer(ABC): + DEFAULT_STOP_SEQUENCES = ["Observation:"] + + stop_sequences: list[str] = field( + default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES), + kw_only=True, + ) + + @property + @abstractmethod + def max_tokens(self) -> int: + ... + + def count_tokens_left(self, text: str) -> int: + diff = self.max_tokens - self.count_tokens(text) + + if diff > 0: + return diff + else: + return 0 + + @abstractmethod + def count_tokens(self, text: str) -> int: + ... + + +@define(frozen=True) +class OpenAITokenizer(BaseTokenizer): + DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003" + DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo" + DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4" + DEFAULT_ENCODING = "cl100k_base" + DEFAULT_MAX_TOKENS = 2049 + TOKEN_OFFSET = 8 + + MODEL_PREFIXES_TO_MAX_TOKENS = { + "gpt-4-32k": 32768, + "gpt-4": 8192, + "gpt-3.5-turbo-16k": 16384, + "gpt-3.5-turbo": 4096, + "gpt-35-turbo-16k": 16384, + "gpt-35-turbo": 4096, + "text-davinci-003": 4097, + "text-davinci-002": 4097, + "code-davinci-002": 8001, + "text-embedding-ada-002": 8191, + "text-embedding-ada-001": 2046, + } + + EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"] + + model: str = field(kw_only=True) + + @property + def encoding(self) -> tiktoken.Encoding: + try: + return tiktoken.encoding_for_model(self.model) + except KeyError: + return tiktoken.get_encoding(self.DEFAULT_ENCODING) + + @property + def max_tokens(self) -> int: + tokens = next( + v + for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() + if self.model.startswith(k) + ) + offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET + + return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset + + def count_tokens( + self, text: str | list, model: Optional[str] = None + ) -> int: + """ + Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: + https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + """ + if isinstance(text, list): + model = model if model else self.model + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logging.warning("model not found. Using cl100k_base encoding.") + + encoding = tiktoken.get_encoding("cl100k_base") + + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: + logging.info( + "gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613." + ) + return self.count_tokens(text, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + logging.info( + "gpt-4 may update over time. Returning num tokens assuming gpt-4-0613." + ) + return self.count_tokens(text, model="gpt-4-0613") + else: + raise NotImplementedError( + f"""token_count() is not implemented for model {model}. + See https://github.com/openai/openai-python/blob/main/chatml.md for + information on how messages are converted to tokens.""" + ) + + num_tokens = 0 + + for message in text: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with <|start|>assistant<|message|> + num_tokens += 3 + + return num_tokens + else: + return len( + self.encoding.encode( + text, allowed_special=set(self.stop_sequences) + ) + ) \ No newline at end of file diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index 4e21c3df..9ff021f4 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -116,6 +116,7 @@ class Flow: dynamic_temperature: bool = False, saved_state_path: Optional[str] = "flow_state.json", autosave: bool = False, + context_length: int = 8192, **kwargs: Any, ): self.llm = llm @@ -188,6 +189,26 @@ class Flow: return "\n".join(params_str_list) + def truncate_history(self): + """ + Take the history and truncate it to fit into the model context length + """ + truncated_history = self.memory[-1][-self.context_length :] + self.memory[-1] = truncated_history + + def add_task_to_memory(self, task: str): + """Add the task to the memory""" + self.memory.append([f"Human: {task}"]) + + def add_message_to_memory(self, message: str): + """Add the message to the memory""" + self.memory[-1].append(message) + + def add_message_to_memory_and_truncate(self, message: str): + """Add the message to the memory and truncate""" + self.memory[-1].append(message) + self.truncate_history() + def print_dashboard(self, task: str): """Print dashboard""" model_config = self.get_llm_init_params() diff --git a/swarms/tools/interpreter_tool.py b/swarms/tools/interpreter_tool.py deleted file mode 100644 index 22758de6..00000000 --- a/swarms/tools/interpreter_tool.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -import interpreter - - -def compile(task: str): - """ - Open Interpreter lets LLMs run code (Python, Javascript, Shell, and more) locally. You can chat with Open Interpreter through a ChatGPT-like interface in your terminal by running $ interpreter after installing. - - This provides a natural-language interface to your computer's general-purpose capabilities: - - Create and edit photos, videos, PDFs, etc. - Control a Chrome browser to perform research - Plot, clean, and analyze large datasets - ...etc. - ⚠️ Note: You'll be asked to approve code before it's run. - """ - - task = interpreter.chat(task, return_messages=True) - interpreter.chat() - interpreter.reset(task) - - os.environ["INTERPRETER_CLI_AUTO_RUN"] = True - os.environ["INTERPRETER_CLI_FAST_MODE"] = True - os.environ["INTERPRETER_CLI_DEBUG"] = True diff --git a/swarms/workers/__init__.py b/swarms/workers/__init__.py index 2a7cc4f1..9dabe94d 100644 --- a/swarms/workers/__init__.py +++ b/swarms/workers/__init__.py @@ -1,2 +1,2 @@ -from swarms.workers.worker import Worker +# from swarms.workers.worker import Worker from swarms.workers.base import AbstractWorker diff --git a/tests/chunkers/basechunker.py b/tests/chunkers/basechunker.py index f70705bc..4fd92da1 100644 --- a/tests/chunkers/basechunker.py +++ b/tests/chunkers/basechunker.py @@ -3,7 +3,7 @@ from swarms.chunkers.base import ( BaseChunker, TextArtifact, ChunkSeparator, - OpenAiTokenizer, + OpenAITokenizer, ) # adjust the import paths accordingly @@ -21,7 +21,7 @@ def test_default_separators(): def test_default_tokenizer(): chunker = BaseChunker() - assert isinstance(chunker.tokenizer, OpenAiTokenizer) + assert isinstance(chunker.tokenizer, OpenAITokenizer) # 2. Test Basic Chunking diff --git a/tests/models/dalle3.py b/tests/models/dalle3.py index 42b851b7..f9a2f8cf 100644 --- a/tests/models/dalle3.py +++ b/tests/models/dalle3.py @@ -23,8 +23,12 @@ def dalle3(mock_openai_client): def test_dalle3_call_success(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) + expected_img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) + mock_openai_client.images.generate.return_value = Mock( + data=[Mock(url=expected_img_url)] + ) # Act img_url = dalle3(task) @@ -40,7 +44,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + mock_openai_client.images.generate.side_effect = OpenAIError( + expected_error_message, http_status=500, error="Internal Server Error" + ) # Act and assert with pytest.raises(OpenAIError) as excinfo: @@ -57,8 +63,12 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys): def test_dalle3_create_variations_success(dalle3, mock_openai_client): # Arrange img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" - mock_openai_client.images.create_variation.return_value = Mock(data=[Mock(url=expected_variation_url)]) + expected_variation_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + ) + mock_openai_client.images.create_variation.return_value = Mock( + data=[Mock(url=expected_variation_url)] + ) # Act variation_img_url = dalle3.create_variations(img_url) @@ -78,7 +88,9 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + mock_openai_client.images.create_variation.side_effect = OpenAIError( + expected_error_message, http_status=500, error="Internal Server Error" + ) # Act and assert with pytest.raises(OpenAIError) as excinfo: @@ -86,7 +98,7 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys): assert str(excinfo.value) == expected_error_message mock_openai_client.images.create_variation.assert_called_once() - + # Ensure the error message is printed in red captured = capsys.readouterr() assert colored(expected_error_message, "red") in captured.out @@ -142,8 +154,12 @@ def test_dalle3_convert_to_bytesio(): def test_dalle3_call_multiple_times(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)]) + expected_img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) + mock_openai_client.images.generate.return_value = Mock( + data=[Mock(url=expected_img_url)] + ) # Act img_url1 = dalle3(task) @@ -159,7 +175,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client): # Arrange task = "A" * 2048 # Input longer than API's limit expected_error_message = "Error running Dalle3: API Error" - mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + mock_openai_client.images.generate.side_effect = OpenAIError( + expected_error_message, http_status=500, error="Internal Server Error" + ) # Act and assert with pytest.raises(OpenAIError) as excinfo: @@ -204,7 +222,9 @@ def test_dalle3_convert_to_bytesio_invalid_format(dalle3): def test_dalle3_call_with_retry(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) # Simulate a retry scenario mock_openai_client.images.generate.side_effect = [ @@ -223,7 +243,9 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client): def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client): # Arrange img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + expected_variation_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + ) # Simulate a retry scenario mock_openai_client.images.create_variation.side_effect = [ @@ -245,7 +267,9 @@ def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys): expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + mock_openai_client.images.generate.side_effect = OpenAIError( + expected_error_message, http_status=500, error="Internal Server Error" + ) # Act with pytest.raises(OpenAIError): @@ -262,7 +286,9 @@ def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client, expected_error_message = "Error running Dalle3: API Error" # Mocking OpenAIError - mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error") + mock_openai_client.images.create_variation.side_effect = OpenAIError( + expected_error_message, http_status=500, error="Internal Server Error" + ) # Act with pytest.raises(OpenAIError): @@ -313,7 +339,9 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client) task = "A painting of a dog" # Simulate max retries exceeded - mock_openai_client.images.generate.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") + mock_openai_client.images.generate.side_effect = OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ) # Act and assert with pytest.raises(OpenAIError) as excinfo: @@ -322,12 +350,16 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client) assert "Retry limit exceeded" in str(excinfo.value) -def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_openai_client): +def test_dalle3_create_variations_with_retry_max_retries_exceeded( + dalle3, mock_openai_client +): # Arrange img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" # Simulate max retries exceeded - mock_openai_client.images.create_variation.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error") + mock_openai_client.images.create_variation.side_effect = OpenAIError( + "Temporary error", http_status=500, error="Internal Server Error" + ) # Act and assert with pytest.raises(OpenAIError) as excinfo: @@ -339,7 +371,9 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_o def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): # Arrange task = "A painting of a dog" - expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + expected_img_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" + ) # Simulate success after a retry mock_openai_client.images.generate.side_effect = [ @@ -358,7 +392,9 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client): def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client): # Arrange img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png" - expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + expected_variation_url = ( + "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png" + ) # Simulate success after a retry mock_openai_client.images.create_variation.side_effect = [ diff --git a/tests/models/gpt4v.py b/tests/models/gpt4v.py index 40ccc7f5..23e97d03 100644 --- a/tests/models/gpt4v.py +++ b/tests/models/gpt4v.py @@ -12,19 +12,22 @@ load_dotenv api_key = os.getenv("OPENAI_API_KEY") + # Mock the OpenAI client @pytest.fixture def mock_openai_client(): return Mock() + @pytest.fixture def gpt4vision(mock_openai_client): return GPT4Vision(client=mock_openai_client) + def test_gpt4vision_default_values(): # Arrange and Act gpt4vision = GPT4Vision() - + # Assert assert gpt4vision.max_retries == 3 assert gpt4vision.model == "gpt-4-vision-preview" @@ -34,59 +37,68 @@ def test_gpt4vision_default_values(): assert gpt4vision.quality == "low" assert gpt4vision.max_tokens == 200 + def test_gpt4vision_api_key_from_env_variable(): # Arrange - api_key = os.environ["OPENAI_API_KEY"] - + api_key = os.environ["OPENAI_API_KEY"] + # Act gpt4vision = GPT4Vision() - + # Assert assert gpt4vision.api_key == api_key + def test_gpt4vision_set_api_key(): # Arrange gpt4vision = GPT4Vision(api_key=api_key) - + # Assert assert gpt4vision.api_key == api_key + def test_gpt4vision_invalid_max_retries(): # Arrange and Act with pytest.raises(ValueError): GPT4Vision(max_retries=-1) + def test_gpt4vision_invalid_backoff_factor(): # Arrange and Act with pytest.raises(ValueError): GPT4Vision(backoff_factor=-1) + def test_gpt4vision_invalid_timeout_seconds(): # Arrange and Act with pytest.raises(ValueError): GPT4Vision(timeout_seconds=-1) + def test_gpt4vision_invalid_max_tokens(): # Arrange and Act with pytest.raises(ValueError): GPT4Vision(max_tokens=-1) + def test_gpt4vision_logger_initialized(): # Arrange gpt4vision = GPT4Vision() - + # Assert assert isinstance(gpt4vision.logger, logging.Logger) + def test_gpt4vision_process_img_nonexistent_file(): # Arrange gpt4vision = GPT4Vision() img_path = "nonexistent_image.jpg" - + # Act and Assert with pytest.raises(FileNotFoundError): gpt4vision.process_img(img_path) + def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" @@ -96,7 +108,10 @@ def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision): with pytest.raises(AttributeError): gpt4vision(img_url, [task]) -def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, mock_openai_client): + +def test_gpt4vision_call_single_task_single_image_empty_response( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." @@ -110,7 +125,10 @@ def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, moc assert response.answer == "" mock_openai_client.chat.completions.create.assert_called_once() -def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, mock_openai_client): + +def test_gpt4vision_call_multiple_tasks_single_image_empty_responses( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" tasks = ["Describe this image.", "What's in this picture?"] @@ -122,20 +140,30 @@ def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, # Assert assert all(response.answer == "" for response in responses) - assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + assert ( + mock_openai_client.chat.completions.create.call_count == 1 + ) # Should be called only once + -def test_gpt4vision_call_single_task_single_image_timeout(gpt4vision, mock_openai_client): +def test_gpt4vision_call_single_task_single_image_timeout( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = Timeout("Request timed out") + mock_openai_client.chat.completions.create.side_effect = Timeout( + "Request timed out" + ) # Act and Assert with pytest.raises(Timeout): gpt4vision(img_url, [task]) -def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_openai_client): + +def test_gpt4vision_call_retry_with_success_after_timeout( + gpt4vision, mock_openai_client +): # Arrange img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." @@ -143,7 +171,11 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena # Simulate success after a timeout and retry mock_openai_client.chat.completions.create.side_effect = [ Timeout("Request timed out"), - {"choices": [{"message": {"content": {"text": "A description of the image."}}}],} + { + "choices": [ + {"message": {"content": {"text": "A description of the image."}}} + ], + }, ] # Act @@ -151,7 +183,9 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena # Assert assert response.answer == "A description of the image." - assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice + assert ( + mock_openai_client.chat.completions.create.call_count == 2 + ) # Should be called twice def test_gpt4vision_process_img(): @@ -173,7 +207,9 @@ def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client expected_response = GPT4VisionResponse(answer="A description of the image.") - mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[ + 0 + ].text = expected_response.answer # Act response = gpt4vision(img_url, [task]) @@ -190,7 +226,9 @@ def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_cli expected_response = GPT4VisionResponse(answer="Descriptions of the images.") - mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer + mock_openai_client.chat.completions.create.return_value.choices[ + 0 + ].text = expected_response.answer # Act response = gpt4vision(img_urls, [task]) @@ -213,57 +251,76 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli def create_mock_response(response): return {"choices": [{"message": {"content": {"text": response.answer}}}]} - mock_openai_client.chat.completions.create.side_effect = [create_mock_response(response) for response in expected_responses] + mock_openai_client.chat.completions.create.side_effect = [ + create_mock_response(response) for response in expected_responses + ] # Act responses = gpt4vision(img_url, tasks) # Assert assert responses == expected_responses - assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once - def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client): - # Arrange - img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" - tasks = ["Describe this image.", "What's in this picture?"] - - expected_responses = [ - GPT4VisionResponse(answer="A description of the image."), - GPT4VisionResponse(answer="It contains various objects."), + assert ( + mock_openai_client.chat.completions.create.call_count == 1 + ) # Should be called only once + + def test_gpt4vision_call_multiple_tasks_single_image( + gpt4vision, mock_openai_client + ): + # Arrange + img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + tasks = ["Describe this image.", "What's in this picture?"] + + expected_responses = [ + GPT4VisionResponse(answer="A description of the image."), + GPT4VisionResponse(answer="It contains various objects."), + ] + + mock_openai_client.chat.completions.create.side_effect = [ + { + "choices": [ + {"message": {"content": {"text": expected_responses[i].answer}}} ] + } + for i in range(len(expected_responses)) + ] - mock_openai_client.chat.completions.create.side_effect = [ - {"choices": [{"message": {"content": {"text": expected_responses[i].answer}}}] } for i in range(len(expected_responses)) - ] + # Act + responses = gpt4vision(img_url, tasks) - # Act - responses = gpt4vision(img_url, tasks) - - # Assert - assert responses == expected_responses - assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + # Assert + assert responses == expected_responses + assert ( + mock_openai_client.chat.completions.create.call_count == 1 + ) # Should be called only once def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client): # Arrange - img_urls = ["https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"] + img_urls = [ + "https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", + ] tasks = ["Describe these images.", "What's in these pictures?"] expected_responses = [ GPT4VisionResponse(answer="Descriptions of the images."), - GPT4VisionResponse(answer="They contain various objects.") + GPT4VisionResponse(answer="They contain various objects."), ] mock_openai_client.chat.completions.create.side_effect = [ - {"choices": [{"message": {"content": {"text": response.answer}}}] } for response in expected_responses + {"choices": [{"message": {"content": {"text": response.answer}}}]} + for response in expected_responses ] # Act responses = gpt4vision(img_urls, tasks) - # Assert assert responses == expected_responses - assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once + assert ( + mock_openai_client.chat.completions.create.call_count == 1 + ) # Should be called only once def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client): @@ -283,7 +340,9 @@ def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client): img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = RequestException("Request Error") + mock_openai_client.chat.completions.create.side_effect = RequestException( + "Request Error" + ) # Act and Assert with pytest.raises(RequestException): @@ -295,7 +354,9 @@ def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client): img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" task = "Describe this image." - mock_openai_client.chat.completions.create.side_effect = ConnectionError("Connection Error") + mock_openai_client.chat.completions.create.side_effect = ConnectionError( + "Connection Error" + ) # Act and Assert with pytest.raises(ConnectionError): @@ -310,7 +371,9 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client): # Simulate success after a retry mock_openai_client.chat.completions.create.side_effect = [ RequestException("Temporary error"), - {"choices": [{"text": "A description of the image."}]} # fixed dictionary syntax + { + "choices": [{"text": "A description of the image."}] + }, # fixed dictionary syntax ] # Act @@ -318,4 +381,6 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client): # Assert assert response.answer == "A description of the image." - assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice + assert ( + mock_openai_client.chat.completions.create.call_count == 2 + ) # Should be called twice