whisper x tool

main
Kye 2 years ago
parent bc0a736086
commit 589d4e1aec

@ -129,3 +129,5 @@ Thank you for being a part of our project!
* Create benchmrks * Create benchmrks
* Create evaluations * Create evaluations
* Add new tool that uses WhiseperX to transcribe a youtube video

@ -48,3 +48,9 @@ google-auth-oauthli
google-auth-httplib2 google-auth-httplib2
beautifulsoup4 beautifulsoup4
O365 O365
# whisperx
pytube
pydub
git+https://github.com/m-bain/whisperx.git@v3

@ -2211,3 +2211,126 @@ agent_executor = create_vectorstore_router_agent(
llm=llm, toolkit=router_toolkit, verbose=True llm=llm, toolkit=router_toolkit, verbose=True
) )
############################################### ===========================> Whisperx speech to text
import os
from pydantic import BaseModel, Field
from pydub import AudioSegment
from pytube import YouTube
import whisperx
from langchain.tools import tool
# define a custom input schema for the youtube url
class YouTubeVideoInput(BaseModel):
video_url: str = Field(description="YouTube Video URL to transcribe")
def download_youtube_video(video_url, audio_format='mp3'):
audio_file = f'video.{audio_format}'
# Download video
yt = YouTube(video_url)
yt_stream = yt.streams.filter(only_audio=True).first()
yt_stream.download(filename='video.mp4')
# Convert video to audio
video = AudioSegment.from_file("video.mp4", format="mp4")
video.export(audio_file, format=audio_format)
os.remove("video.mp4")
return audio_file
@tool("transcribe_youtube_video", args_schema=YouTubeVideoInput, return_direct=True)
def transcribe_youtube_video(video_url: str) -> str:
"""Transcribes a YouTube video."""
audio_file = download_youtube_video(video_url)
device = "cuda"
batch_size = 16
compute_type = "float16"
# 1. Transcribe with original Whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token='hugging face stable api key', device=device)
diarize_segments = diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(segment['text'] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")
###################################################
from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
import requests
import whisperx
class AudioInput(BaseModel):
audio_file: str = Field(description="Path to audio file")
class TranscribeAudioTool(BaseTool):
name = "transcribe_audio"
description = "Transcribes an audio file using WhisperX"
args_schema: Type[AudioInput] = AudioInput
def _run(
self,
audio_file: str,
device: str = "cuda",
batch_size: int = 16,
compute_type: str = "float16",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
diarize_model = whisperx.DiarizationPipeline(use_auth_token='hugging face stable api key', device=device)
diarize_segments = diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(segment['text'] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")
async def _arun(
self,
audio_file: str,
device: str = "cuda",
batch_size: int = 16,
compute_type: str = "float16",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("transcribe_audio does not support async")

Loading…
Cancel
Save