pull/286/head
Kye 1 year ago
parent 44009e3dc5
commit 1e8137249a

@ -1,279 +0,0 @@
import os
import asyncio
import base64
import concurrent.futures
import re
from dataclasses import dataclass
from typing import List, Optional, Tuple
import openai
import requests
from cachetools import TTLCache
from dotenv import load_dotenv
from openai import OpenAI
from ratelimit import limits, sleep_and_retry
from termcolor import colored
# ENV
load_dotenv()
@dataclass
class GPT4VisionResponse:
"""A response structure for GPT-4"""
answer: str
@dataclass
class GPT4Vision:
"""
GPT4Vision model class
Attributes:
-----------
max_retries: int
The maximum number of retries to make to the API
backoff_factor: float
The backoff factor to use for exponential backoff
timeout_seconds: int
The timeout in seconds for the API request
api_key: str
The API key to use for the API request
quality: str
The quality of the image to generate
max_tokens: int
The maximum number of tokens to use for the API request
Methods:
--------
process_img(self, img_path: str) -> str:
Processes the image to be used for the API request
run(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse:
Makes a call to the GPT-4 Vision API and returns the image url
Example:
>>> gpt4vision = GPT4Vision()
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
>>> tasks = ["A painting of a dog"]
>>> answer = gpt4vision(img, tasks)
>>> print(answer)
"""
max_retries: int = 3
model: str = "gpt-4-vision-preview"
backoff_factor: float = 2.0
timeout_seconds: int = 10
openai_api_key: Optional[str] = None or os.getenv(
"OPENAI_API_KEY"
)
# 'Low' or 'High' for respesctively fast or high quality, but high more token usage
quality: str = "low"
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know
max_tokens: int = 200
client = OpenAI(
api_key=openai_api_key,
)
dashboard: bool = True
call_limit: int = 1
period_seconds: int = 60
# Cache for storing API Responses
cache = TTLCache(maxsize=100, ttl=600) # Cache for 10 minutes
class Config:
"""Config class for the GPT4Vision model"""
arbitary_types_allowed = True
def process_img(self, img: str) -> str:
"""Processes the image to be used for the API request"""
with open(img, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@sleep_and_retry
@limits(
calls=call_limit, period=period_seconds
) # Rate limit of 10 calls per minute
def run(self, task: str, img: str):
"""
Run the GPT-4 Vision model
Task: str
The task to run
Img: str
The image to run the task on
"""
if self.dashboard:
self.print_dashboard()
try:
response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": str(img),
},
},
],
}
],
max_tokens=self.max_tokens,
)
out = print(response.choices[0])
# out = self.clean_output(out)
return out
except openai.OpenAIError as e:
# logger.error(f"OpenAI API error: {e}")
return (
f"OpenAI API error: Could not process the image. {e}"
)
except Exception as e:
return (
"Unexpected error occurred while processing the"
f" image. {e}"
)
def clean_output(self, output: str):
# Regex pattern to find the Choice object representation in the output
pattern = r"Choice\(.*?\(content=\"(.*?)\".*?\)\)"
match = re.search(pattern, output, re.DOTALL)
if match:
# Extract the content from the matched pattern
content = match.group(1)
# Replace escaped quotes to get the clean content
content = content.replace(r"\"", '"')
print(content)
else:
print("No content found in the output.")
async def arun(self, task: str, img: str):
"""
Arun is an async version of run
Task: str
The task to run
Img: str
The image to run the task on
"""
try:
response = await self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": img,
},
},
],
}
],
max_tokens=self.max_tokens,
)
return print(response.choices[0])
except openai.OpenAIError as e:
# logger.error(f"OpenAI API error: {e}")
return (
f"OpenAI API error: Could not process the image. {e}"
)
except Exception as e:
return (
"Unexpected error occurred while processing the"
f" image. {e}"
)
def run_batch(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.run, task, img)
for task, img in tasks_images
]
results = [future.result() for future in futures]
return results
async def run_batch_async(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(None, self.run, task, img)
for task, img in tasks_images
]
return await asyncio.gather(*futures)
async def run_batch_async_with_retries(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously with retries"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(
None, self.run_with_retries, task, img
)
for task, img in tasks_images
]
return await asyncio.gather(*futures)
def print_dashboard(self):
dashboard = print(
colored(
f"""
GPT4Vision Dashboard
-------------------
Max Retries: {self.max_retries}
Model: {self.model}
Backoff Factor: {self.backoff_factor}
Timeout Seconds: {self.timeout_seconds}
Image Quality: {self.quality}
Max Tokens: {self.max_tokens}
""",
"green",
)
)
return dashboard
def health_check(self):
"""Health check for the GPT4Vision model"""
try:
response = requests.get(
"https://api.openai.com/v1/engines"
)
return response.status_code == 200
except requests.RequestException as error:
print(f"Health check failed: {error}")
return False
def sanitize_input(self, text: str) -> str:
"""
Sanitize input to prevent injection attacks.
Parameters:
text: str - The input text to be sanitized.
Returns:
The sanitized text.
"""
# Example of simple sanitization, this should be expanded based on the context and usage
sanitized_text = re.sub(r"[^\w\s]", "", text)
return sanitized_text
Loading…
Cancel
Save