parent
44009e3dc5
commit
1e8137249a
@ -1,279 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import concurrent.futures
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import requests
|
|
||||||
from cachetools import TTLCache
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from openai import OpenAI
|
|
||||||
from ratelimit import limits, sleep_and_retry
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
# ENV
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPT4VisionResponse:
|
|
||||||
"""A response structure for GPT-4"""
|
|
||||||
|
|
||||||
answer: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPT4Vision:
|
|
||||||
"""
|
|
||||||
GPT4Vision model class
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
-----------
|
|
||||||
max_retries: int
|
|
||||||
The maximum number of retries to make to the API
|
|
||||||
backoff_factor: float
|
|
||||||
The backoff factor to use for exponential backoff
|
|
||||||
timeout_seconds: int
|
|
||||||
The timeout in seconds for the API request
|
|
||||||
api_key: str
|
|
||||||
The API key to use for the API request
|
|
||||||
quality: str
|
|
||||||
The quality of the image to generate
|
|
||||||
max_tokens: int
|
|
||||||
The maximum number of tokens to use for the API request
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
--------
|
|
||||||
process_img(self, img_path: str) -> str:
|
|
||||||
Processes the image to be used for the API request
|
|
||||||
run(self, img: Union[str, List[str]], tasks: List[str]) -> GPT4VisionResponse:
|
|
||||||
Makes a call to the GPT-4 Vision API and returns the image url
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> gpt4vision = GPT4Vision()
|
|
||||||
>>> img = "https://cdn.openai.com/dall-e/encoded/feats/feats_01J9J5ZKJZJY9.png"
|
|
||||||
>>> tasks = ["A painting of a dog"]
|
|
||||||
>>> answer = gpt4vision(img, tasks)
|
|
||||||
>>> print(answer)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_retries: int = 3
|
|
||||||
model: str = "gpt-4-vision-preview"
|
|
||||||
backoff_factor: float = 2.0
|
|
||||||
timeout_seconds: int = 10
|
|
||||||
openai_api_key: Optional[str] = None or os.getenv(
|
|
||||||
"OPENAI_API_KEY"
|
|
||||||
)
|
|
||||||
# 'Low' or 'High' for respesctively fast or high quality, but high more token usage
|
|
||||||
quality: str = "low"
|
|
||||||
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know
|
|
||||||
max_tokens: int = 200
|
|
||||||
client = OpenAI(
|
|
||||||
api_key=openai_api_key,
|
|
||||||
)
|
|
||||||
dashboard: bool = True
|
|
||||||
call_limit: int = 1
|
|
||||||
period_seconds: int = 60
|
|
||||||
|
|
||||||
# Cache for storing API Responses
|
|
||||||
cache = TTLCache(maxsize=100, ttl=600) # Cache for 10 minutes
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Config class for the GPT4Vision model"""
|
|
||||||
|
|
||||||
arbitary_types_allowed = True
|
|
||||||
|
|
||||||
def process_img(self, img: str) -> str:
|
|
||||||
"""Processes the image to be used for the API request"""
|
|
||||||
with open(img, "rb") as image_file:
|
|
||||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
|
|
||||||
@sleep_and_retry
|
|
||||||
@limits(
|
|
||||||
calls=call_limit, period=period_seconds
|
|
||||||
) # Rate limit of 10 calls per minute
|
|
||||||
def run(self, task: str, img: str):
|
|
||||||
"""
|
|
||||||
Run the GPT-4 Vision model
|
|
||||||
|
|
||||||
Task: str
|
|
||||||
The task to run
|
|
||||||
Img: str
|
|
||||||
The image to run the task on
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.dashboard:
|
|
||||||
self.print_dashboard()
|
|
||||||
try:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model="gpt-4-vision-preview",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": task},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": str(img),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = print(response.choices[0])
|
|
||||||
# out = self.clean_output(out)
|
|
||||||
return out
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
# logger.error(f"OpenAI API error: {e}")
|
|
||||||
return (
|
|
||||||
f"OpenAI API error: Could not process the image. {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return (
|
|
||||||
"Unexpected error occurred while processing the"
|
|
||||||
f" image. {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def clean_output(self, output: str):
|
|
||||||
# Regex pattern to find the Choice object representation in the output
|
|
||||||
pattern = r"Choice\(.*?\(content=\"(.*?)\".*?\)\)"
|
|
||||||
match = re.search(pattern, output, re.DOTALL)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
# Extract the content from the matched pattern
|
|
||||||
content = match.group(1)
|
|
||||||
# Replace escaped quotes to get the clean content
|
|
||||||
content = content.replace(r"\"", '"')
|
|
||||||
print(content)
|
|
||||||
else:
|
|
||||||
print("No content found in the output.")
|
|
||||||
|
|
||||||
async def arun(self, task: str, img: str):
|
|
||||||
"""
|
|
||||||
Arun is an async version of run
|
|
||||||
|
|
||||||
Task: str
|
|
||||||
The task to run
|
|
||||||
Img: str
|
|
||||||
The image to run the task on
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model="gpt-4-vision-preview",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": task},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": img,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return print(response.choices[0])
|
|
||||||
except openai.OpenAIError as e:
|
|
||||||
# logger.error(f"OpenAI API error: {e}")
|
|
||||||
return (
|
|
||||||
f"OpenAI API error: Could not process the image. {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return (
|
|
||||||
"Unexpected error occurred while processing the"
|
|
||||||
f" image. {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_batch(
|
|
||||||
self, tasks_images: List[Tuple[str, str]]
|
|
||||||
) -> List[str]:
|
|
||||||
"""Process a batch of tasks and images"""
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(self.run, task, img)
|
|
||||||
for task, img in tasks_images
|
|
||||||
]
|
|
||||||
results = [future.result() for future in futures]
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def run_batch_async(
|
|
||||||
self, tasks_images: List[Tuple[str, str]]
|
|
||||||
) -> List[str]:
|
|
||||||
"""Process a batch of tasks and images asynchronously"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
futures = [
|
|
||||||
loop.run_in_executor(None, self.run, task, img)
|
|
||||||
for task, img in tasks_images
|
|
||||||
]
|
|
||||||
return await asyncio.gather(*futures)
|
|
||||||
|
|
||||||
async def run_batch_async_with_retries(
|
|
||||||
self, tasks_images: List[Tuple[str, str]]
|
|
||||||
) -> List[str]:
|
|
||||||
"""Process a batch of tasks and images asynchronously with retries"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
futures = [
|
|
||||||
loop.run_in_executor(
|
|
||||||
None, self.run_with_retries, task, img
|
|
||||||
)
|
|
||||||
for task, img in tasks_images
|
|
||||||
]
|
|
||||||
return await asyncio.gather(*futures)
|
|
||||||
|
|
||||||
def print_dashboard(self):
|
|
||||||
dashboard = print(
|
|
||||||
colored(
|
|
||||||
f"""
|
|
||||||
GPT4Vision Dashboard
|
|
||||||
-------------------
|
|
||||||
Max Retries: {self.max_retries}
|
|
||||||
Model: {self.model}
|
|
||||||
Backoff Factor: {self.backoff_factor}
|
|
||||||
Timeout Seconds: {self.timeout_seconds}
|
|
||||||
Image Quality: {self.quality}
|
|
||||||
Max Tokens: {self.max_tokens}
|
|
||||||
|
|
||||||
""",
|
|
||||||
"green",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return dashboard
|
|
||||||
|
|
||||||
def health_check(self):
|
|
||||||
"""Health check for the GPT4Vision model"""
|
|
||||||
try:
|
|
||||||
response = requests.get(
|
|
||||||
"https://api.openai.com/v1/engines"
|
|
||||||
)
|
|
||||||
return response.status_code == 200
|
|
||||||
except requests.RequestException as error:
|
|
||||||
print(f"Health check failed: {error}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def sanitize_input(self, text: str) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize input to prevent injection attacks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
text: str - The input text to be sanitized.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The sanitized text.
|
|
||||||
"""
|
|
||||||
# Example of simple sanitization, this should be expanded based on the context and usage
|
|
||||||
sanitized_text = re.sub(r"[^\w\s]", "", text)
|
|
||||||
return sanitized_text
|
|
Loading…
Reference in new issue