From 8c0a1dedfa97fc570650cb964c6d74398d975a2b Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 7 Dec 2023 13:10:01 -0800 Subject: [PATCH] [FIXES][QDrant, WhsiperX][Import Catches with install] --- swarms/memory/pg.py | 27 +++++++++++++++++++++------ swarms/memory/qdrant.py | 31 ++++++++++++++++++++++++------- swarms/models/whisperx_model.py | 22 +++++++++++----------- 3 files changed, 56 insertions(+), 24 deletions(-) diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index 334ccf70..1d44984b 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -1,14 +1,29 @@ +import subprocess import uuid from typing import Optional from attr import define, field, Factory from dataclasses import dataclass from swarms.memory.base import BaseVectorStore -from sqlalchemy.engine import Engine -from sqlalchemy import create_engine, Column, String, JSON -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Session -from pgvector.sqlalchemy import Vector + +try: + from sqlalchemy.engine import Engine + from sqlalchemy import create_engine, Column, String, JSON + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.dialects.postgresql import UUID + from sqlalchemy.orm import Session +except ImportError: + print("The PgVectorVectorStore requires sqlalchemy to be installed") + print("pip install sqlalchemy") + subprocess.run(["pip", "install", "sqlalchemy"]) + +try: + + from pgvector.sqlalchemy import Vector +except ImportError: + print("The PgVectorVectorStore requires pgvector to be installed") + print("pip install pgvector") + subprocess.run(["pip", "install", "pgvector"]) + @define diff --git a/swarms/memory/qdrant.py b/swarms/memory/qdrant.py index 76a5785b..14a63e0b 100644 --- a/swarms/memory/qdrant.py +++ b/swarms/memory/qdrant.py @@ -1,12 +1,29 @@ +import subprocess from typing import List -from sentence_transformers import SentenceTransformer from httpx import RequestError -from qdrant_client import QdrantClient -from qdrant_client.http.models import ( - Distance, - VectorParams, - PointStruct, -) + +try: + + from sentence_transformers import SentenceTransformer +except ImportError: + print("Please install the sentence-transformers package") + print("pip install sentence-transformers") + print("pip install qdrant-client") + subprocess.run(["pip", "install", "sentence-transformers"]) + + +try: + + from qdrant_client import QdrantClient + from qdrant_client.http.models import ( + Distance, + VectorParams, + PointStruct, + ) +except ImportError: + print("Please install the qdrant-client package") + print("pip install qdrant-client") + subprocess.run(["pip", "install", "qdrant-client"]) class Qdrant: diff --git a/swarms/models/whisperx_model.py b/swarms/models/whisperx_model.py index a41d0430..e3b76fae 100644 --- a/swarms/models/whisperx_model.py +++ b/swarms/models/whisperx_model.py @@ -2,7 +2,7 @@ import os import subprocess try: - import swarms.models.whisperx_model as whisperx_model + import whisperx from pydub import AudioSegment from pytube import YouTube except Exception as error: @@ -66,17 +66,17 @@ class WhisperX: compute_type = "float16" # 1. Transcribe with original Whisper (batched) 🗣️ - model = whisperx_model.load_model( + model = whisperx.load_model( "large-v2", device, compute_type=compute_type ) - audio = whisperx_model.load_audio(audio_file) + audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=batch_size) # 2. Align Whisper output 🔍 - model_a, metadata = whisperx_model.load_align_model( + model_a, metadata = whisperx.load_align_model( language_code=result["language"], device=device ) - result = whisperx_model.align( + result = whisperx.align( result["segments"], model_a, metadata, @@ -86,7 +86,7 @@ class WhisperX: ) # 3. Assign speaker labels 🏷️ - diarize_model = whisperx_model.DiarizationPipeline( + diarize_model = whisperx.DiarizationPipeline( use_auth_token=self.hf_api_key, device=device ) diarize_model(audio_file) @@ -101,18 +101,18 @@ class WhisperX: print("The key 'segments' is not found in the result.") def transcribe(self, audio_file): - model = whisperx_model.load_model( + model = whisperx.load_model( "large-v2", self.device, self.compute_type ) - audio = whisperx_model.load_audio(audio_file) + audio = whisperx.load_audio(audio_file) result = model.transcribe(audio, batch_size=self.batch_size) # 2. Align Whisper output 🔍 - model_a, metadata = whisperx_model.load_align_model( + model_a, metadata = whisperx.load_align_model( language_code=result["language"], device=self.device ) - result = whisperx_model.align( + result = whisperx.align( result["segments"], model_a, metadata, @@ -122,7 +122,7 @@ class WhisperX: ) # 3. Assign speaker labels 🏷️ - diarize_model = whisperx_model.DiarizationPipeline( + diarize_model = whisperx.DiarizationPipeline( use_auth_token=self.hf_api_key, device=self.device )