chore: sync

pull/307/head
Zack 1 year ago
parent 0c6bc9b281
commit 80f288c832

@ -232,8 +232,10 @@ class SSD1B:
except Exception as error:
print(
colored(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again",
(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again"
),
"red",
)
)

@ -0,0 +1,65 @@
import pytest
import os
import shutil
from swarms.agents.idea_to_image_agent import Idea2Image
openai_key = os.getenv("OPENAI_API_KEY")
dalle_cookie = os.getenv("BING_COOKIE")
# Constants for testing
TEST_PROMPT = "Happy fish."
TEST_OUTPUT_FOLDER = "test_images/"
OPENAI_API_KEY = openai_key
DALLE_COOKIE = dalle_cookie
@pytest.fixture(scope="module")
def idea2image_instance():
# Create an instance of the Idea2Image class
idea2image = Idea2Image(
image=TEST_PROMPT,
openai_api_key=OPENAI_API_KEY,
cookie=DALLE_COOKIE,
output_folder=TEST_OUTPUT_FOLDER,
)
yield idea2image
# Clean up the test output folder after testing
if os.path.exists(TEST_OUTPUT_FOLDER):
shutil.rmtree(TEST_OUTPUT_FOLDER)
def test_idea2image_instance(idea2image_instance):
# Check if the instance is created successfully
assert isinstance(idea2image_instance, Idea2Image)
def test_llm_prompt(idea2image_instance):
# Test the llm_prompt method
prompt = idea2image_instance.llm_prompt()
assert isinstance(prompt, str)
def test_generate_image(idea2image_instance):
# Test the generate_image method
idea2image_instance.generate_image()
# Check if the output folder is created
assert os.path.exists(TEST_OUTPUT_FOLDER)
# Check if files are downloaded (assuming DALLE-3 responds with URLs)
files = os.listdir(TEST_OUTPUT_FOLDER)
assert len(files) > 0
def test_invalid_openai_api_key():
# Test with an invalid OpenAI API key
with pytest.raises(Exception) as exc_info:
Idea2Image(
image=TEST_PROMPT,
openai_api_key="invalid_api_key",
cookie=DALLE_COOKIE,
output_folder=TEST_OUTPUT_FOLDER,
)
assert "Failed to initialize OpenAIChat" in str(exc_info.value)
if __name__ == "__main__":
pytest.main()

@ -0,0 +1,33 @@
import pytest
from unittest.mock import patch
from swarms.models.pegasus import PegasusEmbedding
def test_init():
with patch("your_module.Pegasus") as MockPegasus:
embedder = PegasusEmbedding(modality="text")
MockPegasus.assert_called_once()
assert embedder.pegasus == MockPegasus.return_value
def test_init_exception():
with patch("your_module.Pegasus", side_effect=Exception("Test exception")):
with pytest.raises(Exception) as e:
PegasusEmbedding(modality="text")
assert str(e.value) == "Test exception"
def test_embed():
with patch("your_module.Pegasus") as MockPegasus:
embedder = PegasusEmbedding(modality="text")
embedder.embed("Hello world")
MockPegasus.return_value.embed.assert_called_once()
def test_embed_exception():
with patch("your_module.Pegasus") as MockPegasus:
MockPegasus.return_value.embed.side_effect = Exception("Test exception")
embedder = PegasusEmbedding(modality="text")
with pytest.raises(Exception) as e:
embedder.embed("Hello world")
assert str(e.value) == "Test exception"

