[FIXES][QDrant, WhsiperX][Import Catches with install]

pull/286/head
Kye 1 year ago
parent 7529c5280e
commit f68832af38

@ -1,14 +1,29 @@
import subprocess
import uuid import uuid
from typing import Optional from typing import Optional
from attr import define, field, Factory from attr import define, field, Factory
from dataclasses import dataclass from dataclasses import dataclass
from swarms.memory.base import BaseVectorStore from swarms.memory.base import BaseVectorStore
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine, Column, String, JSON try:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import Engine
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy import create_engine, Column, String, JSON
from sqlalchemy.orm import Session from sqlalchemy.ext.declarative import declarative_base
from pgvector.sqlalchemy import Vector from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Session
except ImportError:
print("The PgVectorVectorStore requires sqlalchemy to be installed")
print("pip install sqlalchemy")
subprocess.run(["pip", "install", "sqlalchemy"])
try:
from pgvector.sqlalchemy import Vector
except ImportError:
print("The PgVectorVectorStore requires pgvector to be installed")
print("pip install pgvector")
subprocess.run(["pip", "install", "pgvector"])
@define @define

@ -1,12 +1,29 @@
import subprocess
from typing import List from typing import List
from sentence_transformers import SentenceTransformer
from httpx import RequestError from httpx import RequestError
from qdrant_client import QdrantClient
from qdrant_client.http.models import ( try:
Distance,
VectorParams, from sentence_transformers import SentenceTransformer
PointStruct, except ImportError:
) print("Please install the sentence-transformers package")
print("pip install sentence-transformers")
print("pip install qdrant-client")
subprocess.run(["pip", "install", "sentence-transformers"])
try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import (
Distance,
VectorParams,
PointStruct,
)
except ImportError:
print("Please install the qdrant-client package")
print("pip install qdrant-client")
subprocess.run(["pip", "install", "qdrant-client"])
class Qdrant: class Qdrant:

@ -2,7 +2,7 @@ import os
import subprocess import subprocess
try: try:
import swarms.models.whisperx_model as whisperx_model import whisperx
from pydub import AudioSegment from pydub import AudioSegment
from pytube import YouTube from pytube import YouTube
except Exception as error: except Exception as error:
@ -66,17 +66,17 @@ class WhisperX:
compute_type = "float16" compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️ # 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx_model.load_model( model = whisperx.load_model(
"large-v2", device, compute_type=compute_type "large-v2", device, compute_type=compute_type
) )
audio = whisperx_model.load_audio(audio_file) audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size) result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx_model.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device language_code=result["language"], device=device
) )
result = whisperx_model.align( result = whisperx.align(
result["segments"], result["segments"],
model_a, model_a,
metadata, metadata,
@ -86,7 +86,7 @@ class WhisperX:
) )
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx_model.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device use_auth_token=self.hf_api_key, device=device
) )
diarize_model(audio_file) diarize_model(audio_file)
@ -101,18 +101,18 @@ class WhisperX:
print("The key 'segments' is not found in the result.") print("The key 'segments' is not found in the result.")
def transcribe(self, audio_file): def transcribe(self, audio_file):
model = whisperx_model.load_model( model = whisperx.load_model(
"large-v2", self.device, self.compute_type "large-v2", self.device, self.compute_type
) )
audio = whisperx_model.load_audio(audio_file) audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=self.batch_size) result = model.transcribe(audio, batch_size=self.batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx_model.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device language_code=result["language"], device=self.device
) )
result = whisperx_model.align( result = whisperx.align(
result["segments"], result["segments"],
model_a, model_a,
metadata, metadata,
@ -122,7 +122,7 @@ class WhisperX:
) )
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx_model.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device use_auth_token=self.hf_api_key, device=self.device
) )

Loading…
Cancel
Save