parent
							
								
									991979dfc6
								
							
						
					
					
						commit
						4596ddc5ff
					
				@ -1,31 +1,42 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					# ==================================
 | 
				
			||||||
# Use an official Python runtime as a parent image
 | 
					# Use an official Python runtime as a parent image
 | 
				
			||||||
FROM python:3.9-slim
 | 
					FROM python:3.9-slim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Set environment variables to make Python output unbuffered and disable the PIP cache
 | 
					# Set environment variables
 | 
				
			||||||
ENV PYTHONDONTWRITEBYTECODE 1
 | 
					ENV PYTHONDONTWRITEBYTECODE 1
 | 
				
			||||||
ENV PYTHONUNBUFFERED 1
 | 
					ENV PYTHONUNBUFFERED 1
 | 
				
			||||||
ENV PIP_NO_CACHE_DIR off
 | 
					 | 
				
			||||||
ENV PIP_DISABLE_PIP_VERSION_CHECK on
 | 
					 | 
				
			||||||
ENV PIP_DEFAULT_TIMEOUT 100
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Set the working directory in the container
 | 
					# Set the working directory in the container
 | 
				
			||||||
WORKDIR /usr/src/app
 | 
					WORKDIR /usr/src/swarm_cloud
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install Python dependencies
 | 
				
			||||||
 | 
					# COPY requirements.txt and pyproject.toml if you're using poetry for dependency management
 | 
				
			||||||
 | 
					COPY requirements.txt .
 | 
				
			||||||
 | 
					RUN pip install --upgrade pip
 | 
				
			||||||
 | 
					RUN pip install --no-cache-dir -r requirements.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Copy the current directory contents into the container at /usr/src/app
 | 
					# Install the 'swarms' package, assuming it's available on PyPI
 | 
				
			||||||
 | 
					RUN pip install swarms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Copy the rest of the application
 | 
				
			||||||
COPY . .
 | 
					COPY . .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Install Poetry
 | 
					# Add entrypoint script if needed
 | 
				
			||||||
RUN pip install poetry
 | 
					# COPY ./entrypoint.sh .
 | 
				
			||||||
 | 
					# RUN chmod +x /usr/src/swarm_cloud/entrypoint.sh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Disable virtualenv creation by poetry and install dependencies
 | 
					# Expose port if your application has a web interface
 | 
				
			||||||
RUN poetry config virtualenvs.create false
 | 
					# EXPOSE 5000
 | 
				
			||||||
RUN poetry install --no-interaction --no-ansi
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Install the 'swarms' package if it's not included in the poetry.lock
 | 
					# # Define environment variable for the swarm to work
 | 
				
			||||||
RUN pip install swarms
 | 
					# ENV SWARM_API_KEY=your_swarm_api_key_here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Assuming tests require pytest to run
 | 
					# # Add Docker CMD or ENTRYPOINT script to run the application
 | 
				
			||||||
RUN pip install pytest
 | 
					# CMD python your_swarm_startup_script.py
 | 
				
			||||||
 | 
					# Or use the entrypoint script if you have one
 | 
				
			||||||
 | 
					# ENTRYPOINT ["/usr/src/swarm_cloud/entrypoint.sh"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Run pytest on all tests in the tests directory
 | 
					# If you're using `CMD` to execute a Python script, make sure it's executable
 | 
				
			||||||
CMD find ./tests -name '*.py' -exec pytest {} +
 | 
					# RUN chmod +x your_swarm_startup_script.py
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,33 @@
 | 
				
			|||||||
 | 
					# TESTING
 | 
				
			||||||
 | 
					# -==================
 | 
				
			||||||
 | 
					# Use an official Python runtime as a parent image
 | 
				
			||||||
 | 
					FROM python:3.9-slim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Set environment variables to make Python output unbuffered and disable the PIP cache
 | 
				
			||||||
 | 
					ENV PYTHONDONTWRITEBYTECODE 1
 | 
				
			||||||
 | 
					ENV PYTHONUNBUFFERED 1
 | 
				
			||||||
 | 
					ENV PIP_NO_CACHE_DIR off
 | 
				
			||||||
 | 
					ENV PIP_DISABLE_PIP_VERSION_CHECK on
 | 
				
			||||||
 | 
					ENV PIP_DEFAULT_TIMEOUT 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Set the working directory in the container
 | 
				
			||||||
 | 
					WORKDIR /usr/src/app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Copy the current directory contents into the container at /usr/src/app
 | 
				
			||||||
 | 
					COPY . .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install Poetry
 | 
				
			||||||
 | 
					RUN pip install poetry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Disable virtualenv creation by poetry and install dependencies
 | 
				
			||||||
 | 
					RUN poetry config virtualenvs.create false
 | 
				
			||||||
 | 
					RUN poetry install --no-interaction --no-ansi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Install the 'swarms' package if it's not included in the poetry.lock
 | 
				
			||||||
 | 
					RUN pip install swarms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Assuming tests require pytest to run
 | 
				
			||||||
 | 
					RUN pip install pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Run pytest on all tests in the tests directory
 | 
				
			||||||
 | 
					CMD find ./tests -name '*.py' -exec pytest {} +
 | 
				
			||||||
