|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
|
import base64
|
|
|
|
import base64
|
|
|
|
import concurrent.futures
|
|
|
|
import concurrent.futures
|
|
|
@ -7,8 +8,8 @@ from io import BytesIO
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
from ABC import abstractmethod
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from termcolor import colored
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseMultiModalModel:
|
|
|
|
class BaseMultiModalModel:
|
|
|
@ -37,7 +38,6 @@ class BaseMultiModalModel:
|
|
|
|
self.retries = retries
|
|
|
|
self.retries = retries
|
|
|
|
self.chat_history = []
|
|
|
|
self.chat_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
@abstractmethod
|
|
|
|
def __call__(self, text: str, img: str):
|
|
|
|
def __call__(self, text: str, img: str):
|
|
|
|
"""Run the model"""
|
|
|
|
"""Run the model"""
|
|
|
@ -61,17 +61,17 @@ class BaseMultiModalModel:
|
|
|
|
except requests.RequestException as error:
|
|
|
|
except requests.RequestException as error:
|
|
|
|
print(f"Error fetching image from {img} and error: {error}")
|
|
|
|
print(f"Error fetching image from {img} and error: {error}")
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def encode_img(self, img: str):
|
|
|
|
def encode_img(self, img: str):
|
|
|
|
"""Encode the image to base64"""
|
|
|
|
"""Encode the image to base64"""
|
|
|
|
with open(img, "rb") as image_file:
|
|
|
|
with open(img, "rb") as image_file:
|
|
|
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
|
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
def get_img(self, img: str):
|
|
|
|
def get_img(self, img: str):
|
|
|
|
"""Get the image from the path"""
|
|
|
|
"""Get the image from the path"""
|
|
|
|
image_pil = Image.open(img)
|
|
|
|
image_pil = Image.open(img)
|
|
|
|
return image_pil
|
|
|
|
return image_pil
|
|
|
|
|
|
|
|
|
|
|
|
def clear_chat_history(self):
|
|
|
|
def clear_chat_history(self):
|
|
|
|
"""Clear the chat history"""
|
|
|
|
"""Clear the chat history"""
|
|
|
|
self.chat_history = []
|
|
|
|
self.chat_history = []
|
|
|
@ -87,11 +87,11 @@ class BaseMultiModalModel:
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
tasks (List[str]): List of tasks
|
|
|
|
tasks (List[str]): List of tasks
|
|
|
|
imgs (List[str]): List of image paths
|
|
|
|
imgs (List[str]): List of image paths
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
List[str]: List of responses
|
|
|
|
List[str]: List of responses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# Instantiate the thread pool executor
|
|
|
|
# Instantiate the thread pool executor
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
|
@ -101,7 +101,6 @@ class BaseMultiModalModel:
|
|
|
|
for result in results:
|
|
|
|
for result in results:
|
|
|
|
print(result)
|
|
|
|
print(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
|
|
|
|
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
|
|
|
|
"""Process a batch of tasks and images"""
|
|
|
|
"""Process a batch of tasks and images"""
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
@ -133,11 +132,11 @@ class BaseMultiModalModel:
|
|
|
|
for task, img in tasks_images
|
|
|
|
for task, img in tasks_images
|
|
|
|
]
|
|
|
|
]
|
|
|
|
return await asyncio.gather(*futures)
|
|
|
|
return await asyncio.gather(*futures)
|
|
|
|
|
|
|
|
|
|
|
|
def unique_chat_history(self):
|
|
|
|
def unique_chat_history(self):
|
|
|
|
"""Get the unique chat history"""
|
|
|
|
"""Get the unique chat history"""
|
|
|
|
return list(set(self.chat_history))
|
|
|
|
return list(set(self.chat_history))
|
|
|
|
|
|
|
|
|
|
|
|
def run_with_retries(self, task: str, img: str):
|
|
|
|
def run_with_retries(self, task: str, img: str):
|
|
|
|
"""Run the model with retries"""
|
|
|
|
"""Run the model with retries"""
|
|
|
|
for i in range(self.retries):
|
|
|
|
for i in range(self.retries):
|
|
|
@ -146,7 +145,7 @@ class BaseMultiModalModel:
|
|
|
|
except Exception as error:
|
|
|
|
except Exception as error:
|
|
|
|
print(f"Error with the request {error}")
|
|
|
|
print(f"Error with the request {error}")
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]):
|
|
|
|
def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]):
|
|
|
|
"""Run the model with retries"""
|
|
|
|
"""Run the model with retries"""
|
|
|
|
for i in range(self.retries):
|
|
|
|
for i in range(self.retries):
|
|
|
@ -188,28 +187,37 @@ class BaseMultiModalModel:
|
|
|
|
if self.start_time and self.end_time:
|
|
|
|
if self.start_time and self.end_time:
|
|
|
|
return self.end_time - self.start_time
|
|
|
|
return self.end_time - self.start_time
|
|
|
|
return 0
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def get_chat_history(self):
|
|
|
|
def get_chat_history(self):
|
|
|
|
"""Get the chat history"""
|
|
|
|
"""Get the chat history"""
|
|
|
|
return self.chat_history
|
|
|
|
return self.chat_history
|
|
|
|
|
|
|
|
|
|
|
|
def get_unique_chat_history(self):
|
|
|
|
def get_unique_chat_history(self):
|
|
|
|
"""Get the unique chat history"""
|
|
|
|
"""Get the unique chat history"""
|
|
|
|
return list(set(self.chat_history))
|
|
|
|
return list(set(self.chat_history))
|
|
|
|
|
|
|
|
|
|
|
|
def get_chat_history_length(self):
|
|
|
|
def get_chat_history_length(self):
|
|
|
|
"""Get the chat history length"""
|
|
|
|
"""Get the chat history length"""
|
|
|
|
return len(self.chat_history)
|
|
|
|
return len(self.chat_history)
|
|
|
|
|
|
|
|
|
|
|
|
def get_unique_chat_history_length(self):
|
|
|
|
def get_unique_chat_history_length(self):
|
|
|
|
"""Get the unique chat history length"""
|
|
|
|
"""Get the unique chat history length"""
|
|
|
|
return len(list(set(self.chat_history)))
|
|
|
|
return len(list(set(self.chat_history)))
|
|
|
|
|
|
|
|
|
|
|
|
def get_chat_history_tokens(self):
|
|
|
|
def get_chat_history_tokens(self):
|
|
|
|
"""Get the chat history tokens"""
|
|
|
|
"""Get the chat history tokens"""
|
|
|
|
return self._num_tokens()
|
|
|
|
return self._num_tokens()
|
|
|
|
|
|
|
|
|
|
|
|
def print_beautiful(self, content: str, color: str = "cyan"):
|
|
|
|
def print_beautiful(self, content: str, color: str = "cyan"):
|
|
|
|
"""Print Beautifully with termcolor"""
|
|
|
|
"""Print Beautifully with termcolor"""
|
|
|
|
content = colored(content, color)
|
|
|
|
content = colored(content, color)
|
|
|
|
print(content)
|
|
|
|
print(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stream(self, content: str):
|
|
|
|
|
|
|
|
"""Stream the output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
content (str): _description_
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
for chunk in content:
|
|
|
|
|
|
|
|
print(chunk)
|
|
|
|