parent
f39d722f2a
commit
fd58cfa2a1
@ -0,0 +1,96 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pygame
|
||||||
|
import speech_recognition as sr
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from playsound import playsound
|
||||||
|
|
||||||
|
from swarms import OpenAIChat, OpenAITTS
|
||||||
|
|
||||||
|
# Load the environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get the API key from the environment
|
||||||
|
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Initialize the language model
|
||||||
|
llm = OpenAIChat(
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the text-to-speech model
|
||||||
|
tts = OpenAITTS(
|
||||||
|
model_name="tts-1-1106",
|
||||||
|
voice="onyx",
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
saved_filepath="runs/tts_speech.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the speech recognition model
|
||||||
|
r = sr.Recognizer()
|
||||||
|
|
||||||
|
|
||||||
|
def play_audio(file_path):
|
||||||
|
# Check if the file exists
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
print(f"Audio file {file_path} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize the mixer module
|
||||||
|
pygame.mixer.init()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load the mp3 file
|
||||||
|
pygame.mixer.music.load(file_path)
|
||||||
|
|
||||||
|
# Play the mp3 file
|
||||||
|
pygame.mixer.music.play()
|
||||||
|
|
||||||
|
# Wait for the audio to finish playing
|
||||||
|
while pygame.mixer.music.get_busy():
|
||||||
|
pygame.time.Clock().tick(10)
|
||||||
|
except pygame.error as e:
|
||||||
|
print(f"Couldn't play {file_path}: {e}")
|
||||||
|
finally:
|
||||||
|
# Stop the mixer module and free resources
|
||||||
|
pygame.mixer.quit()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Listen for user speech
|
||||||
|
with sr.Microphone() as source:
|
||||||
|
print("Listening...")
|
||||||
|
audio = r.listen(source)
|
||||||
|
|
||||||
|
# Convert speech to text
|
||||||
|
try:
|
||||||
|
print("Recognizing...")
|
||||||
|
task = r.recognize_google(audio)
|
||||||
|
print(f"User said: {task}")
|
||||||
|
except sr.UnknownValueError:
|
||||||
|
print("Could not understand audio")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# Run the Gemini model on the task
|
||||||
|
print("Running GPT4 model...")
|
||||||
|
out = llm(task)
|
||||||
|
print(f"Gemini output: {out}")
|
||||||
|
|
||||||
|
# Convert the Gemini output to speech
|
||||||
|
print("Running text-to-speech model...")
|
||||||
|
out = tts.run_and_save(out)
|
||||||
|
print(f"Text-to-speech output: {out}")
|
||||||
|
|
||||||
|
# Ask the user if they want to play the audio
|
||||||
|
# play_audio = input("Do you want to play the audio? (yes/no): ")
|
||||||
|
# if play_audio.lower() == "yes":
|
||||||
|
# Initialize the mixer module
|
||||||
|
# Play the audio file
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
playsound('runs/tts_speech.wav')
|
@ -1,336 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from functools import wraps
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
||||||
|
|
||||||
from swarms.models.distilled_whisperx import (
|
|
||||||
DistilWhisperModel,
|
|
||||||
async_retry,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def distil_whisper_model():
|
|
||||||
return DistilWhisperModel()
|
|
||||||
|
|
||||||
|
|
||||||
def create_audio_file(
|
|
||||||
data: np.ndarray, sample_rate: int, file_path: str
|
|
||||||
):
|
|
||||||
data.tofile(file_path)
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(distil_whisper_model):
|
|
||||||
assert isinstance(distil_whisper_model, DistilWhisperModel)
|
|
||||||
assert isinstance(distil_whisper_model.model, torch.nn.Module)
|
|
||||||
assert isinstance(distil_whisper_model.processor, torch.nn.Module)
|
|
||||||
assert distil_whisper_model.device in ["cpu", "cuda:0"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_transcribe_audio_file(distil_whisper_model):
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000
|
|
||||||
) # Simulated audio data (1 second)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
suffix=".wav", delete=False
|
|
||||||
) as audio_file:
|
|
||||||
audio_file_path = create_audio_file(
|
|
||||||
test_data, 16000, audio_file.name
|
|
||||||
)
|
|
||||||
transcription = distil_whisper_model.transcribe(
|
|
||||||
audio_file_path
|
|
||||||
)
|
|
||||||
os.remove(audio_file_path)
|
|
||||||
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
assert transcription.strip() != ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcribe_audio_file(distil_whisper_model):
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000
|
|
||||||
) # Simulated audio data (1 second)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
suffix=".wav", delete=False
|
|
||||||
) as audio_file:
|
|
||||||
audio_file_path = create_audio_file(
|
|
||||||
test_data, 16000, audio_file.name
|
|
||||||
)
|
|
||||||
transcription = await distil_whisper_model.async_transcribe(
|
|
||||||
audio_file_path
|
|
||||||
)
|
|
||||||
os.remove(audio_file_path)
|
|
||||||
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
assert transcription.strip() != ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_transcribe_audio_data(distil_whisper_model):
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000
|
|
||||||
) # Simulated audio data (1 second)
|
|
||||||
transcription = distil_whisper_model.transcribe(
|
|
||||||
test_data.tobytes()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
assert transcription.strip() != ""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcribe_audio_data(distil_whisper_model):
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000
|
|
||||||
) # Simulated audio data (1 second)
|
|
||||||
transcription = await distil_whisper_model.async_transcribe(
|
|
||||||
test_data.tobytes()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
assert transcription.strip() != ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_real_time_transcribe(distil_whisper_model, capsys):
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000 * 5
|
|
||||||
) # Simulated audio data (5 seconds)
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
suffix=".wav", delete=False
|
|
||||||
) as audio_file:
|
|
||||||
audio_file_path = create_audio_file(
|
|
||||||
test_data, 16000, audio_file.name
|
|
||||||
)
|
|
||||||
|
|
||||||
distil_whisper_model.real_time_transcribe(
|
|
||||||
audio_file_path, chunk_duration=1
|
|
||||||
)
|
|
||||||
|
|
||||||
os.remove(audio_file_path)
|
|
||||||
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert "Starting real-time transcription..." in captured.out
|
|
||||||
assert "Chunk" in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
def test_real_time_transcribe_audio_file_not_found(
|
|
||||||
distil_whisper_model, capsys
|
|
||||||
):
|
|
||||||
audio_file_path = "non_existent_audio.wav"
|
|
||||||
distil_whisper_model.real_time_transcribe(
|
|
||||||
audio_file_path, chunk_duration=1
|
|
||||||
)
|
|
||||||
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert "The audio file was not found." in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_async_retry():
|
|
||||||
def _mock_async_retry(
|
|
||||||
retries=3, exceptions=(Exception,), delay=1
|
|
||||||
):
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"distil_whisper_model.async_retry", new=_mock_async_retry()
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_retry_decorator_success():
|
|
||||||
async def mock_async_function():
|
|
||||||
return "Success"
|
|
||||||
|
|
||||||
decorated_function = async_retry()(mock_async_function)
|
|
||||||
result = await decorated_function()
|
|
||||||
assert result == "Success"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_retry_decorator_failure():
|
|
||||||
async def mock_async_function():
|
|
||||||
raise Exception("Error")
|
|
||||||
|
|
||||||
decorated_function = async_retry()(mock_async_function)
|
|
||||||
with pytest.raises(Exception, match="Error"):
|
|
||||||
await decorated_function()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_retry_decorator_multiple_attempts():
|
|
||||||
async def mock_async_function():
|
|
||||||
if mock_async_function.attempts == 0:
|
|
||||||
mock_async_function.attempts += 1
|
|
||||||
raise Exception("Error")
|
|
||||||
else:
|
|
||||||
return "Success"
|
|
||||||
|
|
||||||
mock_async_function.attempts = 0
|
|
||||||
decorated_function = async_retry(max_retries=2)(
|
|
||||||
mock_async_function
|
|
||||||
)
|
|
||||||
result = await decorated_function()
|
|
||||||
assert result == "Success"
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_audio_file():
|
|
||||||
test_data = np.random.rand(
|
|
||||||
16000
|
|
||||||
) # Simulated audio data (1 second)
|
|
||||||
sample_rate = 16000
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
suffix=".wav", delete=False
|
|
||||||
) as audio_file:
|
|
||||||
audio_file_path = create_audio_file(
|
|
||||||
test_data, sample_rate, audio_file.name
|
|
||||||
)
|
|
||||||
|
|
||||||
assert os.path.exists(audio_file_path)
|
|
||||||
os.remove(audio_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
# test_distilled_whisperx.py
|
|
||||||
|
|
||||||
|
|
||||||
# Fixtures for setting up model, processor, and audio files
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def model_id():
|
|
||||||
return "distil-whisper/distil-large-v2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def whisper_model(model_id):
|
|
||||||
return DistilWhisperModel(model_id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def audio_file_path(tmp_path_factory):
|
|
||||||
# You would create a small temporary MP3 file here for testing
|
|
||||||
# or use a public domain MP3 file's path
|
|
||||||
return "path/to/valid_audio.mp3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def invalid_audio_file_path():
|
|
||||||
return "path/to/invalid_audio.mp3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def audio_dict():
|
|
||||||
# This should represent a valid audio dictionary as expected by the model
|
|
||||||
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
|
|
||||||
|
|
||||||
|
|
||||||
# Test initialization
|
|
||||||
def test_initialization(whisper_model):
|
|
||||||
assert whisper_model.model is not None
|
|
||||||
assert whisper_model.processor is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Test successful transcription with file path
|
|
||||||
def test_transcribe_with_file_path(whisper_model, audio_file_path):
|
|
||||||
transcription = whisper_model.transcribe(audio_file_path)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
# Test successful transcription with audio dict
|
|
||||||
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
|
|
||||||
transcription = whisper_model.transcribe(audio_dict)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
# Test for file not found error
|
|
||||||
def test_file_not_found(whisper_model, invalid_audio_file_path):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
whisper_model.transcribe(invalid_audio_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Asynchronous tests
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcription_success(
|
|
||||||
whisper_model, audio_file_path
|
|
||||||
):
|
|
||||||
transcription = await whisper_model.async_transcribe(
|
|
||||||
audio_file_path
|
|
||||||
)
|
|
||||||
assert isinstance(transcription, str)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcription_failure(
|
|
||||||
whisper_model, invalid_audio_file_path
|
|
||||||
):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await whisper_model.async_transcribe(invalid_audio_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Testing real-time transcription simulation
|
|
||||||
def test_real_time_transcription(
|
|
||||||
whisper_model, audio_file_path, capsys
|
|
||||||
):
|
|
||||||
whisper_model.real_time_transcribe(
|
|
||||||
audio_file_path, chunk_duration=1
|
|
||||||
)
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
assert "Starting real-time transcription..." in captured.out
|
|
||||||
|
|
||||||
|
|
||||||
# Testing retry decorator for asynchronous function
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_retry():
|
|
||||||
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
|
|
||||||
async def failing_func():
|
|
||||||
raise ValueError("Test")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await failing_func()
|
|
||||||
|
|
||||||
|
|
||||||
# Mocking the actual model to avoid GPU/CPU intensive operations during test
|
|
||||||
@pytest.fixture
|
|
||||||
def mocked_model(monkeypatch):
|
|
||||||
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
|
|
||||||
processor_mock = MagicMock(AutoProcessor)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
|
|
||||||
model_mock,
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained",
|
|
||||||
processor_mock,
|
|
||||||
)
|
|
||||||
return model_mock, processor_mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_transcribe_with_mocked_model(
|
|
||||||
mocked_model, audio_file_path
|
|
||||||
):
|
|
||||||
model_mock, processor_mock = mocked_model
|
|
||||||
# Set up what the mock should return when it's called
|
|
||||||
model_mock.return_value.generate.return_value = torch.tensor(
|
|
||||||
[[0]]
|
|
||||||
)
|
|
||||||
processor_mock.return_value.batch_decode.return_value = [
|
|
||||||
"mocked transcription"
|
|
||||||
]
|
|
||||||
model_wrapper = DistilWhisperModel()
|
|
||||||
transcription = await model_wrapper.async_transcribe(
|
|
||||||
audio_file_path
|
|
||||||
)
|
|
||||||
assert transcription == "mocked transcription"
|
|
Loading…
Reference in new issue