@ -0,0 +1,52 @@
import pytest
from unittest.mock import Mock
from swarms.memory.ocean import OceanDB
@pytest.fixture
def mock_ocean_client():
return Mock()
@pytest.fixture
def mock_collection():
return Mock()
@pytest.fixture
def ocean_db(mock_ocean_client):
OceanDB.client = mock_ocean_client
return OceanDB()
def test_init(ocean_db, mock_ocean_client):
mock_ocean_client.heartbeat.return_value = "OK"
assert ocean_db.client.heartbeat() == "OK"
def test_create_collection(ocean_db, mock_ocean_client, mock_collection):
mock_ocean_client.create_collection.return_value = mock_collection
collection = ocean_db.create_collection("test", "text")
assert collection == mock_collection
def test_append_document(ocean_db, mock_collection):
document = "test_document"
id = "test_id"
ocean_db.append_document(mock_collection, document, id)
mock_collection.add.assert_called_once_with(documents=[document], ids=[id])
def test_add_documents(ocean_db, mock_collection):
documents = ["test_document1", "test_document2"]
ids = ["test_id1", "test_id2"]
ocean_db.add_documents(mock_collection, documents, ids)
mock_collection.add.assert_called_once_with(documents=documents, ids=ids)
def test_query(ocean_db, mock_collection):
query_texts = ["test_query"]
n_results = 10
mock_collection.query.return_value = "query_result"
result = ocean_db.query(mock_collection, query_texts, n_results)
assert result == "query_result"

@ -0,0 +1,93 @@
import unittest
from unittest.mock import patch
from RevChatGPTModelv4 import RevChatGPTModelv4
class TestRevChatGPT(unittest.TestCase):
def setUp(self):
self.access_token = "123"
self.model = RevChatGPTModelv4(access_token=self.access_token)
def test_run(self):
prompt = "What is the capital of France?"
self.model.start_time = 10
self.model.end_time = 20
response = self.model.run(prompt)
self.assertEqual(response, "The capital of France is Paris.")
self.assertEqual(self.model.start_time, 10)
self.assertEqual(self.model.end_time, 20)
def test_generate_summary(self):
text = "Hello world. This is some text. It has multiple sentences."
summary = self.model.generate_summary(text)
self.assertEqual(summary, "")
@patch("RevChatGPTModelv4.Chatbot.install_plugin")
def test_enable_plugin(self, mock_install_plugin):
plugin_id = "plugin123"
self.model.enable_plugin(plugin_id)
mock_install_plugin.assert_called_with(plugin_id=plugin_id)
@patch("RevChatGPTModelv4.Chatbot.get_plugins")
def test_list_plugins(self, mock_get_plugins):
mock_get_plugins.return_value = [{"id": "123", "name": "Test Plugin"}]
plugins = self.model.list_plugins()
self.assertEqual(len(plugins), 1)
self.assertEqual(plugins[0]["id"], "123")
self.assertEqual(plugins[0]["name"], "Test Plugin")
@patch("RevChatGPTModelv4.Chatbot.get_conversations")
def test_get_conversations(self, mock_get_conversations):
self.model.chatbot.get_conversations()
mock_get_conversations.assert_called()
@patch("RevChatGPTModelv4.Chatbot.get_msg_history")
def test_get_msg_history(self, mock_get_msg_history):
convo_id = "123"
self.model.chatbot.get_msg_history(convo_id)
mock_get_msg_history.assert_called_with(convo_id)
@patch("RevChatGPTModelv4.Chatbot.share_conversation")
def test_share_conversation(self, mock_share_conversation):
self.model.chatbot.share_conversation()
mock_share_conversation.assert_called()
@patch("RevChatGPTModelv4.Chatbot.gen_title")
def test_gen_title(self, mock_gen_title):
convo_id = "123"
message_id = "456"
self.model.chatbot.gen_title(convo_id, message_id)
mock_gen_title.assert_called_with(convo_id, message_id)
@patch("RevChatGPTModelv4.Chatbot.change_title")
def test_change_title(self, mock_change_title):
convo_id = "123"
title = "New Title"
self.model.chatbot.change_title(convo_id, title)
mock_change_title.assert_called_with(convo_id, title)
@patch("RevChatGPTModelv4.Chatbot.delete_conversation")
def test_delete_conversation(self, mock_delete_conversation):
convo_id = "123"
self.model.chatbot.delete_conversation(convo_id)
mock_delete_conversation.assert_called_with(convo_id)
@patch("RevChatGPTModelv4.Chatbot.clear_conversations")
def test_clear_conversations(self, mock_clear_conversations):
self.model.chatbot.clear_conversations()
mock_clear_conversations.assert_called()
@patch("RevChatGPTModelv4.Chatbot.rollback_conversation")
def test_rollback_conversation(self, mock_rollback_conversation):
num = 2
self.model.chatbot.rollback_conversation(num)
mock_rollback_conversation.assert_called_with(num)
@patch("RevChatGPTModelv4.Chatbot.reset_chat")
def test_reset_chat(self, mock_reset_chat):
self.model.chatbot.reset_chat()
mock_reset_chat.assert_called()
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