@ -0,0 +1,76 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
 | 
					from unittest.mock import Mock, patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from dotenv import load_dotenv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from swarms.models.autotemp import AutoTempAgent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					api_key = os.getenv("OPENAI_API_KEY")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load_dotenv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def auto_temp_agent():
 | 
				
			||||||
 | 
					    return AutoTempAgent(api_key=api_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_initialization(auto_temp_agent):
 | 
				
			||||||
 | 
					    assert isinstance(auto_temp_agent, AutoTempAgent)
 | 
				
			||||||
 | 
					    assert auto_temp_agent.auto_select is True
 | 
				
			||||||
 | 
					    assert auto_temp_agent.max_workers == 6
 | 
				
			||||||
 | 
					    assert auto_temp_agent.temperature == 0.5
 | 
				
			||||||
 | 
					    assert auto_temp_agent.alt_temps == [0.4, 0.6, 0.8, 1.0, 1.2, 1.4]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_evaluate_output(auto_temp_agent):
 | 
				
			||||||
 | 
					    output = "This is a test output."
 | 
				
			||||||
 | 
					    with patch("swarms.models.OpenAIChat") as MockOpenAIChat:
 | 
				
			||||||
 | 
					        mock_instance = MockOpenAIChat.return_value
 | 
				
			||||||
 | 
					        mock_instance.return_value = "Score: 95.5"
 | 
				
			||||||
 | 
					        score = auto_temp_agent.evaluate_output(output)
 | 
				
			||||||
 | 
					        assert score == 95.5
 | 
				
			||||||
 | 
					        mock_instance.assert_called_once()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_run_auto_select(auto_temp_agent):
 | 
				
			||||||
 | 
					    task = "Generate a blog post."
 | 
				
			||||||
 | 
					    temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4"
 | 
				
			||||||
 | 
					    result = auto_temp_agent.run(task, temperature_string)
 | 
				
			||||||
 | 
					    assert "Best AutoTemp Output" in result
 | 
				
			||||||
 | 
					    assert "Temp" in result
 | 
				
			||||||
 | 
					    assert "Score" in result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_run_no_scores(auto_temp_agent):
 | 
				
			||||||
 | 
					    task = "Invalid task."
 | 
				
			||||||
 | 
					    temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4"
 | 
				
			||||||
 | 
					    with ThreadPoolExecutor(max_workers=auto_temp_agent.max_workers) as executor:
 | 
				
			||||||
 | 
					        with patch.object(executor, "submit", side_effect=[None, None, None, None, None, None]):
 | 
				
			||||||
 | 
					            result = auto_temp_agent.run(task, temperature_string)
 | 
				
			||||||
 | 
					            assert result == "No valid outputs generated."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_run_manual_select(auto_temp_agent):
 | 
				
			||||||
 | 
					    auto_temp_agent.auto_select = False
 | 
				
			||||||
 | 
					    task = "Generate a blog post."
 | 
				
			||||||
 | 
					    temperature_string = "0.4,0.6,0.8,1.0,1.2,1.4"
 | 
				
			||||||
 | 
					    result = auto_temp_agent.run(task, temperature_string)
 | 
				
			||||||
 | 
					    assert "Best AutoTemp Output" not in result
 | 
				
			||||||
 | 
					    assert "Temp" in result
 | 
				
			||||||
 | 
					    assert "Score" in result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_failed_initialization():
 | 
				
			||||||
 | 
					    with pytest.raises(Exception):
 | 
				
			||||||
 | 
					        AutoTempAgent()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_failed_evaluate_output(auto_temp_agent):
 | 
				
			||||||
 | 
					    output = "This is a test output."
 | 
				
			||||||
 | 
					    with patch("swarms.models.OpenAIChat") as MockOpenAIChat:
 | 
				
			||||||
 | 
					        mock_instance = MockOpenAIChat.return_value
 | 
				
			||||||
 | 
					        mock_instance.return_value = "Invalid score text"
 | 
				
			||||||
 | 
					        score = auto_temp_agent.evaluate_output(output)
 | 
				
			||||||
 | 
					        assert score == 0.0
 | 
				
			||||||
@ -0,0 +1,154 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					from functools import wraps
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from swarms.models.distill_whisperx import DistilWhisperModel, async_retry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def distil_whisper_model():
 | 
				
			||||||
 | 
					    return DistilWhisperModel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_audio_file(data: np.ndarray, sample_rate: int, file_path: str):
 | 
				
			||||||
 | 
					    data.tofile(file_path)
 | 
				
			||||||
 | 
					    return file_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_initialization(distil_whisper_model):
 | 
				
			||||||
 | 
					    assert isinstance(distil_whisper_model, DistilWhisperModel)
 | 
				
			||||||
 | 
					    assert isinstance(distil_whisper_model.model, torch.nn.Module)
 | 
				
			||||||
 | 
					    assert isinstance(distil_whisper_model.processor, torch.nn.Module)
 | 
				
			||||||
 | 
					    assert distil_whisper_model.device in ["cpu", "cuda:0"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_transcribe_audio_file(distil_whisper_model):
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000)  # Simulated audio data (1 second)
 | 
				
			||||||
 | 
					    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
 | 
				
			||||||
 | 
					        audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
 | 
				
			||||||
 | 
					        transcription = distil_whisper_model.transcribe(audio_file_path)
 | 
				
			||||||
 | 
					        os.remove(audio_file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert isinstance(transcription, str)
 | 
				
			||||||
 | 
					    assert transcription.strip() != ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_transcribe_audio_file(distil_whisper_model):
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000)  # Simulated audio data (1 second)
 | 
				
			||||||
 | 
					    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
 | 
				
			||||||
 | 
					        audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
 | 
				
			||||||
 | 
					        transcription = await distil_whisper_model.async_transcribe(audio_file_path)
 | 
				
			||||||
 | 
					        os.remove(audio_file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert isinstance(transcription, str)
 | 
				
			||||||
 | 
					    assert transcription.strip() != ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_transcribe_audio_data(distil_whisper_model):
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000)  # Simulated audio data (1 second)
 | 
				
			||||||
 | 
					    transcription = distil_whisper_model.transcribe(test_data.tobytes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert isinstance(transcription, str)
 | 
				
			||||||
 | 
					    assert transcription.strip() != ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_transcribe_audio_data(distil_whisper_model):
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000)  # Simulated audio data (1 second)
 | 
				
			||||||
 | 
					    transcription = await distil_whisper_model.async_transcribe(test_data.tobytes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert isinstance(transcription, str)
 | 
				
			||||||
 | 
					    assert transcription.strip() != ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_real_time_transcribe(distil_whisper_model, capsys):
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000 * 5)  # Simulated audio data (5 seconds)
 | 
				
			||||||
 | 
					    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
 | 
				
			||||||
 | 
					        audio_file_path = create_audio_file(test_data, 16000, audio_file.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.remove(audio_file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    captured = capsys.readouterr()
 | 
				
			||||||
 | 
					    assert "Starting real-time transcription..." in captured.out
 | 
				
			||||||
 | 
					    assert "Chunk" in captured.out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_real_time_transcribe_audio_file_not_found(distil_whisper_model, capsys):
 | 
				
			||||||
 | 
					    audio_file_path = "non_existent_audio.wav"
 | 
				
			||||||
 | 
					    distil_whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    captured = capsys.readouterr()
 | 
				
			||||||
 | 
					    assert "The audio file was not found." in captured.out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def mock_async_retry():
 | 
				
			||||||
 | 
					    def _mock_async_retry(retries=3, exceptions=(Exception,), delay=1):
 | 
				
			||||||
 | 
					        def decorator(func):
 | 
				
			||||||
 | 
					            @wraps(func)
 | 
				
			||||||
 | 
					            async def wrapper(*args, **kwargs):
 | 
				
			||||||
 | 
					                return await func(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return decorator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with patch("distil_whisper_model.async_retry", new=_mock_async_retry()):
 | 
				
			||||||
 | 
					        yield
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_retry_decorator_success():
 | 
				
			||||||
 | 
					    async def mock_async_function():
 | 
				
			||||||
 | 
					        return "Success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    decorated_function = async_retry()(mock_async_function)
 | 
				
			||||||
 | 
					    result = await decorated_function()
 | 
				
			||||||
 | 
					    assert result == "Success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_retry_decorator_failure():
 | 
				
			||||||
 | 
					    async def mock_async_function():
 | 
				
			||||||
 | 
					        raise Exception("Error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    decorated_function = async_retry()(mock_async_function)
 | 
				
			||||||
 | 
					    with pytest.raises(Exception, match="Error"):
 | 
				
			||||||
 | 
					        await decorated_function()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_async_retry_decorator_multiple_attempts():
 | 
				
			||||||
 | 
					    async def mock_async_function():
 | 
				
			||||||
 | 
					        if mock_async_function.attempts == 0:
 | 
				
			||||||
 | 
					            mock_async_function.attempts += 1
 | 
				
			||||||
 | 
					            raise Exception("Error")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return "Success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mock_async_function.attempts = 0
 | 
				
			||||||
 | 
					    decorated_function = async_retry(max_retries=2)(mock_async_function)
 | 
				
			||||||
 | 
					    result = await decorated_function()
 | 
				
			||||||
 | 
					    assert result == "Success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_create_audio_file():
 | 
				
			||||||
 | 
					    test_data = np.random.rand(16000)  # Simulated audio data (1 second)
 | 
				
			||||||
 | 
					    sample_rate = 16000
 | 
				
			||||||
 | 
					    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
 | 
				
			||||||
 | 
					        audio_file_path = create_audio_file(test_data, sample_rate, audio_file.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert os.path.exists(audio_file_path)
 | 
				
			||||||
 | 
					        os.remove(audio_file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    pytest.main()
 | 
				
			||||||
@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from swarms.models.jina_embeds import JinaEmbeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def model():
 | 
				
			||||||
 | 
					    return JinaEmbeddings("bert-base-uncased", verbose=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_initialization(model):
 | 
				
			||||||
 | 
					    assert isinstance(model, JinaEmbeddings)
 | 
				
			||||||
 | 
					    assert model.device in ["cuda", "cpu"]
 | 
				
			||||||
 | 
					    assert model.max_length == 500
 | 
				
			||||||
 | 
					    assert model.verbose is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_run_sync(model):
 | 
				
			||||||
 | 
					    task = "Encode this text"
 | 
				
			||||||
 | 
					    result = model.run(task)
 | 
				
			||||||
 | 
					    assert isinstance(result, torch.Tensor)
 | 
				
			||||||
 | 
					    assert result.shape == (model.max_length,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_run_async(model):
 | 
				
			||||||
 | 
					    task = "Encode this text"
 | 
				
			||||||
 | 
					    result = model.run_async(task)
 | 
				
			||||||
 | 
					    assert isinstance(result, torch.Tensor)
 | 
				
			||||||
 | 
					    assert result.shape == (model.max_length,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_save_model(tmp_path, model):
 | 
				
			||||||
 | 
					    model_path = tmp_path / "model"
 | 
				
			||||||
 | 
					    model.save_model(model_path)
 | 
				
			||||||
 | 
					    assert (model_path / "config.json").is_file()
 | 
				
			||||||
 | 
					    assert (model_path / "pytorch_model.bin").is_file()
 | 
				
			||||||
 | 
					    assert (model_path / "vocab.txt").is_file()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_gpu_available(model):
 | 
				
			||||||
 | 
					    gpu_status = model.gpu_available()
 | 
				
			||||||
 | 
					    if torch.cuda.is_available():
 | 
				
			||||||
 | 
					        assert gpu_status is True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert gpu_status is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_memory_consumption(model):
 | 
				
			||||||
 | 
					    memory_stats = model.memory_consumption()
 | 
				
			||||||
 | 
					    if torch.cuda.is_available():
 | 
				
			||||||
 | 
					        assert "allocated" in memory_stats
 | 
				
			||||||
 | 
					        assert "reserved" in memory_stats
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert "error" in memory_stats
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_cosine_similarity(model):
 | 
				
			||||||
 | 
					    task1 = "This is a sample text for testing."
 | 
				
			||||||
 | 
					    task2 = "Another sample text for testing."
 | 
				
			||||||
 | 
					    embeddings1 = model.run(task1)
 | 
				
			||||||
 | 
					    embeddings2 = model.run(task2)
 | 
				
			||||||
 | 
					    sim = model.cos_sim(embeddings1, embeddings2)
 | 
				
			||||||
 | 
					    assert isinstance(sim, torch.Tensor)
 | 
				
			||||||
 | 
					    assert sim.item() >= -1.0 and sim.item() <= 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_failed_load_model(caplog):
 | 
				
			||||||
 | 
					    with pytest.raises(Exception):
 | 
				
			||||||
 | 
					        JinaEmbeddings("invalid_model")
 | 
				
			||||||
 | 
					    assert "Failed to load the model or the tokenizer" in caplog.text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_failed_generate_text(caplog, model):
 | 
				
			||||||
 | 
					    with pytest.raises(Exception):
 | 
				
			||||||
 | 
					        model.run("invalid_task")
 | 
				
			||||||
 | 
					    assert "Failed to generate the text" in caplog.text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("device", ["cuda", "cpu"])
 | 
				
			||||||
 | 
					def test_change_device(model, device):
 | 
				
			||||||
 | 
					    model.device = device
 | 
				
			||||||
 | 
					    assert model.device == device
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue