From 57cf98b48412e89897e7ca546bf6afb41843dfc0 Mon Sep 17 00:00:00 2001
From: pliny <133052465+elder-plinius@users.noreply.github.com>
Date: Fri, 1 Dec 2023 08:39:33 -0800
Subject: [PATCH] Update gpt4_vision_api.py

---
 swarms/models/gpt4_vision_api.py | 101 +++++++++++--------------------
 1 file changed, 34 insertions(+), 67 deletions(-)

diff --git a/swarms/models/gpt4_vision_api.py b/swarms/models/gpt4_vision_api.py
index 6efb68f4..b589fe70 100644
--- a/swarms/models/gpt4_vision_api.py
+++ b/swarms/models/gpt4_vision_api.py
@@ -15,10 +15,7 @@ from termcolor import colored
 try:
     import cv2
 except ImportError:
-    print(
-        "OpenCV not installed. Please install OpenCV to use this"
-        " model."
-    )
+    print("OpenCV not installed. Please install OpenCV to use this model.")
     raise ImportError
 
 # Load environment variables
@@ -103,79 +100,60 @@ class GPT4VisionAPI:
         if self.meta_prompt:
             self.system_prompt = self.meta_prompt_init()
 
-    def encode_image(self, img: str):
+    def encode_image(self, img_path: str):
         """Encode image to base64."""
-        if not os.path.exists(img):
-            print(f"Image file not found: {img}")
+        if not os.path.exists(img_path):
+            print(f"Image file not found: {img_path}")
             return None
 
-        with open(img, "rb") as image_file:
+        with open(img_path, "rb") as image_file:
             return base64.b64encode(image_file.read()).decode("utf-8")
 
+    
+
     def download_img_then_encode(self, img: str):
         """Download image from URL then encode image to base64 using requests"""
         pass
 
     # Function to handle vision tasks
     def run(
-        self,
-        task: Optional[str] = None,
-        img: Optional[str] = None,
-        *args,
-        **kwargs,
-    ):
+            self, 
+            image_path, 
+            task):
         """Run the model."""
         try:
-            base64_image = self.encode_image(img)
+            base64_image = self.encode_image(image_path)
             headers = {
                 "Content-Type": "application/json",
-                "Authorization": f"Bearer {openai_api_key}",
+                "Authorization": f"Bearer {self.openai_api_key}",
             }
             payload = {
                 "model": self.model_name,
                 "messages": [
-                    {
-                        "role": "system",
-                        "content": [self.system_prompt],
-                    },
+                    {"role": "system", "content": [self.system_prompt]},
                     {
                         "role": "user",
                         "content": [
                             {"type": "text", "text": task},
-                            {
-                                "type": "image_url",
-                                "image_url": {
-                                    "url": f"data:image/jpeg;base64,{base64_image}"
-                                },
-                            },
+                            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                         ],
                     },
                 ],
                 "max_tokens": self.max_tokens,
             }
-            response = requests.post(
-                self.openai_proxy,
-                headers=headers,
-                json=payload,
-            )
+            response = requests.post(self.openai_proxy, headers=headers, json=payload)
 
             out = response.json()
-            content = out["choices"][0]["message"]["content"]
-
-            if self.streaming_enabled:
-                content = self.stream_response(content)
-            else:
-                pass
-
-            if self.beautify:
-                content = colored(content, "cyan")
-                print(content)
+            if 'choices' in out and out['choices']:
+                content = out["choices"][0].get("message", {}).get("content", None)
+                return content
             else:
-                print(content)
+                print("No valid response in 'choices'")
+                return None
 
         except Exception as error:
             print(f"Error with the request: {error}")
-            raise error
+            return None
 
     def video_prompt(self, frames):
         """
@@ -249,9 +227,7 @@ class GPT4VisionAPI:
             if not success:
                 break
             _, buffer = cv2.imencode(".jpg", frame)
-            base64_frames.append(
-                base64.b64encode(buffer).decode("utf-8")
-            )
+            base64_frames.append(base64.b64encode(buffer).decode("utf-8"))
 
         video.release()
         print(len(base64_frames), "frames read.")
@@ -276,10 +252,7 @@ class GPT4VisionAPI:
             payload = {
                 "model": self.model_name,
                 "messages": [
-                    {
-                        "role": "system",
-                        "content": [self.system_prompt],
-                    },
+                    {"role": "system", "content": [self.system_prompt]},
                     {
                         "role": "user",
                         "content": [
@@ -287,7 +260,9 @@ class GPT4VisionAPI:
                             {
                                 "type": "image_url",
                                 "image_url": {
-                                    "url": f"data:image/jpeg;base64,{base64_image}"
+                                    "url": (
+                                        f"data:image/jpeg;base64,{base64_image}"
+                                    )
                                 },
                             },
                         ],
@@ -337,9 +312,7 @@ class GPT4VisionAPI:
 
         """
         # Instantiate the thread pool executor
-        with ThreadPoolExecutor(
-            max_workers=self.max_workers
-        ) as executor:
+        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
             results = executor.map(self.run, tasks, imgs)
 
         # Print the results for debugging
@@ -385,7 +358,9 @@ class GPT4VisionAPI:
                             {
                                 "type": "image_url",
                                 "image_url": {
-                                    "url": f"data:image/jpeg;base64,{base64_image}"
+                                    "url": (
+                                        f"data:image/jpeg;base64,{base64_image}"
+                                    )
                                 },
                             },
                         ],
@@ -395,9 +370,7 @@ class GPT4VisionAPI:
             }
             async with aiohttp.ClientSession() as session:
                 async with session.post(
-                    self.openai_proxy,
-                    headers=headers,
-                    data=json.dumps(payload),
+                    self.openai_proxy, headers=headers, data=json.dumps(payload)
                 ) as response:
                     out = await response.json()
                     content = out["choices"][0]["message"]["content"]
@@ -406,9 +379,7 @@ class GPT4VisionAPI:
             print(f"Error with the request {error}")
             raise error
 
-    def run_batch(
-        self, tasks_images: List[Tuple[str, str]]
-    ) -> List[str]:
+    def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
         """Process a batch of tasks and images"""
         with concurrent.futures.ThreadPoolExecutor() as executor:
             futures = [
@@ -435,9 +406,7 @@ class GPT4VisionAPI:
         """Process a batch of tasks and images asynchronously with retries"""
         loop = asyncio.get_event_loop()
         futures = [
-            loop.run_in_executor(
-                None, self.run_with_retries, task, img
-            )
+            loop.run_in_executor(None, self.run_with_retries, task, img)
             for task, img in tasks_images
         ]
         return await asyncio.gather(*futures)
@@ -445,9 +414,7 @@ class GPT4VisionAPI:
     def health_check(self):
         """Health check for the GPT4Vision model"""
         try:
-            response = requests.get(
-                "https://api.openai.com/v1/engines"
-            )
+            response = requests.get("https://api.openai.com/v1/engines")
             return response.status_code == 200
         except requests.RequestException as error:
             print(f"Health check failed: {error}")