parent
790e0e383d
commit
df2be1d22e
@ -1,263 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm_models.anthropic import Anthropic
|
||||
|
||||
|
||||
# Mock the Anthropic API client for testing
|
||||
class MockAnthropicClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def completions_create(
|
||||
self, prompt, stop_sequences, stream, **kwargs
|
||||
):
|
||||
return MockAnthropicResponse()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anthropic_env():
|
||||
os.environ["ANTHROPIC_API_URL"] = "https://test.anthropic.com"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
|
||||
yield
|
||||
del os.environ["ANTHROPIC_API_URL"]
|
||||
del os.environ["ANTHROPIC_API_KEY"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests_post():
|
||||
with patch("requests.post") as mock_post:
|
||||
yield mock_post
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_instance():
|
||||
return Anthropic(model="test-model")
|
||||
|
||||
|
||||
def test_anthropic_init_default_values(anthropic_instance):
|
||||
assert anthropic_instance.model == "test-model"
|
||||
assert anthropic_instance.max_tokens_to_sample == 256
|
||||
assert anthropic_instance.temperature is None
|
||||
assert anthropic_instance.top_k is None
|
||||
assert anthropic_instance.top_p is None
|
||||
assert anthropic_instance.streaming is False
|
||||
assert anthropic_instance.default_request_timeout == 600
|
||||
assert (
|
||||
anthropic_instance.anthropic_api_url
|
||||
== "https://test.anthropic.com"
|
||||
)
|
||||
assert anthropic_instance.anthropic_api_key == "test_api_key"
|
||||
|
||||
|
||||
def test_anthropic_init_custom_values():
|
||||
anthropic_instance = Anthropic(
|
||||
model="custom-model",
|
||||
max_tokens_to_sample=128,
|
||||
temperature=0.8,
|
||||
top_k=5,
|
||||
top_p=0.9,
|
||||
streaming=True,
|
||||
default_request_timeout=300,
|
||||
)
|
||||
assert anthropic_instance.model == "custom-model"
|
||||
assert anthropic_instance.max_tokens_to_sample == 128
|
||||
assert anthropic_instance.temperature == 0.8
|
||||
assert anthropic_instance.top_k == 5
|
||||
assert anthropic_instance.top_p == 0.9
|
||||
assert anthropic_instance.streaming is True
|
||||
assert anthropic_instance.default_request_timeout == 300
|
||||
|
||||
|
||||
def test_anthropic_default_params(anthropic_instance):
|
||||
default_params = anthropic_instance._default_params()
|
||||
assert default_params == {
|
||||
"max_tokens_to_sample": 256,
|
||||
"model": "test-model",
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_run(
|
||||
mock_anthropic_env, mock_requests_post, anthropic_instance
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"completion": "Generated text"}
|
||||
mock_requests_post.return_value = mock_response
|
||||
|
||||
task = "Generate text"
|
||||
stop = ["stop1", "stop2"]
|
||||
|
||||
completion = anthropic_instance.run(task, stop)
|
||||
|
||||
assert completion == "Generated text"
|
||||
mock_requests_post.assert_called_once_with(
|
||||
"https://test.anthropic.com/completions",
|
||||
headers={"Authorization": "Bearer test_api_key"},
|
||||
json={
|
||||
"prompt": task,
|
||||
"stop_sequences": stop,
|
||||
"max_tokens_to_sample": 256,
|
||||
"model": "test-model",
|
||||
},
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_call(
|
||||
mock_anthropic_env, mock_requests_post, anthropic_instance
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"completion": "Generated text"}
|
||||
mock_requests_post.return_value = mock_response
|
||||
|
||||
task = "Generate text"
|
||||
stop = ["stop1", "stop2"]
|
||||
|
||||
completion = anthropic_instance(task, stop)
|
||||
|
||||
assert completion == "Generated text"
|
||||
mock_requests_post.assert_called_once_with(
|
||||
"https://test.anthropic.com/completions",
|
||||
headers={"Authorization": "Bearer test_api_key"},
|
||||
json={
|
||||
"prompt": task,
|
||||
"stop_sequences": stop,
|
||||
"max_tokens_to_sample": 256,
|
||||
"model": "test-model",
|
||||
},
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_exception_handling(
|
||||
mock_anthropic_env, mock_requests_post, anthropic_instance
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"error": "An error occurred"}
|
||||
mock_requests_post.return_value = mock_response
|
||||
|
||||
task = "Generate text"
|
||||
stop = ["stop1", "stop2"]
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
anthropic_instance(task, stop)
|
||||
|
||||
assert "An error occurred" in str(excinfo.value)
|
||||
|
||||
|
||||
class MockAnthropicResponse:
|
||||
def __init__(self):
|
||||
self.completion = "Mocked Response from Anthropic"
|
||||
|
||||
|
||||
def test_anthropic_instance_creation(anthropic_instance):
|
||||
assert isinstance(anthropic_instance, Anthropic)
|
||||
|
||||
|
||||
def test_anthropic_call_method(anthropic_instance):
|
||||
response = anthropic_instance("What is the meaning of life?")
|
||||
assert response == "Mocked Response from Anthropic"
|
||||
|
||||
|
||||
def test_anthropic_stream_method(anthropic_instance):
|
||||
generator = anthropic_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_anthropic_async_call_method(anthropic_instance):
|
||||
response = anthropic_instance.async_call("Tell me a joke.")
|
||||
assert response == "Mocked Response from Anthropic"
|
||||
|
||||
|
||||
def test_anthropic_async_stream_method(anthropic_instance):
|
||||
async_generator = anthropic_instance.async_stream(
|
||||
"Translate to French."
|
||||
)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_anthropic_get_num_tokens(anthropic_instance):
|
||||
text = "This is a test sentence."
|
||||
num_tokens = anthropic_instance.get_num_tokens(text)
|
||||
assert num_tokens > 0
|
||||
|
||||
|
||||
# Add more test cases to cover other functionalities and edge cases of the Anthropic class
|
||||
|
||||
|
||||
def test_anthropic_wrap_prompt(anthropic_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
wrapped_prompt = anthropic_instance._wrap_prompt(prompt)
|
||||
assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT)
|
||||
assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_anthropic_convert_prompt(anthropic_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
converted_prompt = anthropic_instance.convert_prompt(prompt)
|
||||
assert converted_prompt.startswith(
|
||||
anthropic_instance.HUMAN_PROMPT
|
||||
)
|
||||
assert converted_prompt.endswith(anthropic_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_anthropic_call_with_stop(anthropic_instance):
|
||||
response = anthropic_instance(
|
||||
"Translate to French.", stop=["stop1", "stop2"]
|
||||
)
|
||||
assert response == "Mocked Response from Anthropic"
|
||||
|
||||
|
||||
def test_anthropic_stream_with_stop(anthropic_instance):
|
||||
generator = anthropic_instance.stream(
|
||||
"Write a story.", stop=["stop1", "stop2"]
|
||||
)
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_anthropic_async_call_with_stop(anthropic_instance):
|
||||
response = anthropic_instance.async_call(
|
||||
"Tell me a joke.", stop=["stop1", "stop2"]
|
||||
)
|
||||
assert response == "Mocked Response from Anthropic"
|
||||
|
||||
|
||||
def test_anthropic_async_stream_with_stop(anthropic_instance):
|
||||
async_generator = anthropic_instance.async_stream(
|
||||
"Translate to French.", stop=["stop1", "stop2"]
|
||||
)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_anthropic_get_num_tokens_with_count_tokens(
|
||||
anthropic_instance,
|
||||
):
|
||||
anthropic_instance.count_tokens = Mock(return_value=10)
|
||||
text = "This is a test sentence."
|
||||
num_tokens = anthropic_instance.get_num_tokens(text)
|
||||
assert num_tokens == 10
|
||||
|
||||
|
||||
def test_anthropic_get_num_tokens_without_count_tokens(
|
||||
anthropic_instance,
|
||||
):
|
||||
del anthropic_instance.count_tokens
|
||||
with pytest.raises(NameError):
|
||||
text = "This is a test sentence."
|
||||
anthropic_instance.get_num_tokens(text)
|
||||
|
||||
|
||||
def test_anthropic_wrap_prompt_without_human_ai_prompt(
|
||||
anthropic_instance,
|
||||
):
|
||||
del anthropic_instance.HUMAN_PROMPT
|
||||
del anthropic_instance.AI_PROMPT
|
||||
prompt = "What is the meaning of life?"
|
||||
with pytest.raises(NameError):
|
||||
anthropic_instance._wrap_prompt(prompt)
|
@ -1,784 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms import Cohere
|
||||
|
||||
# Load the environment variables
|
||||
load_dotenv()
|
||||
api_key = os.getenv("COHERE_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_instance():
|
||||
return Cohere(cohere_api_key=api_key)
|
||||
|
||||
|
||||
def test_cohere_custom_configuration(cohere_instance):
|
||||
# Test customizing Cohere configurations
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.temperature = 0.5
|
||||
cohere_instance.max_tokens = 100
|
||||
cohere_instance.k = 1
|
||||
cohere_instance.p = 0.8
|
||||
cohere_instance.frequency_penalty = 0.2
|
||||
cohere_instance.presence_penalty = 0.4
|
||||
response = cohere_instance("Customize configurations.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_api_error_handling(cohere_instance):
|
||||
# Test error handling when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
cohere_instance("Error handling with invalid API key.")
|
||||
|
||||
|
||||
def test_cohere_async_api_error_handling(cohere_instance):
|
||||
# Test async error handling when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
cohere_instance.async_call(
|
||||
"Error handling with invalid API key."
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_stream_api_error_handling(cohere_instance):
|
||||
# Test error handling in streaming mode when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
generator = cohere_instance.stream(
|
||||
"Error handling with invalid API key."
|
||||
)
|
||||
for token in generator:
|
||||
pass
|
||||
|
||||
|
||||
def test_cohere_streaming_mode(cohere_instance):
|
||||
# Test the streaming mode for large text generation
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.streaming = True
|
||||
prompt = "Generate a lengthy text using streaming mode."
|
||||
generator = cohere_instance.stream(prompt)
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_streaming_mode_async(cohere_instance):
|
||||
# Test the async streaming mode for large text generation
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.streaming = True
|
||||
prompt = "Generate a lengthy text using async streaming mode."
|
||||
async_generator = cohere_instance.async_stream(prompt)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_wrap_prompt(cohere_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
wrapped_prompt = cohere_instance._wrap_prompt(prompt)
|
||||
assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||
assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_cohere_convert_prompt(cohere_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
converted_prompt = cohere_instance.convert_prompt(prompt)
|
||||
assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||
assert converted_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_cohere_call_with_stop(cohere_instance):
|
||||
response = cohere_instance(
|
||||
"Translate to French.", stop=["stop1", "stop2"]
|
||||
)
|
||||
assert response == "Mocked Response from Cohere"
|
||||
|
||||
|
||||
def test_cohere_stream_with_stop(cohere_instance):
|
||||
generator = cohere_instance.stream(
|
||||
"Write a story.", stop=["stop1", "stop2"]
|
||||
)
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_stop(cohere_instance):
|
||||
response = cohere_instance.async_call(
|
||||
"Tell me a joke.", stop=["stop1", "stop2"]
|
||||
)
|
||||
assert response == "Mocked Response from Cohere"
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_stop(cohere_instance):
|
||||
async_generator = cohere_instance.async_stream(
|
||||
"Translate to French.", stop=["stop1", "stop2"]
|
||||
)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_get_num_tokens_with_count_tokens(cohere_instance):
|
||||
cohere_instance.count_tokens = Mock(return_value=10)
|
||||
text = "This is a test sentence."
|
||||
num_tokens = cohere_instance.get_num_tokens(text)
|
||||
assert num_tokens == 10
|
||||
|
||||
|
||||
def test_cohere_get_num_tokens_without_count_tokens(cohere_instance):
|
||||
del cohere_instance.count_tokens
|
||||
with pytest.raises(NameError):
|
||||
text = "This is a test sentence."
|
||||
cohere_instance.get_num_tokens(text)
|
||||
|
||||
|
||||
def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance):
|
||||
del cohere_instance.HUMAN_PROMPT
|
||||
del cohere_instance.AI_PROMPT
|
||||
prompt = "What is the meaning of life?"
|
||||
with pytest.raises(NameError):
|
||||
cohere_instance._wrap_prompt(prompt)
|
||||
|
||||
|
||||
def test_base_cohere_import():
|
||||
with patch.dict("sys.modules", {"cohere": None}):
|
||||
with pytest.raises(ImportError):
|
||||
pass
|
||||
|
||||
|
||||
def test_base_cohere_validate_environment():
|
||||
values = {
|
||||
"cohere_api_key": "my-api-key",
|
||||
"user_agent": "langchain",
|
||||
}
|
||||
validated_values = Cohere.validate_environment(values)
|
||||
assert "client" in validated_values
|
||||
assert "async_client" in validated_values
|
||||
|
||||
|
||||
def test_base_cohere_validate_environment_without_cohere():
|
||||
values = {
|
||||
"cohere_api_key": "my-api-key",
|
||||
"user_agent": "langchain",
|
||||
}
|
||||
with patch.dict("sys.modules", {"cohere": None}):
|
||||
with pytest.raises(ImportError):
|
||||
Cohere.validate_environment(values)
|
||||
|
||||
|
||||
# Test cases for benchmarking generations with various models
|
||||
def test_cohere_generate_with_command_light(cohere_instance):
|
||||
cohere_instance.model = "command-light"
|
||||
response = cohere_instance(
|
||||
"Generate text with Command Light model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated text with Command Light model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_command(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance("Generate text with Command model.")
|
||||
assert response.startswith("Generated text with Command model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_base_light(cohere_instance):
|
||||
cohere_instance.model = "base-light"
|
||||
response = cohere_instance("Generate text with Base Light model.")
|
||||
assert response.startswith("Generated text with Base Light model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_base(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance("Generate text with Base model.")
|
||||
assert response.startswith("Generated text with Base model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with English v2.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with English v2.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-english-light-v2.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with English Light v2.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with English Light v2.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with Multilingual v2.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with Multilingual v2.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with English v3.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with English v3.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with English Light v3.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with English Light v3.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with Multilingual v3.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with Multilingual v3.0 model"
|
||||
)
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_light_v3(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with Multilingual Light v3.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with Multilingual Light v3.0 model"
|
||||
)
|
||||
|
||||
|
||||
# Add more test cases to benchmark other models and functionalities
|
||||
|
||||
|
||||
def test_cohere_call_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_multilingual_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_multilingual_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_invalid_model(cohere_instance):
|
||||
cohere_instance.model = "invalid-model"
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance("Translate to French.")
|
||||
|
||||
|
||||
def test_cohere_call_with_long_prompt(cohere_instance):
|
||||
prompt = "This is a very long prompt. " * 100
|
||||
response = cohere_instance(prompt)
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance(prompt)
|
||||
|
||||
|
||||
def test_cohere_stream_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_multilingual_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_multilingual_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_english_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_english_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_multilingual_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_multilingual_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_english_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_english_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_multilingual_v2_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_multilingual_v3_model(
|
||||
cohere_instance,
|
||||
):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_embedding(cohere_instance):
|
||||
# Test using the Representation model for text embedding
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
embedding = cohere_instance.embed(
|
||||
"Generate an embedding for this text."
|
||||
)
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_classification(cohere_instance):
|
||||
# Test using the Representation model for text classification
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
classification = cohere_instance.classify("Classify this text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for language detection
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of this text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
# Add more production-grade test cases based on real-world scenarios
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_embedding(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual text embedding
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
embedding = cohere_instance.embed(
|
||||
"Generate multilingual embeddings."
|
||||
)
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_classification(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual text classification
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
classification = cohere_instance.classify(
|
||||
"Classify multilingual text."
|
||||
)
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual language detection
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of multilingual text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for multilingual model
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit"
|
||||
" for multilingual model."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_embedding(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual light text embedding
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
embedding = cohere_instance.embed(
|
||||
"Generate multilingual light embeddings."
|
||||
)
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_classification(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual light text classification
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
classification = cohere_instance.classify(
|
||||
"Classify multilingual light text."
|
||||
)
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual light language detection
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of multilingual light text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for multilingual light model
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit"
|
||||
" for multilingual light model."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_command_light_model(cohere_instance):
|
||||
# Test using the Command Light model for text generation
|
||||
cohere_instance.model = "command-light"
|
||||
response = cohere_instance(
|
||||
"Generate text using Command Light model."
|
||||
)
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_base_light_model(cohere_instance):
|
||||
# Test using the Base Light model for text generation
|
||||
cohere_instance.model = "base-light"
|
||||
response = cohere_instance(
|
||||
"Generate text using Base Light model."
|
||||
)
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_generate_summarize_endpoint(cohere_instance):
|
||||
# Test using the Co.summarize() endpoint for text summarization
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance.summarize("Summarize this text.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_embedding(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English text embedding
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
embedding = cohere_instance.embed("Generate English embeddings.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_classification(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English text classification
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
classification = cohere_instance.classify(
|
||||
"Classify English text."
|
||||
)
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English language detection
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of English text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for English model
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit"
|
||||
" for English model."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_embedding(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English light text embedding
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
embedding = cohere_instance.embed(
|
||||
"Generate English light embeddings."
|
||||
)
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_classification(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English light text classification
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
classification = cohere_instance.classify(
|
||||
"Classify English light text."
|
||||
)
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for English light language detection
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of English light text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for English light model
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit"
|
||||
" for English light model."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_command_model(cohere_instance):
|
||||
# Test using the Command model for text generation
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance(
|
||||
"Generate text using the Command model."
|
||||
)
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# Add more production-grade test cases based on real-world scenarios
|
||||
|
||||
|
||||
def test_cohere_invalid_model(cohere_instance):
|
||||
# Test using an invalid model name
|
||||
cohere_instance.model = "invalid-model"
|
||||
with pytest.raises(ValueError):
|
||||
cohere_instance("Generate text using an invalid model.")
|
||||
|
||||
|
||||
def test_cohere_base_model_generation_with_max_tokens(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test generating text using the base model with a specified max_tokens limit
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.max_tokens = 20
|
||||
prompt = "Generate text with max_tokens limit."
|
||||
response = cohere_instance(prompt)
|
||||
assert len(response.split()) <= 20
|
||||
|
||||
|
||||
def test_cohere_command_light_generation_with_stop(cohere_instance):
|
||||
# Test generating text using the command-light model with stop words
|
||||
cohere_instance.model = "command-light"
|
||||
prompt = "Generate text with stop words."
|
||||
stop = ["stop", "words"]
|
||||
response = cohere_instance(prompt, stop=stop)
|
||||
assert all(word not in response for word in stop)
|
@ -1,207 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import FuyuImageProcessor, FuyuProcessor
|
||||
|
||||
from swarm_models.fuyu import Fuyu
|
||||
|
||||
|
||||
# Basic test to ensure instantiation of class.
|
||||
def test_fuyu_initialization():
|
||||
fuyu_instance = Fuyu()
|
||||
assert isinstance(fuyu_instance, Fuyu)
|
||||
|
||||
|
||||
# Using parameterized testing for different init parameters.
|
||||
@pytest.mark.parametrize(
|
||||
"pretrained_path, device_map, max_new_tokens",
|
||||
[
|
||||
("adept/fuyu-8b", "cuda:0", 7),
|
||||
("adept/fuyu-8b", "cpu", 10),
|
||||
],
|
||||
)
|
||||
def test_fuyu_parameters(pretrained_path, device_map, max_new_tokens):
|
||||
fuyu_instance = Fuyu(pretrained_path, device_map, max_new_tokens)
|
||||
assert fuyu_instance.pretrained_path == pretrained_path
|
||||
assert fuyu_instance.device_map == device_map
|
||||
assert fuyu_instance.max_new_tokens == max_new_tokens
|
||||
|
||||
|
||||
# Fixture for creating a Fuyu instance.
|
||||
@pytest.fixture
|
||||
def fuyu_instance():
|
||||
return Fuyu()
|
||||
|
||||
|
||||
# Test using the fixture.
|
||||
def test_fuyu_processor_initialization(fuyu_instance):
|
||||
assert isinstance(fuyu_instance.processor, FuyuProcessor)
|
||||
assert isinstance(
|
||||
fuyu_instance.image_processor, FuyuImageProcessor
|
||||
)
|
||||
|
||||
|
||||
# Test exception when providing an invalid image path.
|
||||
def test_invalid_image_path(fuyu_instance):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
fuyu_instance("Hello", "invalid/path/to/image.png")
|
||||
|
||||
|
||||
# Using monkeypatch to replace the Image.open method to simulate a failure.
|
||||
def test_image_open_failure(fuyu_instance, monkeypatch):
|
||||
def mock_open(*args, **kwargs):
|
||||
raise Exception("Mocked failure")
|
||||
|
||||
monkeypatch.setattr(Image, "open", mock_open)
|
||||
|
||||
with pytest.raises(Exception, match="Mocked failure"):
|
||||
fuyu_instance(
|
||||
"Hello",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
|
||||
|
||||
# Marking a slow test.
|
||||
@pytest.mark.slow
|
||||
def test_fuyu_model_output(fuyu_instance):
|
||||
# This is a dummy test and may not be functional without real data.
|
||||
output = fuyu_instance(
|
||||
"Hello, my name is",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_tokenizer_type(fuyu_instance):
|
||||
assert "tokenizer" in dir(fuyu_instance)
|
||||
|
||||
|
||||
def test_processor_has_image_processor_and_tokenizer(fuyu_instance):
|
||||
assert (
|
||||
fuyu_instance.processor.image_processor
|
||||
== fuyu_instance.image_processor
|
||||
)
|
||||
assert (
|
||||
fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer
|
||||
)
|
||||
|
||||
|
||||
def test_model_device_map(fuyu_instance):
|
||||
assert fuyu_instance.model.device_map == fuyu_instance.device_map
|
||||
|
||||
|
||||
# Testing maximum tokens setting
|
||||
def test_max_new_tokens_setting(fuyu_instance):
|
||||
assert fuyu_instance.max_new_tokens == 7
|
||||
|
||||
|
||||
# Test if an exception is raised when invalid text is provided.
|
||||
def test_invalid_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance(None, "path/to/image.png")
|
||||
|
||||
|
||||
# Test if an exception is raised when empty text is provided.
|
||||
def test_empty_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance("", "path/to/image.png")
|
||||
|
||||
|
||||
# Test if an exception is raised when a very long text is provided.
|
||||
def test_very_long_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance("A" * 10000, "path/to/image.png")
|
||||
|
||||
|
||||
# Check model's default device map
|
||||
def test_default_device_map():
|
||||
fuyu_instance = Fuyu()
|
||||
assert fuyu_instance.device_map == "cuda:0"
|
||||
|
||||
|
||||
# Testing if processor is correctly initialized
|
||||
def test_processor_initialization(fuyu_instance):
|
||||
assert isinstance(fuyu_instance.processor, FuyuProcessor)
|
||||
|
||||
|
||||
# Test `get_img` method with a valid image path
|
||||
def test_get_img_valid_path(fuyu_instance):
|
||||
with patch("PIL.Image.open") as mock_open:
|
||||
mock_open.return_value = "Test image"
|
||||
result = fuyu_instance.get_img("valid/path/to/image.png")
|
||||
assert result == "Test image"
|
||||
|
||||
|
||||
# Test `get_img` method with an invalid image path
|
||||
def test_get_img_invalid_path(fuyu_instance):
|
||||
with patch("PIL.Image.open") as mock_open:
|
||||
mock_open.side_effect = FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
fuyu_instance.get_img("invalid/path/to/image.png")
|
||||
|
||||
|
||||
# Test `run` method with valid inputs
|
||||
def test_run_valid_inputs(fuyu_instance):
|
||||
with patch.object(
|
||||
fuyu_instance, "get_img"
|
||||
) as mock_get_img, patch.object(
|
||||
fuyu_instance, "processor"
|
||||
) as mock_processor, patch.object(
|
||||
fuyu_instance, "model"
|
||||
) as mock_model:
|
||||
mock_get_img.return_value = "Test image"
|
||||
mock_processor.return_value = {
|
||||
"input_ids": torch.tensor([1, 2, 3])
|
||||
}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test text"]
|
||||
result = fuyu_instance.run(
|
||||
"Hello, world!", "valid/path/to/image.png"
|
||||
)
|
||||
assert result == ["Test text"]
|
||||
|
||||
|
||||
# Test `run` method with invalid text input
|
||||
def test_run_invalid_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance.run(None, "valid/path/to/image.png")
|
||||
|
||||
|
||||
# Test `run` method with empty text input
|
||||
def test_run_empty_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance.run("", "valid/path/to/image.png")
|
||||
|
||||
|
||||
# Test `run` method with very long text input
|
||||
def test_run_very_long_text_input(fuyu_instance):
|
||||
with pytest.raises(Exception):
|
||||
fuyu_instance.run("A" * 10000, "valid/path/to/image.png")
|
||||
|
||||
|
||||
# Test `run` method with invalid image path
|
||||
def test_run_invalid_image_path(fuyu_instance):
|
||||
with patch.object(fuyu_instance, "get_img") as mock_get_img:
|
||||
mock_get_img.side_effect = FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
fuyu_instance.run(
|
||||
"Hello, world!", "invalid/path/to/image.png"
|
||||
)
|
||||
|
||||
|
||||
# Test `__init__` method with default parameters
|
||||
def test_init_default_parameters():
|
||||
fuyu_instance = Fuyu()
|
||||
assert fuyu_instance.pretrained_path == "adept/fuyu-8b"
|
||||
assert fuyu_instance.device_map == "auto"
|
||||
assert fuyu_instance.max_new_tokens == 500
|
||||
|
||||
|
||||
# Test `__init__` method with custom parameters
|
||||
def test_init_custom_parameters():
|
||||
fuyu_instance = Fuyu("custom/path", "cpu", 1000)
|
||||
assert fuyu_instance.pretrained_path == "custom/path"
|
||||
assert fuyu_instance.device_map == "cpu"
|
||||
assert fuyu_instance.max_new_tokens == 1000
|
@ -1,315 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm_models.gemini import Gemini
|
||||
|
||||
|
||||
# Define test fixtures
|
||||
@pytest.fixture
|
||||
def mock_gemini_api_key(monkeypatch):
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "mocked-api-key")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_model():
|
||||
return Mock()
|
||||
|
||||
|
||||
# Test initialization of Gemini
|
||||
def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model):
|
||||
model = Gemini()
|
||||
assert model.model_name == "gemini-pro"
|
||||
assert model.gemini_api_key == "mocked-api-key"
|
||||
assert model.model is mock_genai_model
|
||||
|
||||
|
||||
def test_gemini_init_custom_params(
|
||||
mock_gemini_api_key, mock_genai_model
|
||||
):
|
||||
model = Gemini(
|
||||
model_name="custom-model", gemini_api_key="custom-api-key"
|
||||
)
|
||||
assert model.model_name == "custom-model"
|
||||
assert model.gemini_api_key == "custom-api-key"
|
||||
assert model.model is mock_genai_model
|
||||
|
||||
|
||||
# Test Gemini run method
|
||||
@patch("swarms.models.gemini.Gemini.process_img")
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_with_img(
|
||||
mock_generate_content,
|
||||
mock_process_img,
|
||||
mock_gemini_api_key,
|
||||
mock_genai_model,
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
img = "cat.png"
|
||||
response_mock = Mock(text="Generated response")
|
||||
mock_generate_content.return_value = response_mock
|
||||
mock_process_img.return_value = "Processed image"
|
||||
|
||||
response = model.run(task=task, img=img)
|
||||
|
||||
assert response == "Generated response"
|
||||
mock_generate_content.assert_called_with(
|
||||
content=[task, "Processed image"]
|
||||
)
|
||||
mock_process_img.assert_called_with(img=img)
|
||||
|
||||
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_without_img(
|
||||
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
response_mock = Mock(text="Generated response")
|
||||
mock_generate_content.return_value = response_mock
|
||||
|
||||
response = model.run(task=task)
|
||||
|
||||
assert response == "Generated response"
|
||||
mock_generate_content.assert_called_with(task=task)
|
||||
|
||||
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_exception(
|
||||
mock_generate_content, mock_gemini_api_key, mock_genai_model
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
mock_generate_content.side_effect = Exception("Test exception")
|
||||
|
||||
response = model.run(task=task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
# Test Gemini process_img method
|
||||
def test_gemini_process_img(mock_gemini_api_key, mock_genai_model):
|
||||
model = Gemini(gemini_api_key="custom-api-key")
|
||||
img = "cat.png"
|
||||
img_data = b"Mocked image data"
|
||||
|
||||
with patch("builtins.open", create=True) as open_mock:
|
||||
open_mock.return_value.__enter__.return_value.read.return_value = (
|
||||
img_data
|
||||
)
|
||||
|
||||
processed_img = model.process_img(img)
|
||||
|
||||
assert processed_img == [
|
||||
{"mime_type": "image/png", "data": img_data}
|
||||
]
|
||||
open_mock.assert_called_with(img, "rb")
|
||||
|
||||
|
||||
# Test Gemini initialization with missing API key
|
||||
def test_gemini_init_missing_api_key():
|
||||
with pytest.raises(
|
||||
ValueError, match="Please provide a Gemini API key"
|
||||
):
|
||||
Gemini(gemini_api_key=None)
|
||||
|
||||
|
||||
# Test Gemini initialization with missing model name
|
||||
def test_gemini_init_missing_model_name():
|
||||
with pytest.raises(
|
||||
ValueError, match="Please provide a model name"
|
||||
):
|
||||
Gemini(model_name=None)
|
||||
|
||||
|
||||
# Test Gemini run method with empty task
|
||||
def test_gemini_run_empty_task(mock_gemini_api_key, mock_genai_model):
|
||||
model = Gemini()
|
||||
task = ""
|
||||
response = model.run(task=task)
|
||||
assert response is None
|
||||
|
||||
|
||||
# Test Gemini run method with empty image
|
||||
def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
img = ""
|
||||
response = model.run(task=task, img=img)
|
||||
assert response is None
|
||||
|
||||
|
||||
# Test Gemini process_img method with missing image
|
||||
def test_gemini_process_img_missing_image(
|
||||
mock_gemini_api_key, mock_genai_model
|
||||
):
|
||||
model = Gemini()
|
||||
img = None
|
||||
with pytest.raises(
|
||||
ValueError, match="Please provide an image to process"
|
||||
):
|
||||
model.process_img(img=img)
|
||||
|
||||
|
||||
# Test Gemini process_img method with missing image type
|
||||
def test_gemini_process_img_missing_image_type(
|
||||
mock_gemini_api_key, mock_genai_model
|
||||
):
|
||||
model = Gemini()
|
||||
img = "cat.png"
|
||||
with pytest.raises(
|
||||
ValueError, match="Please provide the image type"
|
||||
):
|
||||
model.process_img(img=img, type=None)
|
||||
|
||||
|
||||
# Test Gemini process_img method with missing Gemini API key
|
||||
def test_gemini_process_img_missing_api_key(mock_genai_model):
|
||||
model = Gemini(gemini_api_key=None)
|
||||
img = "cat.png"
|
||||
with pytest.raises(
|
||||
ValueError, match="Please provide a Gemini API key"
|
||||
):
|
||||
model.process_img(img=img, type="image/png")
|
||||
|
||||
|
||||
# Test Gemini run method with mocked image processing
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
@patch("swarms.models.gemini.Gemini.process_img")
|
||||
def test_gemini_run_mock_img_processing(
|
||||
mock_process_img,
|
||||
mock_generate_content,
|
||||
mock_gemini_api_key,
|
||||
mock_genai_model,
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
img = "cat.png"
|
||||
response_mock = Mock(text="Generated response")
|
||||
mock_generate_content.return_value = response_mock
|
||||
mock_process_img.return_value = "Processed image"
|
||||
|
||||
response = model.run(task=task, img=img)
|
||||
|
||||
assert response == "Generated response"
|
||||
mock_generate_content.assert_called_with(
|
||||
content=[task, "Processed image"]
|
||||
)
|
||||
mock_process_img.assert_called_with(img=img)
|
||||
|
||||
|
||||
# Test Gemini run method with mocked image processing and exception
|
||||
@patch("swarms.models.gemini.Gemini.process_img")
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_mock_img_processing_exception(
|
||||
mock_generate_content,
|
||||
mock_process_img,
|
||||
mock_gemini_api_key,
|
||||
mock_genai_model,
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A cat"
|
||||
img = "cat.png"
|
||||
mock_process_img.side_effect = Exception("Test exception")
|
||||
|
||||
response = model.run(task=task, img=img)
|
||||
|
||||
assert response is None
|
||||
mock_generate_content.assert_not_called()
|
||||
mock_process_img.assert_called_with(img=img)
|
||||
|
||||
|
||||
# Test Gemini run method with mocked image processing and different exception
|
||||
@patch("swarms.models.gemini.Gemini.process_img")
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_mock_img_processing_different_exception(
|
||||
mock_generate_content,
|
||||
mock_process_img,
|
||||
mock_gemini_api_key,
|
||||
mock_genai_model,
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A dog"
|
||||
img = "dog.png"
|
||||
mock_process_img.side_effect = ValueError("Test exception")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.run(task=task, img=img)
|
||||
|
||||
mock_generate_content.assert_not_called()
|
||||
mock_process_img.assert_called_with(img=img)
|
||||
|
||||
|
||||
# Test Gemini run method with mocked image processing and no exception
|
||||
@patch("swarms.models.gemini.Gemini.process_img")
|
||||
@patch("swarms.models.gemini.genai.GenerativeModel.generate_content")
|
||||
def test_gemini_run_mock_img_processing_no_exception(
|
||||
mock_generate_content,
|
||||
mock_process_img,
|
||||
mock_gemini_api_key,
|
||||
mock_genai_model,
|
||||
):
|
||||
model = Gemini()
|
||||
task = "A bird"
|
||||
img = "bird.png"
|
||||
mock_generate_content.return_value = "A bird is flying"
|
||||
|
||||
response = model.run(task=task, img=img)
|
||||
|
||||
assert response == "A bird is flying"
|
||||
mock_generate_content.assert_called_once()
|
||||
mock_process_img.assert_called_with(img=img)
|
||||
|
||||
|
||||
# Test Gemini chat method
|
||||
@patch("swarms.models.gemini.Gemini.chat")
|
||||
def test_gemini_chat(mock_chat):
|
||||
model = Gemini()
|
||||
mock_chat.return_value = "Hello, Gemini!"
|
||||
|
||||
response = model.chat("Hello, Gemini!")
|
||||
|
||||
assert response == "Hello, Gemini!"
|
||||
mock_chat.assert_called_once()
|
||||
|
||||
|
||||
# Test Gemini list_models method
|
||||
@patch("swarms.models.gemini.Gemini.list_models")
|
||||
def test_gemini_list_models(mock_list_models):
|
||||
model = Gemini()
|
||||
mock_list_models.return_value = ["model1", "model2"]
|
||||
|
||||
response = model.list_models()
|
||||
|
||||
assert response == ["model1", "model2"]
|
||||
mock_list_models.assert_called_once()
|
||||
|
||||
|
||||
# Test Gemini stream_tokens method
|
||||
@patch("swarms.models.gemini.Gemini.stream_tokens")
|
||||
def test_gemini_stream_tokens(mock_stream_tokens):
|
||||
model = Gemini()
|
||||
mock_stream_tokens.return_value = ["token1", "token2"]
|
||||
|
||||
response = model.stream_tokens()
|
||||
|
||||
assert response == ["token1", "token2"]
|
||||
mock_stream_tokens.assert_called_once()
|
||||
|
||||
|
||||
# Test Gemini process_img_pil method
|
||||
@patch("swarms.models.gemini.Gemini.process_img_pil")
|
||||
def test_gemini_process_img_pil(mock_process_img_pil):
|
||||
model = Gemini()
|
||||
img = "bird.png"
|
||||
mock_process_img_pil.return_value = "processed image"
|
||||
|
||||
response = model.process_img_pil(img)
|
||||
|
||||
assert response == "processed image"
|
||||
mock_process_img_pil.assert_called_with(img)
|
||||
|
||||
|
||||
# Repeat the above tests for different scenarios or different methods in your Gemini class
|
||||
# until you have 15 tests in total.
|
@ -1,252 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import ClientResponseError
|
||||
from dotenv import load_dotenv
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from swarm_models.gpt4_vision_api import GPT4VisionAPI
|
||||
|
||||
load_dotenv()
|
||||
|
||||
custom_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
img = "images/swarms.jpeg"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vision_api():
|
||||
return GPT4VisionAPI(openai_api_key="test_api_key")
|
||||
|
||||
|
||||
def test_init(vision_api):
|
||||
assert vision_api.openai_api_key == "test_api_key"
|
||||
|
||||
|
||||
def test_encode_image(vision_api):
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=b"test_image_data"),
|
||||
create=True,
|
||||
):
|
||||
encoded_image = vision_api.encode_image(img)
|
||||
assert encoded_image == "dGVzdF9pbWFnZV9kYXRh"
|
||||
|
||||
|
||||
def test_run_success(vision_api):
|
||||
expected_response = {"This is the model's response."}
|
||||
with patch(
|
||||
"requests.post",
|
||||
return_value=Mock(json=lambda: expected_response),
|
||||
) as mock_post:
|
||||
result = vision_api.run("What is this?", img)
|
||||
mock_post.assert_called_once()
|
||||
assert result == "This is the model's response."
|
||||
|
||||
|
||||
def test_run_request_error(vision_api):
|
||||
with patch(
|
||||
"requests.post", side_effect=RequestException("Request Error")
|
||||
):
|
||||
with pytest.raises(RequestException):
|
||||
vision_api.run("What is this?", img)
|
||||
|
||||
|
||||
def test_run_response_error(vision_api):
|
||||
expected_response = {"error": "Model Error"}
|
||||
with patch(
|
||||
"requests.post",
|
||||
return_value=Mock(json=lambda: expected_response),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
vision_api.run("What is this?", img)
|
||||
|
||||
|
||||
def test_call(vision_api):
|
||||
expected_response = {
|
||||
"choices": [{"text": "This is the model's response."}]
|
||||
}
|
||||
with patch(
|
||||
"requests.post",
|
||||
return_value=Mock(json=lambda: expected_response),
|
||||
) as mock_post:
|
||||
result = vision_api("What is this?", img)
|
||||
mock_post.assert_called_once()
|
||||
assert result == "This is the model's response."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt_api():
|
||||
return GPT4VisionAPI()
|
||||
|
||||
|
||||
def test_initialization_with_default_key():
|
||||
api = GPT4VisionAPI()
|
||||
assert api.openai_api_key == custom_api_key
|
||||
|
||||
|
||||
def test_initialization_with_custom_key():
|
||||
custom_key = custom_api_key
|
||||
api = GPT4VisionAPI(openai_api_key=custom_key)
|
||||
assert api.openai_api_key == custom_key
|
||||
|
||||
|
||||
def test_run_with_exception(gpt_api):
|
||||
task = "What is in the image?"
|
||||
img_url = img
|
||||
with patch(
|
||||
"requests.post", side_effect=Exception("Test Exception")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
gpt_api.run(task, img_url)
|
||||
|
||||
|
||||
def test_call_method_successful_response(gpt_api):
|
||||
task = "What is in the image?"
|
||||
img_url = img
|
||||
response_json = {
|
||||
"choices": [{"text": "Answer from GPT-4 Vision"}]
|
||||
}
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = response_json
|
||||
with patch(
|
||||
"requests.post", return_value=mock_response
|
||||
) as mock_post:
|
||||
result = gpt_api(task, img_url)
|
||||
mock_post.assert_called_once()
|
||||
assert result == response_json
|
||||
|
||||
|
||||
def test_call_method_with_exception(gpt_api):
|
||||
task = "What is in the image?"
|
||||
img_url = img
|
||||
with patch(
|
||||
"requests.post", side_effect=Exception("Test Exception")
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
gpt_api(task, img_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_success(vision_api):
|
||||
expected_response = {
|
||||
"choices": [
|
||||
{"message": {"content": "This is the model's response."}}
|
||||
]
|
||||
}
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=AsyncMock(
|
||||
json=AsyncMock(return_value=expected_response)
|
||||
),
|
||||
) as mock_post:
|
||||
result = await vision_api.arun("What is this?", img)
|
||||
mock_post.assert_called_once()
|
||||
assert result == "This is the model's response."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_request_error(vision_api):
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Request Error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await vision_api.arun("What is this?", img)
|
||||
|
||||
|
||||
def test_run_many_success(vision_api):
|
||||
expected_response = {
|
||||
"choices": [
|
||||
{"message": {"content": "This is the model's response."}}
|
||||
]
|
||||
}
|
||||
with patch(
|
||||
"requests.post",
|
||||
return_value=Mock(json=lambda: expected_response),
|
||||
) as mock_post:
|
||||
tasks = ["What is this?", "What is that?"]
|
||||
imgs = [img, img]
|
||||
results = vision_api.run_many(tasks, imgs)
|
||||
assert mock_post.call_count == 2
|
||||
assert results == [
|
||||
"This is the model's response.",
|
||||
"This is the model's response.",
|
||||
]
|
||||
|
||||
|
||||
def test_run_many_request_error(vision_api):
|
||||
with patch(
|
||||
"requests.post", side_effect=RequestException("Request Error")
|
||||
):
|
||||
tasks = ["What is this?", "What is that?"]
|
||||
imgs = [img, img]
|
||||
with pytest.raises(RequestException):
|
||||
vision_api.run_many(tasks, imgs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_json_decode_error(vision_api):
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=AsyncMock(
|
||||
json=AsyncMock(side_effect=ValueError)
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await vision_api.arun("What is this?", img)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_api_error(vision_api):
|
||||
error_response = {"error": {"message": "API Error"}}
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=AsyncMock(
|
||||
json=AsyncMock(return_value=error_response)
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await vision_api.arun("What is this?", img)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_unexpected_response(vision_api):
|
||||
unexpected_response = {"unexpected": "response"}
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
return_value=AsyncMock(
|
||||
json=AsyncMock(return_value=unexpected_response)
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match="Unexpected response"):
|
||||
await vision_api.arun("What is this?", img)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_retries(vision_api):
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ClientResponseError(None, None),
|
||||
) as mock_post:
|
||||
with pytest.raises(ClientResponseError):
|
||||
await vision_api.arun("What is this?", img)
|
||||
assert mock_post.call_count == vision_api.retries + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_timeout(vision_api):
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=asyncio.TimeoutError,
|
||||
):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await vision_api.arun("What is this?", img)
|
@ -1,465 +0,0 @@
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarm_models.huggingface import HuggingfaceLLM
|
||||
|
||||
|
||||
# Fixture for the class instance
|
||||
@pytest.fixture
|
||||
def llm_instance():
|
||||
model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
||||
instance = HuggingfaceLLM(model_id=model_id)
|
||||
return instance
|
||||
|
||||
|
||||
# Test for instantiation and attributes
|
||||
def test_llm_initialization(llm_instance):
|
||||
assert (
|
||||
llm_instance.model_id
|
||||
== "NousResearch/Nous-Hermes-2-Vision-Alpha"
|
||||
)
|
||||
assert llm_instance.max_length == 500
|
||||
# ... add more assertions for all default attributes
|
||||
|
||||
|
||||
# Parameterized test for setting devices
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_llm_set_device(llm_instance, device):
|
||||
llm_instance.set_device(device)
|
||||
assert llm_instance.device == device
|
||||
|
||||
|
||||
# Test exception during initialization with a bad model_id
|
||||
def test_llm_bad_model_initialization():
|
||||
with pytest.raises(Exception):
|
||||
HuggingfaceLLM(model_id="unknown-model")
|
||||
|
||||
|
||||
# # Mocking the tokenizer and model to test run method
|
||||
# @patch("swarms.models.huggingface.AutoTokenizer.from_pretrained")
|
||||
# @patch(
|
||||
# "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
|
||||
# )
|
||||
# def test_llm_run(mock_model, mock_tokenizer, llm_instance):
|
||||
# mock_model.return_value.generate.return_value = "mocked output"
|
||||
# mock_tokenizer.return_value.encode.return_value = "mocked input"
|
||||
# result = llm_instance.run("test task")
|
||||
# assert result == "mocked output"
|
||||
|
||||
|
||||
# Async test (requires pytest-asyncio plugin)
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_run_async(llm_instance):
|
||||
result = await llm_instance.run_async("test task")
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# Test for checking GPU availability
|
||||
def test_llm_gpu_availability(llm_instance):
|
||||
# Assuming the test is running on a machine where the GPU availability is known
|
||||
expected_result = torch.cuda.is_available()
|
||||
assert llm_instance.gpu_available() == expected_result
|
||||
|
||||
|
||||
# Test for memory consumption reporting
|
||||
def test_llm_memory_consumption(llm_instance):
|
||||
# Mocking torch.cuda functions for consistent results
|
||||
with patch("torch.cuda.memory_allocated", return_value=1024):
|
||||
with patch("torch.cuda.memory_reserved", return_value=2048):
|
||||
memory = llm_instance.memory_consumption()
|
||||
assert memory == {"allocated": 1024, "reserved": 2048}
|
||||
|
||||
|
||||
# Test different initialization parameters
|
||||
@pytest.mark.parametrize(
|
||||
"model_id, max_length",
|
||||
[
|
||||
("NousResearch/Nous-Hermes-2-Vision-Alpha", 100),
|
||||
("microsoft/Orca-2-13b", 200),
|
||||
(
|
||||
"berkeley-nest/Starling-LM-7B-alpha",
|
||||
None,
|
||||
), # None to check default behavior
|
||||
],
|
||||
)
|
||||
def test_llm_initialization_params(model_id, max_length):
|
||||
if max_length:
|
||||
instance = HuggingfaceLLM(
|
||||
model_id=model_id, max_length=max_length
|
||||
)
|
||||
assert instance.max_length == max_length
|
||||
else:
|
||||
instance = HuggingfaceLLM(model_id=model_id)
|
||||
assert (
|
||||
instance.max_length == 500
|
||||
) # Assuming 500 is the default max_length
|
||||
|
||||
|
||||
# Test for setting an invalid device
|
||||
def test_llm_set_invalid_device(llm_instance):
|
||||
with pytest.raises(ValueError):
|
||||
llm_instance.set_device("quantum_processor")
|
||||
|
||||
|
||||
# Mocking external API call to test run method without network
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_without_network(mock_run, llm_instance):
|
||||
mock_run.return_value = "mocked output"
|
||||
result = llm_instance.run("test task without network")
|
||||
assert result == "mocked output"
|
||||
|
||||
|
||||
# Test handling of empty input for the run method
|
||||
def test_llm_run_empty_input(llm_instance):
|
||||
with pytest.raises(ValueError):
|
||||
llm_instance.run("")
|
||||
|
||||
|
||||
# Test the generation with a provided seed for reproducibility
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_with_seed(mock_run, llm_instance):
|
||||
seed = 42
|
||||
llm_instance.set_seed(seed)
|
||||
# Assuming set_seed method affects the randomness in the model
|
||||
# You would typically ensure that setting the seed gives reproducible results
|
||||
mock_run.return_value = "mocked deterministic output"
|
||||
result = llm_instance.run("test task", seed=seed)
|
||||
assert result == "mocked deterministic output"
|
||||
|
||||
|
||||
# Test the output length is as expected
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_run_output_length(mock_run, llm_instance):
|
||||
input_text = "test task"
|
||||
llm_instance.max_length = 50 # set a max_length for the output
|
||||
mock_run.return_value = "mocked output" * 10 # some long text
|
||||
result = llm_instance.run(input_text)
|
||||
assert len(result.split()) <= llm_instance.max_length
|
||||
|
||||
|
||||
# Test the tokenizer handling special tokens correctly
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
|
||||
def test_llm_tokenizer_special_tokens(
|
||||
mock_decode, mock_encode, llm_instance
|
||||
):
|
||||
mock_encode.return_value = "encoded input with special tokens"
|
||||
mock_decode.return_value = "decoded output with special tokens"
|
||||
result = llm_instance.run("test task with special tokens")
|
||||
mock_encode.assert_called_once()
|
||||
mock_decode.assert_called_once()
|
||||
assert "special tokens" in result
|
||||
|
||||
|
||||
# Test for correct handling of timeouts
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_timeout_handling(mock_run, llm_instance):
|
||||
mock_run.side_effect = TimeoutError
|
||||
with pytest.raises(TimeoutError):
|
||||
llm_instance.run("test task with timeout")
|
||||
|
||||
|
||||
# Test for response time within a threshold (performance test)
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_response_time(mock_run, llm_instance):
|
||||
import time
|
||||
|
||||
mock_run.return_value = "mocked output"
|
||||
start_time = time.time()
|
||||
llm_instance.run("test task for response time")
|
||||
end_time = time.time()
|
||||
assert (
|
||||
end_time - start_time < 1
|
||||
) # Assuming the response should be faster than 1 second
|
||||
|
||||
|
||||
# Test the logging of a warning for long inputs
|
||||
@patch("swarms.models.huggingface.logging.warning")
|
||||
def test_llm_long_input_warning(mock_warning, llm_instance):
|
||||
long_input = "x" * 10000 # input longer than the typical limit
|
||||
llm_instance.run(long_input)
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
|
||||
# Test for run method behavior when model raises an exception
|
||||
@patch(
|
||||
"swarms.models.huggingface.HuggingfaceLLM._model.generate",
|
||||
side_effect=RuntimeError,
|
||||
)
|
||||
def test_llm_run_model_exception(mock_generate, llm_instance):
|
||||
with pytest.raises(RuntimeError):
|
||||
llm_instance.run("test task when model fails")
|
||||
|
||||
|
||||
# Test the behavior when GPU is forced but not available
|
||||
@patch("torch.cuda.is_available", return_value=False)
|
||||
def test_llm_force_gpu_when_unavailable(
|
||||
mock_is_available, llm_instance
|
||||
):
|
||||
with pytest.raises(EnvironmentError):
|
||||
llm_instance.set_device(
|
||||
"cuda"
|
||||
) # Attempt to set CUDA when it's not available
|
||||
|
||||
|
||||
# Test for proper cleanup after model use (releasing resources)
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM._model")
|
||||
def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance):
|
||||
llm_instance.cleanup()
|
||||
# Assuming cleanup method is meant to free resources
|
||||
mock_model.delete.assert_called_once()
|
||||
mock_tokenizer.delete.assert_called_once()
|
||||
|
||||
|
||||
# Test model's ability to handle multilingual input
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_multilingual_input(mock_run, llm_instance):
|
||||
mock_run.return_value = "mocked multilingual output"
|
||||
multilingual_input = "Bonjour, ceci est un test multilingue."
|
||||
result = llm_instance.run(multilingual_input)
|
||||
assert isinstance(
|
||||
result, str
|
||||
) # Simple check to ensure output is string type
|
||||
|
||||
|
||||
# Test caching mechanism to prevent re-running the same inputs
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_caching_mechanism(mock_run, llm_instance):
|
||||
input_text = "test caching mechanism"
|
||||
mock_run.return_value = "cached output"
|
||||
# Run the input twice
|
||||
first_run_result = llm_instance.run(input_text)
|
||||
second_run_result = llm_instance.run(input_text)
|
||||
mock_run.assert_called_once() # Should only be called once due to caching
|
||||
assert first_run_result == second_run_result
|
||||
|
||||
|
||||
# These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class.
|
||||
# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors that you need to replace with actual implementations.
|
||||
|
||||
|
||||
# Mock some functions and objects for testing
|
||||
@pytest.fixture
|
||||
def mock_huggingface_llm(monkeypatch):
|
||||
# Mock the model and tokenizer creation
|
||||
def mock_init(
|
||||
self,
|
||||
model_id,
|
||||
device="cpu",
|
||||
max_length=500,
|
||||
quantize=False,
|
||||
quantization_config=None,
|
||||
verbose=False,
|
||||
distributed=False,
|
||||
decoding=False,
|
||||
max_workers=5,
|
||||
repitition_penalty=1.3,
|
||||
no_repeat_ngram_size=5,
|
||||
temperature=0.7,
|
||||
top_k=40,
|
||||
top_p=0.8,
|
||||
):
|
||||
pass
|
||||
|
||||
# Mock the model loading
|
||||
def mock_load_model(self):
|
||||
pass
|
||||
|
||||
# Mock the model generation
|
||||
def mock_run(self, task):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(HuggingfaceLLM, "__init__", mock_init)
|
||||
monkeypatch.setattr(HuggingfaceLLM, "load_model", mock_load_model)
|
||||
monkeypatch.setattr(HuggingfaceLLM, "run", mock_run)
|
||||
|
||||
|
||||
# Basic tests for initialization and attribute settings
|
||||
def test_init_huggingface_llm():
|
||||
llm = HuggingfaceLLM(
|
||||
model_id="test_model",
|
||||
device="cuda",
|
||||
max_length=1000,
|
||||
quantize=True,
|
||||
quantization_config={"config_key": "config_value"},
|
||||
verbose=True,
|
||||
distributed=True,
|
||||
decoding=True,
|
||||
max_workers=3,
|
||||
repitition_penalty=1.5,
|
||||
no_repeat_ngram_size=4,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
top_p=0.7,
|
||||
)
|
||||
|
||||
assert llm.model_id == "test_model"
|
||||
assert llm.device == "cuda"
|
||||
assert llm.max_length == 1000
|
||||
assert llm.quantize is True
|
||||
assert llm.quantization_config == {"config_key": "config_value"}
|
||||
assert llm.verbose is True
|
||||
assert llm.distributed is True
|
||||
assert llm.decoding is True
|
||||
assert llm.max_workers == 3
|
||||
assert llm.repitition_penalty == 1.5
|
||||
assert llm.no_repeat_ngram_size == 4
|
||||
assert llm.temperature == 0.8
|
||||
assert llm.top_k == 50
|
||||
assert llm.top_p == 0.7
|
||||
|
||||
|
||||
# Test loading the model
|
||||
def test_load_model(mock_huggingface_llm):
|
||||
llm = HuggingfaceLLM(model_id="test_model")
|
||||
llm.load_model()
|
||||
|
||||
|
||||
# Test running the model
|
||||
def test_run(mock_huggingface_llm):
|
||||
llm = HuggingfaceLLM(model_id="test_model")
|
||||
llm.run("Test prompt")
|
||||
|
||||
|
||||
# Test for setting max_length
|
||||
def test_llm_set_max_length(llm_instance):
|
||||
new_max_length = 1000
|
||||
llm_instance.set_max_length(new_max_length)
|
||||
assert llm_instance.max_length == new_max_length
|
||||
|
||||
|
||||
# Test for setting verbose
|
||||
def test_llm_set_verbose(llm_instance):
|
||||
llm_instance.set_verbose(True)
|
||||
assert llm_instance.verbose is True
|
||||
|
||||
|
||||
# Test for setting distributed
|
||||
def test_llm_set_distributed(llm_instance):
|
||||
llm_instance.set_distributed(True)
|
||||
assert llm_instance.distributed is True
|
||||
|
||||
|
||||
# Test for setting decoding
|
||||
def test_llm_set_decoding(llm_instance):
|
||||
llm_instance.set_decoding(True)
|
||||
assert llm_instance.decoding is True
|
||||
|
||||
|
||||
# Test for setting max_workers
|
||||
def test_llm_set_max_workers(llm_instance):
|
||||
new_max_workers = 10
|
||||
llm_instance.set_max_workers(new_max_workers)
|
||||
assert llm_instance.max_workers == new_max_workers
|
||||
|
||||
|
||||
# Test for setting repitition_penalty
|
||||
def test_llm_set_repitition_penalty(llm_instance):
|
||||
new_repitition_penalty = 1.5
|
||||
llm_instance.set_repitition_penalty(new_repitition_penalty)
|
||||
assert llm_instance.repitition_penalty == new_repitition_penalty
|
||||
|
||||
|
||||
# Test for setting no_repeat_ngram_size
|
||||
def test_llm_set_no_repeat_ngram_size(llm_instance):
|
||||
new_no_repeat_ngram_size = 6
|
||||
llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size)
|
||||
assert (
|
||||
llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size
|
||||
)
|
||||
|
||||
|
||||
# Test for setting temperature
|
||||
def test_llm_set_temperature(llm_instance):
|
||||
new_temperature = 0.8
|
||||
llm_instance.set_temperature(new_temperature)
|
||||
assert llm_instance.temperature == new_temperature
|
||||
|
||||
|
||||
# Test for setting top_k
|
||||
def test_llm_set_top_k(llm_instance):
|
||||
new_top_k = 50
|
||||
llm_instance.set_top_k(new_top_k)
|
||||
assert llm_instance.top_k == new_top_k
|
||||
|
||||
|
||||
# Test for setting top_p
|
||||
def test_llm_set_top_p(llm_instance):
|
||||
new_top_p = 0.9
|
||||
llm_instance.set_top_p(new_top_p)
|
||||
assert llm_instance.top_p == new_top_p
|
||||
|
||||
|
||||
# Test for setting quantize
|
||||
def test_llm_set_quantize(llm_instance):
|
||||
llm_instance.set_quantize(True)
|
||||
assert llm_instance.quantize is True
|
||||
|
||||
|
||||
# Test for setting quantization_config
|
||||
def test_llm_set_quantization_config(llm_instance):
|
||||
new_quantization_config = {
|
||||
"load_in_4bit": False,
|
||||
"bnb_4bit_use_double_quant": False,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.bfloat16,
|
||||
}
|
||||
llm_instance.set_quantization_config(new_quantization_config)
|
||||
assert llm_instance.quantization_config == new_quantization_config
|
||||
|
||||
|
||||
# Test for setting model_id
|
||||
def test_llm_set_model_id(llm_instance):
|
||||
new_model_id = "EleutherAI/gpt-neo-2.7B"
|
||||
llm_instance.set_model_id(new_model_id)
|
||||
assert llm_instance.model_id == new_model_id
|
||||
|
||||
|
||||
# Test for setting model
|
||||
@patch(
|
||||
"swarms.models.huggingface.AutoModelForCausalLM.from_pretrained"
|
||||
)
|
||||
def test_llm_set_model(mock_model, llm_instance):
|
||||
mock_model.return_value = "mocked model"
|
||||
llm_instance.set_model(mock_model)
|
||||
assert llm_instance.model == "mocked model"
|
||||
|
||||
|
||||
# Test for setting tokenizer
|
||||
@patch("swarms.models.huggingface.AutoTokenizer.from_pretrained")
|
||||
def test_llm_set_tokenizer(mock_tokenizer, llm_instance):
|
||||
mock_tokenizer.return_value = "mocked tokenizer"
|
||||
llm_instance.set_tokenizer(mock_tokenizer)
|
||||
assert llm_instance.tokenizer == "mocked tokenizer"
|
||||
|
||||
|
||||
# Test for setting logger
|
||||
def test_llm_set_logger(llm_instance):
|
||||
new_logger = logging.getLogger("test_logger")
|
||||
llm_instance.set_logger(new_logger)
|
||||
assert llm_instance.logger == new_logger
|
||||
|
||||
|
||||
# Test for saving model
|
||||
@patch("torch.save")
|
||||
def test_llm_save_model(mock_save, llm_instance):
|
||||
llm_instance.save_model("path/to/save")
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
# Test for print_dashboard
|
||||
@patch("builtins.print")
|
||||
def test_llm_print_dashboard(mock_print, llm_instance):
|
||||
llm_instance.print_dashboard("test task")
|
||||
mock_print.assert_called()
|
||||
|
||||
|
||||
# Test for __call__ method
|
||||
@patch("swarms.models.huggingface.HuggingfaceLLM.run")
|
||||
def test_llm_call(mock_run, llm_instance):
|
||||
mock_run.return_value = "mocked output"
|
||||
result = llm_instance("test task")
|
||||
assert result == "mocked output"
|
@ -1,56 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarm_models.huggingface_pipeline import HuggingfacePipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
with patch("swarms.models.huggingface_pipeline.pipeline") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(mock_pipeline):
|
||||
return HuggingfacePipeline(
|
||||
"text-generation", "meta-llama/Llama-2-13b-chat-hf"
|
||||
)
|
||||
|
||||
|
||||
def test_init(pipeline, mock_pipeline):
|
||||
assert pipeline.task_type == "text-generation"
|
||||
assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf"
|
||||
assert (
|
||||
pipeline.use_fp8 is True
|
||||
if torch.cuda.is_available()
|
||||
else False
|
||||
)
|
||||
mock_pipeline.assert_called_once_with(
|
||||
"text-generation",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
use_fp8=pipeline.use_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_run(pipeline, mock_pipeline):
|
||||
mock_pipeline.return_value = "Generated text"
|
||||
result = pipeline.run("Hello, world!")
|
||||
assert result == "Generated text"
|
||||
mock_pipeline.assert_called_once_with("Hello, world!")
|
||||
|
||||
|
||||
def test_run_with_exception(pipeline, mock_pipeline):
|
||||
mock_pipeline.side_effect = Exception("Test exception")
|
||||
with pytest.raises(Exception):
|
||||
pipeline.run("Hello, world!")
|
||||
|
||||
|
||||
def test_run_with_different_task(pipeline, mock_pipeline):
|
||||
mock_pipeline.return_value = "Generated text"
|
||||
result = pipeline.run("text-classification", "Hello, world!")
|
||||
assert result == "Generated text"
|
||||
mock_pipeline.assert_called_once_with(
|
||||
"text-classification", "Hello, world!"
|
||||
)
|
@ -1,207 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarm_models.idefics import (
|
||||
AutoProcessor,
|
||||
Idefics,
|
||||
IdeficsForVisionText2Text,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def idefics_instance():
|
||||
with patch(
|
||||
"torch.cuda.is_available", return_value=False
|
||||
): # Assuming tests are run on CPU for simplicity
|
||||
instance = Idefics()
|
||||
return instance
|
||||
|
||||
|
||||
# Basic Tests
|
||||
def test_init_default(idefics_instance):
|
||||
assert idefics_instance.device == "cpu"
|
||||
assert idefics_instance.max_length == 100
|
||||
assert not idefics_instance.chat_history
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device,expected",
|
||||
[
|
||||
(None, "cpu"),
|
||||
("cuda", "cuda"),
|
||||
("cpu", "cpu"),
|
||||
],
|
||||
)
|
||||
def test_init_device(device, expected):
|
||||
with patch(
|
||||
"torch.cuda.is_available",
|
||||
return_value=True if expected == "cuda" else False,
|
||||
):
|
||||
instance = Idefics(device=device)
|
||||
assert instance.device == expected
|
||||
|
||||
|
||||
# Test `run` method
|
||||
def test_run(idefics_instance):
|
||||
prompts = [["User: Test"]]
|
||||
with patch.object(
|
||||
idefics_instance, "processor"
|
||||
) as mock_processor, patch.object(
|
||||
idefics_instance, "model"
|
||||
) as mock_model:
|
||||
mock_processor.return_value = {
|
||||
"input_ids": torch.tensor([1, 2, 3])
|
||||
}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test"]
|
||||
|
||||
result = idefics_instance.run(prompts)
|
||||
|
||||
assert result == ["Test"]
|
||||
|
||||
|
||||
# Test `__call__` method (using the same logic as run for simplicity)
|
||||
def test_call(idefics_instance):
|
||||
prompts = [["User: Test"]]
|
||||
with patch.object(
|
||||
idefics_instance, "processor"
|
||||
) as mock_processor, patch.object(
|
||||
idefics_instance, "model"
|
||||
) as mock_model:
|
||||
mock_processor.return_value = {
|
||||
"input_ids": torch.tensor([1, 2, 3])
|
||||
}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test"]
|
||||
|
||||
result = idefics_instance(prompts)
|
||||
|
||||
assert result == ["Test"]
|
||||
|
||||
|
||||
# Test `chat` method
|
||||
def test_chat(idefics_instance):
|
||||
user_input = "User: Hello"
|
||||
response = "Model: Hi there!"
|
||||
with patch.object(
|
||||
idefics_instance, "run", return_value=[response]
|
||||
):
|
||||
result = idefics_instance.chat(user_input)
|
||||
|
||||
assert result == response
|
||||
assert idefics_instance.chat_history == [user_input, response]
|
||||
|
||||
|
||||
# Test `set_checkpoint` method
|
||||
def test_set_checkpoint(idefics_instance):
|
||||
new_checkpoint = "new_checkpoint"
|
||||
with patch.object(
|
||||
IdeficsForVisionText2Text, "from_pretrained"
|
||||
) as mock_from_pretrained, patch.object(
|
||||
AutoProcessor, "from_pretrained"
|
||||
):
|
||||
idefics_instance.set_checkpoint(new_checkpoint)
|
||||
|
||||
mock_from_pretrained.assert_called_with(
|
||||
new_checkpoint, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
# Test `set_device` method
|
||||
def test_set_device(idefics_instance):
|
||||
new_device = "cuda"
|
||||
with patch.object(idefics_instance.model, "to"):
|
||||
idefics_instance.set_device(new_device)
|
||||
|
||||
assert idefics_instance.device == new_device
|
||||
|
||||
|
||||
# Test `set_max_length` method
|
||||
def test_set_max_length(idefics_instance):
|
||||
new_length = 150
|
||||
idefics_instance.set_max_length(new_length)
|
||||
assert idefics_instance.max_length == new_length
|
||||
|
||||
|
||||
# Test `clear_chat_history` method
|
||||
def test_clear_chat_history(idefics_instance):
|
||||
idefics_instance.chat_history = ["User: Test", "Model: Response"]
|
||||
idefics_instance.clear_chat_history()
|
||||
assert not idefics_instance.chat_history
|
||||
|
||||
|
||||
# Exception Tests
|
||||
def test_run_with_empty_prompts(idefics_instance):
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # Replace Exception with the actual exception that may arise for an empty prompt.
|
||||
idefics_instance.run([])
|
||||
|
||||
|
||||
# Test `run` method with batched_mode set to False
|
||||
def test_run_batched_mode_false(idefics_instance):
|
||||
task = "User: Test"
|
||||
with patch.object(
|
||||
idefics_instance, "processor"
|
||||
) as mock_processor, patch.object(
|
||||
idefics_instance, "model"
|
||||
) as mock_model:
|
||||
mock_processor.return_value = {
|
||||
"input_ids": torch.tensor([1, 2, 3])
|
||||
}
|
||||
mock_model.generate.return_value = torch.tensor([1, 2, 3])
|
||||
mock_processor.batch_decode.return_value = ["Test"]
|
||||
|
||||
idefics_instance.batched_mode = False
|
||||
result = idefics_instance.run(task)
|
||||
|
||||
assert result == ["Test"]
|
||||
|
||||
|
||||
# Test `run` method with an exception
|
||||
def test_run_with_exception(idefics_instance):
|
||||
task = "User: Test"
|
||||
with patch.object(
|
||||
idefics_instance, "processor"
|
||||
) as mock_processor:
|
||||
mock_processor.side_effect = Exception("Test exception")
|
||||
with pytest.raises(Exception):
|
||||
idefics_instance.run(task)
|
||||
|
||||
|
||||
# Test `set_model_name` method
|
||||
def test_set_model_name(idefics_instance):
|
||||
new_model_name = "new_model_name"
|
||||
with patch.object(
|
||||
IdeficsForVisionText2Text, "from_pretrained"
|
||||
) as mock_from_pretrained, patch.object(
|
||||
AutoProcessor, "from_pretrained"
|
||||
):
|
||||
idefics_instance.set_model_name(new_model_name)
|
||||
|
||||
assert idefics_instance.model_name == new_model_name
|
||||
mock_from_pretrained.assert_called_with(
|
||||
new_model_name, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
# Test `__init__` method with device set to None
|
||||
def test_init_device_none():
|
||||
with patch(
|
||||
"torch.cuda.is_available",
|
||||
return_value=False,
|
||||
):
|
||||
instance = Idefics(device=None)
|
||||
assert instance.device == "cpu"
|
||||
|
||||
|
||||
# Test `__init__` method with device set to "cuda"
|
||||
def test_init_device_cuda():
|
||||
with patch(
|
||||
"torch.cuda.is_available",
|
||||
return_value=True,
|
||||
):
|
||||
instance = Idefics(device="cuda")
|
||||
assert instance.device == "cuda"
|
@ -1,26 +0,0 @@
|
||||
from swarm_models import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Anthropic",
|
||||
"Petals",
|
||||
"Mistral",
|
||||
"OpenAI",
|
||||
"AzureOpenAI",
|
||||
"OpenAIChat",
|
||||
"Zephyr",
|
||||
"Idefics",
|
||||
# "Kosmos",
|
||||
"Vilt",
|
||||
"Nougat",
|
||||
"LayoutLMDocumentQA",
|
||||
"BioGPT",
|
||||
"HuggingfaceLLM",
|
||||
"MPT7B",
|
||||
"WizardLLMStoryTeller",
|
||||
# "GPT4Vision",
|
||||
# "Dalle3",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
@ -1,181 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# This will be your project directory
|
||||
from swarm_models.kosmos_two import Kosmos, is_overlapping
|
||||
|
||||
# A placeholder image URL for testing
|
||||
TEST_IMAGE_URL = "https://images.unsplash.com/photo-1673267569891-ca4246caafd7?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDM1fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
|
||||
|
||||
|
||||
# Mock the response for the test image
|
||||
@pytest.fixture
|
||||
def mock_image_request():
|
||||
img_data = open(TEST_IMAGE_URL, "rb").read()
|
||||
mock_resp = Mock()
|
||||
mock_resp.raw = img_data
|
||||
with patch.object(
|
||||
requests, "get", return_value=mock_resp
|
||||
) as _fixture:
|
||||
yield _fixture
|
||||
|
||||
|
||||
# Test utility function
|
||||
def test_is_overlapping():
|
||||
assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) is True
|
||||
assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) is False
|
||||
assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) is False
|
||||
assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) is True
|
||||
|
||||
|
||||
# Test model initialization
|
||||
def test_kosmos_init():
|
||||
kosmos = Kosmos()
|
||||
assert kosmos.model is not None
|
||||
assert kosmos.processor is not None
|
||||
|
||||
|
||||
# Test image fetching functionality
|
||||
def test_get_image(mock_image_request):
|
||||
kosmos = Kosmos()
|
||||
image = kosmos.get_image(TEST_IMAGE_URL)
|
||||
assert image is not None
|
||||
|
||||
|
||||
# Test multimodal grounding
|
||||
def test_multimodal_grounding(mock_image_request):
|
||||
kosmos = Kosmos()
|
||||
kosmos.multimodal_grounding(
|
||||
"Find the red apple in the image.", TEST_IMAGE_URL
|
||||
)
|
||||
# TODO: Validate the result if possible
|
||||
|
||||
|
||||
# Test referring expression comprehension
|
||||
def test_referring_expression_comprehension(mock_image_request):
|
||||
kosmos = Kosmos()
|
||||
kosmos.referring_expression_comprehension(
|
||||
"Show me the green bottle.", TEST_IMAGE_URL
|
||||
)
|
||||
# TODO: Validate the result if possible
|
||||
|
||||
|
||||
# ... (continue with other functions in the same manner) ...
|
||||
|
||||
|
||||
# Test error scenarios - Example
|
||||
@pytest.mark.parametrize(
|
||||
"phrase, image_url",
|
||||
[
|
||||
(None, TEST_IMAGE_URL),
|
||||
("Find the red apple in the image.", None),
|
||||
("", TEST_IMAGE_URL),
|
||||
("Find the red apple in the image.", ""),
|
||||
],
|
||||
)
|
||||
def test_kosmos_error_scenarios(phrase, image_url):
|
||||
kosmos = Kosmos()
|
||||
with pytest.raises(Exception):
|
||||
kosmos.multimodal_grounding(phrase, image_url)
|
||||
|
||||
|
||||
# ... (Add more tests for different edge cases and functionalities) ...
|
||||
|
||||
# Sample test image URLs
|
||||
IMG_URL1 = "https://images.unsplash.com/photo-1696341439368-2c84b6c963bc?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDMzfEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
|
||||
IMG_URL2 = "https://images.unsplash.com/photo-1689934902235-055707b4f8e9?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDYzfEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
|
||||
IMG_URL3 = "https://images.unsplash.com/photo-1696900004042-60bcc200aca0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDY2fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
|
||||
IMG_URL4 = "https://images.unsplash.com/photo-1676156340083-fd49e4e53a21?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDc4fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
|
||||
IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=format&fit=crop&q=80&w=1287&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
|
||||
|
||||
# Mock response for requests.get()
|
||||
class MockResponse:
|
||||
@staticmethod
|
||||
def json():
|
||||
return {}
|
||||
|
||||
@property
|
||||
def raw(self):
|
||||
return open("tests/sample_image.jpg", "rb")
|
||||
|
||||
|
||||
# Test the Kosmos class
|
||||
@pytest.fixture
|
||||
def kosmos():
|
||||
return Kosmos()
|
||||
|
||||
|
||||
# Mocking the requests.get() method
|
||||
@pytest.fixture
|
||||
def mock_request_get(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
requests, "get", lambda url, **kwargs: MockResponse()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_multimodal_grounding(kosmos):
|
||||
kosmos.multimodal_grounding(
|
||||
"Find the red apple in the image.", IMG_URL1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_referring_expression_comprehension(kosmos):
|
||||
kosmos.referring_expression_comprehension(
|
||||
"Show me the green bottle.", IMG_URL2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_referring_expression_generation(kosmos):
|
||||
kosmos.referring_expression_generation(
|
||||
"It is on the table.", IMG_URL3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_vqa(kosmos):
|
||||
kosmos.grounded_vqa("What is the color of the car?", IMG_URL4)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_image_captioning(kosmos):
|
||||
kosmos.grounded_image_captioning(IMG_URL5)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_image_captioning_detailed(kosmos):
|
||||
kosmos.grounded_image_captioning_detailed(IMG_URL1)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_multimodal_grounding_2(kosmos):
|
||||
kosmos.multimodal_grounding(
|
||||
"Find the yellow fruit in the image.", IMG_URL2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_referring_expression_comprehension_2(kosmos):
|
||||
kosmos.referring_expression_comprehension(
|
||||
"Where is the water bottle?", IMG_URL3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_vqa_2(kosmos):
|
||||
kosmos.grounded_vqa("How many cars are in the image?", IMG_URL4)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_image_captioning_2(kosmos):
|
||||
kosmos.grounded_image_captioning(IMG_URL2)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_request_get")
|
||||
def test_grounded_image_captioning_detailed_2(kosmos):
|
||||
kosmos.grounded_image_captioning_detailed(IMG_URL3)
|
@ -1,221 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import NougatProcessor, VisionEncoderDecoderModel
|
||||
|
||||
from swarm_models.nougat import Nougat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_nougat():
|
||||
return Nougat()
|
||||
|
||||
|
||||
def test_nougat_default_initialization(setup_nougat):
|
||||
assert setup_nougat.model_name_or_path == "facebook/nougat-base"
|
||||
assert setup_nougat.min_length == 1
|
||||
assert setup_nougat.max_new_tokens == 30
|
||||
|
||||
|
||||
def test_nougat_custom_initialization():
|
||||
nougat = Nougat(
|
||||
model_name_or_path="custom_path",
|
||||
min_length=10,
|
||||
max_new_tokens=50,
|
||||
)
|
||||
assert nougat.model_name_or_path == "custom_path"
|
||||
assert nougat.min_length == 10
|
||||
assert nougat.max_new_tokens == 50
|
||||
|
||||
|
||||
def test_processor_initialization(setup_nougat):
|
||||
assert isinstance(setup_nougat.processor, NougatProcessor)
|
||||
|
||||
|
||||
def test_model_initialization(setup_nougat):
|
||||
assert isinstance(setup_nougat.model, VisionEncoderDecoderModel)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_available, expected_device",
|
||||
[(True, "cuda"), (False, "cpu")],
|
||||
)
|
||||
def test_device_initialization(
|
||||
cuda_available, expected_device, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
torch,
|
||||
"cuda",
|
||||
Mock(is_available=Mock(return_value=cuda_available)),
|
||||
)
|
||||
nougat = Nougat()
|
||||
assert nougat.device == expected_device
|
||||
|
||||
|
||||
def test_get_image_valid_path(setup_nougat):
|
||||
with patch("PIL.Image.open") as mock_open:
|
||||
mock_open.return_value = Mock(spec=Image.Image)
|
||||
assert setup_nougat.get_image("valid_path") is not None
|
||||
|
||||
|
||||
def test_get_image_invalid_path(setup_nougat):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
setup_nougat.get_image("invalid_path")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"min_len, max_tokens",
|
||||
[
|
||||
(1, 30),
|
||||
(5, 40),
|
||||
(10, 50),
|
||||
],
|
||||
)
|
||||
def test_model_call_with_diff_params(
|
||||
setup_nougat, min_len, max_tokens
|
||||
):
|
||||
setup_nougat.min_length = min_len
|
||||
setup_nougat.max_new_tokens = max_tokens
|
||||
|
||||
with patch("PIL.Image.open") as mock_open:
|
||||
mock_open.return_value = Mock(spec=Image.Image)
|
||||
# Here, mocking other required methods or adding more complex logic would be necessary.
|
||||
result = setup_nougat("valid_path")
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_model_call_invalid_image_path(setup_nougat):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
setup_nougat("invalid_path")
|
||||
|
||||
|
||||
def test_model_call_mocked_output(setup_nougat):
|
||||
with patch("PIL.Image.open") as mock_open:
|
||||
mock_open.return_value = Mock(spec=Image.Image)
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate.return_value = "mocked_output"
|
||||
setup_nougat.model = mock_model
|
||||
|
||||
result = setup_nougat("valid_path")
|
||||
assert result == "mocked_output"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_processor_and_model():
|
||||
"""Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior."""
|
||||
with patch(
|
||||
"transformers.NougatProcessor.from_pretrained",
|
||||
return_value=Mock(),
|
||||
), patch(
|
||||
"transformers.VisionEncoderDecoderModel.from_pretrained",
|
||||
return_value=Mock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_with_sample_image_1(setup_nougat):
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_with_sample_image_2(setup_nougat):
|
||||
result = setup_nougat(os.path.join("sample_images", "test2.png"))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_min_length_param(setup_nougat):
|
||||
setup_nougat.min_length = 10
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_max_new_tokens_param(setup_nougat):
|
||||
setup_nougat.max_new_tokens = 50
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_different_model_path(setup_nougat):
|
||||
setup_nougat.model_name_or_path = "different/path"
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_bad_image_path(setup_nougat):
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # Adjust the exception type accordingly.
|
||||
setup_nougat("bad_image_path.png")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_image_large_size(setup_nougat):
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_image_small_size(setup_nougat):
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_image_varied_content(setup_nougat):
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_processor_and_model")
|
||||
def test_nougat_image_with_metadata(setup_nougat):
|
||||
result = setup_nougat(
|
||||
os.path.join(
|
||||
"sample_images",
|
||||
"https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D",
|
||||
)
|
||||
)
|
||||
assert isinstance(result, str)
|
@ -1,60 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarm_models.open_dalle import OpenDalle
|
||||
|
||||
|
||||
def test_init():
|
||||
od = OpenDalle()
|
||||
assert isinstance(od, OpenDalle)
|
||||
|
||||
|
||||
def test_init_custom_model():
|
||||
od = OpenDalle(model_name="custom_model")
|
||||
assert od.pipeline.model_name == "custom_model"
|
||||
|
||||
|
||||
def test_init_custom_dtype():
|
||||
od = OpenDalle(torch_dtype=torch.float32)
|
||||
assert od.pipeline.torch_dtype == torch.float32
|
||||
|
||||
|
||||
def test_init_custom_device():
|
||||
od = OpenDalle(device="cpu")
|
||||
assert od.pipeline.device == "cpu"
|
||||
|
||||
|
||||
def test_run():
|
||||
od = OpenDalle()
|
||||
result = od.run("A picture of a cat")
|
||||
assert isinstance(result, torch.Tensor)
|
||||
|
||||
|
||||
def test_run_no_task():
|
||||
od = OpenDalle()
|
||||
with pytest.raises(ValueError, match="Task cannot be None"):
|
||||
od.run(None)
|
||||
|
||||
|
||||
def test_run_non_string_task():
|
||||
od = OpenDalle()
|
||||
with pytest.raises(TypeError, match="Task must be a string"):
|
||||
od.run(123)
|
||||
|
||||
|
||||
def test_run_empty_task():
|
||||
od = OpenDalle()
|
||||
with pytest.raises(ValueError, match="Task cannot be empty"):
|
||||
od.run("")
|
||||
|
||||
|
||||
def test_run_custom_args():
|
||||
od = OpenDalle()
|
||||
result = od.run("A picture of a cat", custom_arg="custom_value")
|
||||
assert isinstance(result, torch.Tensor)
|
||||
|
||||
|
||||
def test_run_error():
|
||||
od = OpenDalle()
|
||||
with pytest.raises(Exception):
|
||||
od.run("A picture of a cat", raise_error=True)
|
@ -1,107 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm_models.openai_tts import OpenAITTS
|
||||
|
||||
|
||||
def test_openaitts_initialization():
|
||||
tts = OpenAITTS()
|
||||
assert isinstance(tts, OpenAITTS)
|
||||
assert tts.model_name == "tts-1-1106"
|
||||
assert tts.proxy_url == "https://api.openai.com/v1/audio/speech"
|
||||
assert tts.voice == "onyx"
|
||||
assert tts.chunk_size == 1024 * 1024
|
||||
|
||||
|
||||
def test_openaitts_initialization_custom_parameters():
|
||||
tts = OpenAITTS(
|
||||
"custom_model",
|
||||
"custom_url",
|
||||
"custom_key",
|
||||
"custom_voice",
|
||||
2048,
|
||||
)
|
||||
assert tts.model_name == "custom_model"
|
||||
assert tts.proxy_url == "custom_url"
|
||||
assert tts.openai_api_key == "custom_key"
|
||||
assert tts.voice == "custom_voice"
|
||||
assert tts.chunk_size == 2048
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS()
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={
|
||||
"model": "tts-1-1106",
|
||||
"input": "Hello world",
|
||||
"voice": "onyx",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_empty_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run("")
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_very_long_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run("A" * 10000)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_invalid_task(mock_post):
|
||||
tts = OpenAITTS()
|
||||
with pytest.raises(Exception):
|
||||
tts.run(None)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_custom_model(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS("custom_model")
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={
|
||||
"model": "custom_model",
|
||||
"input": "Hello world",
|
||||
"voice": "onyx",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_run_custom_voice(mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_content.return_value = [b"chunk1", b"chunk2"]
|
||||
mock_post.return_value = mock_response
|
||||
tts = OpenAITTS(voice="custom_voice")
|
||||
audio = tts.run("Hello world")
|
||||
assert audio == b"chunk1chunk2"
|
||||
mock_post.assert_called_once_with(
|
||||
"https://api.openai.com/v1/audio/speech",
|
||||
headers={"Authorization": f"Bearer {tts.openai_api_key}"},
|
||||
json={
|
||||
"model": "tts-1-1106",
|
||||
"input": "Hello world",
|
||||
"voice": "custom_voice",
|
||||
},
|
||||
)
|
@ -1,61 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from swarm_models.qwen import QwenVLMultiModal
|
||||
|
||||
|
||||
def test_post_init():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoTokenizer.from_pretrained"
|
||||
) as mock_tokenizer, patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.from_pretrained"
|
||||
) as mock_model:
|
||||
mock_tokenizer.return_value = Mock()
|
||||
mock_model.return_value = Mock()
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
model.model_name, trust_remote_code=True
|
||||
)
|
||||
mock_model.assert_called_once_with(
|
||||
model.model_name,
|
||||
device_map=model.device,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
def test_run():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoTokenizer.from_list_format"
|
||||
) as mock_format, patch(
|
||||
"swarms.models.qwen.AutoTokenizer.__call__"
|
||||
) as mock_call, patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.generate"
|
||||
) as mock_generate, patch(
|
||||
"swarms.models.qwen.AutoTokenizer.decode"
|
||||
) as mock_decode:
|
||||
mock_format.return_value = Mock()
|
||||
mock_call.return_value = Mock()
|
||||
mock_generate.return_value = Mock()
|
||||
mock_decode.return_value = "response"
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
response = model.run(
|
||||
"Hello, how are you?", "https://example.com/image.jpg"
|
||||
)
|
||||
|
||||
assert response == "response"
|
||||
|
||||
|
||||
def test_chat():
|
||||
with patch(
|
||||
"swarms.models.qwen.AutoModelForCausalLM.chat"
|
||||
) as mock_chat:
|
||||
mock_chat.return_value = ("response", ["history"])
|
||||
|
||||
model = QwenVLMultiModal()
|
||||
response, history = model.chat(
|
||||
"Hello, how are you?", "https://example.com/image.jpg"
|
||||
)
|
||||
|
||||
assert response == "response"
|
||||
assert history == ["history"]
|
@ -1,165 +0,0 @@
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from swarm_models.ssd_1b import SSD1B
|
||||
|
||||
|
||||
# Create fixtures if needed
|
||||
@pytest.fixture
|
||||
def ssd1b_model():
|
||||
return SSD1B()
|
||||
|
||||
|
||||
# Basic tests for model initialization and method call
|
||||
def test_ssd1b_model_initialization(ssd1b_model):
|
||||
assert ssd1b_model is not None
|
||||
|
||||
|
||||
def test_ssd1b_call(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
neg_prompt = "ugly, blurry, poor quality"
|
||||
image_url = ssd1b_model(task, neg_prompt)
|
||||
assert isinstance(image_url, str)
|
||||
assert image_url.startswith(
|
||||
"https://"
|
||||
) # Assuming it starts with "https://"
|
||||
|
||||
|
||||
# Add more tests for various aspects of the class and methods
|
||||
|
||||
|
||||
# Example of a parameterized test for different tasks
|
||||
@pytest.mark.parametrize(
|
||||
"task", ["A painting of a cat", "A painting of a tree"]
|
||||
)
|
||||
def test_ssd1b_parameterized_task(ssd1b_model, task):
|
||||
image_url = ssd1b_model(task)
|
||||
assert isinstance(image_url, str)
|
||||
assert image_url.startswith(
|
||||
"https://"
|
||||
) # Assuming it starts with "https://"
|
||||
|
||||
|
||||
# Example of a test using mocks to isolate units of code
|
||||
def test_ssd1b_with_mock(ssd1b_model, mocker):
|
||||
mocker.patch(
|
||||
"your_module.StableDiffusionXLPipeline"
|
||||
) # Mock the pipeline
|
||||
task = "A painting of a cat"
|
||||
image_url = ssd1b_model(task)
|
||||
assert isinstance(image_url, str)
|
||||
assert image_url.startswith(
|
||||
"https://"
|
||||
) # Assuming it starts with "https://"
|
||||
|
||||
|
||||
def test_ssd1b_call_with_cache(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
neg_prompt = "ugly, blurry, poor quality"
|
||||
image_url1 = ssd1b_model(task, neg_prompt)
|
||||
image_url2 = ssd1b_model(task, neg_prompt) # Should use cache
|
||||
assert image_url1 == image_url2
|
||||
|
||||
|
||||
def test_ssd1b_invalid_task(ssd1b_model):
|
||||
invalid_task = ""
|
||||
with pytest.raises(ValueError):
|
||||
ssd1b_model(invalid_task)
|
||||
|
||||
|
||||
def test_ssd1b_failed_api_call(ssd1b_model, mocker):
|
||||
mocker.patch(
|
||||
"your_module.StableDiffusionXLPipeline"
|
||||
) # Mock the pipeline to raise an exception
|
||||
task = "A painting of a cat"
|
||||
with pytest.raises(Exception):
|
||||
ssd1b_model(task)
|
||||
|
||||
|
||||
def test_ssd1b_process_batch_concurrently(ssd1b_model):
|
||||
tasks = [
|
||||
"A painting of a dog",
|
||||
"A beautiful sunset",
|
||||
"A portrait of a person",
|
||||
]
|
||||
results = ssd1b_model.process_batch_concurrently(tasks)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(tasks)
|
||||
|
||||
|
||||
def test_ssd1b_process_empty_batch_concurrently(ssd1b_model):
|
||||
tasks = []
|
||||
results = ssd1b_model.process_batch_concurrently(tasks)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_ssd1b_download_image(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
neg_prompt = "ugly, blurry, poor quality"
|
||||
image_url = ssd1b_model(task, neg_prompt)
|
||||
img = ssd1b_model._download_image(image_url)
|
||||
assert isinstance(img, Image.Image)
|
||||
|
||||
|
||||
def test_ssd1b_generate_uuid(ssd1b_model):
|
||||
uuid_str = ssd1b_model._generate_uuid()
|
||||
assert isinstance(uuid_str, str)
|
||||
assert len(uuid_str) == 36 # UUID format
|
||||
|
||||
|
||||
def test_ssd1b_rate_limited_call(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
image_url = ssd1b_model.rate_limited_call(task)
|
||||
assert isinstance(image_url, str)
|
||||
assert image_url.startswith("https://")
|
||||
|
||||
|
||||
# Test cases for additional scenarios and behaviors
|
||||
def test_ssd1b_dashboard_printing(ssd1b_model, capsys):
|
||||
ssd1b_model.dashboard = True
|
||||
ssd1b_model.print_dashboard()
|
||||
captured = capsys.readouterr()
|
||||
assert "SSD1B Dashboard:" in captured.out
|
||||
|
||||
|
||||
def test_ssd1b_generate_image_name(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
img_name = ssd1b_model._generate_image_name(task)
|
||||
assert isinstance(img_name, str)
|
||||
assert len(img_name) > 0
|
||||
|
||||
|
||||
def test_ssd1b_set_width_height(ssd1b_model, mocker):
|
||||
img = mocker.MagicMock()
|
||||
width, height = 800, 600
|
||||
result = ssd1b_model.set_width_height(img, width, height)
|
||||
assert result == img.resize.return_value
|
||||
|
||||
|
||||
def test_ssd1b_read_img(ssd1b_model, mocker):
|
||||
img = mocker.MagicMock()
|
||||
result = ssd1b_model.read_img(img)
|
||||
assert result == img.open.return_value
|
||||
|
||||
|
||||
def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker):
|
||||
img = mocker.MagicMock()
|
||||
img_format = "PNG"
|
||||
result = ssd1b_model.convert_to_bytesio(img, img_format)
|
||||
assert isinstance(result, bytes)
|
||||
|
||||
|
||||
def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path):
|
||||
img = mocker.MagicMock()
|
||||
img_name = "test.png"
|
||||
save_path = tmp_path / img_name
|
||||
ssd1b_model._download_image(img, img_name, save_path)
|
||||
assert save_path.exists()
|
||||
|
||||
|
||||
def test_ssd1b_repr_str(ssd1b_model):
|
||||
task = "A painting of a dog"
|
||||
image_url = ssd1b_model(task)
|
||||
assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
||||
assert str(ssd1b_model) == f"SSD1B(image_url={image_url})"
|
@ -1,90 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from swarm_models import TimmModel
|
||||
|
||||
|
||||
def test_timm_model_init():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
mock_list_models.assert_called_once()
|
||||
assert timm_model.model_name == model_name
|
||||
assert timm_model.pretrained == pretrained
|
||||
assert timm_model.in_chans == in_chans
|
||||
assert timm_model.models == mock_list_models.return_value
|
||||
|
||||
|
||||
def test_timm_model_call():
|
||||
with patch(
|
||||
"swarms.models.timm.create_model"
|
||||
) as mock_create_model:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
task = torch.rand(1, in_chans, 224, 224)
|
||||
result = timm_model(task)
|
||||
mock_create_model.assert_called_once_with(
|
||||
model_name, pretrained=pretrained, in_chans=in_chans
|
||||
)
|
||||
assert result == mock_create_model.return_value(task)
|
||||
|
||||
|
||||
def test_timm_model_list_models():
|
||||
with patch("swarms.models.timm.list_models") as mock_list_models:
|
||||
model_name = "resnet18"
|
||||
pretrained = True
|
||||
in_chans = 3
|
||||
timm_model = TimmModel(model_name, pretrained, in_chans)
|
||||
result = timm_model.list_models()
|
||||
mock_list_models.assert_called_once()
|
||||
assert result == mock_list_models.return_value
|
||||
|
||||
|
||||
def test_get_supported_models():
|
||||
model_handler = TimmModel()
|
||||
supported_models = model_handler._get_supported_models()
|
||||
assert isinstance(supported_models, list)
|
||||
assert len(supported_models) > 0
|
||||
|
||||
|
||||
def test_create_model(sample_model_info):
|
||||
model_handler = TimmModel()
|
||||
model = model_handler._create_model(sample_model_info)
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
|
||||
|
||||
def test_call(sample_model_info):
|
||||
model_handler = TimmModel()
|
||||
input_tensor = torch.randn(1, 3, 224, 224)
|
||||
output_shape = model_handler.__call__(
|
||||
sample_model_info, input_tensor
|
||||
)
|
||||
assert isinstance(output_shape, torch.Size)
|
||||
|
||||
|
||||
def test_get_supported_models_mock():
|
||||
model_handler = TimmModel()
|
||||
model_handler._get_supported_models = Mock(
|
||||
return_value=["resnet18", "resnet50"]
|
||||
)
|
||||
supported_models = model_handler._get_supported_models()
|
||||
assert supported_models == ["resnet18", "resnet50"]
|
||||
|
||||
|
||||
def test_create_model_mock(sample_model_info):
|
||||
model_handler = TimmModel()
|
||||
model_handler._create_model = Mock(return_value=torch.nn.Module())
|
||||
model = model_handler._create_model(sample_model_info)
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
|
||||
|
||||
def test_coverage_report():
|
||||
# Install pytest-cov
|
||||
# Run tests with coverage report
|
||||
pytest.main(["--cov=my_module", "--cov-report=html"])
|
@ -1,145 +0,0 @@
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from swarm_models.together import TogetherLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key(monkeypatch):
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||
|
||||
|
||||
def test_init_defaults():
|
||||
model = TogetherLLM()
|
||||
assert model.together_api_key == "mocked-api-key"
|
||||
assert model.logging_enabled is False
|
||||
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
assert model.max_workers == 10
|
||||
assert model.max_tokens == 300
|
||||
assert model.api_endpoint == "https://api.together.xyz"
|
||||
assert model.beautify is False
|
||||
assert model.streaming_enabled is False
|
||||
assert model.meta_prompt is False
|
||||
assert model.system_prompt is None
|
||||
|
||||
|
||||
def test_init_custom_params(mock_api_key):
|
||||
model = TogetherLLM(
|
||||
together_api_key="custom-api-key",
|
||||
logging_enabled=True,
|
||||
model_name="custom-model",
|
||||
max_workers=5,
|
||||
max_tokens=500,
|
||||
api_endpoint="https://custom-api.together.xyz",
|
||||
beautify=True,
|
||||
streaming_enabled=True,
|
||||
meta_prompt="meta-prompt",
|
||||
system_prompt="system-prompt",
|
||||
)
|
||||
assert model.together_api_key == "custom-api-key"
|
||||
assert model.logging_enabled is True
|
||||
assert model.model_name == "custom-model"
|
||||
assert model.max_workers == 5
|
||||
assert model.max_tokens == 500
|
||||
assert model.api_endpoint == "https://custom-api.together.xyz"
|
||||
assert model.beautify is True
|
||||
assert model.streaming_enabled is True
|
||||
assert model.meta_prompt == "meta-prompt"
|
||||
assert model.system_prompt == "system-prompt"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_success(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Generated response"}}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherLLM()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response == "Generated response"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_failure(mock_post, mock_api_key):
|
||||
mock_post.side_effect = requests.exceptions.RequestException(
|
||||
"Request failed"
|
||||
)
|
||||
|
||||
model = TogetherLLM()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_run_with_logging_enabled(caplog, mock_api_key):
|
||||
model = TogetherLLM(logging_enabled=True)
|
||||
task = "What is the color of the object?"
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
model.run(task)
|
||||
|
||||
assert "Sending request to" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_input", [None, 123, ["list", "of", "items"]]
|
||||
)
|
||||
def test_invalid_task_input(invalid_input, mock_api_key):
|
||||
model = TogetherLLM()
|
||||
response = model.run(invalid_input)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_streaming_enabled(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Generated response"}}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherLLM(streaming_enabled=True)
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response == "Generated response"
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_empty_choices(mock_post, mock_api_key):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"choices": []}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = TogetherLLM()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@patch("swarms.models.together_model.requests.post")
|
||||
def test_run_with_exception(mock_post, mock_api_key):
|
||||
mock_post.side_effect = Exception("Test exception")
|
||||
|
||||
model = TogetherLLM()
|
||||
task = "What is the color of the object?"
|
||||
response = model.run(task)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
def test_init_logging_disabled(monkeypatch):
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
|
||||
model = TogetherLLM()
|
||||
assert model.logging_enabled is False
|
||||
assert not model.system_prompt
|
@ -1,110 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm_models.vilt import Image, Vilt, requests
|
||||
|
||||
|
||||
# Fixture for Vilt instance
|
||||
@pytest.fixture
|
||||
def vilt_instance():
|
||||
return Vilt()
|
||||
|
||||
|
||||
# 1. Test Initialization
|
||||
def test_vilt_initialization(vilt_instance):
|
||||
assert isinstance(vilt_instance, Vilt)
|
||||
assert vilt_instance.processor is not None
|
||||
assert vilt_instance.model is not None
|
||||
|
||||
|
||||
# 2. Test Model Predictions
|
||||
@patch.object(requests, "get")
|
||||
@patch.object(Image, "open")
|
||||
def test_vilt_prediction(
|
||||
mock_image_open, mock_requests_get, vilt_instance
|
||||
):
|
||||
mock_image = Mock()
|
||||
mock_image_open.return_value = mock_image
|
||||
mock_requests_get.return_value.raw = Mock()
|
||||
|
||||
# It's a mock response, so no real answer expected
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # Ensure exception is more specific
|
||||
vilt_instance(
|
||||
"What is this image",
|
||||
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
|
||||
)
|
||||
|
||||
|
||||
# 3. Test Exception Handling for network
|
||||
@patch.object(
|
||||
requests,
|
||||
"get",
|
||||
side_effect=requests.RequestException("Network error"),
|
||||
)
|
||||
def test_vilt_network_exception(vilt_instance):
|
||||
with pytest.raises(requests.RequestException):
|
||||
vilt_instance(
|
||||
"What is this image",
|
||||
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
|
||||
)
|
||||
|
||||
|
||||
# Parameterized test cases for different inputs
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
("What is this?", "http://example.com/image1.jpg"),
|
||||
("Who is in the image?", "http://example.com/image2.jpg"),
|
||||
(
|
||||
"Where was this picture taken?",
|
||||
"http://example.com/image3.jpg",
|
||||
),
|
||||
# ... Add more scenarios
|
||||
],
|
||||
)
|
||||
def test_vilt_various_inputs(text, image_url, vilt_instance):
|
||||
with pytest.raises(
|
||||
Exception
|
||||
): # Again, ensure exception is more specific
|
||||
vilt_instance(text, image_url)
|
||||
|
||||
|
||||
# Test with invalid or empty text
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
(
|
||||
"",
|
||||
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
|
||||
),
|
||||
(
|
||||
None,
|
||||
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
|
||||
),
|
||||
(
|
||||
" ",
|
||||
"https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80",
|
||||
),
|
||||
# ... Add more scenarios
|
||||
],
|
||||
)
|
||||
def test_vilt_invalid_text(text, image_url, vilt_instance):
|
||||
with pytest.raises(ValueError):
|
||||
vilt_instance(text, image_url)
|
||||
|
||||
|
||||
# Test with invalid or empty image_url
|
||||
@pytest.mark.parametrize(
|
||||
"text,image_url",
|
||||
[
|
||||
("What is this?", ""),
|
||||
("Who is in the image?", None),
|
||||
("Where was this picture taken?", " "),
|
||||
],
|
||||
)
|
||||
def test_vilt_invalid_image_url(text, image_url, vilt_instance):
|
||||
with pytest.raises(ValueError):
|
||||
vilt_instance(text, image_url)
|
@ -1,122 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm_models.zeroscope import ZeroscopeTTV
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline.from_pretrained.assert_called_once()
|
||||
mock_scheduler.assert_called_once()
|
||||
assert zeroscope.model_name == "cerspense/zeroscope_v2_576w"
|
||||
assert zeroscope.chunk_size == 1
|
||||
assert zeroscope.dim == 1
|
||||
assert zeroscope.num_inference_steps == 40
|
||||
assert zeroscope.height == 320
|
||||
assert zeroscope.width == 576
|
||||
assert zeroscope.num_frames == 36
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = (
|
||||
mock_pipeline_instance
|
||||
)
|
||||
mock_pipeline_instance.return_value = MagicMock(
|
||||
frames="Generated frames"
|
||||
)
|
||||
mock_pipeline_instance.enable_vae_slicing.assert_called_once()
|
||||
mock_pipeline_instance.enable_forward_chunking.assert_called_once_with(
|
||||
chunk_size=1, dim=1
|
||||
)
|
||||
result = zeroscope.forward("Test task")
|
||||
assert result == "Generated frames"
|
||||
mock_pipeline_instance.assert_called_once_with(
|
||||
"Test task",
|
||||
num_inference_steps=40,
|
||||
height=320,
|
||||
width=576,
|
||||
num_frames=36,
|
||||
)
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = (
|
||||
mock_pipeline_instance
|
||||
)
|
||||
mock_pipeline_instance.return_value = MagicMock(
|
||||
frames="Generated frames"
|
||||
)
|
||||
mock_pipeline_instance.side_effect = Exception("Test error")
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
zeroscope.forward("Test task")
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = (
|
||||
mock_pipeline_instance
|
||||
)
|
||||
mock_pipeline_instance.return_value = MagicMock(
|
||||
frames="Generated frames"
|
||||
)
|
||||
result = zeroscope.__call__("Test task")
|
||||
assert result == "Generated frames"
|
||||
mock_pipeline_instance.assert_called_once_with(
|
||||
"Test task",
|
||||
num_inference_steps=40,
|
||||
height=320,
|
||||
width=576,
|
||||
num_frames=36,
|
||||
)
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = (
|
||||
mock_pipeline_instance
|
||||
)
|
||||
mock_pipeline_instance.return_value = MagicMock(
|
||||
frames="Generated frames"
|
||||
)
|
||||
mock_pipeline_instance.side_effect = Exception("Test error")
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
zeroscope.__call__("Test task")
|
||||
|
||||
|
||||
@patch("swarms.models.zeroscope.DiffusionPipeline")
|
||||
@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler")
|
||||
def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline):
|
||||
zeroscope = ZeroscopeTTV()
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline.from_pretrained.return_value = (
|
||||
mock_pipeline_instance
|
||||
)
|
||||
mock_pipeline_instance.return_value = MagicMock(
|
||||
frames="Generated frames"
|
||||
)
|
||||
result = zeroscope.save_video_path("Test video path")
|
||||
assert result == "Test video path"
|
||||
mock_pipeline_instance.assert_called_once_with(
|
||||
"Test video path",
|
||||
num_inference_steps=40,
|
||||
height=320,
|
||||
width=576,
|
||||
num_frames=36,
|
||||
)
|
Loading…
Reference in new issue