From 51c82cf1f2fb31fcf31a148e7dcbbe1dfb69c377 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 25 Nov 2023 02:16:59 -0800 Subject: [PATCH] Jarvis demo, base multimmodalmodel, whisperx -> whisperx_model --- .../jarvis_multi_modal_auto_agent/jarvis.py | 20 ++ swarms/models/__init__.py | 4 +- swarms/models/base_multimodal_model.py | 209 ++++++++++++++++++ swarms/models/fuyu.py | 10 +- swarms/models/gpt4_vision_api.py | 1 - swarms/models/kosmos_two.py | 37 ++-- .../models/{whisperx.py => whisperx_model.py} | 22 +- tests/models/test_whisperx.py | 2 +- 8 files changed, 263 insertions(+), 42 deletions(-) create mode 100644 playground/demos/jarvis_multi_modal_auto_agent/jarvis.py create mode 100644 swarms/models/base_multimodal_model.py rename swarms/models/{whisperx.py => whisperx_model.py} (85%) diff --git a/playground/demos/jarvis_multi_modal_auto_agent/jarvis.py b/playground/demos/jarvis_multi_modal_auto_agent/jarvis.py new file mode 100644 index 00000000..3e0a05cc --- /dev/null +++ b/playground/demos/jarvis_multi_modal_auto_agent/jarvis.py @@ -0,0 +1,20 @@ +from swarms.structs import Flow +from swarms.models.gpt4_vision_api import GPT4VisionAPI +from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( + MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, +) + + +llm = GPT4VisionAPI() + +task = "What is the color of the object?" +img = "images/swarms.jpeg" + +## Initialize the workflow +flow = Flow( + llm=llm, + sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, + max_loops="auto", +) + +flow.run(task=task, img=img) diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index deac803f..b7f3b8ce 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -20,8 +20,6 @@ from swarms.models.mpt import MPT7B # noqa: E402 # MultiModal Models from swarms.models.idefics import Idefics # noqa: E402 - -# from swarms.models.kosmos_two import Kosmos # noqa: E402 from swarms.models.vilt import Vilt # noqa: E402 from swarms.models.nougat import Nougat # noqa: E402 from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402 @@ -30,6 +28,8 @@ from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402 # from swarms.models.gpt4v import GPT4Vision # from swarms.models.dalle3 import Dalle3 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 +# from swarms.models.whisperx_model import WhisperX # noqa: E402 +# from swarms.models.kosmos_two import Kosmos # noqa: E402 __all__ = [ "Anthropic", diff --git a/swarms/models/base_multimodal_model.py b/swarms/models/base_multimodal_model.py new file mode 100644 index 00000000..54eed0ed --- /dev/null +++ b/swarms/models/base_multimodal_model.py @@ -0,0 +1,209 @@ +import asyncio +import base64 +import concurrent.futures +import time +from concurrent import ThreadPoolExecutor +from io import BytesIO +from typing import List, Optional, Tuple + +import requests +from ABC import abstractmethod +from PIL import Image + + +class BaseMultiModalModel: + def __init__( + self, + model_name: Optional[str], + temperature: Optional[int] = 0.5, + max_tokens: Optional[int] = 500, + max_workers: Optional[int] = 10, + top_p: Optional[int] = 1, + top_k: Optional[int] = 50, + device: Optional[str] = "cuda", + max_new_tokens: Optional[int] = 500, + retries: Optional[int] = 3, + ): + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.max_workers = max_workers + self.top_p = top_p + self.top_k = top_k + self.device = device + self.max_new_tokens = max_new_tokens + self.retries = retries + self.chat_history = [] + + + @abstractmethod + def __call__(self, text: str, img: str): + """Run the model""" + pass + + def run(self, task: str, img: str): + """Run the model""" + pass + + async def arun(self, task: str, img: str): + """Run the model asynchronously""" + pass + + def get_img_from_web(self, img: str): + """Get the image from the web""" + try: + response = requests.get(img) + response.raise_for_status() + image_pil = Image.open(BytesIO(response.content)) + return image_pil + except requests.RequestException as error: + print(f"Error fetching image from {img} and error: {error}") + return None + + def encode_img(self, img: str): + """Encode the image to base64""" + with open(img, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def get_img(self, img: str): + """Get the image from the path""" + image_pil = Image.open(img) + return image_pil + + def clear_chat_history(self): + """Clear the chat history""" + self.chat_history = [] + + def run_many( + self, + tasks: List[str], + imgs: List[str], + ): + """ + Run the model on multiple tasks and images all at once using concurrent + + Args: + tasks (List[str]): List of tasks + imgs (List[str]): List of image paths + + Returns: + List[str]: List of responses + + + """ + # Instantiate the thread pool executor + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + results = executor.map(self.run, tasks, imgs) + + # Print the results for debugging + for result in results: + print(result) + + + def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + """Process a batch of tasks and images""" + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.run, task, img) + for task, img in tasks_images + ] + results = [future.result() for future in futures] + return results + + async def run_batch_async( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: + """Process a batch of tasks and images asynchronously""" + loop = asyncio.get_event_loop() + futures = [ + loop.run_in_executor(None, self.run, task, img) + for task, img in tasks_images + ] + return await asyncio.gather(*futures) + + async def run_batch_async_with_retries( + self, tasks_images: List[Tuple[str, str]] + ) -> List[str]: + """Process a batch of tasks and images asynchronously with retries""" + loop = asyncio.get_event_loop() + futures = [ + loop.run_in_executor(None, self.run_with_retries, task, img) + for task, img in tasks_images + ] + return await asyncio.gather(*futures) + + def unique_chat_history(self): + """Get the unique chat history""" + return list(set(self.chat_history)) + + def run_with_retries(self, task: str, img: str): + """Run the model with retries""" + for i in range(self.retries): + try: + return self.run(task, img) + except Exception as error: + print(f"Error with the request {error}") + continue + + def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]): + """Run the model with retries""" + for i in range(self.retries): + try: + return self.run_batch(tasks_images) + except Exception as error: + print(f"Error with the request {error}") + continue + + def _tokens_per_second(self) -> float: + """Tokens per second""" + elapsed_time = self.end_time - self.start_time + if elapsed_time == 0: + return float("inf") + return self._num_tokens() / elapsed_time + + def _time_for_generation(self, task: str) -> float: + """Time for Generation""" + self.start_time = time.time() + self.run(task) + self.end_time = time.time() + return self.end_time - self.start_time + + @abstractmethod + def generate_summary(self, text: str) -> str: + """Generate Summary""" + pass + + def set_temperature(self, value: float): + """Set Temperature""" + self.temperature = value + + def set_max_tokens(self, value: int): + """Set new max tokens""" + self.max_tokens = value + + def get_generation_time(self) -> float: + """Get generation time""" + if self.start_time and self.end_time: + return self.end_time - self.start_time + return 0 + + def get_chat_history(self): + """Get the chat history""" + return self.chat_history + + def get_unique_chat_history(self): + """Get the unique chat history""" + return list(set(self.chat_history)) + + def get_chat_history_length(self): + """Get the chat history length""" + return len(self.chat_history) + + def get_unique_chat_history_length(self): + """Get the unique chat history length""" + return len(list(set(self.chat_history))) + + def get_chat_history_tokens(self): + """Get the chat history tokens""" + return self._num_tokens() + \ No newline at end of file diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index ed955260..79dc1c47 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -63,9 +63,9 @@ class Fuyu: def __call__(self, text: str, img: str): """Call the model with text and img paths""" - image_pil = Image.open(img) + img = self.get_img(img) model_inputs = self.processor( - text=text, images=[image_pil], device=self.device_map + text=text, images=[img], device=self.device_map ) for k, v in model_inputs.items(): @@ -79,13 +79,13 @@ class Fuyu: ) return print(str(text)) - def get_img_from_web(self, img_url: str): + def get_img_from_web(self, img: str): """Get the image from the web""" try: - response = requests.get(img_url) + response = requests.get(img) response.raise_for_status() image_pil = Image.open(BytesIO(response.content)) return image_pil except requests.RequestException as error: - print(f"Error fetching image from {img_url} and error: {error}") + print(f"Error fetching image from {img} and error: {error}") return None diff --git a/swarms/models/gpt4_vision_api.py b/swarms/models/gpt4_vision_api.py index 2a242670..8cf9371d 100644 --- a/swarms/models/gpt4_vision_api.py +++ b/swarms/models/gpt4_vision_api.py @@ -114,7 +114,6 @@ class GPT4VisionAPI: except Exception as error: print(f"Error with the request: {error}") raise error - # Function to handle vision tasks def __call__(self, task: str, img: str): """Run the model.""" diff --git a/swarms/models/kosmos_two.py b/swarms/models/kosmos_two.py index c696ef34..7e9da590 100644 --- a/swarms/models/kosmos_two.py +++ b/swarms/models/kosmos_two.py @@ -18,38 +18,31 @@ def is_overlapping(rect1, rect2): class Kosmos: """ + Kosmos model by Yen-Chun Shieh + + Parameters + ---------- + model_name : str + Path to the pretrained model + + Examples + -------- + >>> kosmos = Kosmos() + >>> kosmos("Hello, my name is", "path/to/image.png") - Args: - - - # Initialize Kosmos - kosmos = Kosmos() - - # Perform multimodal grounding - kosmos.multimodal_grounding("Find the red apple in the image.", "https://example.com/apple.jpg") - - # Perform referring expression comprehension - kosmos.referring_expression_comprehension("Show me the green bottle.", "https://example.com/bottle.jpg") - - # Generate referring expressions - kosmos.referring_expression_generation("It is on the table.", "https://example.com/table.jpg") - - # Perform grounded visual question answering - kosmos.grounded_vqa("What is the color of the car?", "https://example.com/car.jpg") - - # Generate grounded image caption - kosmos.grounded_image_captioning("https://example.com/beach.jpg") """ def __init__( self, model_name="ydshieh/kosmos-2-patch14-224", + *args, + **kwargs, ): self.model = AutoModelForVision2Seq.from_pretrained( - model_name, trust_remote_code=True + model_name, trust_remote_code=True, *args, **kwargs ) self.processor = AutoProcessor.from_pretrained( - model_name, trust_remote_code=True + model_name, trust_remote_code=True, *args, **kwargs ) def get_image(self, url): diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx_model.py similarity index 85% rename from swarms/models/whisperx.py rename to swarms/models/whisperx_model.py index 338971da..883c3edb 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx_model.py @@ -2,7 +2,7 @@ import os import subprocess try: - import whisperx + import swarms.models.whisperx_model as whisperx_model from pydub import AudioSegment from pytube import YouTube except Exception as error: @@ -66,17 +66,17 @@ class WhisperX: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx.load_model( + model = whisperx_model.load_model( "large-v2", device, compute_type=compute_type ) - audio = whisperx.load_audio(audio_file) + audio = whisperx_model.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( + model_a, metadata = whisperx_model.load_align_model( language_code=result["language"], device=device ) - result = whisperx.align( + result = whisperx_model.align( result["segments"], model_a, metadata, @@ -86,7 +86,7 @@ class WhisperX: ) # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( + diarize_model = whisperx_model.DiarizationPipeline( use_auth_token=self.hf_api_key, device=device ) diarize_model(audio_file) @@ -99,16 +99,16 @@ class WhisperX: print("The key 'segments' is not found in the result.") def transcribe(self, audio_file): - model = whisperx.load_model("large-v2", self.device, self.compute_type) - audio = whisperx.load_audio(audio_file) + model = whisperx_model.load_model("large-v2", self.device, self.compute_type) + audio = whisperx_model.load_audio(audio_file) result = model.transcribe(audio, batch_size=self.batch_size) # 2. Align Whisper output 🔍 - model_a, metadata = whisperx.load_align_model( + model_a, metadata = whisperx_model.load_align_model( language_code=result["language"], device=self.device ) - result = whisperx.align( + result = whisperx_model.align( result["segments"], model_a, metadata, @@ -118,7 +118,7 @@ class WhisperX: ) # 3. Assign speaker labels 🏷️ - diarize_model = whisperx.DiarizationPipeline( + diarize_model = whisperx_model.DiarizationPipeline( use_auth_token=self.hf_api_key, device=self.device ) diff --git a/tests/models/test_whisperx.py b/tests/models/test_whisperx.py index 5fad3431..ed671cb2 100644 --- a/tests/models/test_whisperx.py +++ b/tests/models/test_whisperx.py @@ -7,7 +7,7 @@ import pytest import whisperx from pydub import AudioSegment from pytube import YouTube -from swarms.models.whisperx import WhisperX +from swarms.models.whisperx_model import WhisperX # Fixture to create a temporary directory for testing