Jarvis demo, base multimmodalmodel, whisperx -> whisperx_model

pull/187/head
Kye 1 year ago
parent 9390efb8aa
commit 51c82cf1f2

@ -0,0 +1,20 @@
from swarms.structs import Flow
from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)
llm = GPT4VisionAPI()
task = "What is the color of the object?"
img = "images/swarms.jpeg"
## Initialize the workflow
flow = Flow(
llm=llm,
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
max_loops="auto",
)
flow.run(task=task, img=img)

@ -20,8 +20,6 @@ from swarms.models.mpt import MPT7B # noqa: E402
# MultiModal Models # MultiModal Models
from swarms.models.idefics import Idefics # noqa: E402 from swarms.models.idefics import Idefics # noqa: E402
# from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.vilt import Vilt # noqa: E402 from swarms.models.vilt import Vilt # noqa: E402
from swarms.models.nougat import Nougat # noqa: E402 from swarms.models.nougat import Nougat # noqa: E402
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402 from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402
@ -30,6 +28,8 @@ from swarms.models.gpt4_vision_api import GPT4VisionAPI # noqa: E402
# from swarms.models.gpt4v import GPT4Vision # from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402
# from swarms.models.whisperx_model import WhisperX # noqa: E402
# from swarms.models.kosmos_two import Kosmos # noqa: E402
__all__ = [ __all__ = [
"Anthropic", "Anthropic",

@ -0,0 +1,209 @@
import asyncio
import base64
import concurrent.futures
import time
from concurrent import ThreadPoolExecutor
from io import BytesIO
from typing import List, Optional, Tuple
import requests
from ABC import abstractmethod
from PIL import Image
class BaseMultiModalModel:
def __init__(
self,
model_name: Optional[str],
temperature: Optional[int] = 0.5,
max_tokens: Optional[int] = 500,
max_workers: Optional[int] = 10,
top_p: Optional[int] = 1,
top_k: Optional[int] = 50,
device: Optional[str] = "cuda",
max_new_tokens: Optional[int] = 500,
retries: Optional[int] = 3,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.max_workers = max_workers
self.top_p = top_p
self.top_k = top_k
self.device = device
self.max_new_tokens = max_new_tokens
self.retries = retries
self.chat_history = []
@abstractmethod
def __call__(self, text: str, img: str):
"""Run the model"""
pass
def run(self, task: str, img: str):
"""Run the model"""
pass
async def arun(self, task: str, img: str):
"""Run the model asynchronously"""
pass
def get_img_from_web(self, img: str):
"""Get the image from the web"""
try:
response = requests.get(img)
response.raise_for_status()
image_pil = Image.open(BytesIO(response.content))
return image_pil
except requests.RequestException as error:
print(f"Error fetching image from {img} and error: {error}")
return None
def encode_img(self, img: str):
"""Encode the image to base64"""
with open(img, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def get_img(self, img: str):
"""Get the image from the path"""
image_pil = Image.open(img)
return image_pil
def clear_chat_history(self):
"""Clear the chat history"""
self.chat_history = []
def run_many(
self,
tasks: List[str],
imgs: List[str],
):
"""
Run the model on multiple tasks and images all at once using concurrent
Args:
tasks (List[str]): List of tasks
imgs (List[str]): List of image paths
Returns:
List[str]: List of responses
"""
# Instantiate the thread pool executor
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
results = executor.map(self.run, tasks, imgs)
# Print the results for debugging
for result in results:
print(result)
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
"""Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.run, task, img)
for task, img in tasks_images
]
results = [future.result() for future in futures]
return results
async def run_batch_async(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(None, self.run, task, img)
for task, img in tasks_images
]
return await asyncio.gather(*futures)
async def run_batch_async_with_retries(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously with retries"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(None, self.run_with_retries, task, img)
for task, img in tasks_images
]
return await asyncio.gather(*futures)
def unique_chat_history(self):
"""Get the unique chat history"""
return list(set(self.chat_history))
def run_with_retries(self, task: str, img: str):
"""Run the model with retries"""
for i in range(self.retries):
try:
return self.run(task, img)
except Exception as error:
print(f"Error with the request {error}")
continue
def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]):
"""Run the model with retries"""
for i in range(self.retries):
try:
return self.run_batch(tasks_images)
except Exception as error:
print(f"Error with the request {error}")
continue
def _tokens_per_second(self) -> float:
"""Tokens per second"""
elapsed_time = self.end_time - self.start_time
if elapsed_time == 0:
return float("inf")
return self._num_tokens() / elapsed_time
def _time_for_generation(self, task: str) -> float:
"""Time for Generation"""
self.start_time = time.time()
self.run(task)
self.end_time = time.time()
return self.end_time - self.start_time
@abstractmethod
def generate_summary(self, text: str) -> str:
"""Generate Summary"""
pass
def set_temperature(self, value: float):
"""Set Temperature"""
self.temperature = value
def set_max_tokens(self, value: int):
"""Set new max tokens"""
self.max_tokens = value
def get_generation_time(self) -> float:
"""Get generation time"""
if self.start_time and self.end_time:
return self.end_time - self.start_time
return 0
def get_chat_history(self):
"""Get the chat history"""
return self.chat_history
def get_unique_chat_history(self):
"""Get the unique chat history"""
return list(set(self.chat_history))
def get_chat_history_length(self):
"""Get the chat history length"""
return len(self.chat_history)
def get_unique_chat_history_length(self):
"""Get the unique chat history length"""
return len(list(set(self.chat_history)))
def get_chat_history_tokens(self):
"""Get the chat history tokens"""
return self._num_tokens()

@ -63,9 +63,9 @@ class Fuyu:
def __call__(self, text: str, img: str): def __call__(self, text: str, img: str):
"""Call the model with text and img paths""" """Call the model with text and img paths"""
image_pil = Image.open(img) img = self.get_img(img)
model_inputs = self.processor( model_inputs = self.processor(
text=text, images=[image_pil], device=self.device_map text=text, images=[img], device=self.device_map
) )
for k, v in model_inputs.items(): for k, v in model_inputs.items():
@ -79,13 +79,13 @@ class Fuyu:
) )
return print(str(text)) return print(str(text))
def get_img_from_web(self, img_url: str): def get_img_from_web(self, img: str):
"""Get the image from the web""" """Get the image from the web"""
try: try:
response = requests.get(img_url) response = requests.get(img)
response.raise_for_status() response.raise_for_status()
image_pil = Image.open(BytesIO(response.content)) image_pil = Image.open(BytesIO(response.content))
return image_pil return image_pil
except requests.RequestException as error: except requests.RequestException as error:
print(f"Error fetching image from {img_url} and error: {error}") print(f"Error fetching image from {img} and error: {error}")
return None return None

@ -114,7 +114,6 @@ class GPT4VisionAPI:
except Exception as error: except Exception as error:
print(f"Error with the request: {error}") print(f"Error with the request: {error}")
raise error raise error
# Function to handle vision tasks
def __call__(self, task: str, img: str): def __call__(self, task: str, img: str):
"""Run the model.""" """Run the model."""

@ -18,38 +18,31 @@ def is_overlapping(rect1, rect2):
class Kosmos: class Kosmos:
""" """
Kosmos model by Yen-Chun Shieh
Parameters
----------
model_name : str
Path to the pretrained model
Examples
--------
>>> kosmos = Kosmos()
>>> kosmos("Hello, my name is", "path/to/image.png")
Args:
# Initialize Kosmos
kosmos = Kosmos()
# Perform multimodal grounding
kosmos.multimodal_grounding("Find the red apple in the image.", "https://example.com/apple.jpg")
# Perform referring expression comprehension
kosmos.referring_expression_comprehension("Show me the green bottle.", "https://example.com/bottle.jpg")
# Generate referring expressions
kosmos.referring_expression_generation("It is on the table.", "https://example.com/table.jpg")
# Perform grounded visual question answering
kosmos.grounded_vqa("What is the color of the car?", "https://example.com/car.jpg")
# Generate grounded image caption
kosmos.grounded_image_captioning("https://example.com/beach.jpg")
""" """
def __init__( def __init__(
self, self,
model_name="ydshieh/kosmos-2-patch14-224", model_name="ydshieh/kosmos-2-patch14-224",
*args,
**kwargs,
): ):
self.model = AutoModelForVision2Seq.from_pretrained( self.model = AutoModelForVision2Seq.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True, *args, **kwargs
) )
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True, *args, **kwargs
) )
def get_image(self, url): def get_image(self, url):

@ -2,7 +2,7 @@ import os
import subprocess import subprocess
try: try:
import whisperx import swarms.models.whisperx_model as whisperx_model
from pydub import AudioSegment from pydub import AudioSegment
from pytube import YouTube from pytube import YouTube
except Exception as error: except Exception as error:
@ -66,17 +66,17 @@ class WhisperX:
compute_type = "float16" compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️ # 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model( model = whisperx_model.load_model(
"large-v2", device, compute_type=compute_type "large-v2", device, compute_type=compute_type
) )
audio = whisperx.load_audio(audio_file) audio = whisperx_model.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size) result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx_model.load_align_model(
language_code=result["language"], device=device language_code=result["language"], device=device
) )
result = whisperx.align( result = whisperx_model.align(
result["segments"], result["segments"],
model_a, model_a,
metadata, metadata,
@ -86,7 +86,7 @@ class WhisperX:
) )
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx_model.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device use_auth_token=self.hf_api_key, device=device
) )
diarize_model(audio_file) diarize_model(audio_file)
@ -99,16 +99,16 @@ class WhisperX:
print("The key 'segments' is not found in the result.") print("The key 'segments' is not found in the result.")
def transcribe(self, audio_file): def transcribe(self, audio_file):
model = whisperx.load_model("large-v2", self.device, self.compute_type) model = whisperx_model.load_model("large-v2", self.device, self.compute_type)
audio = whisperx.load_audio(audio_file) audio = whisperx_model.load_audio(audio_file)
result = model.transcribe(audio, batch_size=self.batch_size) result = model.transcribe(audio, batch_size=self.batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx_model.load_align_model(
language_code=result["language"], device=self.device language_code=result["language"], device=self.device
) )
result = whisperx.align( result = whisperx_model.align(
result["segments"], result["segments"],
model_a, model_a,
metadata, metadata,
@ -118,7 +118,7 @@ class WhisperX:
) )
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx_model.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device use_auth_token=self.hf_api_key, device=self.device
) )

@ -7,7 +7,7 @@ import pytest
import whisperx import whisperx
from pydub import AudioSegment from pydub import AudioSegment
from pytube import YouTube from pytube import YouTube
from swarms.models.whisperx import WhisperX from swarms.models.whisperx_model import WhisperX
# Fixture to create a temporary directory for testing # Fixture to create a temporary directory for testing

Loading…
Cancel
Save