[DEMOS][TESTS]

pull/334/head
Kye 1 year ago
parent f39d722f2a
commit fd58cfa2a1

@ -0,0 +1,96 @@
import time
import os
import pygame
import speech_recognition as sr
from dotenv import load_dotenv
from playsound import playsound
from swarms import OpenAIChat, OpenAITTS
# Load the environment variables
load_dotenv()
# Get the API key from the environment
openai_api_key = os.environ.get("OPENAI_API_KEY")
# Initialize the language model
llm = OpenAIChat(
openai_api_key=openai_api_key,
)
# Initialize the text-to-speech model
tts = OpenAITTS(
model_name="tts-1-1106",
voice="onyx",
openai_api_key=openai_api_key,
saved_filepath="runs/tts_speech.wav",
)
# Initialize the speech recognition model
r = sr.Recognizer()
def play_audio(file_path):
# Check if the file exists
if not os.path.isfile(file_path):
print(f"Audio file {file_path} not found.")
return
# Initialize the mixer module
pygame.mixer.init()
try:
# Load the mp3 file
pygame.mixer.music.load(file_path)
# Play the mp3 file
pygame.mixer.music.play()
# Wait for the audio to finish playing
while pygame.mixer.music.get_busy():
pygame.time.Clock().tick(10)
except pygame.error as e:
print(f"Couldn't play {file_path}: {e}")
finally:
# Stop the mixer module and free resources
pygame.mixer.quit()
while True:
# Listen for user speech
with sr.Microphone() as source:
print("Listening...")
audio = r.listen(source)
# Convert speech to text
try:
print("Recognizing...")
task = r.recognize_google(audio)
print(f"User said: {task}")
except sr.UnknownValueError:
print("Could not understand audio")
continue
except Exception as e:
print(f"Error: {e}")
continue
# Run the Gemini model on the task
print("Running GPT4 model...")
out = llm(task)
print(f"Gemini output: {out}")
# Convert the Gemini output to speech
print("Running text-to-speech model...")
out = tts.run_and_save(out)
print(f"Text-to-speech output: {out}")
# Ask the user if they want to play the audio
# play_audio = input("Do you want to play the audio? (yes/no): ")
# if play_audio.lower() == "yes":
# Initialize the mixer module
# Play the audio file
time.sleep(5)
playsound('runs/tts_speech.wav')

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "2.3.0"
version = "2.3.8"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -50,7 +50,7 @@ class WeaviateDB(VectorDatabase):
grpc_secure: Optional[bool] = None,
auth_client_secret: Optional[Any] = None,
additional_headers: Optional[Dict[str, str]] = None,
additional_config: Optional[weaviate.AdditionalConfig] = None,
additional_config: Optional[Any] = None,
connection_params: Dict[str, Any] = None,
*args,
**kwargs,

