parent
f39d722f2a
commit
fd58cfa2a1
@ -0,0 +1,96 @@
|
||||
import time
|
||||
import os
|
||||
|
||||
import pygame
|
||||
import speech_recognition as sr
|
||||
from dotenv import load_dotenv
|
||||
from playsound import playsound
|
||||
|
||||
from swarms import OpenAIChat, OpenAITTS
|
||||
|
||||
# Load the environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get the API key from the environment
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
# Initialize the language model
|
||||
llm = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
|
||||
# Initialize the text-to-speech model
|
||||
tts = OpenAITTS(
|
||||
model_name="tts-1-1106",
|
||||
voice="onyx",
|
||||
openai_api_key=openai_api_key,
|
||||
saved_filepath="runs/tts_speech.wav",
|
||||
)
|
||||
|
||||
# Initialize the speech recognition model
|
||||
r = sr.Recognizer()
|
||||
|
||||
|
||||
def play_audio(file_path):
|
||||
# Check if the file exists
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"Audio file {file_path} not found.")
|
||||
return
|
||||
|
||||
# Initialize the mixer module
|
||||
pygame.mixer.init()
|
||||
|
||||
try:
|
||||
# Load the mp3 file
|
||||
pygame.mixer.music.load(file_path)
|
||||
|
||||
# Play the mp3 file
|
||||
pygame.mixer.music.play()
|
||||
|
||||
# Wait for the audio to finish playing
|
||||
while pygame.mixer.music.get_busy():
|
||||
pygame.time.Clock().tick(10)
|
||||
except pygame.error as e:
|
||||
print(f"Couldn't play {file_path}: {e}")
|
||||
finally:
|
||||
# Stop the mixer module and free resources
|
||||
pygame.mixer.quit()
|
||||
|
||||
while True:
|
||||
# Listen for user speech
|
||||
with sr.Microphone() as source:
|
||||
print("Listening...")
|
||||
audio = r.listen(source)
|
||||
|
||||
# Convert speech to text
|
||||
try:
|
||||
print("Recognizing...")
|
||||
task = r.recognize_google(audio)
|
||||
print(f"User said: {task}")
|
||||
except sr.UnknownValueError:
|
||||
print("Could not understand audio")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Run the Gemini model on the task
|
||||
print("Running GPT4 model...")
|
||||
out = llm(task)
|
||||
print(f"Gemini output: {out}")
|
||||
|
||||
# Convert the Gemini output to speech
|
||||
print("Running text-to-speech model...")
|
||||
out = tts.run_and_save(out)
|
||||
print(f"Text-to-speech output: {out}")
|
||||
|
||||
# Ask the user if they want to play the audio
|
||||
# play_audio = input("Do you want to play the audio? (yes/no): ")
|
||||
# if play_audio.lower() == "yes":
|
||||
# Initialize the mixer module
|
||||
# Play the audio file
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
playsound('runs/tts_speech.wav')
|
@ -1,336 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
||||
|
||||
from swarms.models.distilled_whisperx import (
|
||||
DistilWhisperModel,
|
||||
async_retry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distil_whisper_model():
|
||||
return DistilWhisperModel()
|
||||
|
||||
|
||||
def create_audio_file(
|
||||
data: np.ndarray, sample_rate: int, file_path: str
|
||||
):
|
||||
data.tofile(file_path)
|
||||
return file_path
|
||||
|
||||
|
||||
def test_initialization(distil_whisper_model):
|
||||
assert isinstance(distil_whisper_model, DistilWhisperModel)
|
||||
assert isinstance(distil_whisper_model.model, torch.nn.Module)
|
||||
assert isinstance(distil_whisper_model.processor, torch.nn.Module)
|
||||
assert distil_whisper_model.device in ["cpu", "cuda:0"]
|
||||
|
||||
|
||||
def test_transcribe_audio_file(distil_whisper_model):
|
||||
test_data = np.random.rand(
|
||||
16000
|
||||
) # Simulated audio data (1 second)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False
|
||||
) as audio_file:
|
||||
audio_file_path = create_audio_file(
|
||||
test_data, 16000, audio_file.name
|
||||
)
|
||||
transcription = distil_whisper_model.transcribe(
|
||||
audio_file_path
|
||||
)
|
||||
os.remove(audio_file_path)
|
||||
|
||||
assert isinstance(transcription, str)
|
||||
assert transcription.strip() != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_transcribe_audio_file(distil_whisper_model):
|
||||
test_data = np.random.rand(
|
||||
16000
|
||||
) # Simulated audio data (1 second)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False
|
||||
) as audio_file:
|
||||
audio_file_path = create_audio_file(
|
||||
test_data, 16000, audio_file.name
|
||||
)
|
||||
transcription = await distil_whisper_model.async_transcribe(
|
||||
audio_file_path
|
||||
)
|
||||
os.remove(audio_file_path)
|
||||
|
||||
assert isinstance(transcription, str)
|
||||
assert transcription.strip() != ""
|
||||
|
||||
|
||||
def test_transcribe_audio_data(distil_whisper_model):
|
||||
test_data = np.random.rand(
|
||||
16000
|
||||
) # Simulated audio data (1 second)
|
||||
transcription = distil_whisper_model.transcribe(
|
||||
test_data.tobytes()
|
||||
)
|
||||
|
||||
assert isinstance(transcription, str)
|
||||
assert transcription.strip() != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_transcribe_audio_data(distil_whisper_model):
|
||||
test_data = np.random.rand(
|
||||
16000
|
||||
) # Simulated audio data (1 second)
|
||||
transcription = await distil_whisper_model.async_transcribe(
|
||||
test_data.tobytes()
|
||||
)
|
||||
|
||||
assert isinstance(transcription, str)
|
||||
assert transcription.strip() != ""
|
||||
|
||||
|
||||
def test_real_time_transcribe(distil_whisper_model, capsys):
|
||||
test_data = np.random.rand(
|
||||
16000 * 5
|
||||
) # Simulated audio data (5 seconds)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False
|
||||
) as audio_file:
|
||||
audio_file_path = create_audio_file(
|
||||
test_data, 16000, audio_file.name
|
||||
)
|
||||
|
||||
distil_whisper_model.real_time_transcribe(
|
||||
audio_file_path, chunk_duration=1
|
||||
)
|
||||
|
||||
os.remove(audio_file_path)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Starting real-time transcription..." in captured.out
|
||||
assert "Chunk" in captured.out
|
||||
|
||||
|
||||
def test_real_time_transcribe_audio_file_not_found(
|
||||
distil_whisper_model, capsys
|
||||
):
|
||||
audio_file_path = "non_existent_audio.wav"
|
||||
distil_whisper_model.real_time_transcribe(
|
||||
audio_file_path, chunk_duration=1
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "The audio file was not found." in captured.out
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_retry():
|
||||
def _mock_async_retry(
|
||||
retries=3, exceptions=(Exception,), delay=1
|
||||
):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
with patch(
|
||||
"distil_whisper_model.async_retry", new=_mock_async_retry()
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_decorator_success():
|
||||
async def mock_async_function():
|
||||
return "Success"
|
||||
|
||||
decorated_function = async_retry()(mock_async_function)
|
||||
result = await decorated_function()
|
||||
assert result == "Success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_decorator_failure():
|
||||
async def mock_async_function():
|
||||
raise Exception("Error")
|
||||
|
||||
decorated_function = async_retry()(mock_async_function)
|
||||
with pytest.raises(Exception, match="Error"):
|
||||
await decorated_function()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_decorator_multiple_attempts():
|
||||
async def mock_async_function():
|
||||
if mock_async_function.attempts == 0:
|
||||
mock_async_function.attempts += 1
|
||||
raise Exception("Error")
|
||||
else:
|
||||
return "Success"
|
||||
|
||||
mock_async_function.attempts = 0
|
||||
decorated_function = async_retry(max_retries=2)(
|
||||
mock_async_function
|
||||
)
|
||||
result = await decorated_function()
|
||||
assert result == "Success"
|
||||
|
||||
|
||||
def test_create_audio_file():
|
||||
test_data = np.random.rand(
|
||||
16000
|
||||
) # Simulated audio data (1 second)
|
||||
sample_rate = 16000
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False
|
||||
) as audio_file:
|
||||
audio_file_path = create_audio_file(
|
||||
test_data, sample_rate, audio_file.name
|
||||
)
|
||||
|
||||
assert os.path.exists(audio_file_path)
|
||||
os.remove(audio_file_path)
|
||||
|
||||
|
||||
# test_distilled_whisperx.py
|
||||
|
||||
|
||||
# Fixtures for setting up model, processor, and audio files
|
||||
@pytest.fixture(scope="module")
|
||||
def model_id():
|
||||
return "distil-whisper/distil-large-v2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def whisper_model(model_id):
|
||||
return DistilWhisperModel(model_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_file_path(tmp_path_factory):
|
||||
# You would create a small temporary MP3 file here for testing
|
||||
# or use a public domain MP3 file's path
|
||||
return "path/to/valid_audio.mp3"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def invalid_audio_file_path():
|
||||
return "path/to/invalid_audio.mp3"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_dict():
|
||||
# This should represent a valid audio dictionary as expected by the model
|
||||
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
|
||||
|
||||
|
||||
# Test initialization
|
||||
def test_initialization(whisper_model):
|
||||
assert whisper_model.model is not None
|
||||
assert whisper_model.processor is not None
|
||||
|
||||
|
||||
# Test successful transcription with file path
|
||||
def test_transcribe_with_file_path(whisper_model, audio_file_path):
|
||||
transcription = whisper_model.transcribe(audio_file_path)
|
||||
assert isinstance(transcription, str)
|
||||
|
||||
|
||||
# Test successful transcription with audio dict
|
||||
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
|
||||
transcription = whisper_model.transcribe(audio_dict)
|
||||
assert isinstance(transcription, str)
|
||||
|
||||
|
||||
# Test for file not found error
|
||||
def test_file_not_found(whisper_model, invalid_audio_file_path):
|
||||
with pytest.raises(Exception):
|
||||
whisper_model.transcribe(invalid_audio_file_path)
|
||||
|
||||
|
||||
# Asynchronous tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_transcription_success(
|
||||
whisper_model, audio_file_path
|
||||
):
|
||||
transcription = await whisper_model.async_transcribe(
|
||||
audio_file_path
|
||||
)
|
||||
assert isinstance(transcription, str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_transcription_failure(
|
||||
whisper_model, invalid_audio_file_path
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await whisper_model.async_transcribe(invalid_audio_file_path)
|
||||
|
||||
|
||||
# Testing real-time transcription simulation
|
||||
def test_real_time_transcription(
|
||||
whisper_model, audio_file_path, capsys
|
||||
):
|
||||
whisper_model.real_time_transcribe(
|
||||
audio_file_path, chunk_duration=1
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "Starting real-time transcription..." in captured.out
|
||||
|
||||
|
||||
# Testing retry decorator for asynchronous function
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry():
|
||||
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
|
||||
async def failing_func():
|
||||
raise ValueError("Test")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await failing_func()
|
||||
|
||||
|
||||
# Mocking the actual model to avoid GPU/CPU intensive operations during test
|
||||
@pytest.fixture
|
||||
def mocked_model(monkeypatch):
|
||||
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
|
||||
processor_mock = MagicMock(AutoProcessor)
|
||||
monkeypatch.setattr(
|
||||
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
|
||||
model_mock,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
|
||||
processor_mock,
|
||||
)
|
||||
return model_mock, processor_mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_transcribe_with_mocked_model(
|
||||
mocked_model, audio_file_path
|
||||
):
|
||||
model_mock, processor_mock = mocked_model
|
||||
# Set up what the mock should return when it's called
|
||||
model_mock.return_value.generate.return_value = torch.tensor(
|
||||
[[0]]
|
||||
)
|
||||
processor_mock.return_value.batch_decode.return_value = [
|
||||
"mocked transcription"
|
||||
]
|
||||
model_wrapper = DistilWhisperModel()
|
||||
transcription = await model_wrapper.async_transcribe(
|
||||
audio_file_path
|
||||
)
|
||||
assert transcription == "mocked transcription"
|
Loading…
Reference in new issue