From 75ebbe04f8ceabb85149afac9a177c25ce699dcc Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 4 Nov 2023 16:53:48 -0400 Subject: [PATCH] distilled whisperx --- example.py | 2 +- groupchat.py | 12 +-- swarms/models/__init__.py | 1 + swarms/models/distilled_whisperx.py | 161 +++++++++++++++++++++++++++- 4 files changed, 167 insertions(+), 9 deletions(-) diff --git a/example.py b/example.py index eb750eb7..aeae1b02 100644 --- a/example.py +++ b/example.py @@ -8,7 +8,7 @@ llm = OpenAIChat( # model_name="gpt-4" openai_api_key=api_key, temperature=0.5, - #max_tokens=100, + # max_tokens=100, ) ## Initialize the workflow diff --git a/groupchat.py b/groupchat.py index a97fbdd4..739181d1 100644 --- a/groupchat.py +++ b/groupchat.py @@ -15,28 +15,28 @@ flow1 = Flow( llm=llm, max_loops=1, system_message="YOU ARE SILLY, YOU OFFER NOTHING OF VALUE", - name='silly', + name="silly", dashboard=True, ) flow2 = Flow( llm=llm, max_loops=1, system_message="YOU ARE VERY SMART AND ANSWER RIDDLES", - name='detective', + name="detective", dashboard=True, ) flow3 = Flow( llm=llm, max_loops=1, system_message="YOU MAKE RIDDLES", - name='riddler', + name="riddler", dashboard=True, ) manager = Flow( llm=llm, max_loops=1, system_message="YOU ARE A GROUP CHAT MANAGER", - name='manager', + name="manager", dashboard=True, ) @@ -45,5 +45,5 @@ manager = Flow( agents = [flow1, flow2, flow3] group_chat = GroupChat(agents=agents, messages=[], max_round=10) -chat_manager = GroupChatManager(groupchat=group_chat, selector = manager) -chat_history = chat_manager("Write me a riddle") \ No newline at end of file +chat_manager = GroupChatManager(groupchat=group_chat, selector=manager) +chat_history = chat_manager("Write me a riddle") diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 328dd013..4cb61b9a 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -16,6 +16,7 @@ from swarms.models.kosmos_two import Kosmos from swarms.models.vilt import Vilt from swarms.models.nougat import Nougat from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA +# from swarms.models.distilled_whisperx import DistilWhisperModel # from swarms.models.fuyu import Fuyu # Not working, wait until they update import sys diff --git a/swarms/models/distilled_whisperx.py b/swarms/models/distilled_whisperx.py index 8062daa4..0a60aaac 100644 --- a/swarms/models/distilled_whisperx.py +++ b/swarms/models/distilled_whisperx.py @@ -1,3 +1,160 @@ -""" +import asyncio +import os +import time +from functools import wraps +from typing import Union -""" +import torch +from termcolor import colored +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + + +def async_retry(max_retries=3, exceptions=(Exception,), delay=1): + """ + A decorator for adding retry logic to async functions. + :param max_retries: Maximum number of retries before giving up. + :param exceptions: A tuple of exceptions to catch and retry on. + :param delay: Delay between retries. + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + retries = max_retries + while retries: + try: + return await func(*args, **kwargs) + except exceptions as e: + retries -= 1 + if retries <= 0: + raise + print(f"Retry after exception: {e}, Attempts remaining: {retries}") + await asyncio.sleep(delay) + + return wrapper + + return decorator + + +class DistilWhisperModel: + """ + This class encapsulates the Distil-Whisper model for English speech recognition. + It allows for both synchronous and asynchronous transcription of short and long-form audio. + + Args: + model_id: The model ID to use. Defaults to "distil-whisper/distil-large-v2". + + + Attributes: + device: The device to use for inference. + torch_dtype: The torch data type to use for inference. + model_id: The model ID to use. + model: The model instance. + processor: The processor instance. + + Usage: + model_wrapper = DistilWhisperModel() + transcription = model_wrapper('path/to/audio.mp3') + + # For async usage + transcription = asyncio.run(model_wrapper.async_transcribe('path/to/audio.mp3')) + """ + + def __init__(self, model_id="distil-whisper/distil-large-v2"): + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + self.model_id = model_id + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + use_safetensors=True, + ).to(self.device) + self.processor = AutoProcessor.from_pretrained(model_id) + + def __call__(self, inputs: Union[str, dict]): + return self.transcribe(inputs) + + def transcribe(self, inputs: Union[str, dict]): + """ + Synchronously transcribe the given audio input using the Distil-Whisper model. + :param inputs: A string representing the file path or a dict with audio data. + :return: The transcribed text. + """ + pipe = pipeline( + "automatic-speech-recognition", + model=self.model, + tokenizer=self.processor.tokenizer, + feature_extractor=self.processor.feature_extractor, + max_new_tokens=128, + torch_dtype=self.torch_dtype, + device=self.device, + ) + + return pipe(inputs)["text"] + + @async_retry() + async def async_transcribe(self, inputs: Union[str, dict]): + """ + Asynchronously transcribe the given audio input using the Distil-Whisper model. + :param inputs: A string representing the file path or a dict with audio data. + :return: The transcribed text. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.transcribe, inputs) + + def real_time_transcribe(self, audio_file_path, chunk_duration=5): + """ + Simulates real-time transcription of an audio file, processing and printing results + in chunks with colored output for readability. + + :param audio_file_path: Path to the audio file to be transcribed. + :param chunk_duration: Duration in seconds of each audio chunk to be processed. + """ + if not os.path.isfile(audio_file_path): + print(colored("The audio file was not found.", "red")) + return + + # Assuming `chunk_duration` is in seconds and `processor` can handle chunk-wise processing + try: + with torch.no_grad(): + # Load the whole audio file, but process and transcribe it in chunks + audio_input = self.processor.audio_file_to_array(audio_file_path) + sample_rate = audio_input.sampling_rate + total_duration = len(audio_input.array) / sample_rate + chunks = [ + audio_input.array[i : i + sample_rate * chunk_duration] + for i in range( + 0, len(audio_input.array), sample_rate * chunk_duration + ) + ] + + print(colored("Starting real-time transcription...", "green")) + + for i, chunk in enumerate(chunks): + # Process the current chunk + processed_inputs = self.processor( + chunk, + sampling_rate=sample_rate, + return_tensors="pt", + padding=True, + ) + processed_inputs = processed_inputs.input_values.to(self.device) + + # Generate transcription for the chunk + logits = self.model.generate(processed_inputs) + transcription = self.processor.batch_decode( + logits, skip_special_tokens=True + )[0] + + # Print the chunk's transcription + print( + colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + + transcription + ) + + # Wait for the chunk's duration to simulate real-time processing + time.sleep(chunk_duration) + + except Exception as e: + print(colored(f"An error occurred during transcription: {e}", "red"))