@ -0,0 +1,69 @@
from unittest.mock import patch, MagicMock
from swarms.structs.nonlinear_workflow import NonLinearWorkflow, Task
class MockTask(Task):
def can_execute(self):
return True
def execute(self):
return "Task executed"
def test_nonlinearworkflow_initialization():
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
assert isinstance(workflow, NonLinearWorkflow)
assert workflow.agents == agents
assert workflow.tasks == []
def test_nonlinearworkflow_add():
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
task = MockTask("task1")
workflow.add(task)
assert workflow.tasks == [task]
@patch("your_module.NonLinearWorkflow.is_finished")
@patch("your_module.NonLinearWorkflow.output_tasks")
def test_nonlinearworkflow_run(mock_output_tasks, mock_is_finished):
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
task = MockTask("task1")
workflow.add(task)
mock_is_finished.return_value = False
mock_output_tasks.return_value = [task]
workflow.run()
assert mock_output_tasks.called
def test_nonlinearworkflow_output_tasks():
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
task = MockTask("task1")
workflow.add(task)
assert workflow.output_tasks() == [task]
def test_nonlinearworkflow_to_graph():
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
task = MockTask("task1")
workflow.add(task)
assert workflow.to_graph() == {"task1": set()}
def test_nonlinearworkflow_order_tasks():
agents = MagicMock()
iters_per_task = MagicMock()
workflow = NonLinearWorkflow(agents, iters_per_task)
task = MockTask("task1")
workflow.add(task)
assert workflow.order_tasks() == [task]

@ -0,0 +1,69 @@
from unittest.mock import patch, MagicMock
from swarms.structs.workflow import Workflow
def test_workflow_initialization():
agent = MagicMock()
workflow = Workflow(agent)
assert isinstance(workflow, Workflow)
assert workflow.agent == agent
assert workflow.tasks == []
assert workflow.parallel is False
def test_workflow_add():
agent = MagicMock()
workflow = Workflow(agent)
task = workflow.add("What's the weather in miami")
assert isinstance(task, Workflow.Task)
assert task.task == "What's the weather in miami"
assert task.parents == []
assert task.children == []
assert task.output is None
assert task.structure == workflow
def test_workflow_first_task():
agent = MagicMock()
workflow = Workflow(agent)
assert workflow.first_task() is None
workflow.add("What's the weather in miami")
assert workflow.first_task().task == "What's the weather in miami"
def test_workflow_last_task():
agent = MagicMock()
workflow = Workflow(agent)
assert workflow.last_task() is None
workflow.add("What's the weather in miami")
assert workflow.last_task().task == "What's the weather in miami"
@patch("your_module.Workflow.__run_from_task")
def test_workflow_run(mock_run_from_task):
agent = MagicMock()
workflow = Workflow(agent)
workflow.add("What's the weather in miami")
workflow.run()
mock_run_from_task.assert_called_once()
def test_workflow_context():
agent = MagicMock()
workflow = Workflow(agent)
task = workflow.add("What's the weather in miami")
assert workflow.context(task) == {
"parent_output": None,
"parent": None,
"child": None,
}
@patch("your_module.Workflow.Task.execute")
def test_workflow___run_from_task(mock_execute):
agent = MagicMock()
workflow = Workflow(agent)
task = workflow.add("What's the weather in miami")
mock_execute.return_value = "Sunny"
workflow.__run_from_task(task)
mock_execute.assert_called_once()