@ -108,7 +108,11 @@ class BaseMultiModalModel:
pass
def __call__(
self, task: str = None, img: str = None, *args, **kwargs
self,
task: Optional[str] = None,
img: Optional[str] = None,
*args,
**kwargs,
):
"""Call the model

@ -39,14 +39,11 @@ class FastViT:
Returns:
ClassificationResult: a pydantic BaseModel containing the class ids and confidences of the model's predictions
Example:
>>> fastvit = FastViT()
>>> result = fastvit(img="path_to_image.jpg", confidence_threshold=0.5)
To use, create a json file called: fast_vit_classes.json
"""
def __init__(self):
@ -62,7 +59,7 @@ class FastViT:
def __call__(
self, img: str, confidence_threshold: float = 0.5
) -> ClassificationResult:
"""classifies the input image and returns the top k classes and their probabilities"""
"""Classifies the input image and returns the top k classes and their probabilities"""
img = Image.open(img).convert("RGB")
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
@ -81,7 +78,6 @@ class FastViT:
# Convert to Python lists and map class indices to labels if needed
top_probs = top_probs.cpu().numpy().tolist()
top_classes = top_classes.cpu().numpy().tolist()
# top_class_labels = [FASTVIT_IMAGENET_1K_CLASSES[i] for i in top_classes] # Uncomment if class labels are needed
return ClassificationResult(
class_id=top_classes, confidence=top_probs

@ -66,5 +66,5 @@ def check_device(
return devices
devices = check_device()
logging.info(f"Using device(s): {devices}")
# devices = check_device()
# logging.info(f"Using device(s): {devices}")

@ -1,336 +0,0 @@
import os
import tempfile
from functools import wraps
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from swarms.models.distilled_whisperx import (
DistilWhisperModel,
async_retry,
)
@pytest.fixture
def distil_whisper_model():
return DistilWhisperModel()
def create_audio_file(
data: np.ndarray, sample_rate: int, file_path: str
):
data.tofile(file_path)
return file_path
def test_initialization(distil_whisper_model):
assert isinstance(distil_whisper_model, DistilWhisperModel)
assert isinstance(distil_whisper_model.model, torch.nn.Module)
assert isinstance(distil_whisper_model.processor, torch.nn.Module)
assert distil_whisper_model.device in ["cpu", "cuda:0"]
def test_transcribe_audio_file(distil_whisper_model):
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
transcription = distil_whisper_model.transcribe(
audio_file_path
)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_file(distil_whisper_model):
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
transcription = await distil_whisper_model.async_transcribe(
audio_file_path
)
os.remove(audio_file_path)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_transcribe_audio_data(distil_whisper_model):
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
transcription = distil_whisper_model.transcribe(
test_data.tobytes()
)
assert isinstance(transcription, str)
assert transcription.strip() != ""
@pytest.mark.asyncio
async def test_async_transcribe_audio_data(distil_whisper_model):
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
transcription = await distil_whisper_model.async_transcribe(
test_data.tobytes()
)
assert isinstance(transcription, str)
assert transcription.strip() != ""
def test_real_time_transcribe(distil_whisper_model, capsys):
test_data = np.random.rand(
16000 * 5
) # Simulated audio data (5 seconds)
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, 16000, audio_file.name
)
distil_whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
os.remove(audio_file_path)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
assert "Chunk" in captured.out
def test_real_time_transcribe_audio_file_not_found(
distil_whisper_model, capsys
):
audio_file_path = "non_existent_audio.wav"
distil_whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
captured = capsys.readouterr()
assert "The audio file was not found." in captured.out
@pytest.fixture
def mock_async_retry():
def _mock_async_retry(
retries=3, exceptions=(Exception,), delay=1
):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
return decorator
with patch(
"distil_whisper_model.async_retry", new=_mock_async_retry()
):
yield
@pytest.mark.asyncio
async def test_async_retry_decorator_success():
async def mock_async_function():
return "Success"
decorated_function = async_retry()(mock_async_function)
result = await decorated_function()
assert result == "Success"
@pytest.mark.asyncio
async def test_async_retry_decorator_failure():
async def mock_async_function():
raise Exception("Error")
decorated_function = async_retry()(mock_async_function)
with pytest.raises(Exception, match="Error"):
await decorated_function()
@pytest.mark.asyncio
async def test_async_retry_decorator_multiple_attempts():
async def mock_async_function():
if mock_async_function.attempts == 0:
mock_async_function.attempts += 1
raise Exception("Error")
else:
return "Success"
mock_async_function.attempts = 0
decorated_function = async_retry(max_retries=2)(
mock_async_function
)
result = await decorated_function()
assert result == "Success"
def test_create_audio_file():
test_data = np.random.rand(
16000
) # Simulated audio data (1 second)
sample_rate = 16000
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as audio_file:
audio_file_path = create_audio_file(
test_data, sample_rate, audio_file.name
)
assert os.path.exists(audio_file_path)
os.remove(audio_file_path)
# test_distilled_whisperx.py
# Fixtures for setting up model, processor, and audio files
@pytest.fixture(scope="module")
def model_id():
return "distil-whisper/distil-large-v2"
@pytest.fixture(scope="module")
def whisper_model(model_id):
return DistilWhisperModel(model_id)
@pytest.fixture(scope="session")
def audio_file_path(tmp_path_factory):
# You would create a small temporary MP3 file here for testing
# or use a public domain MP3 file's path
return "path/to/valid_audio.mp3"
@pytest.fixture(scope="session")
def invalid_audio_file_path():
return "path/to/invalid_audio.mp3"
@pytest.fixture(scope="session")
def audio_dict():
# This should represent a valid audio dictionary as expected by the model
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
# Test initialization
def test_initialization(whisper_model):
assert whisper_model.model is not None
assert whisper_model.processor is not None
# Test successful transcription with file path
def test_transcribe_with_file_path(whisper_model, audio_file_path):
transcription = whisper_model.transcribe(audio_file_path)
assert isinstance(transcription, str)
# Test successful transcription with audio dict
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
transcription = whisper_model.transcribe(audio_dict)
assert isinstance(transcription, str)
# Test for file not found error
def test_file_not_found(whisper_model, invalid_audio_file_path):
with pytest.raises(Exception):
whisper_model.transcribe(invalid_audio_file_path)
# Asynchronous tests
@pytest.mark.asyncio
async def test_async_transcription_success(
whisper_model, audio_file_path
):
transcription = await whisper_model.async_transcribe(
audio_file_path
)
assert isinstance(transcription, str)
@pytest.mark.asyncio
async def test_async_transcription_failure(
whisper_model, invalid_audio_file_path
):
with pytest.raises(Exception):
await whisper_model.async_transcribe(invalid_audio_file_path)
# Testing real-time transcription simulation
def test_real_time_transcription(
whisper_model, audio_file_path, capsys
):
whisper_model.real_time_transcribe(
audio_file_path, chunk_duration=1
)
captured = capsys.readouterr()
assert "Starting real-time transcription..." in captured.out
# Testing retry decorator for asynchronous function
@pytest.mark.asyncio
async def test_async_retry():
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
async def failing_func():
raise ValueError("Test")
with pytest.raises(ValueError):
await failing_func()
# Mocking the actual model to avoid GPU/CPU intensive operations during test
@pytest.fixture
def mocked_model(monkeypatch):
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
processor_mock = MagicMock(AutoProcessor)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
model_mock,
)
monkeypatch.setattr(
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
processor_mock,
)
return model_mock, processor_mock
@pytest.mark.asyncio
async def test_async_transcribe_with_mocked_model(
mocked_model, audio_file_path
):
model_mock, processor_mock = mocked_model
# Set up what the mock should return when it's called
model_mock.return_value.generate.return_value = torch.tensor(
[[0]]
)
processor_mock.return_value.batch_decode.return_value = [
"mocked transcription"
]
model_wrapper = DistilWhisperModel()
transcription = await model_wrapper.async_transcribe(
audio_file_path
)
assert transcription == "mocked transcription"

@ -32,9 +32,7 @@ def test_load_model_torch_no_device_specified(mocker):
def test_load_model_torch_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
load_model_torch(
"model_path", device=torch.device("cuda")
)
load_model_torch("model_path", device=torch.device("cuda"))
mock_model.to.assert_called_once_with(torch.device("cuda"))

@ -44,7 +44,5 @@ def test_prep_torch_inference_device_specified(mocker):
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference(
"model_path", device=torch.device("cuda")
)
prep_torch_inference("model_path", device=torch.device("cuda"))
mock_model.eval.assert_called_once()

Loading…
Cancel
Save