From 9402dab489aa598d62f2593709457910e0ae4ef7 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 11 Nov 2023 10:33:00 -0500 Subject: [PATCH] tests for groupchats, anthropic --- tests/apps/discord.py | 2 +- tests/models/anthropic.py | 98 +++++++++- tests/models/auto_temp.py | 2 +- tests/models/bingchat.py | 2 +- tests/models/huggingface.py | 2 +- tests/models/kosmos.py | 8 +- tests/models/revgptv1.py | 2 +- tests/models/whisperx.py | 2 +- tests/structs/sequential_workflow.py | 10 +- tests/swarms/groupchat.py | 214 +++++++++++++++++++++ tests/utils/subprocess_code_interpreter.py | 10 +- 11 files changed, 332 insertions(+), 20 deletions(-) diff --git a/tests/apps/discord.py b/tests/apps/discord.py index bc8daa80..60198e40 100644 --- a/tests/apps/discord.py +++ b/tests/apps/discord.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock from apps.discord import ( Bot, ) # Replace 'Bot' with the name of the file containing your bot's code. diff --git a/tests/models/anthropic.py b/tests/models/anthropic.py index 4dbd365d..feb703a6 100644 --- a/tests/models/anthropic.py +++ b/tests/models/anthropic.py @@ -1,9 +1,20 @@ import os -import pytest from unittest.mock import Mock, patch + +import pytest + from swarms.models.anthropic import Anthropic +# Mock the Anthropic API client for testing +class MockAnthropicClient: + def __init__(self, *args, **kwargs): + pass + + def completions_create(self, prompt, stop_sequences, stream, **kwargs): + return MockAnthropicResponse() + + @pytest.fixture def mock_anthropic_env(): os.environ["ANTHROPIC_API_URL"] = "https://test.anthropic.com" @@ -125,3 +136,88 @@ def test_anthropic_exception_handling( anthropic_instance(task, stop) assert "An error occurred" in str(excinfo.value) + + +class MockAnthropicResponse: + def __init__(self): + self.completion = "Mocked Response from Anthropic" + +def test_anthropic_instance_creation(anthropic_instance): + assert isinstance(anthropic_instance, Anthropic) + +def test_anthropic_call_method(anthropic_instance): + response = anthropic_instance("What is the meaning of life?") + assert response == "Mocked Response from Anthropic" + +def test_anthropic_stream_method(anthropic_instance): + generator = anthropic_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + +def test_anthropic_async_call_method(anthropic_instance): + response = anthropic_instance.async_call("Tell me a joke.") + assert response == "Mocked Response from Anthropic" + +def test_anthropic_async_stream_method(anthropic_instance): + async_generator = anthropic_instance.async_stream("Translate to French.") + for token in async_generator: + assert isinstance(token, str) + +def test_anthropic_get_num_tokens(anthropic_instance): + text = "This is a test sentence." + num_tokens = anthropic_instance.get_num_tokens(text) + assert num_tokens > 0 + +# Add more test cases to cover other functionalities and edge cases of the Anthropic class + + +def test_anthropic_wrap_prompt(anthropic_instance): + prompt = "What is the meaning of life?" + wrapped_prompt = anthropic_instance._wrap_prompt(prompt) + assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT) + assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT) + +def test_anthropic_convert_prompt(anthropic_instance): + prompt = "What is the meaning of life?" + converted_prompt = anthropic_instance.convert_prompt(prompt) + assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) + assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) + +def test_anthropic_call_with_stop(anthropic_instance): + response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"]) + assert response == "Mocked Response from Anthropic" + +def test_anthropic_stream_with_stop(anthropic_instance): + generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"]) + for token in generator: + assert isinstance(token, str) + +def test_anthropic_async_call_with_stop(anthropic_instance): + response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + assert response == "Mocked Response from Anthropic" + +def test_anthropic_async_stream_with_stop(anthropic_instance): + async_generator = anthropic_instance.async_stream("Translate to French.", stop=["stop1", "stop2"]) + for token in async_generator: + assert isinstance(token, str) + +def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance): + anthropic_instance.count_tokens = Mock(return_value=10) + text = "This is a test sentence." + num_tokens = anthropic_instance.get_num_tokens(text) + assert num_tokens == 10 + +def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance): + del anthropic_instance.count_tokens + with pytest.raises(NameError): + text = "This is a test sentence." + anthropic_instance.get_num_tokens(text) + +def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance): + del anthropic_instance.HUMAN_PROMPT + del anthropic_instance.AI_PROMPT + prompt = "What is the meaning of life?" + with pytest.raises(NameError): + anthropic_instance._wrap_prompt(prompt) + + diff --git a/tests/models/auto_temp.py b/tests/models/auto_temp.py index a3461769..bd37e5bb 100644 --- a/tests/models/auto_temp.py +++ b/tests/models/auto_temp.py @@ -1,6 +1,6 @@ import os from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from dotenv import load_dotenv diff --git a/tests/models/bingchat.py b/tests/models/bingchat.py index 5ed2c6ef..ce3af99d 100644 --- a/tests/models/bingchat.py +++ b/tests/models/bingchat.py @@ -3,7 +3,7 @@ import json import os # Assuming the BingChat class is in a file named "bing_chat.py" -from bing_chat import BingChat, ConversationStyle +from bing_chat import BingChat class TestBingChat(unittest.TestCase): diff --git a/tests/models/huggingface.py b/tests/models/huggingface.py index 71fefa67..9a27054a 100644 --- a/tests/models/huggingface.py +++ b/tests/models/huggingface.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest import torch diff --git a/tests/models/kosmos.py b/tests/models/kosmos.py index cffa41e6..11d224d1 100644 --- a/tests/models/kosmos.py +++ b/tests/models/kosmos.py @@ -22,10 +22,10 @@ def mock_image_request(): # Test utility function def test_is_overlapping(): - assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) == True - assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) == False - assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) == False - assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) == True + assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) is True + assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) is False + assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) is False + assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) is True # Test model initialization diff --git a/tests/models/revgptv1.py b/tests/models/revgptv1.py index 95dbb3c6..12ceeea0 100644 --- a/tests/models/revgptv1.py +++ b/tests/models/revgptv1.py @@ -15,7 +15,7 @@ class TestRevChatGPT(unittest.TestCase): def test_run_time(self): prompt = "Generate a 300 word essay about technology." - response = self.model.run(prompt) + self.model.run(prompt) self.assertLess(self.model.end_time - self.model.start_time, 60) def test_generate_summary(self): diff --git a/tests/models/whisperx.py b/tests/models/whisperx.py index 17a28857..af2fe219 100644 --- a/tests/models/whisperx.py +++ b/tests/models/whisperx.py @@ -1,7 +1,7 @@ import os import subprocess import tempfile -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest import whisperx diff --git a/tests/structs/sequential_workflow.py b/tests/structs/sequential_workflow.py index 7bd3e4a4..103711de 100644 --- a/tests/structs/sequential_workflow.py +++ b/tests/structs/sequential_workflow.py @@ -65,10 +65,10 @@ def test_sequential_workflow_initialization(): assert isinstance(workflow, SequentialWorkflow) assert len(workflow.tasks) == 0 assert workflow.max_loops == 1 - assert workflow.autosave == False + assert workflow.autosave is False assert workflow.saved_state_filepath == "sequential_workflow_state.json" - assert workflow.restore_state_filepath == None - assert workflow.dashboard == False + assert workflow.restore_state_filepath is None + assert workflow.dashboard is False def test_sequential_workflow_add_task(): @@ -87,7 +87,7 @@ def test_sequential_workflow_reset_workflow(): task_flow = MockOpenAIChat() workflow.add(task_description, task_flow) workflow.reset_workflow() - assert workflow.tasks[0].result == None + assert workflow.tasks[0].result is None def test_sequential_workflow_get_task_results(): @@ -330,4 +330,4 @@ def test_real_world_usage_with_environment_variables(): def test_real_world_usage_no_openai_key(): # Ensure that an exception is raised when the OpenAI API key is not set with pytest.raises(ValueError): - llm = OpenAIChat() # API key not provided, should raise an exception + OpenAIChat() # API key not provided, should raise an exception diff --git a/tests/swarms/groupchat.py b/tests/swarms/groupchat.py index e69de29b..68609e31 100644 --- a/tests/swarms/groupchat.py +++ b/tests/swarms/groupchat.py @@ -0,0 +1,214 @@ +import pytest + +from swarms.models import OpenAIChat +from swarms.models.anthropic import Anthropic +from swarms.structs.flow import Flow +from swarms.swarms.flow import GroupChat, GroupChatManager + +llm = OpenAIChat() +llm2 = Anthropic() + +# Mock the OpenAI class for testing +class MockOpenAI: + def __init__(self, *args, **kwargs): + pass + + def generate_reply(self, content): + return {"role": "mocked_agent", "content": "Mocked Reply"} + + +# Create fixtures for agents and a sample message +@pytest.fixture +def agent1(): + return Flow(name="Agent1", llm=llm) + + +@pytest.fixture +def agent2(): + return Flow(name="Agent2", llm=llm2) + + +@pytest.fixture +def sample_message(): + return {"role": "Agent1", "content": "Hello, World!"} + + +# Test the initialization of GroupChat +def test_groupchat_initialization(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + assert len(groupchat.agents) == 2 + assert len(groupchat.messages) == 0 + assert groupchat.max_round == 10 + assert groupchat.admin_name == "Admin" + + +# Test resetting the GroupChat +def test_groupchat_reset(agent1, agent2, sample_message): + groupchat = GroupChat(agents=[agent1, agent2]) + groupchat.messages.append(sample_message) + groupchat.reset() + assert len(groupchat.messages) == 0 + + +# Test finding an agent by name +def test_groupchat_find_agent_by_name(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + found_agent = groupchat.agent_by_name("Agent1") + assert found_agent == agent1 + + +# Test selecting the next agent +def test_groupchat_select_next_agent(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + next_agent = groupchat.next_agent(agent1) + assert next_agent == agent2 + + +# Add more tests for different methods and scenarios as needed + + +# Test the GroupChatManager +def test_groupchat_manager(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + selector = agent1 # Assuming agent1 is the selector + manager = GroupChatManager(groupchat, selector) + task = "Task for agent2" + reply = manager(task) + assert reply["role"] == "Agent2" + assert reply["content"] == "Reply from Agent2" + + +# Test selecting the next speaker when there is only one agent +def test_groupchat_select_speaker_single_agent(agent1): + groupchat = GroupChat(agents=[agent1]) + selector = agent1 + manager = GroupChatManager(groupchat, selector) + task = "Task for agent1" + reply = manager(task) + assert reply["role"] == "Agent1" + assert reply["content"] == "Reply from Agent1" + + +# Test selecting the next speaker when GroupChat is underpopulated +def test_groupchat_select_speaker_underpopulated(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + selector = agent1 + manager = GroupChatManager(groupchat, selector) + task = "Task for agent1" + reply = manager(task) + assert reply["role"] == "Agent2" + assert reply["content"] == "Reply from Agent2" + + +# Test formatting history +def test_groupchat_format_history(agent1, agent2, sample_message): + groupchat = GroupChat(agents=[agent1, agent2]) + groupchat.messages.append(sample_message) + formatted_history = groupchat.format_history(groupchat.messages) + expected_history = "'Agent1:Hello, World!" + assert formatted_history == expected_history + + +# Test agent names property +def test_groupchat_agent_names(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + names = groupchat.agent_names + assert len(names) == 2 + assert "Agent1" in names + assert "Agent2" in names + + +# Test GroupChatManager initialization +def test_groupchat_manager_initialization(agent1, agent2): + groupchat = GroupChat(agents=[agent1, agent2]) + selector = agent1 + manager = GroupChatManager(groupchat, selector) + assert manager.groupchat == groupchat + assert manager.selector == selector + +# Test case to ensure GroupChatManager generates a reply from an agent +def test_groupchat_manager_generate_reply(): + # Create a GroupChat with two agents + agents = [agent1, agent2] + groupchat = GroupChat(agents=agents, messages=[], max_round=10) + + # Mock the OpenAI class and GroupChat selector + mocked_openai = MockOpenAI() + selector = agent1 + + # Initialize GroupChatManager + manager = GroupChatManager( + groupchat=groupchat, selector=selector, openai=mocked_openai + ) + + # Generate a reply + task = "Write me a riddle" + reply = manager(task) + + # Check if a valid reply is generated + assert "role" in reply + assert "content" in reply + assert reply["role"] in groupchat.agent_names + + +# Test case to ensure GroupChat selects the next speaker correctly +def test_groupchat_select_speaker(): + agent3 = Flow(name="agent3", llm=llm) + agents = [agent1, agent2, agent3] + groupchat = GroupChat(agents=agents, messages=[], max_round=10) + + # Initialize GroupChatManager with agent1 as selector + selector = agent1 + manager = GroupChatManager(groupchat=groupchat, selector=selector) + + # Simulate selecting the next speaker + last_speaker = agent1 + next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + + # Ensure the next speaker is agent2 + assert next_speaker == agent2 + + +# Test case to ensure GroupChat handles underpopulated group correctly +def test_groupchat_underpopulated_group(): + agent1 = Flow(name="agent1", llm=llm) + agents = [agent1] + groupchat = GroupChat(agents=agents, messages=[], max_round=10) + + # Initialize GroupChatManager with agent1 as selector + selector = agent1 + manager = GroupChatManager(groupchat=groupchat, selector=selector) + + # Simulate selecting the next speaker in an underpopulated group + last_speaker = agent1 + next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + + # Ensure the next speaker is the same as the last speaker in an underpopulated group + assert next_speaker == last_speaker + + +# Test case to ensure GroupChatManager handles the maximum rounds correctly +def test_groupchat_max_rounds(): + agents = [agent1, agent2] + groupchat = GroupChat(agents=agents, messages=[], max_round=2) + + # Initialize GroupChatManager with agent1 as selector + selector = agent1 + manager = GroupChatManager(groupchat=groupchat, selector=selector) + + # Simulate the conversation with max rounds + last_speaker = agent1 + for _ in range(2): + next_speaker = manager.select_speaker( + last_speaker=last_speaker, selector=selector + ) + last_speaker = next_speaker + + # Try one more round, should stay with the last speaker + next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector) + + # Ensure the next speaker is the same as the last speaker after reaching max rounds + assert next_speaker == last_speaker + + +# Continue adding more test cases as needed to cover various scenarios and functionalities of the code. diff --git a/tests/utils/subprocess_code_interpreter.py b/tests/utils/subprocess_code_interpreter.py index 601f8a09..ab7c748f 100644 --- a/tests/utils/subprocess_code_interpreter.py +++ b/tests/utils/subprocess_code_interpreter.py @@ -1,7 +1,9 @@ -import time +import subprocess import threading +import time + import pytest -import subprocess + from swarms.utils.code_interpreter import BaseCodeInterpreter, SubprocessCodeInterpreter @@ -141,7 +143,7 @@ def test_subprocess_code_interpreter_run_debug_mode( ): subprocess_code_interpreter.debug_mode = True code = 'print("Hello, World!")' - result = list(subprocess_code_interpreter.run(code)) + list(subprocess_code_interpreter.run(code)) captured = capsys.readouterr() assert "Running code:\n" in captured.out assert "Received output line:\n" in captured.out @@ -152,7 +154,7 @@ def test_subprocess_code_interpreter_run_no_debug_mode( ): subprocess_code_interpreter.debug_mode = False code = 'print("Hello, World!")' - result = list(subprocess_code_interpreter.run(code)) + list(subprocess_code_interpreter.run(code)) captured = capsys.readouterr() assert "Running code:\n" not in captured.out assert "Received output line:\n" not in captured.out