diff --git a/docs/swarms/chunkers/basechunker.md b/docs/swarms/chunkers/basechunker.md
index fed03277..33b03312 100644
--- a/docs/swarms/chunkers/basechunker.md
+++ b/docs/swarms/chunkers/basechunker.md
@@ -53,7 +53,7 @@ The `BaseChunker` class is the core component of the `BaseChunker` module. It is
#### Parameters:
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the text into chunks.
-- `tokenizer` (OpenAiTokenizer): Defines the tokenizer to be used for counting tokens in the text.
+- `tokenizer` (OpenAITokenizer): Defines the tokenizer to be used for counting tokens in the text.
- `max_tokens` (int): Sets the maximum token limit for each chunk.
### 4.2. Examples
diff --git a/docs/swarms/chunkers/pdf_chunker.md b/docs/swarms/chunkers/pdf_chunker.md
index 5b97a551..8c92060d 100644
--- a/docs/swarms/chunkers/pdf_chunker.md
+++ b/docs/swarms/chunkers/pdf_chunker.md
@@ -52,7 +52,7 @@ The `PdfChunker` class is the core component of the `PdfChunker` module. It is u
#### Parameters:
- `separators` (list[ChunkSeparator]): Specifies a list of `ChunkSeparator` objects used to split the PDF text content into chunks.
-- `tokenizer` (OpenAiTokenizer): Defines the tokenizer used for counting tokens in the text.
+- `tokenizer` (OpenAITokenizer): Defines the tokenizer used for counting tokens in the text.
- `max_tokens` (int): Sets the maximum token limit for each chunk.
### 4.2. Examples
diff --git a/example.py b/example.py
index b3740aa2..6c27bceb 100644
--- a/example.py
+++ b/example.py
@@ -29,7 +29,9 @@ flow = Flow(
# out = flow.load_state("flow_state.json")
# temp = flow.dynamic_temperature()
# filter = flow.add_response_filter("Trump")
-out = flow.run("Generate a 10,000 word blog on mental clarity and the benefits of meditation.")
+out = flow.run(
+ "Generate a 10,000 word blog on mental clarity and the benefits of meditation."
+)
# out = flow.validate_response(out)
# out = flow.analyze_feedback(out)
# out = flow.print_history_and_memory()
diff --git a/pyproject.toml b/pyproject.toml
index a80b6389..3cb153c4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
-version = "2.0.1"
+version = "2.0.2"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez "]
@@ -41,6 +41,7 @@ sentencepiece = "*"
wget = "*"
griptape = "*"
httpx = "*"
+tiktoken = "*"
attrs = "*"
ggl = "*"
beautifulsoup4 = "*"
@@ -49,7 +50,6 @@ pydantic = "*"
tenacity = "*"
Pillow = "*"
chromadb = "*"
-open-interpreter = "*"
tabulate = "*"
termcolor = "*"
black = "*"
diff --git a/requirements.txt b/requirements.txt
index 7ff9d362..cb0c65b8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -29,6 +29,7 @@ sentencepiece
duckduckgo-search
agent-protocol
chromadb
+tiktoken
open-interpreter
tabulate
colored
diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py
index 34dc0f1d..355f0ad1 100644
--- a/swarms/agents/__init__.py
+++ b/swarms/agents/__init__.py
@@ -5,6 +5,7 @@ from swarms.agents.message import Message
# from swarms.agents.stream_response import stream
from swarms.agents.base import AbstractAgent
from swarms.agents.registry import Registry
+
# from swarms.agents.idea_to_image_agent import Idea2Image
from swarms.agents.simple_agent import SimpleAgent
diff --git a/swarms/agents/companion.py b/swarms/agents/companion.py
new file mode 100644
index 00000000..a630895e
--- /dev/null
+++ b/swarms/agents/companion.py
@@ -0,0 +1,4 @@
+"""
+Companion agents converse with the user about the agent the user wants to create then creates the agent with the desired attributes and traits and tools and configurations
+
+"""
diff --git a/swarms/agents/profitpilot.py b/swarms/agents/profitpilot.py
index 8f6927c4..ac1d0b44 100644
--- a/swarms/agents/profitpilot.py
+++ b/swarms/agents/profitpilot.py
@@ -16,7 +16,6 @@ from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from pydantic import BaseModel, Field
from swarms.prompts.sales import SALES_AGENT_TOOLS_PROMPT, conversation_stages
-from swarms.tools.interpreter_tool import compile
# classes
@@ -166,12 +165,7 @@ def get_tools(product_catalog):
func=knowledge_base.run,
description="useful for when you need to answer questions about product information",
),
- # Interpreter
- Tool(
- name="Code Interepeter",
- func=compile,
- description="Useful when you need to run code locally, such as Python, Javascript, Shell, and more.",
- )
+
# omnimodal agent
]
diff --git a/swarms/chunkers/base.py b/swarms/chunkers/base.py
index 464f51e4..0fabdcef 100644
--- a/swarms/chunkers/base.py
+++ b/swarms/chunkers/base.py
@@ -1,10 +1,13 @@
from __future__ import annotations
+
from abc import ABC
from typing import Optional
-from attr import define, field, Factory
+
+from attr import Factory, define, field
from griptape.artifacts import TextArtifact
-from swarms.chunkers.chunk_seperators import ChunkSeparator
-from griptape.tokenizers import OpenAiTokenizer
+
+from swarms.chunkers.chunk_seperator import ChunkSeparator
+from swarms.models.openai_tokenizer import OpenAITokenizer
@define
@@ -16,6 +19,24 @@ class BaseChunker(ABC):
Usage:
--------------
+ from swarms.chunkers.base import BaseChunker
+ from swarms.chunkers.chunk_seperator import ChunkSeparator
+
+ class PdfChunker(BaseChunker):
+ DEFAULT_SEPARATORS = [
+ ChunkSeparator("\n\n"),
+ ChunkSeparator(". "),
+ ChunkSeparator("! "),
+ ChunkSeparator("? "),
+ ChunkSeparator(" "),
+ ]
+
+ # Example
+ pdf = "swarmdeck.pdf"
+ chunker = PdfChunker()
+ chunks = chunker.chunk(pdf)
+ print(chunks)
+
"""
@@ -26,10 +47,10 @@ class BaseChunker(ABC):
default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True),
kw_only=True,
)
- tokenizer: OpenAiTokenizer = field(
+ tokenizer: OpenAITokenizer = field(
default=Factory(
- lambda: OpenAiTokenizer(
- model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
+ lambda: OpenAITokenizer(
+ model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
)
),
kw_only=True,
@@ -47,7 +68,7 @@ class BaseChunker(ABC):
def _chunk_recursively(
self, chunk: str, current_separator: Optional[ChunkSeparator] = None
) -> list[str]:
- token_count = self.tokenizer.token_count(chunk)
+ token_count = self.tokenizer.count_tokens(chunk)
if token_count <= self.max_tokens:
return [chunk]
diff --git a/swarms/chunkers/markdown.py b/swarms/chunkers/markdown.py
index 6c0e755f..7836b0a7 100644
--- a/swarms/chunkers/markdown.py
+++ b/swarms/chunkers/markdown.py
@@ -15,3 +15,10 @@ class MarkdownChunker(BaseChunker):
ChunkSeparator("? "),
ChunkSeparator(" "),
]
+
+
+# # Example using chunker to chunk a markdown file
+# file = open("README.md", "r")
+# text = file.read()
+# chunker = MarkdownChunker()
+# chunks = chunker.chunk(text)
diff --git a/swarms/chunkers/omni_chunker.py b/swarms/chunkers/omni_chunker.py
new file mode 100644
index 00000000..dca569ea
--- /dev/null
+++ b/swarms/chunkers/omni_chunker.py
@@ -0,0 +1,124 @@
+"""
+Omni Chunker is a chunker that chunks all files into select chunks of size x strings
+
+Usage:
+--------------
+from swarms.chunkers.omni_chunker import OmniChunker
+
+# Example
+pdf = "swarmdeck.pdf"
+chunker = OmniChunker(chunk_size=1000, beautify=True)
+chunks = chunker(pdf)
+print(chunks)
+
+
+"""
+from dataclasses import dataclass
+from typing import List, Optional, Callable
+from termcolor import colored
+import os
+import sys
+
+
+
+
+@dataclass
+class OmniChunker:
+ """
+
+
+ """
+ chunk_size: int = 1000
+ beautify: bool = False
+ use_tokenizer: bool = False
+ tokenizer: Optional[Callable[[str], List[str]]] = None
+
+
+
+ def __call__(self, file_path: str) -> List[str]:
+ """
+ Chunk the given file into parts of size `chunk_size`.
+
+ Args:
+ file_path (str): The path to the file to chunk.
+
+ Returns:
+ List[str]: A list of string chunks from the file.
+ """
+ if not os.path.isfile(file_path):
+ print(colored("The file does not exist.", "red"))
+ return []
+
+ file_extension = os.path.splitext(file_path)[1]
+ try:
+ with open(file_path, "rb") as file:
+ content = file.read()
+ # Decode content based on MIME type or file extension
+ decoded_content = self.decode_content(content, file_extension)
+ chunks = self.chunk_content(decoded_content)
+ return chunks
+
+ except Exception as e:
+ print(colored(f"Error reading file: {e}", "red"))
+ return []
+
+ def decode_content(self, content: bytes, file_extension: str) -> str:
+ """
+ Decode the content of the file based on its MIME type or file extension.
+
+ Args:
+ content (bytes): The content of the file.
+ file_extension (str): The file extension of the file.
+
+ Returns:
+ str: The decoded content of the file.
+ """
+ # Add logic to handle different file types based on the extension
+ # For simplicity, this example assumes text files encoded in utf-8
+ try:
+ return content.decode("utf-8")
+ except UnicodeDecodeError as e:
+ print(
+ colored(
+ f"Could not decode file with extension {file_extension}: {e}",
+ "yellow",
+ )
+ )
+ return ""
+
+ def chunk_content(self, content: str) -> List[str]:
+ """
+ Split the content into chunks of size `chunk_size`.
+
+ Args:
+ content (str): The content to chunk.
+
+ Returns:
+ List[str]: The list of chunks.
+ """
+ return [
+ content[i : i + self.chunk_size]
+ for i in range(0, len(content), self.chunk_size)
+ ]
+
+ def __str__(self):
+ return f"OmniChunker(chunk_size={self.chunk_size}, beautify={self.beautify})"
+
+ def metrics(self):
+ return {
+ "chunk_size": self.chunk_size,
+ "beautify": self.beautify,
+ }
+
+ def print_dashboard(self):
+ print(
+ colored(
+ f"""
+ Omni Chunker
+ ------------
+ {self.metrics()}
+ """,
+ "cyan",
+ )
+ )
+
diff --git a/swarms/chunkers/pdf.py b/swarms/chunkers/pdf.py
index 206c74f3..710134a0 100644
--- a/swarms/chunkers/pdf.py
+++ b/swarms/chunkers/pdf.py
@@ -10,3 +10,10 @@ class PdfChunker(BaseChunker):
ChunkSeparator("? "),
ChunkSeparator(" "),
]
+
+
+# # Example
+# pdf = "swarmdeck.pdf"
+# chunker = PdfChunker()
+# chunks = chunker.chunk(pdf)
+# print(chunks)
diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py
index dd21ba80..26c06066 100644
--- a/swarms/models/__init__.py
+++ b/swarms/models/__init__.py
@@ -16,6 +16,7 @@ from swarms.models.kosmos_two import Kosmos
from swarms.models.vilt import Vilt
from swarms.models.nougat import Nougat
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA
+
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
diff --git a/swarms/models/anthropic.py b/swarms/models/anthropic.py
index 9914fce9..cc3931bb 100644
--- a/swarms/models/anthropic.py
+++ b/swarms/models/anthropic.py
@@ -44,7 +44,7 @@ class Anthropic:
top_p=None,
streaming=False,
default_request_timeout=None,
- api_key: str = None
+ api_key: str = None,
):
self.model = model
self.max_tokens_to_sample = max_tokens_to_sample
diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py
index 2ac5d403..899564fc 100644
--- a/swarms/models/dalle3.py
+++ b/swarms/models/dalle3.py
@@ -129,7 +129,7 @@ class Dalle3:
)
)
raise error
-
+
def create_variations(self, img: str):
"""
Create variations of an image using the Dalle3 API
@@ -151,14 +151,11 @@ class Dalle3:
>>> img = dalle3.create_variations(img)
>>> print(img)
-
+
"""
try:
-
response = self.client.images.create_variation(
- img = open(img, "rb"),
- n=self.n,
- size=self.size
+ img=open(img, "rb"), n=self.n, size=self.size
)
img = response.data[0].url
@@ -172,4 +169,4 @@ class Dalle3:
)
print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red"))
- raise error
\ No newline at end of file
+ raise error
diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py
index 0c5bf2c7..f11bf3df 100644
--- a/swarms/models/huggingface.py
+++ b/swarms/models/huggingface.py
@@ -74,7 +74,9 @@ class HuggingfaceLLM:
bnb_config = BitsAndBytesConfig(**quantization_config)
try:
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, *args, **kwargs)
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_id, *args, **kwargs
+ )
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config, *args, **kwargs
)
@@ -162,7 +164,12 @@ class HuggingfaceLLM:
del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
- print(colored(f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments", "red"))
+ print(
+ colored(
+ f"HuggingfaceLLM could not generate text because of error: {e}, try optimizing your arguments",
+ "red",
+ )
+ )
raise
async def run_async(self, task: str, *args, **kwargs) -> str:
diff --git a/swarms/models/openai_assistant.py b/swarms/models/openai_assistant.py
new file mode 100644
index 00000000..6d0c518f
--- /dev/null
+++ b/swarms/models/openai_assistant.py
@@ -0,0 +1,74 @@
+from typing import Dict, List, Optional
+from dataclass import dataclass
+
+from swarms.models import OpenAI
+
+
+@dataclass
+class OpenAIAssistant:
+ name: str = "OpenAI Assistant"
+ instructions: str = None
+ tools: List[Dict] = None
+ model: str = None
+ openai_api_key: str = None
+ temperature: float = 0.5
+ max_tokens: int = 100
+ stop: List[str] = None
+ echo: bool = False
+ stream: bool = False
+ log: bool = False
+ presence: bool = False
+ dashboard: bool = False
+ debug: bool = False
+ max_loops: int = 5
+ stopping_condition: Optional[str] = None
+ loop_interval: int = 1
+ retry_attempts: int = 3
+ retry_interval: int = 1
+ interactive: bool = False
+ dynamic_temperature: bool = False
+ state: Dict = None
+ response_filters: List = None
+ response_filter: Dict = None
+ response_filter_name: str = None
+ response_filter_value: str = None
+ response_filter_type: str = None
+ response_filter_action: str = None
+ response_filter_action_value: str = None
+ response_filter_action_type: str = None
+ response_filter_action_name: str = None
+ client = OpenAI()
+ role: str = "user"
+ instructions: str = None
+
+ def create_assistant(self, task: str):
+ assistant = self.client.create_assistant(
+ name=self.name,
+ instructions=self.instructions,
+ tools=self.tools,
+ model=self.model,
+ )
+ return assistant
+
+ def create_thread(self):
+ thread = self.client.beta.threads.create()
+ return thread
+
+ def add_message_to_thread(self, thread_id: str, message: str):
+ message = self.client.beta.threads.add_message(
+ thread_id=thread_id, role=self.user, content=message
+ )
+ return message
+
+ def run(self, task: str):
+ run = self.client.beta.threads.runs.create(
+ thread_id=self.create_thread().id,
+ assistant_id=self.create_assistant().id,
+ instructions=self.instructions,
+ )
+
+ out = self.client.beta.threads.runs.retrieve(
+ thread_id=run.thread_id, run_id=run.id
+ )
+
+ return out
diff --git a/swarms/models/openai_tokenizer.py b/swarms/models/openai_tokenizer.py
new file mode 100644
index 00000000..b4e375cc
--- /dev/null
+++ b/swarms/models/openai_tokenizer.py
@@ -0,0 +1,150 @@
+from __future__ import annotations
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import tiktoken
+from attr import Factory, define, field
+
+
+@define(frozen=True)
+class BaseTokenizer(ABC):
+ DEFAULT_STOP_SEQUENCES = ["Observation:"]
+
+ stop_sequences: list[str] = field(
+ default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES),
+ kw_only=True,
+ )
+
+ @property
+ @abstractmethod
+ def max_tokens(self) -> int:
+ ...
+
+ def count_tokens_left(self, text: str) -> int:
+ diff = self.max_tokens - self.count_tokens(text)
+
+ if diff > 0:
+ return diff
+ else:
+ return 0
+
+ @abstractmethod
+ def count_tokens(self, text: str) -> int:
+ ...
+
+
+@define(frozen=True)
+class OpenAITokenizer(BaseTokenizer):
+ DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003"
+ DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo"
+ DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4"
+ DEFAULT_ENCODING = "cl100k_base"
+ DEFAULT_MAX_TOKENS = 2049
+ TOKEN_OFFSET = 8
+
+ MODEL_PREFIXES_TO_MAX_TOKENS = {
+ "gpt-4-32k": 32768,
+ "gpt-4": 8192,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-3.5-turbo": 4096,
+ "gpt-35-turbo-16k": 16384,
+ "gpt-35-turbo": 4096,
+ "text-davinci-003": 4097,
+ "text-davinci-002": 4097,
+ "code-davinci-002": 8001,
+ "text-embedding-ada-002": 8191,
+ "text-embedding-ada-001": 2046,
+ }
+
+ EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"]
+
+ model: str = field(kw_only=True)
+
+ @property
+ def encoding(self) -> tiktoken.Encoding:
+ try:
+ return tiktoken.encoding_for_model(self.model)
+ except KeyError:
+ return tiktoken.get_encoding(self.DEFAULT_ENCODING)
+
+ @property
+ def max_tokens(self) -> int:
+ tokens = next(
+ v
+ for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
+ if self.model.startswith(k)
+ )
+ offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
+
+ return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
+
+ def count_tokens(
+ self, text: str | list, model: Optional[str] = None
+ ) -> int:
+ """
+ Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
+ https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ """
+ if isinstance(text, list):
+ model = model if model else self.model
+
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ logging.warning("model not found. Using cl100k_base encoding.")
+
+ encoding = tiktoken.get_encoding("cl100k_base")
+
+ if model in {
+ "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k-0613",
+ "gpt-4-0314",
+ "gpt-4-32k-0314",
+ "gpt-4-0613",
+ "gpt-4-32k-0613",
+ }:
+ tokens_per_message = 3
+ tokens_per_name = 1
+ elif model == "gpt-3.5-turbo-0301":
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ tokens_per_message = 4
+ # if there's a name, the role is omitted
+ tokens_per_name = -1
+ elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
+ logging.info(
+ "gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
+ )
+ return self.count_tokens(text, model="gpt-3.5-turbo-0613")
+ elif "gpt-4" in model:
+ logging.info(
+ "gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
+ )
+ return self.count_tokens(text, model="gpt-4-0613")
+ else:
+ raise NotImplementedError(
+ f"""token_count() is not implemented for model {model}.
+ See https://github.com/openai/openai-python/blob/main/chatml.md for
+ information on how messages are converted to tokens."""
+ )
+
+ num_tokens = 0
+
+ for message in text:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ num_tokens += len(encoding.encode(value))
+ if key == "name":
+ num_tokens += tokens_per_name
+
+ # every reply is primed with <|start|>assistant<|message|>
+ num_tokens += 3
+
+ return num_tokens
+ else:
+ return len(
+ self.encoding.encode(
+ text, allowed_special=set(self.stop_sequences)
+ )
+ )
\ No newline at end of file
diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py
index 4e21c3df..9ff021f4 100644
--- a/swarms/structs/flow.py
+++ b/swarms/structs/flow.py
@@ -116,6 +116,7 @@ class Flow:
dynamic_temperature: bool = False,
saved_state_path: Optional[str] = "flow_state.json",
autosave: bool = False,
+ context_length: int = 8192,
**kwargs: Any,
):
self.llm = llm
@@ -188,6 +189,26 @@ class Flow:
return "\n".join(params_str_list)
+ def truncate_history(self):
+ """
+ Take the history and truncate it to fit into the model context length
+ """
+ truncated_history = self.memory[-1][-self.context_length :]
+ self.memory[-1] = truncated_history
+
+ def add_task_to_memory(self, task: str):
+ """Add the task to the memory"""
+ self.memory.append([f"Human: {task}"])
+
+ def add_message_to_memory(self, message: str):
+ """Add the message to the memory"""
+ self.memory[-1].append(message)
+
+ def add_message_to_memory_and_truncate(self, message: str):
+ """Add the message to the memory and truncate"""
+ self.memory[-1].append(message)
+ self.truncate_history()
+
def print_dashboard(self, task: str):
"""Print dashboard"""
model_config = self.get_llm_init_params()
diff --git a/swarms/tools/interpreter_tool.py b/swarms/tools/interpreter_tool.py
deleted file mode 100644
index 22758de6..00000000
--- a/swarms/tools/interpreter_tool.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import os
-import interpreter
-
-
-def compile(task: str):
- """
- Open Interpreter lets LLMs run code (Python, Javascript, Shell, and more) locally. You can chat with Open Interpreter through a ChatGPT-like interface in your terminal by running $ interpreter after installing.
-
- This provides a natural-language interface to your computer's general-purpose capabilities:
-
- Create and edit photos, videos, PDFs, etc.
- Control a Chrome browser to perform research
- Plot, clean, and analyze large datasets
- ...etc.
- ⚠️ Note: You'll be asked to approve code before it's run.
- """
-
- task = interpreter.chat(task, return_messages=True)
- interpreter.chat()
- interpreter.reset(task)
-
- os.environ["INTERPRETER_CLI_AUTO_RUN"] = True
- os.environ["INTERPRETER_CLI_FAST_MODE"] = True
- os.environ["INTERPRETER_CLI_DEBUG"] = True
diff --git a/swarms/workers/__init__.py b/swarms/workers/__init__.py
index 2a7cc4f1..9dabe94d 100644
--- a/swarms/workers/__init__.py
+++ b/swarms/workers/__init__.py
@@ -1,2 +1,2 @@
-from swarms.workers.worker import Worker
+# from swarms.workers.worker import Worker
from swarms.workers.base import AbstractWorker
diff --git a/tests/chunkers/basechunker.py b/tests/chunkers/basechunker.py
index f70705bc..4fd92da1 100644
--- a/tests/chunkers/basechunker.py
+++ b/tests/chunkers/basechunker.py
@@ -3,7 +3,7 @@ from swarms.chunkers.base import (
BaseChunker,
TextArtifact,
ChunkSeparator,
- OpenAiTokenizer,
+ OpenAITokenizer,
) # adjust the import paths accordingly
@@ -21,7 +21,7 @@ def test_default_separators():
def test_default_tokenizer():
chunker = BaseChunker()
- assert isinstance(chunker.tokenizer, OpenAiTokenizer)
+ assert isinstance(chunker.tokenizer, OpenAITokenizer)
# 2. Test Basic Chunking
diff --git a/tests/models/dalle3.py b/tests/models/dalle3.py
index 42b851b7..f9a2f8cf 100644
--- a/tests/models/dalle3.py
+++ b/tests/models/dalle3.py
@@ -23,8 +23,12 @@ def dalle3(mock_openai_client):
def test_dalle3_call_success(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
- expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
- mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)])
+ expected_img_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ )
+ mock_openai_client.images.generate.return_value = Mock(
+ data=[Mock(url=expected_img_url)]
+ )
# Act
img_url = dalle3(task)
@@ -40,7 +44,9 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
- mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
+ mock_openai_client.images.generate.side_effect = OpenAIError(
+ expected_error_message, http_status=500, error="Internal Server Error"
+ )
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
@@ -57,8 +63,12 @@ def test_dalle3_call_failure(dalle3, mock_openai_client, capsys):
def test_dalle3_create_variations_success(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
- expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
- mock_openai_client.images.create_variation.return_value = Mock(data=[Mock(url=expected_variation_url)])
+ expected_variation_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
+ )
+ mock_openai_client.images.create_variation.return_value = Mock(
+ data=[Mock(url=expected_variation_url)]
+ )
# Act
variation_img_url = dalle3.create_variations(img_url)
@@ -78,7 +88,9 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
- mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
+ mock_openai_client.images.create_variation.side_effect = OpenAIError(
+ expected_error_message, http_status=500, error="Internal Server Error"
+ )
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
@@ -86,7 +98,7 @@ def test_dalle3_create_variations_failure(dalle3, mock_openai_client, capsys):
assert str(excinfo.value) == expected_error_message
mock_openai_client.images.create_variation.assert_called_once()
-
+
# Ensure the error message is printed in red
captured = capsys.readouterr()
assert colored(expected_error_message, "red") in captured.out
@@ -142,8 +154,12 @@ def test_dalle3_convert_to_bytesio():
def test_dalle3_call_multiple_times(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
- expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
- mock_openai_client.images.generate.return_value = Mock(data=[Mock(url=expected_img_url)])
+ expected_img_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ )
+ mock_openai_client.images.generate.return_value = Mock(
+ data=[Mock(url=expected_img_url)]
+ )
# Act
img_url1 = dalle3(task)
@@ -159,7 +175,9 @@ def test_dalle3_call_with_large_input(dalle3, mock_openai_client):
# Arrange
task = "A" * 2048 # Input longer than API's limit
expected_error_message = "Error running Dalle3: API Error"
- mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
+ mock_openai_client.images.generate.side_effect = OpenAIError(
+ expected_error_message, http_status=500, error="Internal Server Error"
+ )
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
@@ -204,7 +222,9 @@ def test_dalle3_convert_to_bytesio_invalid_format(dalle3):
def test_dalle3_call_with_retry(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
- expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ expected_img_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ )
# Simulate a retry scenario
mock_openai_client.images.generate.side_effect = [
@@ -223,7 +243,9 @@ def test_dalle3_call_with_retry(dalle3, mock_openai_client):
def test_dalle3_create_variations_with_retry(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
- expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
+ expected_variation_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
+ )
# Simulate a retry scenario
mock_openai_client.images.create_variation.side_effect = [
@@ -245,7 +267,9 @@ def test_dalle3_call_exception_logging(dalle3, mock_openai_client, capsys):
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
- mock_openai_client.images.generate.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
+ mock_openai_client.images.generate.side_effect = OpenAIError(
+ expected_error_message, http_status=500, error="Internal Server Error"
+ )
# Act
with pytest.raises(OpenAIError):
@@ -262,7 +286,9 @@ def test_dalle3_create_variations_exception_logging(dalle3, mock_openai_client,
expected_error_message = "Error running Dalle3: API Error"
# Mocking OpenAIError
- mock_openai_client.images.create_variation.side_effect = OpenAIError(expected_error_message, http_status=500, error="Internal Server Error")
+ mock_openai_client.images.create_variation.side_effect = OpenAIError(
+ expected_error_message, http_status=500, error="Internal Server Error"
+ )
# Act
with pytest.raises(OpenAIError):
@@ -313,7 +339,9 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client)
task = "A painting of a dog"
# Simulate max retries exceeded
- mock_openai_client.images.generate.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error")
+ mock_openai_client.images.generate.side_effect = OpenAIError(
+ "Temporary error", http_status=500, error="Internal Server Error"
+ )
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
@@ -322,12 +350,16 @@ def test_dalle3_call_with_retry_max_retries_exceeded(dalle3, mock_openai_client)
assert "Retry limit exceeded" in str(excinfo.value)
-def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_openai_client):
+def test_dalle3_create_variations_with_retry_max_retries_exceeded(
+ dalle3, mock_openai_client
+):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
# Simulate max retries exceeded
- mock_openai_client.images.create_variation.side_effect = OpenAIError("Temporary error", http_status=500, error="Internal Server Error")
+ mock_openai_client.images.create_variation.side_effect = OpenAIError(
+ "Temporary error", http_status=500, error="Internal Server Error"
+ )
# Act and assert
with pytest.raises(OpenAIError) as excinfo:
@@ -339,7 +371,9 @@ def test_dalle3_create_variations_with_retry_max_retries_exceeded(dalle3, mock_o
def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
# Arrange
task = "A painting of a dog"
- expected_img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ expected_img_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
+ )
# Simulate success after a retry
mock_openai_client.images.generate.side_effect = [
@@ -358,7 +392,9 @@ def test_dalle3_call_retry_with_success(dalle3, mock_openai_client):
def test_dalle3_create_variations_retry_with_success(dalle3, mock_openai_client):
# Arrange
img_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
- expected_variation_url = "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
+ expected_variation_url = (
+ "https://cdn.openai.com/dall-e/encoded/feats/feats_02ABCDE.png"
+ )
# Simulate success after a retry
mock_openai_client.images.create_variation.side_effect = [
diff --git a/tests/models/gpt4v.py b/tests/models/gpt4v.py
index 40ccc7f5..23e97d03 100644
--- a/tests/models/gpt4v.py
+++ b/tests/models/gpt4v.py
@@ -12,19 +12,22 @@ load_dotenv
api_key = os.getenv("OPENAI_API_KEY")
+
# Mock the OpenAI client
@pytest.fixture
def mock_openai_client():
return Mock()
+
@pytest.fixture
def gpt4vision(mock_openai_client):
return GPT4Vision(client=mock_openai_client)
+
def test_gpt4vision_default_values():
# Arrange and Act
gpt4vision = GPT4Vision()
-
+
# Assert
assert gpt4vision.max_retries == 3
assert gpt4vision.model == "gpt-4-vision-preview"
@@ -34,59 +37,68 @@ def test_gpt4vision_default_values():
assert gpt4vision.quality == "low"
assert gpt4vision.max_tokens == 200
+
def test_gpt4vision_api_key_from_env_variable():
# Arrange
- api_key = os.environ["OPENAI_API_KEY"]
-
+ api_key = os.environ["OPENAI_API_KEY"]
+
# Act
gpt4vision = GPT4Vision()
-
+
# Assert
assert gpt4vision.api_key == api_key
+
def test_gpt4vision_set_api_key():
# Arrange
gpt4vision = GPT4Vision(api_key=api_key)
-
+
# Assert
assert gpt4vision.api_key == api_key
+
def test_gpt4vision_invalid_max_retries():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(max_retries=-1)
+
def test_gpt4vision_invalid_backoff_factor():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(backoff_factor=-1)
+
def test_gpt4vision_invalid_timeout_seconds():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(timeout_seconds=-1)
+
def test_gpt4vision_invalid_max_tokens():
# Arrange and Act
with pytest.raises(ValueError):
GPT4Vision(max_tokens=-1)
+
def test_gpt4vision_logger_initialized():
# Arrange
gpt4vision = GPT4Vision()
-
+
# Assert
assert isinstance(gpt4vision.logger, logging.Logger)
+
def test_gpt4vision_process_img_nonexistent_file():
# Arrange
gpt4vision = GPT4Vision()
img_path = "nonexistent_image.jpg"
-
+
# Act and Assert
with pytest.raises(FileNotFoundError):
gpt4vision.process_img(img_path)
+
def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
@@ -96,7 +108,10 @@ def test_gpt4vision_call_single_task_single_image_no_openai_client(gpt4vision):
with pytest.raises(AttributeError):
gpt4vision(img_url, [task])
-def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, mock_openai_client):
+
+def test_gpt4vision_call_single_task_single_image_empty_response(
+ gpt4vision, mock_openai_client
+):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
@@ -110,7 +125,10 @@ def test_gpt4vision_call_single_task_single_image_empty_response(gpt4vision, moc
assert response.answer == ""
mock_openai_client.chat.completions.create.assert_called_once()
-def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision, mock_openai_client):
+
+def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(
+ gpt4vision, mock_openai_client
+):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
tasks = ["Describe this image.", "What's in this picture?"]
@@ -122,20 +140,30 @@ def test_gpt4vision_call_multiple_tasks_single_image_empty_responses(gpt4vision,
# Assert
assert all(response.answer == "" for response in responses)
- assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 1
+ ) # Should be called only once
+
-def test_gpt4vision_call_single_task_single_image_timeout(gpt4vision, mock_openai_client):
+def test_gpt4vision_call_single_task_single_image_timeout(
+ gpt4vision, mock_openai_client
+):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
- mock_openai_client.chat.completions.create.side_effect = Timeout("Request timed out")
+ mock_openai_client.chat.completions.create.side_effect = Timeout(
+ "Request timed out"
+ )
# Act and Assert
with pytest.raises(Timeout):
gpt4vision(img_url, [task])
-def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_openai_client):
+
+def test_gpt4vision_call_retry_with_success_after_timeout(
+ gpt4vision, mock_openai_client
+):
# Arrange
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
@@ -143,7 +171,11 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena
# Simulate success after a timeout and retry
mock_openai_client.chat.completions.create.side_effect = [
Timeout("Request timed out"),
- {"choices": [{"message": {"content": {"text": "A description of the image."}}}],}
+ {
+ "choices": [
+ {"message": {"content": {"text": "A description of the image."}}}
+ ],
+ },
]
# Act
@@ -151,7 +183,9 @@ def test_gpt4vision_call_retry_with_success_after_timeout(gpt4vision, mock_opena
# Assert
assert response.answer == "A description of the image."
- assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 2
+ ) # Should be called twice
def test_gpt4vision_process_img():
@@ -173,7 +207,9 @@ def test_gpt4vision_call_single_task_single_image(gpt4vision, mock_openai_client
expected_response = GPT4VisionResponse(answer="A description of the image.")
- mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer
+ mock_openai_client.chat.completions.create.return_value.choices[
+ 0
+ ].text = expected_response.answer
# Act
response = gpt4vision(img_url, [task])
@@ -190,7 +226,9 @@ def test_gpt4vision_call_single_task_multiple_images(gpt4vision, mock_openai_cli
expected_response = GPT4VisionResponse(answer="Descriptions of the images.")
- mock_openai_client.chat.completions.create.return_value.choices[0].text = expected_response.answer
+ mock_openai_client.chat.completions.create.return_value.choices[
+ 0
+ ].text = expected_response.answer
# Act
response = gpt4vision(img_urls, [task])
@@ -213,57 +251,76 @@ def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_cli
def create_mock_response(response):
return {"choices": [{"message": {"content": {"text": response.answer}}}]}
- mock_openai_client.chat.completions.create.side_effect = [create_mock_response(response) for response in expected_responses]
+ mock_openai_client.chat.completions.create.side_effect = [
+ create_mock_response(response) for response in expected_responses
+ ]
# Act
responses = gpt4vision(img_url, tasks)
# Assert
assert responses == expected_responses
- assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
- def test_gpt4vision_call_multiple_tasks_single_image(gpt4vision, mock_openai_client):
- # Arrange
- img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
- tasks = ["Describe this image.", "What's in this picture?"]
-
- expected_responses = [
- GPT4VisionResponse(answer="A description of the image."),
- GPT4VisionResponse(answer="It contains various objects."),
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 1
+ ) # Should be called only once
+
+ def test_gpt4vision_call_multiple_tasks_single_image(
+ gpt4vision, mock_openai_client
+ ):
+ # Arrange
+ img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
+ tasks = ["Describe this image.", "What's in this picture?"]
+
+ expected_responses = [
+ GPT4VisionResponse(answer="A description of the image."),
+ GPT4VisionResponse(answer="It contains various objects."),
+ ]
+
+ mock_openai_client.chat.completions.create.side_effect = [
+ {
+ "choices": [
+ {"message": {"content": {"text": expected_responses[i].answer}}}
]
+ }
+ for i in range(len(expected_responses))
+ ]
- mock_openai_client.chat.completions.create.side_effect = [
- {"choices": [{"message": {"content": {"text": expected_responses[i].answer}}}] } for i in range(len(expected_responses))
- ]
+ # Act
+ responses = gpt4vision(img_url, tasks)
- # Act
- responses = gpt4vision(img_url, tasks)
-
- # Assert
- assert responses == expected_responses
- assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
+ # Assert
+ assert responses == expected_responses
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 1
+ ) # Should be called only once
def test_gpt4vision_call_multiple_tasks_multiple_images(gpt4vision, mock_openai_client):
# Arrange
- img_urls = ["https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"]
+ img_urls = [
+ "https://images.unsplash.com/photo-1694734479857-626882b6db37?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
+ "https://images.unsplash.com/photo-1694734479898-6ac4633158ac?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
+ ]
tasks = ["Describe these images.", "What's in these pictures?"]
expected_responses = [
GPT4VisionResponse(answer="Descriptions of the images."),
- GPT4VisionResponse(answer="They contain various objects.")
+ GPT4VisionResponse(answer="They contain various objects."),
]
mock_openai_client.chat.completions.create.side_effect = [
- {"choices": [{"message": {"content": {"text": response.answer}}}] } for response in expected_responses
+ {"choices": [{"message": {"content": {"text": response.answer}}}]}
+ for response in expected_responses
]
# Act
responses = gpt4vision(img_urls, tasks)
-
# Assert
assert responses == expected_responses
- assert mock_openai_client.chat.completions.create.call_count == 1 # Should be called only once
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 1
+ ) # Should be called only once
def test_gpt4vision_call_http_error(gpt4vision, mock_openai_client):
@@ -283,7 +340,9 @@ def test_gpt4vision_call_request_error(gpt4vision, mock_openai_client):
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
- mock_openai_client.chat.completions.create.side_effect = RequestException("Request Error")
+ mock_openai_client.chat.completions.create.side_effect = RequestException(
+ "Request Error"
+ )
# Act and Assert
with pytest.raises(RequestException):
@@ -295,7 +354,9 @@ def test_gpt4vision_call_connection_error(gpt4vision, mock_openai_client):
img_url = "https://images.unsplash.com/photo-1694734479942-8cc7f4660578?q=80&w=1287&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
task = "Describe this image."
- mock_openai_client.chat.completions.create.side_effect = ConnectionError("Connection Error")
+ mock_openai_client.chat.completions.create.side_effect = ConnectionError(
+ "Connection Error"
+ )
# Act and Assert
with pytest.raises(ConnectionError):
@@ -310,7 +371,9 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
# Simulate success after a retry
mock_openai_client.chat.completions.create.side_effect = [
RequestException("Temporary error"),
- {"choices": [{"text": "A description of the image."}]} # fixed dictionary syntax
+ {
+ "choices": [{"text": "A description of the image."}]
+ }, # fixed dictionary syntax
]
# Act
@@ -318,4 +381,6 @@ def test_gpt4vision_call_retry_with_success(gpt4vision, mock_openai_client):
# Assert
assert response.answer == "A description of the image."
- assert mock_openai_client.chat.completions.create.call_count == 2 # Should be called twice
+ assert (
+ mock_openai_client.chat.completions.create.call_count == 2
+ ) # Should be called twice