parent
0c6bc9b281
commit
80f288c832
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
import os
|
||||
import shutil
|
||||
from swarms.agents.idea_to_image_agent import Idea2Image
|
||||
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
dalle_cookie = os.getenv("BING_COOKIE")
|
||||
|
||||
# Constants for testing
|
||||
TEST_PROMPT = "Happy fish."
|
||||
TEST_OUTPUT_FOLDER = "test_images/"
|
||||
OPENAI_API_KEY = openai_key
|
||||
DALLE_COOKIE = dalle_cookie
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def idea2image_instance():
|
||||
# Create an instance of the Idea2Image class
|
||||
idea2image = Idea2Image(
|
||||
image=TEST_PROMPT,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
cookie=DALLE_COOKIE,
|
||||
output_folder=TEST_OUTPUT_FOLDER,
|
||||
)
|
||||
yield idea2image
|
||||
# Clean up the test output folder after testing
|
||||
if os.path.exists(TEST_OUTPUT_FOLDER):
|
||||
shutil.rmtree(TEST_OUTPUT_FOLDER)
|
||||
|
||||
|
||||
def test_idea2image_instance(idea2image_instance):
|
||||
# Check if the instance is created successfully
|
||||
assert isinstance(idea2image_instance, Idea2Image)
|
||||
|
||||
|
||||
def test_llm_prompt(idea2image_instance):
|
||||
# Test the llm_prompt method
|
||||
prompt = idea2image_instance.llm_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
|
||||
def test_generate_image(idea2image_instance):
|
||||
# Test the generate_image method
|
||||
idea2image_instance.generate_image()
|
||||
# Check if the output folder is created
|
||||
assert os.path.exists(TEST_OUTPUT_FOLDER)
|
||||
# Check if files are downloaded (assuming DALLE-3 responds with URLs)
|
||||
files = os.listdir(TEST_OUTPUT_FOLDER)
|
||||
assert len(files) > 0
|
||||
|
||||
|
||||
def test_invalid_openai_api_key():
|
||||
# Test with an invalid OpenAI API key
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Idea2Image(
|
||||
image=TEST_PROMPT,
|
||||
openai_api_key="invalid_api_key",
|
||||
cookie=DALLE_COOKIE,
|
||||
output_folder=TEST_OUTPUT_FOLDER,
|
||||
)
|
||||
assert "Failed to initialize OpenAIChat" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from swarms.models.pegasus import PegasusEmbedding
|
||||
|
||||
|
||||
def test_init():
|
||||
with patch("your_module.Pegasus") as MockPegasus:
|
||||
embedder = PegasusEmbedding(modality="text")
|
||||
MockPegasus.assert_called_once()
|
||||
assert embedder.pegasus == MockPegasus.return_value
|
||||
|
||||
|
||||
def test_init_exception():
|
||||
with patch("your_module.Pegasus", side_effect=Exception("Test exception")):
|
||||
with pytest.raises(Exception) as e:
|
||||
PegasusEmbedding(modality="text")
|
||||
assert str(e.value) == "Test exception"
|
||||
|
||||
|
||||
def test_embed():
|
||||
with patch("your_module.Pegasus") as MockPegasus:
|
||||
embedder = PegasusEmbedding(modality="text")
|
||||
embedder.embed("Hello world")
|
||||
MockPegasus.return_value.embed.assert_called_once()
|
||||
|
||||
|
||||
def test_embed_exception():
|
||||
with patch("your_module.Pegasus") as MockPegasus:
|
||||
MockPegasus.return_value.embed.side_effect = Exception("Test exception")
|
||||
embedder = PegasusEmbedding(modality="text")
|
||||
with pytest.raises(Exception) as e:
|
||||
embedder.embed("Hello world")
|
||||
assert str(e.value) == "Test exception"
|
@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.memory.ocean import OceanDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ocean_client():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collection():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ocean_db(mock_ocean_client):
|
||||
OceanDB.client = mock_ocean_client
|
||||
return OceanDB()
|
||||
|
||||
|
||||
def test_init(ocean_db, mock_ocean_client):
|
||||
mock_ocean_client.heartbeat.return_value = "OK"
|
||||
assert ocean_db.client.heartbeat() == "OK"
|
||||
|
||||
|
||||
def test_create_collection(ocean_db, mock_ocean_client, mock_collection):
|
||||
mock_ocean_client.create_collection.return_value = mock_collection
|
||||
collection = ocean_db.create_collection("test", "text")
|
||||
assert collection == mock_collection
|
||||
|
||||
|
||||
def test_append_document(ocean_db, mock_collection):
|
||||
document = "test_document"
|
||||
id = "test_id"
|
||||
ocean_db.append_document(mock_collection, document, id)
|
||||
mock_collection.add.assert_called_once_with(documents=[document], ids=[id])
|
||||
|
||||
|
||||
def test_add_documents(ocean_db, mock_collection):
|
||||
documents = ["test_document1", "test_document2"]
|
||||
ids = ["test_id1", "test_id2"]
|
||||
ocean_db.add_documents(mock_collection, documents, ids)
|
||||
mock_collection.add.assert_called_once_with(documents=documents, ids=ids)
|
||||
|
||||
|
||||
def test_query(ocean_db, mock_collection):
|
||||
query_texts = ["test_query"]
|
||||
n_results = 10
|
||||
mock_collection.query.return_value = "query_result"
|
||||
result = ocean_db.query(mock_collection, query_texts, n_results)
|
||||
assert result == "query_result"
|
@ -0,0 +1,93 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from RevChatGPTModelv4 import RevChatGPTModelv4
|
||||
|
||||
|
||||
class TestRevChatGPT(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.access_token = "123"
|
||||
self.model = RevChatGPTModelv4(access_token=self.access_token)
|
||||
|
||||
def test_run(self):
|
||||
prompt = "What is the capital of France?"
|
||||
self.model.start_time = 10
|
||||
self.model.end_time = 20
|
||||
response = self.model.run(prompt)
|
||||
self.assertEqual(response, "The capital of France is Paris.")
|
||||
self.assertEqual(self.model.start_time, 10)
|
||||
self.assertEqual(self.model.end_time, 20)
|
||||
|
||||
def test_generate_summary(self):
|
||||
text = "Hello world. This is some text. It has multiple sentences."
|
||||
summary = self.model.generate_summary(text)
|
||||
self.assertEqual(summary, "")
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.install_plugin")
|
||||
def test_enable_plugin(self, mock_install_plugin):
|
||||
plugin_id = "plugin123"
|
||||
self.model.enable_plugin(plugin_id)
|
||||
mock_install_plugin.assert_called_with(plugin_id=plugin_id)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.get_plugins")
|
||||
def test_list_plugins(self, mock_get_plugins):
|
||||
mock_get_plugins.return_value = [{"id": "123", "name": "Test Plugin"}]
|
||||
plugins = self.model.list_plugins()
|
||||
self.assertEqual(len(plugins), 1)
|
||||
self.assertEqual(plugins[0]["id"], "123")
|
||||
self.assertEqual(plugins[0]["name"], "Test Plugin")
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.get_conversations")
|
||||
def test_get_conversations(self, mock_get_conversations):
|
||||
self.model.chatbot.get_conversations()
|
||||
mock_get_conversations.assert_called()
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.get_msg_history")
|
||||
def test_get_msg_history(self, mock_get_msg_history):
|
||||
convo_id = "123"
|
||||
self.model.chatbot.get_msg_history(convo_id)
|
||||
mock_get_msg_history.assert_called_with(convo_id)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.share_conversation")
|
||||
def test_share_conversation(self, mock_share_conversation):
|
||||
self.model.chatbot.share_conversation()
|
||||
mock_share_conversation.assert_called()
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.gen_title")
|
||||
def test_gen_title(self, mock_gen_title):
|
||||
convo_id = "123"
|
||||
message_id = "456"
|
||||
self.model.chatbot.gen_title(convo_id, message_id)
|
||||
mock_gen_title.assert_called_with(convo_id, message_id)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.change_title")
|
||||
def test_change_title(self, mock_change_title):
|
||||
convo_id = "123"
|
||||
title = "New Title"
|
||||
self.model.chatbot.change_title(convo_id, title)
|
||||
mock_change_title.assert_called_with(convo_id, title)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.delete_conversation")
|
||||
def test_delete_conversation(self, mock_delete_conversation):
|
||||
convo_id = "123"
|
||||
self.model.chatbot.delete_conversation(convo_id)
|
||||
mock_delete_conversation.assert_called_with(convo_id)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.clear_conversations")
|
||||
def test_clear_conversations(self, mock_clear_conversations):
|
||||
self.model.chatbot.clear_conversations()
|
||||
mock_clear_conversations.assert_called()
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.rollback_conversation")
|
||||
def test_rollback_conversation(self, mock_rollback_conversation):
|
||||
num = 2
|
||||
self.model.chatbot.rollback_conversation(num)
|
||||
mock_rollback_conversation.assert_called_with(num)
|
||||
|
||||
@patch("RevChatGPTModelv4.Chatbot.reset_chat")
|
||||
def test_reset_chat(self, mock_reset_chat):
|
||||
self.model.chatbot.reset_chat()
|
||||
mock_reset_chat.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from swarms.structs.nonlinear_workflow import NonLinearWorkflow, Task
|
||||
|
||||
|
||||
class MockTask(Task):
|
||||
def can_execute(self):
|
||||
return True
|
||||
|
||||
def execute(self):
|
||||
return "Task executed"
|
||||
|
||||
|
||||
def test_nonlinearworkflow_initialization():
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
assert isinstance(workflow, NonLinearWorkflow)
|
||||
assert workflow.agents == agents
|
||||
assert workflow.tasks == []
|
||||
|
||||
|
||||
def test_nonlinearworkflow_add():
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
task = MockTask("task1")
|
||||
workflow.add(task)
|
||||
assert workflow.tasks == [task]
|
||||
|
||||
|
||||
@patch("your_module.NonLinearWorkflow.is_finished")
|
||||
@patch("your_module.NonLinearWorkflow.output_tasks")
|
||||
def test_nonlinearworkflow_run(mock_output_tasks, mock_is_finished):
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
task = MockTask("task1")
|
||||
workflow.add(task)
|
||||
mock_is_finished.return_value = False
|
||||
mock_output_tasks.return_value = [task]
|
||||
workflow.run()
|
||||
assert mock_output_tasks.called
|
||||
|
||||
|
||||
def test_nonlinearworkflow_output_tasks():
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
task = MockTask("task1")
|
||||
workflow.add(task)
|
||||
assert workflow.output_tasks() == [task]
|
||||
|
||||
|
||||
def test_nonlinearworkflow_to_graph():
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
task = MockTask("task1")
|
||||
workflow.add(task)
|
||||
assert workflow.to_graph() == {"task1": set()}
|
||||
|
||||
|
||||
def test_nonlinearworkflow_order_tasks():
|
||||
agents = MagicMock()
|
||||
iters_per_task = MagicMock()
|
||||
workflow = NonLinearWorkflow(agents, iters_per_task)
|
||||
task = MockTask("task1")
|
||||
workflow.add(task)
|
||||
assert workflow.order_tasks() == [task]
|
@ -0,0 +1,69 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from swarms.structs.workflow import Workflow
|
||||
|
||||
|
||||
def test_workflow_initialization():
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
assert isinstance(workflow, Workflow)
|
||||
assert workflow.agent == agent
|
||||
assert workflow.tasks == []
|
||||
assert workflow.parallel is False
|
||||
|
||||
|
||||
def test_workflow_add():
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
task = workflow.add("What's the weather in miami")
|
||||
assert isinstance(task, Workflow.Task)
|
||||
assert task.task == "What's the weather in miami"
|
||||
assert task.parents == []
|
||||
assert task.children == []
|
||||
assert task.output is None
|
||||
assert task.structure == workflow
|
||||
|
||||
|
||||
def test_workflow_first_task():
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
assert workflow.first_task() is None
|
||||
workflow.add("What's the weather in miami")
|
||||
assert workflow.first_task().task == "What's the weather in miami"
|
||||
|
||||
|
||||
def test_workflow_last_task():
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
assert workflow.last_task() is None
|
||||
workflow.add("What's the weather in miami")
|
||||
assert workflow.last_task().task == "What's the weather in miami"
|
||||
|
||||
|
||||
@patch("your_module.Workflow.__run_from_task")
|
||||
def test_workflow_run(mock_run_from_task):
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
workflow.add("What's the weather in miami")
|
||||
workflow.run()
|
||||
mock_run_from_task.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_context():
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
task = workflow.add("What's the weather in miami")
|
||||
assert workflow.context(task) == {
|
||||
"parent_output": None,
|
||||
"parent": None,
|
||||
"child": None,
|
||||
}
|
||||
|
||||
|
||||
@patch("your_module.Workflow.Task.execute")
|
||||
def test_workflow___run_from_task(mock_execute):
|
||||
agent = MagicMock()
|
||||
workflow = Workflow(agent)
|
||||
task = workflow.add("What's the weather in miami")
|
||||
mock_execute.return_value = "Sunny"
|
||||
workflow.__run_from_task(task)
|
||||
mock_execute.assert_called_once()
|
@ -0,0 +1,22 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.swarms.dialogue_simulator import DialogueSimulator, Worker
|
||||
|
||||
|
||||
def test_dialoguesimulator_initialization():
|
||||
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
|
||||
assert isinstance(dialoguesimulator, DialogueSimulator)
|
||||
assert len(dialoguesimulator.agents) == 5
|
||||
|
||||
|
||||
@patch("swarms.workers.worker.Worker.run")
|
||||
def test_dialoguesimulator_run(mock_run):
|
||||
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
|
||||
dialoguesimulator.run(max_iters=5, name="Agent 1", message="Hello, world!")
|
||||
assert mock_run.call_count == 30
|
||||
|
||||
|
||||
@patch("swarms.workers.worker.Worker.run")
|
||||
def test_dialoguesimulator_run_without_name_and_message(mock_run):
|
||||
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
|
||||
dialoguesimulator.run(max_iters=5)
|
||||
assert mock_run.call_count == 25
|
@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.swarms.orchestrate import Orchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
return {"task_id": 1, "task_data": "data"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_agent, mock_vector_db):
|
||||
agent_list = [mock_agent for _ in range(5)]
|
||||
task_queue = []
|
||||
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
|
||||
|
||||
|
||||
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
|
||||
orchestrator.task_queue.append(mock_task)
|
||||
orchestrator.assign_task(0, mock_task)
|
||||
|
||||
mock_agent.process_task.assert_called_once()
|
||||
mock_vector_db.add_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_retrieve_results(orchestrator, mock_vector_db):
|
||||
mock_vector_db.query.return_value = "expected_result"
|
||||
assert orchestrator.retrieve_results(0) == "expected_result"
|
||||
|
||||
|
||||
def test_update_vector_db(orchestrator, mock_vector_db):
|
||||
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
|
||||
orchestrator.update_vector_db(data)
|
||||
mock_vector_db.add_documents.assert_called_once_with(
|
||||
[data["vector"]], [str(data["task_id"])]
|
||||
)
|
||||
|
||||
|
||||
def test_get_vector_db(orchestrator, mock_vector_db):
|
||||
assert orchestrator.get_vector_db() == mock_vector_db
|
||||
|
||||
|
||||
def test_append_to_db(orchestrator, mock_vector_db):
|
||||
collection = "test_collection"
|
||||
result = "test_result"
|
||||
orchestrator.append_to_db(collection, result)
|
||||
mock_vector_db.append_document.assert_called_once_with(
|
||||
collection, result, id=str(id(result))
|
||||
)
|
||||
|
||||
|
||||
def test_run(orchestrator, mock_agent, mock_vector_db):
|
||||
objective = "test_objective"
|
||||
collection = "test_collection"
|
||||
orchestrator.run(objective, collection)
|
||||
|
||||
mock_agent.process_task.assert_called()
|
||||
mock_vector_db.append_document.assert_called()
|
@ -0,0 +1,51 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.swarms.simple_swarm import SimpleSwarm
|
||||
|
||||
|
||||
def test_simpleswarm_initialization():
|
||||
simpleswarm = SimpleSwarm(
|
||||
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
|
||||
)
|
||||
assert isinstance(simpleswarm, SimpleSwarm)
|
||||
assert len(simpleswarm.workers) == 5
|
||||
assert simpleswarm.task_queue.qsize() == 0
|
||||
assert simpleswarm.priority_queue.qsize() == 0
|
||||
|
||||
|
||||
def test_simpleswarm_distribute():
|
||||
simpleswarm = SimpleSwarm(
|
||||
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
|
||||
)
|
||||
simpleswarm.distribute("task1")
|
||||
assert simpleswarm.task_queue.qsize() == 1
|
||||
simpleswarm.distribute("task2", priority=1)
|
||||
assert simpleswarm.priority_queue.qsize() == 1
|
||||
|
||||
|
||||
@patch("swarms.workers.worker.Worker.run")
|
||||
def test_simpleswarm_process_task(mock_run):
|
||||
simpleswarm = SimpleSwarm(
|
||||
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
|
||||
)
|
||||
simpleswarm._process_task("task1")
|
||||
assert mock_run.call_count == 5
|
||||
|
||||
|
||||
def test_simpleswarm_run():
|
||||
simpleswarm = SimpleSwarm(
|
||||
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
|
||||
)
|
||||
simpleswarm.distribute("task1")
|
||||
simpleswarm.distribute("task2", priority=1)
|
||||
results = simpleswarm.run()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@patch("swarms.workers.Worker.run")
|
||||
def test_simpleswarm_run_old(mock_run):
|
||||
simpleswarm = SimpleSwarm(
|
||||
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
|
||||
)
|
||||
results = simpleswarm.run_old("task1")
|
||||
assert len(results) == 5
|
||||
assert mock_run.call_count == 5
|
Loading…
Reference in new issue