From 589d4e1aecf649bae333f23d7d7252dd4e659fb7 Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 29 Jun 2023 11:12:35 -0400 Subject: [PATCH] whisper x tool --- README.md | 2 + requirements.txt | 8 ++- swarms/tools/main.py | 123 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d981040..b8456cae 100644 --- a/README.md +++ b/README.md @@ -129,3 +129,5 @@ Thank you for being a part of our project! * Create benchmrks * Create evaluations + +* Add new tool that uses WhiseperX to transcribe a youtube video \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 857913ae..bf3c5982 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,4 +47,10 @@ oogle-api-python-client google-auth-oauthli google-auth-httplib2 beautifulsoup4 -O365 \ No newline at end of file +O365 + + +# whisperx +pytube +pydub +git+https://github.com/m-bain/whisperx.git@v3 \ No newline at end of file diff --git a/swarms/tools/main.py b/swarms/tools/main.py index 71ec3dc2..2d5572f1 100644 --- a/swarms/tools/main.py +++ b/swarms/tools/main.py @@ -2211,3 +2211,126 @@ agent_executor = create_vectorstore_router_agent( llm=llm, toolkit=router_toolkit, verbose=True ) + + + + + + +############################################### ===========================> Whisperx speech to text +import os +from pydantic import BaseModel, Field +from pydub import AudioSegment +from pytube import YouTube +import whisperx +from langchain.tools import tool + + +# define a custom input schema for the youtube url +class YouTubeVideoInput(BaseModel): + video_url: str = Field(description="YouTube Video URL to transcribe") + + +def download_youtube_video(video_url, audio_format='mp3'): + audio_file = f'video.{audio_format}' + + # Download video + yt = YouTube(video_url) + yt_stream = yt.streams.filter(only_audio=True).first() + yt_stream.download(filename='video.mp4') + + # Convert video to audio + video = AudioSegment.from_file("video.mp4", format="mp4") + video.export(audio_file, format=audio_format) + os.remove("video.mp4") + + return audio_file + + +@tool("transcribe_youtube_video", args_schema=YouTubeVideoInput, return_direct=True) +def transcribe_youtube_video(video_url: str) -> str: + """Transcribes a YouTube video.""" + audio_file = download_youtube_video(video_url) + + device = "cuda" + batch_size = 16 + compute_type = "float16" + + # 1. Transcribe with original Whisper (batched) + model = whisperx.load_model("large-v2", device, compute_type=compute_type) + audio = whisperx.load_audio(audio_file) + result = model.transcribe(audio, batch_size=batch_size) + + # 2. Align Whisper output + model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) + result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + + # 3. Assign speaker labels + diarize_model = whisperx.DiarizationPipeline(use_auth_token='hugging face stable api key', device=device) + diarize_segments = diarize_model(audio_file) + + try: + segments = result["segments"] + transcription = " ".join(segment['text'] for segment in segments) + return transcription + except KeyError: + print("The key 'segments' is not found in the result.") + + + +################################################### +from typing import Optional, Type +from pydantic import BaseModel, Field +from langchain.tools import BaseTool +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +import requests +import whisperx + +class AudioInput(BaseModel): + audio_file: str = Field(description="Path to audio file") + + +class TranscribeAudioTool(BaseTool): + name = "transcribe_audio" + description = "Transcribes an audio file using WhisperX" + args_schema: Type[AudioInput] = AudioInput + + def _run( + self, + audio_file: str, + device: str = "cuda", + batch_size: int = 16, + compute_type: str = "float16", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + model = whisperx.load_model("large-v2", device, compute_type=compute_type) + audio = whisperx.load_audio(audio_file) + result = model.transcribe(audio, batch_size=batch_size) + + model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) + result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + + diarize_model = whisperx.DiarizationPipeline(use_auth_token='hugging face stable api key', device=device) + diarize_segments = diarize_model(audio_file) + + try: + segments = result["segments"] + transcription = " ".join(segment['text'] for segment in segments) + return transcription + except KeyError: + print("The key 'segments' is not found in the result.") + + async def _arun( + self, + audio_file: str, + device: str = "cuda", + batch_size: int = 16, + compute_type: str = "float16", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("transcribe_audio does not support async")