diff --git a/docs/swarms/structs/moa.md b/docs/swarms/structs/moa.md index 6c0f5959..82b23330 100644 --- a/docs/swarms/structs/moa.md +++ b/docs/swarms/structs/moa.md @@ -169,7 +169,9 @@ For further reading and background information on the concepts used in the `Mixt #### Example 1: Basic Initialization and Run ```python -from swarms import MixtureOfAgents, Agent, OpenAIOpenAIChat +from swarms import MixtureOfAgents, Agent + +from swarm_models import OpenAIChat # Define agents director = Agent( @@ -225,7 +227,9 @@ print(history) #### Example 2: Verbose Output and Auto-Save ```python -from swarms import MixtureOfAgents, Agent, OpenAIChat +from swarms import MixtureOfAgents, Agent + +from swarm_models import OpenAIChat # Define Agents # Define agents @@ -286,7 +290,9 @@ print(history) #### Example 3: Custom Rules and Multiple Layers ```python -from swarms import MixtureOfAgents, Agent, OpenAIOpenAIChat +from swarms import MixtureOfAgents, Agent + +from swarm_models import OpenAIChat # Define agents # Initialize the director agent diff --git a/docs/swarms/structs/round_robin_swarm.md b/docs/swarms/structs/round_robin_swarm.md index d788eb85..33ad7e2b 100644 --- a/docs/swarms/structs/round_robin_swarm.md +++ b/docs/swarms/structs/round_robin_swarm.md @@ -50,7 +50,8 @@ Executes a specified task across all agents in a round-robin manner, cycling thr In this example, `RoundRobinSwarm` is used to distribute network requests evenly among a group of servers. This is common in scenarios where load balancing is crucial for maintaining system responsiveness and scalability. ```python -from swarms import Agent, OpenAIChat, RoundRobinSwarm +from swarms import Agent, RoundRobinSwarm +from swarm_models import OpenAIChat # Initialize the LLM diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py deleted file mode 100644 index 1f583889..00000000 --- a/tests/models/test_anthropic.py +++ /dev/null @@ -1,263 +0,0 @@ -import os -from unittest.mock import Mock, patch - -import pytest - -from swarm_models.anthropic import Anthropic - - -# Mock the Anthropic API client for testing -class MockAnthropicClient: - def __init__(self, *args, **kwargs): - pass - - def completions_create( - self, prompt, stop_sequences, stream, **kwargs - ): - return MockAnthropicResponse() - - -@pytest.fixture -def mock_anthropic_env(): - os.environ["ANTHROPIC_API_URL"] = "https://test.anthropic.com" - os.environ["ANTHROPIC_API_KEY"] = "test_api_key" - yield - del os.environ["ANTHROPIC_API_URL"] - del os.environ["ANTHROPIC_API_KEY"] - - -@pytest.fixture -def mock_requests_post(): - with patch("requests.post") as mock_post: - yield mock_post - - -@pytest.fixture -def anthropic_instance(): - return Anthropic(model="test-model") - - -def test_anthropic_init_default_values(anthropic_instance): - assert anthropic_instance.model == "test-model" - assert anthropic_instance.max_tokens_to_sample == 256 - assert anthropic_instance.temperature is None - assert anthropic_instance.top_k is None - assert anthropic_instance.top_p is None - assert anthropic_instance.streaming is False - assert anthropic_instance.default_request_timeout == 600 - assert ( - anthropic_instance.anthropic_api_url - == "https://test.anthropic.com" - ) - assert anthropic_instance.anthropic_api_key == "test_api_key" - - -def test_anthropic_init_custom_values(): - anthropic_instance = Anthropic( - model="custom-model", - max_tokens_to_sample=128, - temperature=0.8, - top_k=5, - top_p=0.9, - streaming=True, - default_request_timeout=300, - ) - assert anthropic_instance.model == "custom-model" - assert anthropic_instance.max_tokens_to_sample == 128 - assert anthropic_instance.temperature == 0.8 - assert anthropic_instance.top_k == 5 - assert anthropic_instance.top_p == 0.9 - assert anthropic_instance.streaming is True - assert anthropic_instance.default_request_timeout == 300 - - -def test_anthropic_default_params(anthropic_instance): - default_params = anthropic_instance._default_params() - assert default_params == { - "max_tokens_to_sample": 256, - "model": "test-model", - } - - -def test_anthropic_run( - mock_anthropic_env, mock_requests_post, anthropic_instance -): - mock_response = Mock() - mock_response.json.return_value = {"completion": "Generated text"} - mock_requests_post.return_value = mock_response - - task = "Generate text" - stop = ["stop1", "stop2"] - - completion = anthropic_instance.run(task, stop) - - assert completion == "Generated text" - mock_requests_post.assert_called_once_with( - "https://test.anthropic.com/completions", - headers={"Authorization": "Bearer test_api_key"}, - json={ - "prompt": task, - "stop_sequences": stop, - "max_tokens_to_sample": 256, - "model": "test-model", - }, - timeout=600, - ) - - -def test_anthropic_call( - mock_anthropic_env, mock_requests_post, anthropic_instance -): - mock_response = Mock() - mock_response.json.return_value = {"completion": "Generated text"} - mock_requests_post.return_value = mock_response - - task = "Generate text" - stop = ["stop1", "stop2"] - - completion = anthropic_instance(task, stop) - - assert completion == "Generated text" - mock_requests_post.assert_called_once_with( - "https://test.anthropic.com/completions", - headers={"Authorization": "Bearer test_api_key"}, - json={ - "prompt": task, - "stop_sequences": stop, - "max_tokens_to_sample": 256, - "model": "test-model", - }, - timeout=600, - ) - - -def test_anthropic_exception_handling( - mock_anthropic_env, mock_requests_post, anthropic_instance -): - mock_response = Mock() - mock_response.json.return_value = {"error": "An error occurred"} - mock_requests_post.return_value = mock_response - - task = "Generate text" - stop = ["stop1", "stop2"] - - with pytest.raises(Exception) as excinfo: - anthropic_instance(task, stop) - - assert "An error occurred" in str(excinfo.value) - - -class MockAnthropicResponse: - def __init__(self): - self.completion = "Mocked Response from Anthropic" - - -def test_anthropic_instance_creation(anthropic_instance): - assert isinstance(anthropic_instance, Anthropic) - - -def test_anthropic_call_method(anthropic_instance): - response = anthropic_instance("What is the meaning of life?") - assert response == "Mocked Response from Anthropic" - - -def test_anthropic_stream_method(anthropic_instance): - generator = anthropic_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_anthropic_async_call_method(anthropic_instance): - response = anthropic_instance.async_call("Tell me a joke.") - assert response == "Mocked Response from Anthropic" - - -def test_anthropic_async_stream_method(anthropic_instance): - async_generator = anthropic_instance.async_stream( - "Translate to French." - ) - for token in async_generator: - assert isinstance(token, str) - - -def test_anthropic_get_num_tokens(anthropic_instance): - text = "This is a test sentence." - num_tokens = anthropic_instance.get_num_tokens(text) - assert num_tokens > 0 - - -# Add more test cases to cover other functionalities and edge cases of the Anthropic class - - -def test_anthropic_wrap_prompt(anthropic_instance): - prompt = "What is the meaning of life?" - wrapped_prompt = anthropic_instance._wrap_prompt(prompt) - assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT) - assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT) - - -def test_anthropic_convert_prompt(anthropic_instance): - prompt = "What is the meaning of life?" - converted_prompt = anthropic_instance.convert_prompt(prompt) - assert converted_prompt.startswith( - anthropic_instance.HUMAN_PROMPT - ) - assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) - - -def test_anthropic_call_with_stop(anthropic_instance): - response = anthropic_instance( - "Translate to French.", stop=["stop1", "stop2"] - ) - assert response == "Mocked Response from Anthropic" - - -def test_anthropic_stream_with_stop(anthropic_instance): - generator = anthropic_instance.stream( - "Write a story.", stop=["stop1", "stop2"] - ) - for token in generator: - assert isinstance(token, str) - - -def test_anthropic_async_call_with_stop(anthropic_instance): - response = anthropic_instance.async_call( - "Tell me a joke.", stop=["stop1", "stop2"] - ) - assert response == "Mocked Response from Anthropic" - - -def test_anthropic_async_stream_with_stop(anthropic_instance): - async_generator = anthropic_instance.async_stream( - "Translate to French.", stop=["stop1", "stop2"] - ) - for token in async_generator: - assert isinstance(token, str) - - -def test_anthropic_get_num_tokens_with_count_tokens( - anthropic_instance, -): - anthropic_instance.count_tokens = Mock(return_value=10) - text = "This is a test sentence." - num_tokens = anthropic_instance.get_num_tokens(text) - assert num_tokens == 10 - - -def test_anthropic_get_num_tokens_without_count_tokens( - anthropic_instance, -): - del anthropic_instance.count_tokens - with pytest.raises(NameError): - text = "This is a test sentence." - anthropic_instance.get_num_tokens(text) - - -def test_anthropic_wrap_prompt_without_human_ai_prompt( - anthropic_instance, -): - del anthropic_instance.HUMAN_PROMPT - del anthropic_instance.AI_PROMPT - prompt = "What is the meaning of life?" - with pytest.raises(NameError): - anthropic_instance._wrap_prompt(prompt) diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py deleted file mode 100644 index 969f9a26..00000000 --- a/tests/models/test_cohere.py +++ /dev/null @@ -1,784 +0,0 @@ -import os -from unittest.mock import Mock, patch - -import pytest -from dotenv import load_dotenv - -from swarms import Cohere - -# Load the environment variables -load_dotenv() -api_key = os.getenv("COHERE_API_KEY") - - -@pytest.fixture -def cohere_instance(): - return Cohere(cohere_api_key=api_key) - - -def test_cohere_custom_configuration(cohere_instance): - # Test customizing Cohere configurations - cohere_instance.model = "base" - cohere_instance.temperature = 0.5 - cohere_instance.max_tokens = 100 - cohere_instance.k = 1 - cohere_instance.p = 0.8 - cohere_instance.frequency_penalty = 0.2 - cohere_instance.presence_penalty = 0.4 - response = cohere_instance("Customize configurations.") - assert isinstance(response, str) - - -def test_cohere_api_error_handling(cohere_instance): - # Test error handling when the API key is invalid - cohere_instance.model = "base" - cohere_instance.cohere_api_key = "invalid-api-key" - with pytest.raises(Exception): - cohere_instance("Error handling with invalid API key.") - - -def test_cohere_async_api_error_handling(cohere_instance): - # Test async error handling when the API key is invalid - cohere_instance.model = "base" - cohere_instance.cohere_api_key = "invalid-api-key" - with pytest.raises(Exception): - cohere_instance.async_call( - "Error handling with invalid API key." - ) - - -def test_cohere_stream_api_error_handling(cohere_instance): - # Test error handling in streaming mode when the API key is invalid - cohere_instance.model = "base" - cohere_instance.cohere_api_key = "invalid-api-key" - with pytest.raises(Exception): - generator = cohere_instance.stream( - "Error handling with invalid API key." - ) - for token in generator: - pass - - -def test_cohere_streaming_mode(cohere_instance): - # Test the streaming mode for large text generation - cohere_instance.model = "base" - cohere_instance.streaming = True - prompt = "Generate a lengthy text using streaming mode." - generator = cohere_instance.stream(prompt) - for token in generator: - assert isinstance(token, str) - - -def test_cohere_streaming_mode_async(cohere_instance): - # Test the async streaming mode for large text generation - cohere_instance.model = "base" - cohere_instance.streaming = True - prompt = "Generate a lengthy text using async streaming mode." - async_generator = cohere_instance.async_stream(prompt) - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_wrap_prompt(cohere_instance): - prompt = "What is the meaning of life?" - wrapped_prompt = cohere_instance._wrap_prompt(prompt) - assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT) - assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT) - - -def test_cohere_convert_prompt(cohere_instance): - prompt = "What is the meaning of life?" - converted_prompt = cohere_instance.convert_prompt(prompt) - assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT) - assert converted_prompt.endswith(cohere_instance.AI_PROMPT) - - -def test_cohere_call_with_stop(cohere_instance): - response = cohere_instance( - "Translate to French.", stop=["stop1", "stop2"] - ) - assert response == "Mocked Response from Cohere" - - -def test_cohere_stream_with_stop(cohere_instance): - generator = cohere_instance.stream( - "Write a story.", stop=["stop1", "stop2"] - ) - for token in generator: - assert isinstance(token, str) - - -def test_cohere_async_call_with_stop(cohere_instance): - response = cohere_instance.async_call( - "Tell me a joke.", stop=["stop1", "stop2"] - ) - assert response == "Mocked Response from Cohere" - - -def test_cohere_async_stream_with_stop(cohere_instance): - async_generator = cohere_instance.async_stream( - "Translate to French.", stop=["stop1", "stop2"] - ) - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_get_num_tokens_with_count_tokens(cohere_instance): - cohere_instance.count_tokens = Mock(return_value=10) - text = "This is a test sentence." - num_tokens = cohere_instance.get_num_tokens(text) - assert num_tokens == 10 - - -def test_cohere_get_num_tokens_without_count_tokens(cohere_instance): - del cohere_instance.count_tokens - with pytest.raises(NameError): - text = "This is a test sentence." - cohere_instance.get_num_tokens(text) - - -def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance): - del cohere_instance.HUMAN_PROMPT - del cohere_instance.AI_PROMPT - prompt = "What is the meaning of life?" - with pytest.raises(NameError): - cohere_instance._wrap_prompt(prompt) - - -def test_base_cohere_import(): - with patch.dict("sys.modules", {"cohere": None}): - with pytest.raises(ImportError): - pass - - -def test_base_cohere_validate_environment(): - values = { - "cohere_api_key": "my-api-key", - "user_agent": "langchain", - } - validated_values = Cohere.validate_environment(values) - assert "client" in validated_values - assert "async_client" in validated_values - - -def test_base_cohere_validate_environment_without_cohere(): - values = { - "cohere_api_key": "my-api-key", - "user_agent": "langchain", - } - with patch.dict("sys.modules", {"cohere": None}): - with pytest.raises(ImportError): - Cohere.validate_environment(values) - - -# Test cases for benchmarking generations with various models -def test_cohere_generate_with_command_light(cohere_instance): - cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text with Command Light model." - ) - assert response.startswith( - "Generated text with Command Light model" - ) - - -def test_cohere_generate_with_command(cohere_instance): - cohere_instance.model = "command" - response = cohere_instance("Generate text with Command model.") - assert response.startswith("Generated text with Command model") - - -def test_cohere_generate_with_base_light(cohere_instance): - cohere_instance.model = "base-light" - response = cohere_instance("Generate text with Base Light model.") - assert response.startswith("Generated text with Base Light model") - - -def test_cohere_generate_with_base(cohere_instance): - cohere_instance.model = "base" - response = cohere_instance("Generate text with Base model.") - assert response.startswith("Generated text with Base model") - - -def test_cohere_generate_with_embed_english_v2(cohere_instance): - cohere_instance.model = "embed-english-v2.0" - response = cohere_instance( - "Generate embeddings with English v2.0 model." - ) - assert response.startswith( - "Generated embeddings with English v2.0 model" - ) - - -def test_cohere_generate_with_embed_english_light_v2(cohere_instance): - cohere_instance.model = "embed-english-light-v2.0" - response = cohere_instance( - "Generate embeddings with English Light v2.0 model." - ) - assert response.startswith( - "Generated embeddings with English Light v2.0 model" - ) - - -def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): - cohere_instance.model = "embed-multilingual-v2.0" - response = cohere_instance( - "Generate embeddings with Multilingual v2.0 model." - ) - assert response.startswith( - "Generated embeddings with Multilingual v2.0 model" - ) - - -def test_cohere_generate_with_embed_english_v3(cohere_instance): - cohere_instance.model = "embed-english-v3.0" - response = cohere_instance( - "Generate embeddings with English v3.0 model." - ) - assert response.startswith( - "Generated embeddings with English v3.0 model" - ) - - -def test_cohere_generate_with_embed_english_light_v3(cohere_instance): - cohere_instance.model = "embed-english-light-v3.0" - response = cohere_instance( - "Generate embeddings with English Light v3.0 model." - ) - assert response.startswith( - "Generated embeddings with English Light v3.0 model" - ) - - -def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): - cohere_instance.model = "embed-multilingual-v3.0" - response = cohere_instance( - "Generate embeddings with Multilingual v3.0 model." - ) - assert response.startswith( - "Generated embeddings with Multilingual v3.0 model" - ) - - -def test_cohere_generate_with_embed_multilingual_light_v3( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-light-v3.0" - response = cohere_instance( - "Generate embeddings with Multilingual Light v3.0 model." - ) - assert response.startswith( - "Generated embeddings with Multilingual Light v3.0 model" - ) - - -# Add more test cases to benchmark other models and functionalities - - -def test_cohere_call_with_command_model(cohere_instance): - cohere_instance.model = "command" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_base_model(cohere_instance): - cohere_instance.model = "base" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_embed_english_v2_model(cohere_instance): - cohere_instance.model = "embed-english-v2.0" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_embed_english_v3_model(cohere_instance): - cohere_instance.model = "embed-english-v3.0" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_embed_multilingual_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v2.0" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_embed_multilingual_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v3.0" - response = cohere_instance("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_call_with_invalid_model(cohere_instance): - cohere_instance.model = "invalid-model" - with pytest.raises(ValueError): - cohere_instance("Translate to French.") - - -def test_cohere_call_with_long_prompt(cohere_instance): - prompt = "This is a very long prompt. " * 100 - response = cohere_instance(prompt) - assert isinstance(response, str) - - -def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) - with pytest.raises(ValueError): - cohere_instance(prompt) - - -def test_cohere_stream_with_command_model(cohere_instance): - cohere_instance.model = "command" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_stream_with_base_model(cohere_instance): - cohere_instance.model = "base" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_stream_with_embed_english_v2_model(cohere_instance): - cohere_instance.model = "embed-english-v2.0" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_stream_with_embed_english_v3_model(cohere_instance): - cohere_instance.model = "embed-english-v3.0" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_stream_with_embed_multilingual_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v2.0" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_stream_with_embed_multilingual_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v3.0" - generator = cohere_instance.stream("Write a story.") - for token in generator: - assert isinstance(token, str) - - -def test_cohere_async_call_with_command_model(cohere_instance): - cohere_instance.model = "command" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_call_with_base_model(cohere_instance): - cohere_instance.model = "base" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_call_with_embed_english_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-english-v2.0" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_call_with_embed_english_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-english-v3.0" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_call_with_embed_multilingual_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v2.0" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_call_with_embed_multilingual_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v3.0" - response = cohere_instance.async_call("Translate to French.") - assert isinstance(response, str) - - -def test_cohere_async_stream_with_command_model(cohere_instance): - cohere_instance.model = "command" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_async_stream_with_base_model(cohere_instance): - cohere_instance.model = "base" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_async_stream_with_embed_english_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-english-v2.0" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_async_stream_with_embed_english_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-english-v3.0" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_async_stream_with_embed_multilingual_v2_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v2.0" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_async_stream_with_embed_multilingual_v3_model( - cohere_instance, -): - cohere_instance.model = "embed-multilingual-v3.0" - async_generator = cohere_instance.async_stream("Write a story.") - for token in async_generator: - assert isinstance(token, str) - - -def test_cohere_representation_model_embedding(cohere_instance): - # Test using the Representation model for text embedding - cohere_instance.model = "embed-english-v3.0" - embedding = cohere_instance.embed( - "Generate an embedding for this text." - ) - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -def test_cohere_representation_model_classification(cohere_instance): - # Test using the Representation model for text classification - cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify("Classify this text.") - assert isinstance(classification, dict) - assert "class" in classification - assert "score" in classification - - -def test_cohere_representation_model_language_detection( - cohere_instance, -): - # Test using the Representation model for language detection - cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language( - "Detect the language of this text." - ) - assert isinstance(language, str) - - -def test_cohere_representation_model_max_tokens_limit_exceeded( - cohere_instance, -): - # Test handling max tokens limit exceeded error - cohere_instance.model = "embed-english-v3.0" - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit." - ) - with pytest.raises(ValueError): - cohere_instance.embed(prompt) - - -# Add more production-grade test cases based on real-world scenarios - - -def test_cohere_representation_model_multilingual_embedding( - cohere_instance, -): - # Test using the Representation model for multilingual text embedding - cohere_instance.model = "embed-multilingual-v3.0" - embedding = cohere_instance.embed( - "Generate multilingual embeddings." - ) - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -def test_cohere_representation_model_multilingual_classification( - cohere_instance, -): - # Test using the Representation model for multilingual text classification - cohere_instance.model = "embed-multilingual-v3.0" - classification = cohere_instance.classify( - "Classify multilingual text." - ) - assert isinstance(classification, dict) - assert "class" in classification - assert "score" in classification - - -def test_cohere_representation_model_multilingual_language_detection( - cohere_instance, -): - # Test using the Representation model for multilingual language detection - cohere_instance.model = "embed-multilingual-v3.0" - language = cohere_instance.detect_language( - "Detect the language of multilingual text." - ) - assert isinstance(language, str) - - -def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( - cohere_instance, -): - # Test handling max tokens limit exceeded error for multilingual model - cohere_instance.model = "embed-multilingual-v3.0" - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for multilingual model." - ) - with pytest.raises(ValueError): - cohere_instance.embed(prompt) - - -def test_cohere_representation_model_multilingual_light_embedding( - cohere_instance, -): - # Test using the Representation model for multilingual light text embedding - cohere_instance.model = "embed-multilingual-light-v3.0" - embedding = cohere_instance.embed( - "Generate multilingual light embeddings." - ) - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -def test_cohere_representation_model_multilingual_light_classification( - cohere_instance, -): - # Test using the Representation model for multilingual light text classification - cohere_instance.model = "embed-multilingual-light-v3.0" - classification = cohere_instance.classify( - "Classify multilingual light text." - ) - assert isinstance(classification, dict) - assert "class" in classification - assert "score" in classification - - -def test_cohere_representation_model_multilingual_light_language_detection( - cohere_instance, -): - # Test using the Representation model for multilingual light language detection - cohere_instance.model = "embed-multilingual-light-v3.0" - language = cohere_instance.detect_language( - "Detect the language of multilingual light text." - ) - assert isinstance(language, str) - - -def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded( - cohere_instance, -): - # Test handling max tokens limit exceeded error for multilingual light model - cohere_instance.model = "embed-multilingual-light-v3.0" - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for multilingual light model." - ) - with pytest.raises(ValueError): - cohere_instance.embed(prompt) - - -def test_cohere_command_light_model(cohere_instance): - # Test using the Command Light model for text generation - cohere_instance.model = "command-light" - response = cohere_instance( - "Generate text using Command Light model." - ) - assert isinstance(response, str) - - -def test_cohere_base_light_model(cohere_instance): - # Test using the Base Light model for text generation - cohere_instance.model = "base-light" - response = cohere_instance( - "Generate text using Base Light model." - ) - assert isinstance(response, str) - - -def test_cohere_generate_summarize_endpoint(cohere_instance): - # Test using the Co.summarize() endpoint for text summarization - cohere_instance.model = "command" - response = cohere_instance.summarize("Summarize this text.") - assert isinstance(response, str) - - -def test_cohere_representation_model_english_embedding( - cohere_instance, -): - # Test using the Representation model for English text embedding - cohere_instance.model = "embed-english-v3.0" - embedding = cohere_instance.embed("Generate English embeddings.") - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -def test_cohere_representation_model_english_classification( - cohere_instance, -): - # Test using the Representation model for English text classification - cohere_instance.model = "embed-english-v3.0" - classification = cohere_instance.classify( - "Classify English text." - ) - assert isinstance(classification, dict) - assert "class" in classification - assert "score" in classification - - -def test_cohere_representation_model_english_language_detection( - cohere_instance, -): - # Test using the Representation model for English language detection - cohere_instance.model = "embed-english-v3.0" - language = cohere_instance.detect_language( - "Detect the language of English text." - ) - assert isinstance(language, str) - - -def test_cohere_representation_model_english_max_tokens_limit_exceeded( - cohere_instance, -): - # Test handling max tokens limit exceeded error for English model - cohere_instance.model = "embed-english-v3.0" - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for English model." - ) - with pytest.raises(ValueError): - cohere_instance.embed(prompt) - - -def test_cohere_representation_model_english_light_embedding( - cohere_instance, -): - # Test using the Representation model for English light text embedding - cohere_instance.model = "embed-english-light-v3.0" - embedding = cohere_instance.embed( - "Generate English light embeddings." - ) - assert isinstance(embedding, list) - assert len(embedding) > 0 - - -def test_cohere_representation_model_english_light_classification( - cohere_instance, -): - # Test using the Representation model for English light text classification - cohere_instance.model = "embed-english-light-v3.0" - classification = cohere_instance.classify( - "Classify English light text." - ) - assert isinstance(classification, dict) - assert "class" in classification - assert "score" in classification - - -def test_cohere_representation_model_english_light_language_detection( - cohere_instance, -): - # Test using the Representation model for English light language detection - cohere_instance.model = "embed-english-light-v3.0" - language = cohere_instance.detect_language( - "Detect the language of English light text." - ) - assert isinstance(language, str) - - -def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( - cohere_instance, -): - # Test handling max tokens limit exceeded error for English light model - cohere_instance.model = "embed-english-light-v3.0" - cohere_instance.max_tokens = 10 - prompt = ( - "This is a test prompt that will exceed the max tokens limit" - " for English light model." - ) - with pytest.raises(ValueError): - cohere_instance.embed(prompt) - - -def test_cohere_command_model(cohere_instance): - # Test using the Command model for text generation - cohere_instance.model = "command" - response = cohere_instance( - "Generate text using the Command model." - ) - assert isinstance(response, str) - - -# Add more production-grade test cases based on real-world scenarios - - -def test_cohere_invalid_model(cohere_instance): - # Test using an invalid model name - cohere_instance.model = "invalid-model" - with pytest.raises(ValueError): - cohere_instance("Generate text using an invalid model.") - - -def test_cohere_base_model_generation_with_max_tokens( - cohere_instance, -): - # Test generating text using the base model with a specified max_tokens limit - cohere_instance.model = "base" - cohere_instance.max_tokens = 20 - prompt = "Generate text with max_tokens limit." - response = cohere_instance(prompt) - assert len(response.split()) <= 20 - - -def test_cohere_command_light_generation_with_stop(cohere_instance): - # Test generating text using the command-light model with stop words - cohere_instance.model = "command-light" - prompt = "Generate text with stop words." - stop = ["stop", "words"] - response = cohere_instance(prompt, stop=stop) - assert all(word not in response for word in stop) diff --git a/tests/models/test_fuyu.py b/tests/models/test_fuyu.py deleted file mode 100644 index 60044de2..00000000 --- a/tests/models/test_fuyu.py +++ /dev/null @@ -1,207 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch -from PIL import Image -from transformers import FuyuImageProcessor, FuyuProcessor - -from swarm_models.fuyu import Fuyu - - -# Basic test to ensure instantiation of class. -def test_fuyu_initialization(): - fuyu_instance = Fuyu() - assert isinstance(fuyu_instance, Fuyu) - - -# Using parameterized testing for different init parameters. -@pytest.mark.parametrize( - "pretrained_path, device_map, max_new_tokens", - [ - ("adept/fuyu-8b", "cuda:0", 7), - ("adept/fuyu-8b", "cpu", 10), - ], -) -def test_fuyu_parameters(pretrained_path, device_map, max_new_tokens): - fuyu_instance = Fuyu(pretrained_path, device_map, max_new_tokens) - assert fuyu_instance.pretrained_path == pretrained_path - assert fuyu_instance.device_map == device_map - assert fuyu_instance.max_new_tokens == max_new_tokens - - -# Fixture for creating a Fuyu instance. -@pytest.fixture -def fuyu_instance(): - return Fuyu() - - -# Test using the fixture. -def test_fuyu_processor_initialization(fuyu_instance): - assert isinstance(fuyu_instance.processor, FuyuProcessor) - assert isinstance( - fuyu_instance.image_processor, FuyuImageProcessor - ) - - -# Test exception when providing an invalid image path. -def test_invalid_image_path(fuyu_instance): - with pytest.raises(FileNotFoundError): - fuyu_instance("Hello", "invalid/path/to/image.png") - - -# Using monkeypatch to replace the Image.open method to simulate a failure. -def test_image_open_failure(fuyu_instance, monkeypatch): - def mock_open(*args, **kwargs): - raise Exception("Mocked failure") - - monkeypatch.setattr(Image, "open", mock_open) - - with pytest.raises(Exception, match="Mocked failure"): - fuyu_instance( - "Hello", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - - -# Marking a slow test. -@pytest.mark.slow -def test_fuyu_model_output(fuyu_instance): - # This is a dummy test and may not be functional without real data. - output = fuyu_instance( - "Hello, my name is", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - assert isinstance(output, str) - - -def test_tokenizer_type(fuyu_instance): - assert "tokenizer" in dir(fuyu_instance) - - -def test_processor_has_image_processor_and_tokenizer(fuyu_instance): - assert ( - fuyu_instance.processor.image_processor - == fuyu_instance.image_processor - ) - assert ( - fuyu_instance.processor.tokenizer == fuyu_instance.tokenizer - ) - - -def test_model_device_map(fuyu_instance): - assert fuyu_instance.model.device_map == fuyu_instance.device_map - - -# Testing maximum tokens setting -def test_max_new_tokens_setting(fuyu_instance): - assert fuyu_instance.max_new_tokens == 7 - - -# Test if an exception is raised when invalid text is provided. -def test_invalid_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance(None, "path/to/image.png") - - -# Test if an exception is raised when empty text is provided. -def test_empty_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance("", "path/to/image.png") - - -# Test if an exception is raised when a very long text is provided. -def test_very_long_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance("A" * 10000, "path/to/image.png") - - -# Check model's default device map -def test_default_device_map(): - fuyu_instance = Fuyu() - assert fuyu_instance.device_map == "cuda:0" - - -# Testing if processor is correctly initialized -def test_processor_initialization(fuyu_instance): - assert isinstance(fuyu_instance.processor, FuyuProcessor) - - -# Test `get_img` method with a valid image path -def test_get_img_valid_path(fuyu_instance): - with patch("PIL.Image.open") as mock_open: - mock_open.return_value = "Test image" - result = fuyu_instance.get_img("valid/path/to/image.png") - assert result == "Test image" - - -# Test `get_img` method with an invalid image path -def test_get_img_invalid_path(fuyu_instance): - with patch("PIL.Image.open") as mock_open: - mock_open.side_effect = FileNotFoundError - with pytest.raises(FileNotFoundError): - fuyu_instance.get_img("invalid/path/to/image.png") - - -# Test `run` method with valid inputs -def test_run_valid_inputs(fuyu_instance): - with patch.object( - fuyu_instance, "get_img" - ) as mock_get_img, patch.object( - fuyu_instance, "processor" - ) as mock_processor, patch.object( - fuyu_instance, "model" - ) as mock_model: - mock_get_img.return_value = "Test image" - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } - mock_model.generate.return_value = torch.tensor([1, 2, 3]) - mock_processor.batch_decode.return_value = ["Test text"] - result = fuyu_instance.run( - "Hello, world!", "valid/path/to/image.png" - ) - assert result == ["Test text"] - - -# Test `run` method with invalid text input -def test_run_invalid_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance.run(None, "valid/path/to/image.png") - - -# Test `run` method with empty text input -def test_run_empty_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance.run("", "valid/path/to/image.png") - - -# Test `run` method with very long text input -def test_run_very_long_text_input(fuyu_instance): - with pytest.raises(Exception): - fuyu_instance.run("A" * 10000, "valid/path/to/image.png") - - -# Test `run` method with invalid image path -def test_run_invalid_image_path(fuyu_instance): - with patch.object(fuyu_instance, "get_img") as mock_get_img: - mock_get_img.side_effect = FileNotFoundError - with pytest.raises(FileNotFoundError): - fuyu_instance.run( - "Hello, world!", "invalid/path/to/image.png" - ) - - -# Test `__init__` method with default parameters -def test_init_default_parameters(): - fuyu_instance = Fuyu() - assert fuyu_instance.pretrained_path == "adept/fuyu-8b" - assert fuyu_instance.device_map == "auto" - assert fuyu_instance.max_new_tokens == 500 - - -# Test `__init__` method with custom parameters -def test_init_custom_parameters(): - fuyu_instance = Fuyu("custom/path", "cpu", 1000) - assert fuyu_instance.pretrained_path == "custom/path" - assert fuyu_instance.device_map == "cpu" - assert fuyu_instance.max_new_tokens == 1000 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py deleted file mode 100644 index 91e7c0ac..00000000 --- a/tests/models/test_gemini.py +++ /dev/null @@ -1,315 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from swarm_models.gemini import Gemini - - -# Define test fixtures -@pytest.fixture -def mock_gemini_api_key(monkeypatch): - monkeypatch.setenv("GEMINI_API_KEY", "mocked-api-key") - - -@pytest.fixture -def mock_genai_model(): - return Mock() - - -# Test initialization of Gemini -def test_gemini_init_defaults(mock_gemini_api_key, mock_genai_model): - model = Gemini() - assert model.model_name == "gemini-pro" - assert model.gemini_api_key == "mocked-api-key" - assert model.model is mock_genai_model - - -def test_gemini_init_custom_params( - mock_gemini_api_key, mock_genai_model -): - model = Gemini( - model_name="custom-model", gemini_api_key="custom-api-key" - ) - assert model.model_name == "custom-model" - assert model.gemini_api_key == "custom-api-key" - assert model.model is mock_genai_model - - -# Test Gemini run method -@patch("swarms.models.gemini.Gemini.process_img") -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_with_img( - mock_generate_content, - mock_process_img, - mock_gemini_api_key, - mock_genai_model, -): - model = Gemini() - task = "A cat" - img = "cat.png" - response_mock = Mock(text="Generated response") - mock_generate_content.return_value = response_mock - mock_process_img.return_value = "Processed image" - - response = model.run(task=task, img=img) - - assert response == "Generated response" - mock_generate_content.assert_called_with( - content=[task, "Processed image"] - ) - mock_process_img.assert_called_with(img=img) - - -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_without_img( - mock_generate_content, mock_gemini_api_key, mock_genai_model -): - model = Gemini() - task = "A cat" - response_mock = Mock(text="Generated response") - mock_generate_content.return_value = response_mock - - response = model.run(task=task) - - assert response == "Generated response" - mock_generate_content.assert_called_with(task=task) - - -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_exception( - mock_generate_content, mock_gemini_api_key, mock_genai_model -): - model = Gemini() - task = "A cat" - mock_generate_content.side_effect = Exception("Test exception") - - response = model.run(task=task) - - assert response is None - - -# Test Gemini process_img method -def test_gemini_process_img(mock_gemini_api_key, mock_genai_model): - model = Gemini(gemini_api_key="custom-api-key") - img = "cat.png" - img_data = b"Mocked image data" - - with patch("builtins.open", create=True) as open_mock: - open_mock.return_value.__enter__.return_value.read.return_value = ( - img_data - ) - - processed_img = model.process_img(img) - - assert processed_img == [ - {"mime_type": "image/png", "data": img_data} - ] - open_mock.assert_called_with(img, "rb") - - -# Test Gemini initialization with missing API key -def test_gemini_init_missing_api_key(): - with pytest.raises( - ValueError, match="Please provide a Gemini API key" - ): - Gemini(gemini_api_key=None) - - -# Test Gemini initialization with missing model name -def test_gemini_init_missing_model_name(): - with pytest.raises( - ValueError, match="Please provide a model name" - ): - Gemini(model_name=None) - - -# Test Gemini run method with empty task -def test_gemini_run_empty_task(mock_gemini_api_key, mock_genai_model): - model = Gemini() - task = "" - response = model.run(task=task) - assert response is None - - -# Test Gemini run method with empty image -def test_gemini_run_empty_img(mock_gemini_api_key, mock_genai_model): - model = Gemini() - task = "A cat" - img = "" - response = model.run(task=task, img=img) - assert response is None - - -# Test Gemini process_img method with missing image -def test_gemini_process_img_missing_image( - mock_gemini_api_key, mock_genai_model -): - model = Gemini() - img = None - with pytest.raises( - ValueError, match="Please provide an image to process" - ): - model.process_img(img=img) - - -# Test Gemini process_img method with missing image type -def test_gemini_process_img_missing_image_type( - mock_gemini_api_key, mock_genai_model -): - model = Gemini() - img = "cat.png" - with pytest.raises( - ValueError, match="Please provide the image type" - ): - model.process_img(img=img, type=None) - - -# Test Gemini process_img method with missing Gemini API key -def test_gemini_process_img_missing_api_key(mock_genai_model): - model = Gemini(gemini_api_key=None) - img = "cat.png" - with pytest.raises( - ValueError, match="Please provide a Gemini API key" - ): - model.process_img(img=img, type="image/png") - - -# Test Gemini run method with mocked image processing -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -@patch("swarms.models.gemini.Gemini.process_img") -def test_gemini_run_mock_img_processing( - mock_process_img, - mock_generate_content, - mock_gemini_api_key, - mock_genai_model, -): - model = Gemini() - task = "A cat" - img = "cat.png" - response_mock = Mock(text="Generated response") - mock_generate_content.return_value = response_mock - mock_process_img.return_value = "Processed image" - - response = model.run(task=task, img=img) - - assert response == "Generated response" - mock_generate_content.assert_called_with( - content=[task, "Processed image"] - ) - mock_process_img.assert_called_with(img=img) - - -# Test Gemini run method with mocked image processing and exception -@patch("swarms.models.gemini.Gemini.process_img") -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_mock_img_processing_exception( - mock_generate_content, - mock_process_img, - mock_gemini_api_key, - mock_genai_model, -): - model = Gemini() - task = "A cat" - img = "cat.png" - mock_process_img.side_effect = Exception("Test exception") - - response = model.run(task=task, img=img) - - assert response is None - mock_generate_content.assert_not_called() - mock_process_img.assert_called_with(img=img) - - -# Test Gemini run method with mocked image processing and different exception -@patch("swarms.models.gemini.Gemini.process_img") -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_mock_img_processing_different_exception( - mock_generate_content, - mock_process_img, - mock_gemini_api_key, - mock_genai_model, -): - model = Gemini() - task = "A dog" - img = "dog.png" - mock_process_img.side_effect = ValueError("Test exception") - - with pytest.raises(ValueError): - model.run(task=task, img=img) - - mock_generate_content.assert_not_called() - mock_process_img.assert_called_with(img=img) - - -# Test Gemini run method with mocked image processing and no exception -@patch("swarms.models.gemini.Gemini.process_img") -@patch("swarms.models.gemini.genai.GenerativeModel.generate_content") -def test_gemini_run_mock_img_processing_no_exception( - mock_generate_content, - mock_process_img, - mock_gemini_api_key, - mock_genai_model, -): - model = Gemini() - task = "A bird" - img = "bird.png" - mock_generate_content.return_value = "A bird is flying" - - response = model.run(task=task, img=img) - - assert response == "A bird is flying" - mock_generate_content.assert_called_once() - mock_process_img.assert_called_with(img=img) - - -# Test Gemini chat method -@patch("swarms.models.gemini.Gemini.chat") -def test_gemini_chat(mock_chat): - model = Gemini() - mock_chat.return_value = "Hello, Gemini!" - - response = model.chat("Hello, Gemini!") - - assert response == "Hello, Gemini!" - mock_chat.assert_called_once() - - -# Test Gemini list_models method -@patch("swarms.models.gemini.Gemini.list_models") -def test_gemini_list_models(mock_list_models): - model = Gemini() - mock_list_models.return_value = ["model1", "model2"] - - response = model.list_models() - - assert response == ["model1", "model2"] - mock_list_models.assert_called_once() - - -# Test Gemini stream_tokens method -@patch("swarms.models.gemini.Gemini.stream_tokens") -def test_gemini_stream_tokens(mock_stream_tokens): - model = Gemini() - mock_stream_tokens.return_value = ["token1", "token2"] - - response = model.stream_tokens() - - assert response == ["token1", "token2"] - mock_stream_tokens.assert_called_once() - - -# Test Gemini process_img_pil method -@patch("swarms.models.gemini.Gemini.process_img_pil") -def test_gemini_process_img_pil(mock_process_img_pil): - model = Gemini() - img = "bird.png" - mock_process_img_pil.return_value = "processed image" - - response = model.process_img_pil(img) - - assert response == "processed image" - mock_process_img_pil.assert_called_with(img) - - -# Repeat the above tests for different scenarios or different methods in your Gemini class -# until you have 15 tests in total. diff --git a/tests/models/test_gpt4_vision_api.py b/tests/models/test_gpt4_vision_api.py deleted file mode 100644 index 3a67f8ee..00000000 --- a/tests/models/test_gpt4_vision_api.py +++ /dev/null @@ -1,252 +0,0 @@ -import asyncio -import os -from unittest.mock import AsyncMock, Mock, mock_open, patch - -import pytest -from aiohttp import ClientResponseError -from dotenv import load_dotenv -from requests.exceptions import RequestException - -from swarm_models.gpt4_vision_api import GPT4VisionAPI - -load_dotenv() - -custom_api_key = os.environ.get("OPENAI_API_KEY") -img = "images/swarms.jpeg" - - -@pytest.fixture -def vision_api(): - return GPT4VisionAPI(openai_api_key="test_api_key") - - -def test_init(vision_api): - assert vision_api.openai_api_key == "test_api_key" - - -def test_encode_image(vision_api): - with patch( - "builtins.open", - mock_open(read_data=b"test_image_data"), - create=True, - ): - encoded_image = vision_api.encode_image(img) - assert encoded_image == "dGVzdF9pbWFnZV9kYXRh" - - -def test_run_success(vision_api): - expected_response = {"This is the model's response."} - with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), - ) as mock_post: - result = vision_api.run("What is this?", img) - mock_post.assert_called_once() - assert result == "This is the model's response." - - -def test_run_request_error(vision_api): - with patch( - "requests.post", side_effect=RequestException("Request Error") - ): - with pytest.raises(RequestException): - vision_api.run("What is this?", img) - - -def test_run_response_error(vision_api): - expected_response = {"error": "Model Error"} - with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), - ): - with pytest.raises(RuntimeError): - vision_api.run("What is this?", img) - - -def test_call(vision_api): - expected_response = { - "choices": [{"text": "This is the model's response."}] - } - with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), - ) as mock_post: - result = vision_api("What is this?", img) - mock_post.assert_called_once() - assert result == "This is the model's response." - - -@pytest.fixture -def gpt_api(): - return GPT4VisionAPI() - - -def test_initialization_with_default_key(): - api = GPT4VisionAPI() - assert api.openai_api_key == custom_api_key - - -def test_initialization_with_custom_key(): - custom_key = custom_api_key - api = GPT4VisionAPI(openai_api_key=custom_key) - assert api.openai_api_key == custom_key - - -def test_run_with_exception(gpt_api): - task = "What is in the image?" - img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): - with pytest.raises(Exception): - gpt_api.run(task, img_url) - - -def test_call_method_successful_response(gpt_api): - task = "What is in the image?" - img_url = img - response_json = { - "choices": [{"text": "Answer from GPT-4 Vision"}] - } - mock_response = Mock() - mock_response.json.return_value = response_json - with patch( - "requests.post", return_value=mock_response - ) as mock_post: - result = gpt_api(task, img_url) - mock_post.assert_called_once() - assert result == response_json - - -def test_call_method_with_exception(gpt_api): - task = "What is in the image?" - img_url = img - with patch( - "requests.post", side_effect=Exception("Test Exception") - ): - with pytest.raises(Exception): - gpt_api(task, img_url) - - -@pytest.mark.asyncio -async def test_arun_success(vision_api): - expected_response = { - "choices": [ - {"message": {"content": "This is the model's response."}} - ] - } - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=expected_response) - ), - ) as mock_post: - result = await vision_api.arun("What is this?", img) - mock_post.assert_called_once() - assert result == "This is the model's response." - - -@pytest.mark.asyncio -async def test_arun_request_error(vision_api): - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=Exception("Request Error"), - ): - with pytest.raises(Exception): - await vision_api.arun("What is this?", img) - - -def test_run_many_success(vision_api): - expected_response = { - "choices": [ - {"message": {"content": "This is the model's response."}} - ] - } - with patch( - "requests.post", - return_value=Mock(json=lambda: expected_response), - ) as mock_post: - tasks = ["What is this?", "What is that?"] - imgs = [img, img] - results = vision_api.run_many(tasks, imgs) - assert mock_post.call_count == 2 - assert results == [ - "This is the model's response.", - "This is the model's response.", - ] - - -def test_run_many_request_error(vision_api): - with patch( - "requests.post", side_effect=RequestException("Request Error") - ): - tasks = ["What is this?", "What is that?"] - imgs = [img, img] - with pytest.raises(RequestException): - vision_api.run_many(tasks, imgs) - - -@pytest.mark.asyncio -async def test_arun_json_decode_error(vision_api): - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(side_effect=ValueError) - ), - ): - with pytest.raises(ValueError): - await vision_api.arun("What is this?", img) - - -@pytest.mark.asyncio -async def test_arun_api_error(vision_api): - error_response = {"error": {"message": "API Error"}} - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=error_response) - ), - ): - with pytest.raises(Exception, match="API Error"): - await vision_api.arun("What is this?", img) - - -@pytest.mark.asyncio -async def test_arun_unexpected_response(vision_api): - unexpected_response = {"unexpected": "response"} - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - return_value=AsyncMock( - json=AsyncMock(return_value=unexpected_response) - ), - ): - with pytest.raises(Exception, match="Unexpected response"): - await vision_api.arun("What is this?", img) - - -@pytest.mark.asyncio -async def test_arun_retries(vision_api): - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=ClientResponseError(None, None), - ) as mock_post: - with pytest.raises(ClientResponseError): - await vision_api.arun("What is this?", img) - assert mock_post.call_count == vision_api.retries + 1 - - -@pytest.mark.asyncio -async def test_arun_timeout(vision_api): - with patch( - "aiohttp.ClientSession.post", - new_callable=AsyncMock, - side_effect=asyncio.TimeoutError, - ): - with pytest.raises(asyncio.TimeoutError): - await vision_api.arun("What is this?", img) diff --git a/tests/models/test_hf.py b/tests/models/test_hf.py deleted file mode 100644 index 65e52712..00000000 --- a/tests/models/test_hf.py +++ /dev/null @@ -1,465 +0,0 @@ -import logging -from unittest.mock import patch - -import pytest -import torch - -from swarm_models.huggingface import HuggingfaceLLM - - -# Fixture for the class instance -@pytest.fixture -def llm_instance(): - model_id = "NousResearch/Nous-Hermes-2-Vision-Alpha" - instance = HuggingfaceLLM(model_id=model_id) - return instance - - -# Test for instantiation and attributes -def test_llm_initialization(llm_instance): - assert ( - llm_instance.model_id - == "NousResearch/Nous-Hermes-2-Vision-Alpha" - ) - assert llm_instance.max_length == 500 - # ... add more assertions for all default attributes - - -# Parameterized test for setting devices -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_llm_set_device(llm_instance, device): - llm_instance.set_device(device) - assert llm_instance.device == device - - -# Test exception during initialization with a bad model_id -def test_llm_bad_model_initialization(): - with pytest.raises(Exception): - HuggingfaceLLM(model_id="unknown-model") - - -# # Mocking the tokenizer and model to test run method -# @patch("swarms.models.huggingface.AutoTokenizer.from_pretrained") -# @patch( -# "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" -# ) -# def test_llm_run(mock_model, mock_tokenizer, llm_instance): -# mock_model.return_value.generate.return_value = "mocked output" -# mock_tokenizer.return_value.encode.return_value = "mocked input" -# result = llm_instance.run("test task") -# assert result == "mocked output" - - -# Async test (requires pytest-asyncio plugin) -@pytest.mark.asyncio -async def test_llm_run_async(llm_instance): - result = await llm_instance.run_async("test task") - assert isinstance(result, str) - - -# Test for checking GPU availability -def test_llm_gpu_availability(llm_instance): - # Assuming the test is running on a machine where the GPU availability is known - expected_result = torch.cuda.is_available() - assert llm_instance.gpu_available() == expected_result - - -# Test for memory consumption reporting -def test_llm_memory_consumption(llm_instance): - # Mocking torch.cuda functions for consistent results - with patch("torch.cuda.memory_allocated", return_value=1024): - with patch("torch.cuda.memory_reserved", return_value=2048): - memory = llm_instance.memory_consumption() - assert memory == {"allocated": 1024, "reserved": 2048} - - -# Test different initialization parameters -@pytest.mark.parametrize( - "model_id, max_length", - [ - ("NousResearch/Nous-Hermes-2-Vision-Alpha", 100), - ("microsoft/Orca-2-13b", 200), - ( - "berkeley-nest/Starling-LM-7B-alpha", - None, - ), # None to check default behavior - ], -) -def test_llm_initialization_params(model_id, max_length): - if max_length: - instance = HuggingfaceLLM( - model_id=model_id, max_length=max_length - ) - assert instance.max_length == max_length - else: - instance = HuggingfaceLLM(model_id=model_id) - assert ( - instance.max_length == 500 - ) # Assuming 500 is the default max_length - - -# Test for setting an invalid device -def test_llm_set_invalid_device(llm_instance): - with pytest.raises(ValueError): - llm_instance.set_device("quantum_processor") - - -# Mocking external API call to test run method without network -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_run_without_network(mock_run, llm_instance): - mock_run.return_value = "mocked output" - result = llm_instance.run("test task without network") - assert result == "mocked output" - - -# Test handling of empty input for the run method -def test_llm_run_empty_input(llm_instance): - with pytest.raises(ValueError): - llm_instance.run("") - - -# Test the generation with a provided seed for reproducibility -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_run_with_seed(mock_run, llm_instance): - seed = 42 - llm_instance.set_seed(seed) - # Assuming set_seed method affects the randomness in the model - # You would typically ensure that setting the seed gives reproducible results - mock_run.return_value = "mocked deterministic output" - result = llm_instance.run("test task", seed=seed) - assert result == "mocked deterministic output" - - -# Test the output length is as expected -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_run_output_length(mock_run, llm_instance): - input_text = "test task" - llm_instance.max_length = 50 # set a max_length for the output - mock_run.return_value = "mocked output" * 10 # some long text - result = llm_instance.run(input_text) - assert len(result.split()) <= llm_instance.max_length - - -# Test the tokenizer handling special tokens correctly -@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") -@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") -def test_llm_tokenizer_special_tokens( - mock_decode, mock_encode, llm_instance -): - mock_encode.return_value = "encoded input with special tokens" - mock_decode.return_value = "decoded output with special tokens" - result = llm_instance.run("test task with special tokens") - mock_encode.assert_called_once() - mock_decode.assert_called_once() - assert "special tokens" in result - - -# Test for correct handling of timeouts -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_timeout_handling(mock_run, llm_instance): - mock_run.side_effect = TimeoutError - with pytest.raises(TimeoutError): - llm_instance.run("test task with timeout") - - -# Test for response time within a threshold (performance test) -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_response_time(mock_run, llm_instance): - import time - - mock_run.return_value = "mocked output" - start_time = time.time() - llm_instance.run("test task for response time") - end_time = time.time() - assert ( - end_time - start_time < 1 - ) # Assuming the response should be faster than 1 second - - -# Test the logging of a warning for long inputs -@patch("swarms.models.huggingface.logging.warning") -def test_llm_long_input_warning(mock_warning, llm_instance): - long_input = "x" * 10000 # input longer than the typical limit - llm_instance.run(long_input) - mock_warning.assert_called_once() - - -# Test for run method behavior when model raises an exception -@patch( - "swarms.models.huggingface.HuggingfaceLLM._model.generate", - side_effect=RuntimeError, -) -def test_llm_run_model_exception(mock_generate, llm_instance): - with pytest.raises(RuntimeError): - llm_instance.run("test task when model fails") - - -# Test the behavior when GPU is forced but not available -@patch("torch.cuda.is_available", return_value=False) -def test_llm_force_gpu_when_unavailable( - mock_is_available, llm_instance -): - with pytest.raises(EnvironmentError): - llm_instance.set_device( - "cuda" - ) # Attempt to set CUDA when it's not available - - -# Test for proper cleanup after model use (releasing resources) -@patch("swarms.models.huggingface.HuggingfaceLLM._model") -def test_llm_cleanup(mock_model, mock_tokenizer, llm_instance): - llm_instance.cleanup() - # Assuming cleanup method is meant to free resources - mock_model.delete.assert_called_once() - mock_tokenizer.delete.assert_called_once() - - -# Test model's ability to handle multilingual input -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_multilingual_input(mock_run, llm_instance): - mock_run.return_value = "mocked multilingual output" - multilingual_input = "Bonjour, ceci est un test multilingue." - result = llm_instance.run(multilingual_input) - assert isinstance( - result, str - ) # Simple check to ensure output is string type - - -# Test caching mechanism to prevent re-running the same inputs -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_caching_mechanism(mock_run, llm_instance): - input_text = "test caching mechanism" - mock_run.return_value = "cached output" - # Run the input twice - first_run_result = llm_instance.run(input_text) - second_run_result = llm_instance.run(input_text) - mock_run.assert_called_once() # Should only be called once due to caching - assert first_run_result == second_run_result - - -# These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class. -# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors that you need to replace with actual implementations. - - -# Mock some functions and objects for testing -@pytest.fixture -def mock_huggingface_llm(monkeypatch): - # Mock the model and tokenizer creation - def mock_init( - self, - model_id, - device="cpu", - max_length=500, - quantize=False, - quantization_config=None, - verbose=False, - distributed=False, - decoding=False, - max_workers=5, - repitition_penalty=1.3, - no_repeat_ngram_size=5, - temperature=0.7, - top_k=40, - top_p=0.8, - ): - pass - - # Mock the model loading - def mock_load_model(self): - pass - - # Mock the model generation - def mock_run(self, task): - pass - - monkeypatch.setattr(HuggingfaceLLM, "__init__", mock_init) - monkeypatch.setattr(HuggingfaceLLM, "load_model", mock_load_model) - monkeypatch.setattr(HuggingfaceLLM, "run", mock_run) - - -# Basic tests for initialization and attribute settings -def test_init_huggingface_llm(): - llm = HuggingfaceLLM( - model_id="test_model", - device="cuda", - max_length=1000, - quantize=True, - quantization_config={"config_key": "config_value"}, - verbose=True, - distributed=True, - decoding=True, - max_workers=3, - repitition_penalty=1.5, - no_repeat_ngram_size=4, - temperature=0.8, - top_k=50, - top_p=0.7, - ) - - assert llm.model_id == "test_model" - assert llm.device == "cuda" - assert llm.max_length == 1000 - assert llm.quantize is True - assert llm.quantization_config == {"config_key": "config_value"} - assert llm.verbose is True - assert llm.distributed is True - assert llm.decoding is True - assert llm.max_workers == 3 - assert llm.repitition_penalty == 1.5 - assert llm.no_repeat_ngram_size == 4 - assert llm.temperature == 0.8 - assert llm.top_k == 50 - assert llm.top_p == 0.7 - - -# Test loading the model -def test_load_model(mock_huggingface_llm): - llm = HuggingfaceLLM(model_id="test_model") - llm.load_model() - - -# Test running the model -def test_run(mock_huggingface_llm): - llm = HuggingfaceLLM(model_id="test_model") - llm.run("Test prompt") - - -# Test for setting max_length -def test_llm_set_max_length(llm_instance): - new_max_length = 1000 - llm_instance.set_max_length(new_max_length) - assert llm_instance.max_length == new_max_length - - -# Test for setting verbose -def test_llm_set_verbose(llm_instance): - llm_instance.set_verbose(True) - assert llm_instance.verbose is True - - -# Test for setting distributed -def test_llm_set_distributed(llm_instance): - llm_instance.set_distributed(True) - assert llm_instance.distributed is True - - -# Test for setting decoding -def test_llm_set_decoding(llm_instance): - llm_instance.set_decoding(True) - assert llm_instance.decoding is True - - -# Test for setting max_workers -def test_llm_set_max_workers(llm_instance): - new_max_workers = 10 - llm_instance.set_max_workers(new_max_workers) - assert llm_instance.max_workers == new_max_workers - - -# Test for setting repitition_penalty -def test_llm_set_repitition_penalty(llm_instance): - new_repitition_penalty = 1.5 - llm_instance.set_repitition_penalty(new_repitition_penalty) - assert llm_instance.repitition_penalty == new_repitition_penalty - - -# Test for setting no_repeat_ngram_size -def test_llm_set_no_repeat_ngram_size(llm_instance): - new_no_repeat_ngram_size = 6 - llm_instance.set_no_repeat_ngram_size(new_no_repeat_ngram_size) - assert ( - llm_instance.no_repeat_ngram_size == new_no_repeat_ngram_size - ) - - -# Test for setting temperature -def test_llm_set_temperature(llm_instance): - new_temperature = 0.8 - llm_instance.set_temperature(new_temperature) - assert llm_instance.temperature == new_temperature - - -# Test for setting top_k -def test_llm_set_top_k(llm_instance): - new_top_k = 50 - llm_instance.set_top_k(new_top_k) - assert llm_instance.top_k == new_top_k - - -# Test for setting top_p -def test_llm_set_top_p(llm_instance): - new_top_p = 0.9 - llm_instance.set_top_p(new_top_p) - assert llm_instance.top_p == new_top_p - - -# Test for setting quantize -def test_llm_set_quantize(llm_instance): - llm_instance.set_quantize(True) - assert llm_instance.quantize is True - - -# Test for setting quantization_config -def test_llm_set_quantization_config(llm_instance): - new_quantization_config = { - "load_in_4bit": False, - "bnb_4bit_use_double_quant": False, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_compute_dtype": torch.bfloat16, - } - llm_instance.set_quantization_config(new_quantization_config) - assert llm_instance.quantization_config == new_quantization_config - - -# Test for setting model_id -def test_llm_set_model_id(llm_instance): - new_model_id = "EleutherAI/gpt-neo-2.7B" - llm_instance.set_model_id(new_model_id) - assert llm_instance.model_id == new_model_id - - -# Test for setting model -@patch( - "swarms.models.huggingface.AutoModelForCausalLM.from_pretrained" -) -def test_llm_set_model(mock_model, llm_instance): - mock_model.return_value = "mocked model" - llm_instance.set_model(mock_model) - assert llm_instance.model == "mocked model" - - -# Test for setting tokenizer -@patch("swarms.models.huggingface.AutoTokenizer.from_pretrained") -def test_llm_set_tokenizer(mock_tokenizer, llm_instance): - mock_tokenizer.return_value = "mocked tokenizer" - llm_instance.set_tokenizer(mock_tokenizer) - assert llm_instance.tokenizer == "mocked tokenizer" - - -# Test for setting logger -def test_llm_set_logger(llm_instance): - new_logger = logging.getLogger("test_logger") - llm_instance.set_logger(new_logger) - assert llm_instance.logger == new_logger - - -# Test for saving model -@patch("torch.save") -def test_llm_save_model(mock_save, llm_instance): - llm_instance.save_model("path/to/save") - mock_save.assert_called_once() - - -# Test for print_dashboard -@patch("builtins.print") -def test_llm_print_dashboard(mock_print, llm_instance): - llm_instance.print_dashboard("test task") - mock_print.assert_called() - - -# Test for __call__ method -@patch("swarms.models.huggingface.HuggingfaceLLM.run") -def test_llm_call(mock_run, llm_instance): - mock_run.return_value = "mocked output" - result = llm_instance("test task") - assert result == "mocked output" diff --git a/tests/models/test_hf_pipeline.py b/tests/models/test_hf_pipeline.py deleted file mode 100644 index 98490623..00000000 --- a/tests/models/test_hf_pipeline.py +++ /dev/null @@ -1,56 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch - -from swarm_models.huggingface_pipeline import HuggingfacePipeline - - -@pytest.fixture -def mock_pipeline(): - with patch("swarms.models.huggingface_pipeline.pipeline") as mock: - yield mock - - -@pytest.fixture -def pipeline(mock_pipeline): - return HuggingfacePipeline( - "text-generation", "meta-llama/Llama-2-13b-chat-hf" - ) - - -def test_init(pipeline, mock_pipeline): - assert pipeline.task_type == "text-generation" - assert pipeline.model_name == "meta-llama/Llama-2-13b-chat-hf" - assert ( - pipeline.use_fp8 is True - if torch.cuda.is_available() - else False - ) - mock_pipeline.assert_called_once_with( - "text-generation", - "meta-llama/Llama-2-13b-chat-hf", - use_fp8=pipeline.use_fp8, - ) - - -def test_run(pipeline, mock_pipeline): - mock_pipeline.return_value = "Generated text" - result = pipeline.run("Hello, world!") - assert result == "Generated text" - mock_pipeline.assert_called_once_with("Hello, world!") - - -def test_run_with_exception(pipeline, mock_pipeline): - mock_pipeline.side_effect = Exception("Test exception") - with pytest.raises(Exception): - pipeline.run("Hello, world!") - - -def test_run_with_different_task(pipeline, mock_pipeline): - mock_pipeline.return_value = "Generated text" - result = pipeline.run("text-classification", "Hello, world!") - assert result == "Generated text" - mock_pipeline.assert_called_once_with( - "text-classification", "Hello, world!" - ) diff --git a/tests/models/test_idefics.py b/tests/models/test_idefics.py deleted file mode 100644 index f381d41b..00000000 --- a/tests/models/test_idefics.py +++ /dev/null @@ -1,207 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch - -from swarm_models.idefics import ( - AutoProcessor, - Idefics, - IdeficsForVisionText2Text, -) - - -@pytest.fixture -def idefics_instance(): - with patch( - "torch.cuda.is_available", return_value=False - ): # Assuming tests are run on CPU for simplicity - instance = Idefics() - return instance - - -# Basic Tests -def test_init_default(idefics_instance): - assert idefics_instance.device == "cpu" - assert idefics_instance.max_length == 100 - assert not idefics_instance.chat_history - - -@pytest.mark.parametrize( - "device,expected", - [ - (None, "cpu"), - ("cuda", "cuda"), - ("cpu", "cpu"), - ], -) -def test_init_device(device, expected): - with patch( - "torch.cuda.is_available", - return_value=True if expected == "cuda" else False, - ): - instance = Idefics(device=device) - assert instance.device == expected - - -# Test `run` method -def test_run(idefics_instance): - prompts = [["User: Test"]] - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } - mock_model.generate.return_value = torch.tensor([1, 2, 3]) - mock_processor.batch_decode.return_value = ["Test"] - - result = idefics_instance.run(prompts) - - assert result == ["Test"] - - -# Test `__call__` method (using the same logic as run for simplicity) -def test_call(idefics_instance): - prompts = [["User: Test"]] - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } - mock_model.generate.return_value = torch.tensor([1, 2, 3]) - mock_processor.batch_decode.return_value = ["Test"] - - result = idefics_instance(prompts) - - assert result == ["Test"] - - -# Test `chat` method -def test_chat(idefics_instance): - user_input = "User: Hello" - response = "Model: Hi there!" - with patch.object( - idefics_instance, "run", return_value=[response] - ): - result = idefics_instance.chat(user_input) - - assert result == response - assert idefics_instance.chat_history == [user_input, response] - - -# Test `set_checkpoint` method -def test_set_checkpoint(idefics_instance): - new_checkpoint = "new_checkpoint" - with patch.object( - IdeficsForVisionText2Text, "from_pretrained" - ) as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained" - ): - idefics_instance.set_checkpoint(new_checkpoint) - - mock_from_pretrained.assert_called_with( - new_checkpoint, torch_dtype=torch.bfloat16 - ) - - -# Test `set_device` method -def test_set_device(idefics_instance): - new_device = "cuda" - with patch.object(idefics_instance.model, "to"): - idefics_instance.set_device(new_device) - - assert idefics_instance.device == new_device - - -# Test `set_max_length` method -def test_set_max_length(idefics_instance): - new_length = 150 - idefics_instance.set_max_length(new_length) - assert idefics_instance.max_length == new_length - - -# Test `clear_chat_history` method -def test_clear_chat_history(idefics_instance): - idefics_instance.chat_history = ["User: Test", "Model: Response"] - idefics_instance.clear_chat_history() - assert not idefics_instance.chat_history - - -# Exception Tests -def test_run_with_empty_prompts(idefics_instance): - with pytest.raises( - Exception - ): # Replace Exception with the actual exception that may arise for an empty prompt. - idefics_instance.run([]) - - -# Test `run` method with batched_mode set to False -def test_run_batched_mode_false(idefics_instance): - task = "User: Test" - with patch.object( - idefics_instance, "processor" - ) as mock_processor, patch.object( - idefics_instance, "model" - ) as mock_model: - mock_processor.return_value = { - "input_ids": torch.tensor([1, 2, 3]) - } - mock_model.generate.return_value = torch.tensor([1, 2, 3]) - mock_processor.batch_decode.return_value = ["Test"] - - idefics_instance.batched_mode = False - result = idefics_instance.run(task) - - assert result == ["Test"] - - -# Test `run` method with an exception -def test_run_with_exception(idefics_instance): - task = "User: Test" - with patch.object( - idefics_instance, "processor" - ) as mock_processor: - mock_processor.side_effect = Exception("Test exception") - with pytest.raises(Exception): - idefics_instance.run(task) - - -# Test `set_model_name` method -def test_set_model_name(idefics_instance): - new_model_name = "new_model_name" - with patch.object( - IdeficsForVisionText2Text, "from_pretrained" - ) as mock_from_pretrained, patch.object( - AutoProcessor, "from_pretrained" - ): - idefics_instance.set_model_name(new_model_name) - - assert idefics_instance.model_name == new_model_name - mock_from_pretrained.assert_called_with( - new_model_name, torch_dtype=torch.bfloat16 - ) - - -# Test `__init__` method with device set to None -def test_init_device_none(): - with patch( - "torch.cuda.is_available", - return_value=False, - ): - instance = Idefics(device=None) - assert instance.device == "cpu" - - -# Test `__init__` method with device set to "cuda" -def test_init_device_cuda(): - with patch( - "torch.cuda.is_available", - return_value=True, - ): - instance = Idefics(device="cuda") - assert instance.device == "cuda" diff --git a/tests/models/test_imports.py b/tests/models/test_imports.py deleted file mode 100644 index bdca4350..00000000 --- a/tests/models/test_imports.py +++ /dev/null @@ -1,26 +0,0 @@ -from swarm_models import __all__ - -EXPECTED_ALL = [ - "Anthropic", - "Petals", - "Mistral", - "OpenAI", - "AzureOpenAI", - "OpenAIChat", - "Zephyr", - "Idefics", - # "Kosmos", - "Vilt", - "Nougat", - "LayoutLMDocumentQA", - "BioGPT", - "HuggingfaceLLM", - "MPT7B", - "WizardLLMStoryTeller", - # "GPT4Vision", - # "Dalle3", -] - - -def test_all_imports() -> None: - assert set(__all__) == set(EXPECTED_ALL) diff --git a/tests/models/test_kosmos.py b/tests/models/test_kosmos.py deleted file mode 100644 index ce7c36d6..00000000 --- a/tests/models/test_kosmos.py +++ /dev/null @@ -1,181 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest -import requests - -# This will be your project directory -from swarm_models.kosmos_two import Kosmos, is_overlapping - -# A placeholder image URL for testing -TEST_IMAGE_URL = "https://images.unsplash.com/photo-1673267569891-ca4246caafd7?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDM1fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" - - -# Mock the response for the test image -@pytest.fixture -def mock_image_request(): - img_data = open(TEST_IMAGE_URL, "rb").read() - mock_resp = Mock() - mock_resp.raw = img_data - with patch.object( - requests, "get", return_value=mock_resp - ) as _fixture: - yield _fixture - - -# Test utility function -def test_is_overlapping(): - assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) is True - assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) is False - assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) is False - assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) is True - - -# Test model initialization -def test_kosmos_init(): - kosmos = Kosmos() - assert kosmos.model is not None - assert kosmos.processor is not None - - -# Test image fetching functionality -def test_get_image(mock_image_request): - kosmos = Kosmos() - image = kosmos.get_image(TEST_IMAGE_URL) - assert image is not None - - -# Test multimodal grounding -def test_multimodal_grounding(mock_image_request): - kosmos = Kosmos() - kosmos.multimodal_grounding( - "Find the red apple in the image.", TEST_IMAGE_URL - ) - # TODO: Validate the result if possible - - -# Test referring expression comprehension -def test_referring_expression_comprehension(mock_image_request): - kosmos = Kosmos() - kosmos.referring_expression_comprehension( - "Show me the green bottle.", TEST_IMAGE_URL - ) - # TODO: Validate the result if possible - - -# ... (continue with other functions in the same manner) ... - - -# Test error scenarios - Example -@pytest.mark.parametrize( - "phrase, image_url", - [ - (None, TEST_IMAGE_URL), - ("Find the red apple in the image.", None), - ("", TEST_IMAGE_URL), - ("Find the red apple in the image.", ""), - ], -) -def test_kosmos_error_scenarios(phrase, image_url): - kosmos = Kosmos() - with pytest.raises(Exception): - kosmos.multimodal_grounding(phrase, image_url) - - -# ... (Add more tests for different edge cases and functionalities) ... - -# Sample test image URLs -IMG_URL1 = "https://images.unsplash.com/photo-1696341439368-2c84b6c963bc?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDMzfEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" -IMG_URL2 = "https://images.unsplash.com/photo-1689934902235-055707b4f8e9?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDYzfEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" -IMG_URL3 = "https://images.unsplash.com/photo-1696900004042-60bcc200aca0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDY2fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" -IMG_URL4 = "https://images.unsplash.com/photo-1676156340083-fd49e4e53a21?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDc4fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" -IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=format&fit=crop&q=80&w=1287&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" - - -# Mock response for requests.get() -class MockResponse: - @staticmethod - def json(): - return {} - - @property - def raw(self): - return open("tests/sample_image.jpg", "rb") - - -# Test the Kosmos class -@pytest.fixture -def kosmos(): - return Kosmos() - - -# Mocking the requests.get() method -@pytest.fixture -def mock_request_get(monkeypatch): - monkeypatch.setattr( - requests, "get", lambda url, **kwargs: MockResponse() - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_multimodal_grounding(kosmos): - kosmos.multimodal_grounding( - "Find the red apple in the image.", IMG_URL1 - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_referring_expression_comprehension(kosmos): - kosmos.referring_expression_comprehension( - "Show me the green bottle.", IMG_URL2 - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_referring_expression_generation(kosmos): - kosmos.referring_expression_generation( - "It is on the table.", IMG_URL3 - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_vqa(kosmos): - kosmos.grounded_vqa("What is the color of the car?", IMG_URL4) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_image_captioning(kosmos): - kosmos.grounded_image_captioning(IMG_URL5) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_image_captioning_detailed(kosmos): - kosmos.grounded_image_captioning_detailed(IMG_URL1) - - -@pytest.mark.usefixtures("mock_request_get") -def test_multimodal_grounding_2(kosmos): - kosmos.multimodal_grounding( - "Find the yellow fruit in the image.", IMG_URL2 - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_referring_expression_comprehension_2(kosmos): - kosmos.referring_expression_comprehension( - "Where is the water bottle?", IMG_URL3 - ) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_vqa_2(kosmos): - kosmos.grounded_vqa("How many cars are in the image?", IMG_URL4) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_image_captioning_2(kosmos): - kosmos.grounded_image_captioning(IMG_URL2) - - -@pytest.mark.usefixtures("mock_request_get") -def test_grounded_image_captioning_detailed_2(kosmos): - kosmos.grounded_image_captioning_detailed(IMG_URL3) diff --git a/tests/models/test_nougat.py b/tests/models/test_nougat.py deleted file mode 100644 index 2c7f6361..00000000 --- a/tests/models/test_nougat.py +++ /dev/null @@ -1,221 +0,0 @@ -import os -from unittest.mock import MagicMock, Mock, patch - -import pytest -import torch -from PIL import Image -from transformers import NougatProcessor, VisionEncoderDecoderModel - -from swarm_models.nougat import Nougat - - -@pytest.fixture -def setup_nougat(): - return Nougat() - - -def test_nougat_default_initialization(setup_nougat): - assert setup_nougat.model_name_or_path == "facebook/nougat-base" - assert setup_nougat.min_length == 1 - assert setup_nougat.max_new_tokens == 30 - - -def test_nougat_custom_initialization(): - nougat = Nougat( - model_name_or_path="custom_path", - min_length=10, - max_new_tokens=50, - ) - assert nougat.model_name_or_path == "custom_path" - assert nougat.min_length == 10 - assert nougat.max_new_tokens == 50 - - -def test_processor_initialization(setup_nougat): - assert isinstance(setup_nougat.processor, NougatProcessor) - - -def test_model_initialization(setup_nougat): - assert isinstance(setup_nougat.model, VisionEncoderDecoderModel) - - -@pytest.mark.parametrize( - "cuda_available, expected_device", - [(True, "cuda"), (False, "cpu")], -) -def test_device_initialization( - cuda_available, expected_device, monkeypatch -): - monkeypatch.setattr( - torch, - "cuda", - Mock(is_available=Mock(return_value=cuda_available)), - ) - nougat = Nougat() - assert nougat.device == expected_device - - -def test_get_image_valid_path(setup_nougat): - with patch("PIL.Image.open") as mock_open: - mock_open.return_value = Mock(spec=Image.Image) - assert setup_nougat.get_image("valid_path") is not None - - -def test_get_image_invalid_path(setup_nougat): - with pytest.raises(FileNotFoundError): - setup_nougat.get_image("invalid_path") - - -@pytest.mark.parametrize( - "min_len, max_tokens", - [ - (1, 30), - (5, 40), - (10, 50), - ], -) -def test_model_call_with_diff_params( - setup_nougat, min_len, max_tokens -): - setup_nougat.min_length = min_len - setup_nougat.max_new_tokens = max_tokens - - with patch("PIL.Image.open") as mock_open: - mock_open.return_value = Mock(spec=Image.Image) - # Here, mocking other required methods or adding more complex logic would be necessary. - result = setup_nougat("valid_path") - assert isinstance(result, str) - - -def test_model_call_invalid_image_path(setup_nougat): - with pytest.raises(FileNotFoundError): - setup_nougat("invalid_path") - - -def test_model_call_mocked_output(setup_nougat): - with patch("PIL.Image.open") as mock_open: - mock_open.return_value = Mock(spec=Image.Image) - mock_model = MagicMock() - mock_model.generate.return_value = "mocked_output" - setup_nougat.model = mock_model - - result = setup_nougat("valid_path") - assert result == "mocked_output" - - -@pytest.fixture -def mock_processor_and_model(): - """Mock the NougatProcessor and VisionEncoderDecoderModel to simulate their behavior.""" - with patch( - "transformers.NougatProcessor.from_pretrained", - return_value=Mock(), - ), patch( - "transformers.VisionEncoderDecoderModel.from_pretrained", - return_value=Mock(), - ): - yield - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_with_sample_image_1(setup_nougat): - result = setup_nougat( - os.path.join( - "sample_images", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_with_sample_image_2(setup_nougat): - result = setup_nougat(os.path.join("sample_images", "test2.png")) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_min_length_param(setup_nougat): - setup_nougat.min_length = 10 - result = setup_nougat( - os.path.join( - "sample_images", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_max_new_tokens_param(setup_nougat): - setup_nougat.max_new_tokens = 50 - result = setup_nougat( - os.path.join( - "sample_images", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_different_model_path(setup_nougat): - setup_nougat.model_name_or_path = "different/path" - result = setup_nougat( - os.path.join( - "sample_images", - "https://plus.unsplash.com/premium_photo-1687149699194-0207c04bc6e8?auto=format&fit=crop&q=80&w=1378&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_bad_image_path(setup_nougat): - with pytest.raises( - Exception - ): # Adjust the exception type accordingly. - setup_nougat("bad_image_path.png") - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_image_large_size(setup_nougat): - result = setup_nougat( - os.path.join( - "sample_images", - "https://images.unsplash.com/photo-1697641039266-bfa00367f7cb?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDJ8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_image_small_size(setup_nougat): - result = setup_nougat( - os.path.join( - "sample_images", - "https://images.unsplash.com/photo-1697638626987-aa865b769276?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDd8SnBnNktpZGwtSGt8fGVufDB8fHx8fA%3D%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_image_varied_content(setup_nougat): - result = setup_nougat( - os.path.join( - "sample_images", - "https://images.unsplash.com/photo-1697469994783-b12bbd9c4cff?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE0fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - ) - ) - assert isinstance(result, str) - - -@pytest.mark.usefixtures("mock_processor_and_model") -def test_nougat_image_with_metadata(setup_nougat): - result = setup_nougat( - os.path.join( - "sample_images", - "https://images.unsplash.com/photo-1697273300766-5bbaa53ec2f0?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDE5fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D", - ) - ) - assert isinstance(result, str) diff --git a/tests/models/test_open_dalle.py b/tests/models/test_open_dalle.py deleted file mode 100644 index 4dfd200c..00000000 --- a/tests/models/test_open_dalle.py +++ /dev/null @@ -1,60 +0,0 @@ -import pytest -import torch - -from swarm_models.open_dalle import OpenDalle - - -def test_init(): - od = OpenDalle() - assert isinstance(od, OpenDalle) - - -def test_init_custom_model(): - od = OpenDalle(model_name="custom_model") - assert od.pipeline.model_name == "custom_model" - - -def test_init_custom_dtype(): - od = OpenDalle(torch_dtype=torch.float32) - assert od.pipeline.torch_dtype == torch.float32 - - -def test_init_custom_device(): - od = OpenDalle(device="cpu") - assert od.pipeline.device == "cpu" - - -def test_run(): - od = OpenDalle() - result = od.run("A picture of a cat") - assert isinstance(result, torch.Tensor) - - -def test_run_no_task(): - od = OpenDalle() - with pytest.raises(ValueError, match="Task cannot be None"): - od.run(None) - - -def test_run_non_string_task(): - od = OpenDalle() - with pytest.raises(TypeError, match="Task must be a string"): - od.run(123) - - -def test_run_empty_task(): - od = OpenDalle() - with pytest.raises(ValueError, match="Task cannot be empty"): - od.run("") - - -def test_run_custom_args(): - od = OpenDalle() - result = od.run("A picture of a cat", custom_arg="custom_value") - assert isinstance(result, torch.Tensor) - - -def test_run_error(): - od = OpenDalle() - with pytest.raises(Exception): - od.run("A picture of a cat", raise_error=True) diff --git a/tests/models/test_openaitts.py b/tests/models/test_openaitts.py deleted file mode 100644 index 03e1e9c4..00000000 --- a/tests/models/test_openaitts.py +++ /dev/null @@ -1,107 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from swarm_models.openai_tts import OpenAITTS - - -def test_openaitts_initialization(): - tts = OpenAITTS() - assert isinstance(tts, OpenAITTS) - assert tts.model_name == "tts-1-1106" - assert tts.proxy_url == "https://api.openai.com/v1/audio/speech" - assert tts.voice == "onyx" - assert tts.chunk_size == 1024 * 1024 - - -def test_openaitts_initialization_custom_parameters(): - tts = OpenAITTS( - "custom_model", - "custom_url", - "custom_key", - "custom_voice", - 2048, - ) - assert tts.model_name == "custom_model" - assert tts.proxy_url == "custom_url" - assert tts.openai_api_key == "custom_key" - assert tts.voice == "custom_voice" - assert tts.chunk_size == 2048 - - -@patch("requests.post") -def test_run(mock_post): - mock_response = MagicMock() - mock_response.iter_content.return_value = [b"chunk1", b"chunk2"] - mock_post.return_value = mock_response - tts = OpenAITTS() - audio = tts.run("Hello world") - assert audio == b"chunk1chunk2" - mock_post.assert_called_once_with( - "https://api.openai.com/v1/audio/speech", - headers={"Authorization": f"Bearer {tts.openai_api_key}"}, - json={ - "model": "tts-1-1106", - "input": "Hello world", - "voice": "onyx", - }, - ) - - -@patch("requests.post") -def test_run_empty_task(mock_post): - tts = OpenAITTS() - with pytest.raises(Exception): - tts.run("") - - -@patch("requests.post") -def test_run_very_long_task(mock_post): - tts = OpenAITTS() - with pytest.raises(Exception): - tts.run("A" * 10000) - - -@patch("requests.post") -def test_run_invalid_task(mock_post): - tts = OpenAITTS() - with pytest.raises(Exception): - tts.run(None) - - -@patch("requests.post") -def test_run_custom_model(mock_post): - mock_response = MagicMock() - mock_response.iter_content.return_value = [b"chunk1", b"chunk2"] - mock_post.return_value = mock_response - tts = OpenAITTS("custom_model") - audio = tts.run("Hello world") - assert audio == b"chunk1chunk2" - mock_post.assert_called_once_with( - "https://api.openai.com/v1/audio/speech", - headers={"Authorization": f"Bearer {tts.openai_api_key}"}, - json={ - "model": "custom_model", - "input": "Hello world", - "voice": "onyx", - }, - ) - - -@patch("requests.post") -def test_run_custom_voice(mock_post): - mock_response = MagicMock() - mock_response.iter_content.return_value = [b"chunk1", b"chunk2"] - mock_post.return_value = mock_response - tts = OpenAITTS(voice="custom_voice") - audio = tts.run("Hello world") - assert audio == b"chunk1chunk2" - mock_post.assert_called_once_with( - "https://api.openai.com/v1/audio/speech", - headers={"Authorization": f"Bearer {tts.openai_api_key}"}, - json={ - "model": "tts-1-1106", - "input": "Hello world", - "voice": "custom_voice", - }, - ) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py deleted file mode 100644 index 3e5c937e..00000000 --- a/tests/models/test_qwen.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import Mock, patch - -from swarm_models.qwen import QwenVLMultiModal - - -def test_post_init(): - with patch( - "swarms.models.qwen.AutoTokenizer.from_pretrained" - ) as mock_tokenizer, patch( - "swarms.models.qwen.AutoModelForCausalLM.from_pretrained" - ) as mock_model: - mock_tokenizer.return_value = Mock() - mock_model.return_value = Mock() - - model = QwenVLMultiModal() - mock_tokenizer.assert_called_once_with( - model.model_name, trust_remote_code=True - ) - mock_model.assert_called_once_with( - model.model_name, - device_map=model.device, - trust_remote_code=True, - ) - - -def test_run(): - with patch( - "swarms.models.qwen.AutoTokenizer.from_list_format" - ) as mock_format, patch( - "swarms.models.qwen.AutoTokenizer.__call__" - ) as mock_call, patch( - "swarms.models.qwen.AutoModelForCausalLM.generate" - ) as mock_generate, patch( - "swarms.models.qwen.AutoTokenizer.decode" - ) as mock_decode: - mock_format.return_value = Mock() - mock_call.return_value = Mock() - mock_generate.return_value = Mock() - mock_decode.return_value = "response" - - model = QwenVLMultiModal() - response = model.run( - "Hello, how are you?", "https://example.com/image.jpg" - ) - - assert response == "response" - - -def test_chat(): - with patch( - "swarms.models.qwen.AutoModelForCausalLM.chat" - ) as mock_chat: - mock_chat.return_value = ("response", ["history"]) - - model = QwenVLMultiModal() - response, history = model.chat( - "Hello, how are you?", "https://example.com/image.jpg" - ) - - assert response == "response" - assert history == ["history"] diff --git a/tests/models/test_ssd_1b.py b/tests/models/test_ssd_1b.py deleted file mode 100644 index 86a7e94a..00000000 --- a/tests/models/test_ssd_1b.py +++ /dev/null @@ -1,165 +0,0 @@ -import pytest -from PIL import Image - -from swarm_models.ssd_1b import SSD1B - - -# Create fixtures if needed -@pytest.fixture -def ssd1b_model(): - return SSD1B() - - -# Basic tests for model initialization and method call -def test_ssd1b_model_initialization(ssd1b_model): - assert ssd1b_model is not None - - -def test_ssd1b_call(ssd1b_model): - task = "A painting of a dog" - neg_prompt = "ugly, blurry, poor quality" - image_url = ssd1b_model(task, neg_prompt) - assert isinstance(image_url, str) - assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" - - -# Add more tests for various aspects of the class and methods - - -# Example of a parameterized test for different tasks -@pytest.mark.parametrize( - "task", ["A painting of a cat", "A painting of a tree"] -) -def test_ssd1b_parameterized_task(ssd1b_model, task): - image_url = ssd1b_model(task) - assert isinstance(image_url, str) - assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" - - -# Example of a test using mocks to isolate units of code -def test_ssd1b_with_mock(ssd1b_model, mocker): - mocker.patch( - "your_module.StableDiffusionXLPipeline" - ) # Mock the pipeline - task = "A painting of a cat" - image_url = ssd1b_model(task) - assert isinstance(image_url, str) - assert image_url.startswith( - "https://" - ) # Assuming it starts with "https://" - - -def test_ssd1b_call_with_cache(ssd1b_model): - task = "A painting of a dog" - neg_prompt = "ugly, blurry, poor quality" - image_url1 = ssd1b_model(task, neg_prompt) - image_url2 = ssd1b_model(task, neg_prompt) # Should use cache - assert image_url1 == image_url2 - - -def test_ssd1b_invalid_task(ssd1b_model): - invalid_task = "" - with pytest.raises(ValueError): - ssd1b_model(invalid_task) - - -def test_ssd1b_failed_api_call(ssd1b_model, mocker): - mocker.patch( - "your_module.StableDiffusionXLPipeline" - ) # Mock the pipeline to raise an exception - task = "A painting of a cat" - with pytest.raises(Exception): - ssd1b_model(task) - - -def test_ssd1b_process_batch_concurrently(ssd1b_model): - tasks = [ - "A painting of a dog", - "A beautiful sunset", - "A portrait of a person", - ] - results = ssd1b_model.process_batch_concurrently(tasks) - assert isinstance(results, list) - assert len(results) == len(tasks) - - -def test_ssd1b_process_empty_batch_concurrently(ssd1b_model): - tasks = [] - results = ssd1b_model.process_batch_concurrently(tasks) - assert isinstance(results, list) - assert len(results) == 0 - - -def test_ssd1b_download_image(ssd1b_model): - task = "A painting of a dog" - neg_prompt = "ugly, blurry, poor quality" - image_url = ssd1b_model(task, neg_prompt) - img = ssd1b_model._download_image(image_url) - assert isinstance(img, Image.Image) - - -def test_ssd1b_generate_uuid(ssd1b_model): - uuid_str = ssd1b_model._generate_uuid() - assert isinstance(uuid_str, str) - assert len(uuid_str) == 36 # UUID format - - -def test_ssd1b_rate_limited_call(ssd1b_model): - task = "A painting of a dog" - image_url = ssd1b_model.rate_limited_call(task) - assert isinstance(image_url, str) - assert image_url.startswith("https://") - - -# Test cases for additional scenarios and behaviors -def test_ssd1b_dashboard_printing(ssd1b_model, capsys): - ssd1b_model.dashboard = True - ssd1b_model.print_dashboard() - captured = capsys.readouterr() - assert "SSD1B Dashboard:" in captured.out - - -def test_ssd1b_generate_image_name(ssd1b_model): - task = "A painting of a dog" - img_name = ssd1b_model._generate_image_name(task) - assert isinstance(img_name, str) - assert len(img_name) > 0 - - -def test_ssd1b_set_width_height(ssd1b_model, mocker): - img = mocker.MagicMock() - width, height = 800, 600 - result = ssd1b_model.set_width_height(img, width, height) - assert result == img.resize.return_value - - -def test_ssd1b_read_img(ssd1b_model, mocker): - img = mocker.MagicMock() - result = ssd1b_model.read_img(img) - assert result == img.open.return_value - - -def test_ssd1b_convert_to_bytesio(ssd1b_model, mocker): - img = mocker.MagicMock() - img_format = "PNG" - result = ssd1b_model.convert_to_bytesio(img, img_format) - assert isinstance(result, bytes) - - -def test_ssd1b_save_image(ssd1b_model, mocker, tmp_path): - img = mocker.MagicMock() - img_name = "test.png" - save_path = tmp_path / img_name - ssd1b_model._download_image(img, img_name, save_path) - assert save_path.exists() - - -def test_ssd1b_repr_str(ssd1b_model): - task = "A painting of a dog" - image_url = ssd1b_model(task) - assert repr(ssd1b_model) == f"SSD1B(image_url={image_url})" - assert str(ssd1b_model) == f"SSD1B(image_url={image_url})" diff --git a/tests/models/test_timm_model.py b/tests/models/test_timm_model.py deleted file mode 100644 index 5fdaac5a..00000000 --- a/tests/models/test_timm_model.py +++ /dev/null @@ -1,90 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest -import torch - -from swarm_models import TimmModel - - -def test_timm_model_init(): - with patch("swarms.models.timm.list_models") as mock_list_models: - model_name = "resnet18" - pretrained = True - in_chans = 3 - timm_model = TimmModel(model_name, pretrained, in_chans) - mock_list_models.assert_called_once() - assert timm_model.model_name == model_name - assert timm_model.pretrained == pretrained - assert timm_model.in_chans == in_chans - assert timm_model.models == mock_list_models.return_value - - -def test_timm_model_call(): - with patch( - "swarms.models.timm.create_model" - ) as mock_create_model: - model_name = "resnet18" - pretrained = True - in_chans = 3 - timm_model = TimmModel(model_name, pretrained, in_chans) - task = torch.rand(1, in_chans, 224, 224) - result = timm_model(task) - mock_create_model.assert_called_once_with( - model_name, pretrained=pretrained, in_chans=in_chans - ) - assert result == mock_create_model.return_value(task) - - -def test_timm_model_list_models(): - with patch("swarms.models.timm.list_models") as mock_list_models: - model_name = "resnet18" - pretrained = True - in_chans = 3 - timm_model = TimmModel(model_name, pretrained, in_chans) - result = timm_model.list_models() - mock_list_models.assert_called_once() - assert result == mock_list_models.return_value - - -def test_get_supported_models(): - model_handler = TimmModel() - supported_models = model_handler._get_supported_models() - assert isinstance(supported_models, list) - assert len(supported_models) > 0 - - -def test_create_model(sample_model_info): - model_handler = TimmModel() - model = model_handler._create_model(sample_model_info) - assert isinstance(model, torch.nn.Module) - - -def test_call(sample_model_info): - model_handler = TimmModel() - input_tensor = torch.randn(1, 3, 224, 224) - output_shape = model_handler.__call__( - sample_model_info, input_tensor - ) - assert isinstance(output_shape, torch.Size) - - -def test_get_supported_models_mock(): - model_handler = TimmModel() - model_handler._get_supported_models = Mock( - return_value=["resnet18", "resnet50"] - ) - supported_models = model_handler._get_supported_models() - assert supported_models == ["resnet18", "resnet50"] - - -def test_create_model_mock(sample_model_info): - model_handler = TimmModel() - model_handler._create_model = Mock(return_value=torch.nn.Module()) - model = model_handler._create_model(sample_model_info) - assert isinstance(model, torch.nn.Module) - - -def test_coverage_report(): - # Install pytest-cov - # Run tests with coverage report - pytest.main(["--cov=my_module", "--cov-report=html"]) diff --git a/tests/models/test_togther.py b/tests/models/test_togther.py deleted file mode 100644 index c7a0421c..00000000 --- a/tests/models/test_togther.py +++ /dev/null @@ -1,145 +0,0 @@ -import logging -from unittest.mock import Mock, patch - -import pytest -import requests - -from swarm_models.together import TogetherLLM - - -@pytest.fixture -def mock_api_key(monkeypatch): - monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") - - -def test_init_defaults(): - model = TogetherLLM() - assert model.together_api_key == "mocked-api-key" - assert model.logging_enabled is False - assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1" - assert model.max_workers == 10 - assert model.max_tokens == 300 - assert model.api_endpoint == "https://api.together.xyz" - assert model.beautify is False - assert model.streaming_enabled is False - assert model.meta_prompt is False - assert model.system_prompt is None - - -def test_init_custom_params(mock_api_key): - model = TogetherLLM( - together_api_key="custom-api-key", - logging_enabled=True, - model_name="custom-model", - max_workers=5, - max_tokens=500, - api_endpoint="https://custom-api.together.xyz", - beautify=True, - streaming_enabled=True, - meta_prompt="meta-prompt", - system_prompt="system-prompt", - ) - assert model.together_api_key == "custom-api-key" - assert model.logging_enabled is True - assert model.model_name == "custom-model" - assert model.max_workers == 5 - assert model.max_tokens == 500 - assert model.api_endpoint == "https://custom-api.together.xyz" - assert model.beautify is True - assert model.streaming_enabled is True - assert model.meta_prompt == "meta-prompt" - assert model.system_prompt == "system-prompt" - - -@patch("swarms.models.together_model.requests.post") -def test_run_success(mock_post, mock_api_key): - mock_response = Mock() - mock_response.json.return_value = { - "choices": [{"message": {"content": "Generated response"}}] - } - mock_post.return_value = mock_response - - model = TogetherLLM() - task = "What is the color of the object?" - response = model.run(task) - - assert response == "Generated response" - - -@patch("swarms.models.together_model.requests.post") -def test_run_failure(mock_post, mock_api_key): - mock_post.side_effect = requests.exceptions.RequestException( - "Request failed" - ) - - model = TogetherLLM() - task = "What is the color of the object?" - response = model.run(task) - - assert response is None - - -def test_run_with_logging_enabled(caplog, mock_api_key): - model = TogetherLLM(logging_enabled=True) - task = "What is the color of the object?" - - with caplog.at_level(logging.DEBUG): - model.run(task) - - assert "Sending request to" in caplog.text - - -@pytest.mark.parametrize( - "invalid_input", [None, 123, ["list", "of", "items"]] -) -def test_invalid_task_input(invalid_input, mock_api_key): - model = TogetherLLM() - response = model.run(invalid_input) - - assert response is None - - -@patch("swarms.models.together_model.requests.post") -def test_run_streaming_enabled(mock_post, mock_api_key): - mock_response = Mock() - mock_response.json.return_value = { - "choices": [{"message": {"content": "Generated response"}}] - } - mock_post.return_value = mock_response - - model = TogetherLLM(streaming_enabled=True) - task = "What is the color of the object?" - response = model.run(task) - - assert response == "Generated response" - - -@patch("swarms.models.together_model.requests.post") -def test_run_empty_choices(mock_post, mock_api_key): - mock_response = Mock() - mock_response.json.return_value = {"choices": []} - mock_post.return_value = mock_response - - model = TogetherLLM() - task = "What is the color of the object?" - response = model.run(task) - - assert response is None - - -@patch("swarms.models.together_model.requests.post") -def test_run_with_exception(mock_post, mock_api_key): - mock_post.side_effect = Exception("Test exception") - - model = TogetherLLM() - task = "What is the color of the object?" - response = model.run(task) - - assert response is None - - -def test_init_logging_disabled(monkeypatch): - monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") - model = TogetherLLM() - assert model.logging_enabled is False - assert not model.system_prompt diff --git a/tests/models/test_vilt.py b/tests/models/test_vilt.py deleted file mode 100644 index 8e222637..00000000 --- a/tests/models/test_vilt.py +++ /dev/null @@ -1,110 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from swarm_models.vilt import Image, Vilt, requests - - -# Fixture for Vilt instance -@pytest.fixture -def vilt_instance(): - return Vilt() - - -# 1. Test Initialization -def test_vilt_initialization(vilt_instance): - assert isinstance(vilt_instance, Vilt) - assert vilt_instance.processor is not None - assert vilt_instance.model is not None - - -# 2. Test Model Predictions -@patch.object(requests, "get") -@patch.object(Image, "open") -def test_vilt_prediction( - mock_image_open, mock_requests_get, vilt_instance -): - mock_image = Mock() - mock_image_open.return_value = mock_image - mock_requests_get.return_value.raw = Mock() - - # It's a mock response, so no real answer expected - with pytest.raises( - Exception - ): # Ensure exception is more specific - vilt_instance( - "What is this image", - "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", - ) - - -# 3. Test Exception Handling for network -@patch.object( - requests, - "get", - side_effect=requests.RequestException("Network error"), -) -def test_vilt_network_exception(vilt_instance): - with pytest.raises(requests.RequestException): - vilt_instance( - "What is this image", - "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", - ) - - -# Parameterized test cases for different inputs -@pytest.mark.parametrize( - "text,image_url", - [ - ("What is this?", "http://example.com/image1.jpg"), - ("Who is in the image?", "http://example.com/image2.jpg"), - ( - "Where was this picture taken?", - "http://example.com/image3.jpg", - ), - # ... Add more scenarios - ], -) -def test_vilt_various_inputs(text, image_url, vilt_instance): - with pytest.raises( - Exception - ): # Again, ensure exception is more specific - vilt_instance(text, image_url) - - -# Test with invalid or empty text -@pytest.mark.parametrize( - "text,image_url", - [ - ( - "", - "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", - ), - ( - None, - "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", - ), - ( - " ", - "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80", - ), - # ... Add more scenarios - ], -) -def test_vilt_invalid_text(text, image_url, vilt_instance): - with pytest.raises(ValueError): - vilt_instance(text, image_url) - - -# Test with invalid or empty image_url -@pytest.mark.parametrize( - "text,image_url", - [ - ("What is this?", ""), - ("Who is in the image?", None), - ("Where was this picture taken?", " "), - ], -) -def test_vilt_invalid_image_url(text, image_url, vilt_instance): - with pytest.raises(ValueError): - vilt_instance(text, image_url) diff --git a/tests/models/test_zeroscope.py b/tests/models/test_zeroscope.py deleted file mode 100644 index c8809cd1..00000000 --- a/tests/models/test_zeroscope.py +++ /dev/null @@ -1,122 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from swarm_models.zeroscope import ZeroscopeTTV - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_init(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline.from_pretrained.assert_called_once() - mock_scheduler.assert_called_once() - assert zeroscope.model_name == "cerspense/zeroscope_v2_576w" - assert zeroscope.chunk_size == 1 - assert zeroscope.dim == 1 - assert zeroscope.num_inference_steps == 40 - assert zeroscope.height == 320 - assert zeroscope.width == 576 - assert zeroscope.num_frames == 36 - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_forward(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) - mock_pipeline_instance.enable_vae_slicing.assert_called_once() - mock_pipeline_instance.enable_forward_chunking.assert_called_once_with( - chunk_size=1, dim=1 - ) - result = zeroscope.forward("Test task") - assert result == "Generated frames" - mock_pipeline_instance.assert_called_once_with( - "Test task", - num_inference_steps=40, - height=320, - width=576, - num_frames=36, - ) - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_forward_error(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) - mock_pipeline_instance.side_effect = Exception("Test error") - with pytest.raises(Exception, match="Test error"): - zeroscope.forward("Test task") - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_call(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) - result = zeroscope.__call__("Test task") - assert result == "Generated frames" - mock_pipeline_instance.assert_called_once_with( - "Test task", - num_inference_steps=40, - height=320, - width=576, - num_frames=36, - ) - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_call_error(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) - mock_pipeline_instance.side_effect = Exception("Test error") - with pytest.raises(Exception, match="Test error"): - zeroscope.__call__("Test task") - - -@patch("swarms.models.zeroscope.DiffusionPipeline") -@patch("swarms.models.zeroscope.DPMSolverMultistepScheduler") -def test_zeroscope_ttv_save_video_path(mock_scheduler, mock_pipeline): - zeroscope = ZeroscopeTTV() - mock_pipeline_instance = MagicMock() - mock_pipeline.from_pretrained.return_value = ( - mock_pipeline_instance - ) - mock_pipeline_instance.return_value = MagicMock( - frames="Generated frames" - ) - result = zeroscope.save_video_path("Test video path") - assert result == "Test video path" - mock_pipeline_instance.assert_called_once_with( - "Test video path", - num_inference_steps=40, - height=320, - width=576, - num_frames=36, - )