@ -0,0 +1,22 @@
from unittest.mock import patch
from swarms.swarms.dialogue_simulator import DialogueSimulator, Worker
def test_dialoguesimulator_initialization():
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
assert isinstance(dialoguesimulator, DialogueSimulator)
assert len(dialoguesimulator.agents) == 5
@patch("swarms.workers.worker.Worker.run")
def test_dialoguesimulator_run(mock_run):
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
dialoguesimulator.run(max_iters=5, name="Agent 1", message="Hello, world!")
assert mock_run.call_count == 30
@patch("swarms.workers.worker.Worker.run")
def test_dialoguesimulator_run_without_name_and_message(mock_run):
dialoguesimulator = DialogueSimulator(agents=[Worker] * 5)
dialoguesimulator.run(max_iters=5)
assert mock_run.call_count == 25

@ -0,0 +1,68 @@
import pytest
from unittest.mock import Mock
from swarms.swarms.orchestrate import Orchestrator
@pytest.fixture
def mock_agent():
return Mock()
@pytest.fixture
def mock_task():
return {"task_id": 1, "task_data": "data"}
@pytest.fixture
def mock_vector_db():
return Mock()
@pytest.fixture
def orchestrator(mock_agent, mock_vector_db):
agent_list = [mock_agent for _ in range(5)]
task_queue = []
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
orchestrator.task_queue.append(mock_task)
orchestrator.assign_task(0, mock_task)
mock_agent.process_task.assert_called_once()
mock_vector_db.add_documents.assert_called_once()
def test_retrieve_results(orchestrator, mock_vector_db):
mock_vector_db.query.return_value = "expected_result"
assert orchestrator.retrieve_results(0) == "expected_result"
def test_update_vector_db(orchestrator, mock_vector_db):
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
orchestrator.update_vector_db(data)
mock_vector_db.add_documents.assert_called_once_with(
[data["vector"]], [str(data["task_id"])]
)
def test_get_vector_db(orchestrator, mock_vector_db):
assert orchestrator.get_vector_db() == mock_vector_db
def test_append_to_db(orchestrator, mock_vector_db):
collection = "test_collection"
result = "test_result"
orchestrator.append_to_db(collection, result)
mock_vector_db.append_document.assert_called_once_with(
collection, result, id=str(id(result))
)
def test_run(orchestrator, mock_agent, mock_vector_db):
objective = "test_objective"
collection = "test_collection"
orchestrator.run(objective, collection)
mock_agent.process_task.assert_called()
mock_vector_db.append_document.assert_called()

@ -0,0 +1,51 @@
from unittest.mock import patch
from swarms.swarms.simple_swarm import SimpleSwarm
def test_simpleswarm_initialization():
simpleswarm = SimpleSwarm(
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
)
assert isinstance(simpleswarm, SimpleSwarm)
assert len(simpleswarm.workers) == 5
assert simpleswarm.task_queue.qsize() == 0
assert simpleswarm.priority_queue.qsize() == 0
def test_simpleswarm_distribute():
simpleswarm = SimpleSwarm(
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
)
simpleswarm.distribute("task1")
assert simpleswarm.task_queue.qsize() == 1
simpleswarm.distribute("task2", priority=1)
assert simpleswarm.priority_queue.qsize() == 1
@patch("swarms.workers.worker.Worker.run")
def test_simpleswarm_process_task(mock_run):
simpleswarm = SimpleSwarm(
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
)
simpleswarm._process_task("task1")
assert mock_run.call_count == 5
def test_simpleswarm_run():
simpleswarm = SimpleSwarm(
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
)
simpleswarm.distribute("task1")
simpleswarm.distribute("task2", priority=1)
results = simpleswarm.run()
assert len(results) == 2
@patch("swarms.workers.Worker.run")
def test_simpleswarm_run_old(mock_run):
simpleswarm = SimpleSwarm(
num_workers=5, openai_api_key="api_key", ai_name="ai_name"
)
results = simpleswarm.run_old("task1")
assert len(results) == 5
assert mock_run.call_count == 5
Loading…
Cancel
Save