Update gpt4_vision_api.py

pull/239/head
pliny 1 year ago committed by GitHub
parent b726f04b02
commit 57cf98b484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,10 +15,7 @@ from termcolor import colored
try: try:
import cv2 import cv2
except ImportError: except ImportError:
print( print("OpenCV not installed. Please install OpenCV to use this model.")
"OpenCV not installed. Please install OpenCV to use this"
" model."
)
raise ImportError raise ImportError
# Load environment variables # Load environment variables
@ -103,79 +100,60 @@ class GPT4VisionAPI:
if self.meta_prompt: if self.meta_prompt:
self.system_prompt = self.meta_prompt_init() self.system_prompt = self.meta_prompt_init()
def encode_image(self, img: str): def encode_image(self, img_path: str):
"""Encode image to base64.""" """Encode image to base64."""
if not os.path.exists(img): if not os.path.exists(img_path):
print(f"Image file not found: {img}") print(f"Image file not found: {img_path}")
return None return None
with open(img, "rb") as image_file: with open(img_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")
def download_img_then_encode(self, img: str): def download_img_then_encode(self, img: str):
"""Download image from URL then encode image to base64 using requests""" """Download image from URL then encode image to base64 using requests"""
pass pass
# Function to handle vision tasks # Function to handle vision tasks
def run( def run(
self, self,
task: Optional[str] = None, image_path,
img: Optional[str] = None, task):
*args,
**kwargs,
):
"""Run the model.""" """Run the model."""
try: try:
base64_image = self.encode_image(img) base64_image = self.encode_image(image_path)
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}", "Authorization": f"Bearer {self.openai_api_key}",
} }
payload = { payload = {
"model": self.model_name, "model": self.model_name,
"messages": [ "messages": [
{ {"role": "system", "content": [self.system_prompt]},
"role": "system",
"content": [self.system_prompt],
},
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": task}, {"type": "text", "text": task},
{ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
], ],
}, },
], ],
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
} }
response = requests.post( response = requests.post(self.openai_proxy, headers=headers, json=payload)
self.openai_proxy,
headers=headers,
json=payload,
)
out = response.json() out = response.json()
content = out["choices"][0]["message"]["content"] if 'choices' in out and out['choices']:
content = out["choices"][0].get("message", {}).get("content", None)
if self.streaming_enabled: return content
content = self.stream_response(content)
else: else:
pass print("No valid response in 'choices'")
return None
if self.beautify:
content = colored(content, "cyan")
print(content)
else:
print(content)
except Exception as error: except Exception as error:
print(f"Error with the request: {error}") print(f"Error with the request: {error}")
raise error return None
def video_prompt(self, frames): def video_prompt(self, frames):
""" """
@ -249,9 +227,7 @@ class GPT4VisionAPI:
if not success: if not success:
break break
_, buffer = cv2.imencode(".jpg", frame) _, buffer = cv2.imencode(".jpg", frame)
base64_frames.append( base64_frames.append(base64.b64encode(buffer).decode("utf-8"))
base64.b64encode(buffer).decode("utf-8")
)
video.release() video.release()
print(len(base64_frames), "frames read.") print(len(base64_frames), "frames read.")
@ -276,10 +252,7 @@ class GPT4VisionAPI:
payload = { payload = {
"model": self.model_name, "model": self.model_name,
"messages": [ "messages": [
{ {"role": "system", "content": [self.system_prompt]},
"role": "system",
"content": [self.system_prompt],
},
{ {
"role": "user", "role": "user",
"content": [ "content": [
@ -287,7 +260,9 @@ class GPT4VisionAPI:
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": f"data:image/jpeg;base64,{base64_image}" "url": (
f"data:image/jpeg;base64,{base64_image}"
)
}, },
}, },
], ],
@ -337,9 +312,7 @@ class GPT4VisionAPI:
""" """
# Instantiate the thread pool executor # Instantiate the thread pool executor
with ThreadPoolExecutor( with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
max_workers=self.max_workers
) as executor:
results = executor.map(self.run, tasks, imgs) results = executor.map(self.run, tasks, imgs)
# Print the results for debugging # Print the results for debugging
@ -385,7 +358,9 @@ class GPT4VisionAPI:
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": f"data:image/jpeg;base64,{base64_image}" "url": (
f"data:image/jpeg;base64,{base64_image}"
)
}, },
}, },
], ],
@ -395,9 +370,7 @@ class GPT4VisionAPI:
} }
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
self.openai_proxy, self.openai_proxy, headers=headers, data=json.dumps(payload)
headers=headers,
data=json.dumps(payload),
) as response: ) as response:
out = await response.json() out = await response.json()
content = out["choices"][0]["message"]["content"] content = out["choices"][0]["message"]["content"]
@ -406,9 +379,7 @@ class GPT4VisionAPI:
print(f"Error with the request {error}") print(f"Error with the request {error}")
raise error raise error
def run_batch( def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images""" """Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [ futures = [
@ -435,9 +406,7 @@ class GPT4VisionAPI:
"""Process a batch of tasks and images asynchronously with retries""" """Process a batch of tasks and images asynchronously with retries"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
futures = [ futures = [
loop.run_in_executor( loop.run_in_executor(None, self.run_with_retries, task, img)
None, self.run_with_retries, task, img
)
for task, img in tasks_images for task, img in tasks_images
] ]
return await asyncio.gather(*futures) return await asyncio.gather(*futures)
@ -445,9 +414,7 @@ class GPT4VisionAPI:
def health_check(self): def health_check(self):
"""Health check for the GPT4Vision model""" """Health check for the GPT4Vision model"""
try: try:
response = requests.get( response = requests.get("https://api.openai.com/v1/engines")
"https://api.openai.com/v1/engines"
)
return response.status_code == 200 return response.status_code == 200
except requests.RequestException as error: except requests.RequestException as error:
print(f"Health check failed: {error}") print(f"Health check failed: {error}")

Loading…
Cancel
Save