From 9dc68ee04ff6cf6723edccbcf5bf57e9c71ce6bc Mon Sep 17 00:00:00 2001 From: pliny <133052465+elder-plinius@users.noreply.github.com> Date: Sat, 2 Dec 2023 15:12:34 -0800 Subject: [PATCH] Update gpt4_vision_api.py --- swarms/models/gpt4_vision_api.py | 83 ++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/swarms/models/gpt4_vision_api.py b/swarms/models/gpt4_vision_api.py index b589fe70..cd6e5ddb 100644 --- a/swarms/models/gpt4_vision_api.py +++ b/swarms/models/gpt4_vision_api.py @@ -15,7 +15,10 @@ from termcolor import colored try: import cv2 except ImportError: - print("OpenCV not installed. Please install OpenCV to use this model.") + print( + "OpenCV not installed. Please install OpenCV to use this" + " model." + ) raise ImportError # Load environment variables @@ -100,29 +103,24 @@ class GPT4VisionAPI: if self.meta_prompt: self.system_prompt = self.meta_prompt_init() - def encode_image(self, img_path: str): + def encode_image(self, img: str): """Encode image to base64.""" - if not os.path.exists(img_path): - print(f"Image file not found: {img_path}") + if not os.path.exists(img): + print(f"Image file not found: {img}") return None - with open(img_path, "rb") as image_file: + with open(img, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") - - def download_img_then_encode(self, img: str): """Download image from URL then encode image to base64 using requests""" pass # Function to handle vision tasks - def run( - self, - image_path, - task): + def run(self, img, task): """Run the model.""" try: - base64_image = self.encode_image(image_path) + base64_image = self.encode_image(img) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}", @@ -130,22 +128,36 @@ class GPT4VisionAPI: payload = { "model": self.model_name, "messages": [ - {"role": "system", "content": [self.system_prompt]}, + { + "role": "system", + "content": [self.system_prompt], + }, { "role": "user", "content": [ {"type": "text", "text": task}, - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, + }, ], }, ], "max_tokens": self.max_tokens, } - response = requests.post(self.openai_proxy, headers=headers, json=payload) + response = requests.post( + self.openai_proxy, headers=headers, json=payload + ) out = response.json() - if 'choices' in out and out['choices']: - content = out["choices"][0].get("message", {}).get("content", None) + if "choices" in out and out["choices"]: + content = ( + out["choices"][0] + .get("message", {}) + .get("content", None) + ) return content else: print("No valid response in 'choices'") @@ -227,7 +239,9 @@ class GPT4VisionAPI: if not success: break _, buffer = cv2.imencode(".jpg", frame) - base64_frames.append(base64.b64encode(buffer).decode("utf-8")) + base64_frames.append( + base64.b64encode(buffer).decode("utf-8") + ) video.release() print(len(base64_frames), "frames read.") @@ -252,7 +266,10 @@ class GPT4VisionAPI: payload = { "model": self.model_name, "messages": [ - {"role": "system", "content": [self.system_prompt]}, + { + "role": "system", + "content": [self.system_prompt], + }, { "role": "user", "content": [ @@ -260,9 +277,7 @@ class GPT4VisionAPI: { "type": "image_url", "image_url": { - "url": ( - f"data:image/jpeg;base64,{base64_image}" - ) + "url": f"data:image/jpeg;base64,{base64_image}" }, }, ], @@ -312,7 +327,9 @@ class GPT4VisionAPI: """ # Instantiate the thread pool executor - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with ThreadPoolExecutor( + max_workers=self.max_workers + ) as executor: results = executor.map(self.run, tasks, imgs) # Print the results for debugging @@ -358,9 +375,7 @@ class GPT4VisionAPI: { "type": "image_url", "image_url": { - "url": ( - f"data:image/jpeg;base64,{base64_image}" - ) + "url": f"data:image/jpeg;base64,{base64_image}" }, }, ], @@ -370,7 +385,9 @@ class GPT4VisionAPI: } async with aiohttp.ClientSession() as session: async with session.post( - self.openai_proxy, headers=headers, data=json.dumps(payload) + self.openai_proxy, + headers=headers, + data=json.dumps(payload), ) as response: out = await response.json() content = out["choices"][0]["message"]["content"] @@ -379,7 +396,9 @@ class GPT4VisionAPI: print(f"Error with the request {error}") raise error - def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + def run_batch( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: """Process a batch of tasks and images""" with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -406,7 +425,9 @@ class GPT4VisionAPI: """Process a batch of tasks and images asynchronously with retries""" loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(None, self.run_with_retries, task, img) + loop.run_in_executor( + None, self.run_with_retries, task, img + ) for task, img in tasks_images ] return await asyncio.gather(*futures) @@ -414,7 +435,9 @@ class GPT4VisionAPI: def health_check(self): """Health check for the GPT4Vision model""" try: - response = requests.get("https://api.openai.com/v1/engines") + response = requests.get( + "https://api.openai.com/v1/engines" + ) return response.status_code == 200 except requests.RequestException as error: print(f"Health check failed: {error}")