parent
069b2aed45
commit
9c0c6c06cc
@ -1,68 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.swarms.orchestrate import Orchestrator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
return {"task_id": 1, "task_data": "data"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_agent, mock_vector_db):
|
||||
agent_list = [mock_agent for _ in range(5)]
|
||||
task_queue = []
|
||||
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
|
||||
|
||||
|
||||
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
|
||||
orchestrator.task_queue.append(mock_task)
|
||||
orchestrator.assign_task(0, mock_task)
|
||||
|
||||
mock_agent.process_task.assert_called_once()
|
||||
mock_vector_db.add_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_retrieve_results(orchestrator, mock_vector_db):
|
||||
mock_vector_db.query.return_value = "expected_result"
|
||||
assert orchestrator.retrieve_results(0) == "expected_result"
|
||||
|
||||
|
||||
def test_update_vector_db(orchestrator, mock_vector_db):
|
||||
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
|
||||
orchestrator.update_vector_db(data)
|
||||
mock_vector_db.add_documents.assert_called_once_with(
|
||||
[data["vector"]], [str(data["task_id"])]
|
||||
)
|
||||
|
||||
|
||||
def test_get_vector_db(orchestrator, mock_vector_db):
|
||||
assert orchestrator.get_vector_db() == mock_vector_db
|
||||
|
||||
|
||||
def test_append_to_db(orchestrator, mock_vector_db):
|
||||
collection = "test_collection"
|
||||
result = "test_result"
|
||||
orchestrator.append_to_db(collection, result)
|
||||
mock_vector_db.append_document.assert_called_once_with(
|
||||
collection, result, id=str(id(result))
|
||||
)
|
||||
|
||||
|
||||
def test_run(orchestrator, mock_agent, mock_vector_db):
|
||||
objective = "test_objective"
|
||||
collection = "test_collection"
|
||||
orchestrator.run(objective, collection)
|
||||
|
||||
mock_agent.process_task.assert_called()
|
||||
mock_vector_db.append_document.assert_called()
|
@ -1,71 +1,68 @@
|
||||
import numpy as np
|
||||
from swarms.swarms.orchestrate import Orchestrator, Worker
|
||||
import chromadb
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.swarms.orchestrate import Orchestrator
|
||||
|
||||
|
||||
def test_orchestrator_initialization():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
assert isinstance(orchestrator, Orchestrator)
|
||||
assert orchestrator.agents.qsize() == 5
|
||||
assert orchestrator.task_queue.qsize() == 0
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return Mock()
|
||||
|
||||
|
||||
def test_orchestrator_assign_task():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
orchestrator.assign_task(1, {"content": "task1"})
|
||||
assert orchestrator.task_queue.qsize() == 1
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
return {"task_id": 1, "task_data": "data"}
|
||||
|
||||
|
||||
def test_orchestrator_embed():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
result = orchestrator.embed("Hello, world!", "api_key", "model_name")
|
||||
assert isinstance(result, np.ndarray)
|
||||
@pytest.fixture
|
||||
def mock_vector_db():
|
||||
return Mock()
|
||||
|
||||
|
||||
def test_orchestrator_retrieve_results():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
result = orchestrator.retrieve_results(1)
|
||||
assert isinstance(result, list)
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_agent, mock_vector_db):
|
||||
agent_list = [mock_agent for _ in range(5)]
|
||||
task_queue = []
|
||||
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
|
||||
|
||||
|
||||
def test_orchestrator_update_vector_db():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
data = {"vector": np.array([1, 2, 3]), "task_id": 1}
|
||||
orchestrator.update_vector_db(data)
|
||||
assert orchestrator.collection.count() == 1
|
||||
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
|
||||
orchestrator.task_queue.append(mock_task)
|
||||
orchestrator.assign_task(0, mock_task)
|
||||
|
||||
mock_agent.process_task.assert_called_once()
|
||||
mock_vector_db.add_documents.assert_called_once()
|
||||
|
||||
def test_orchestrator_get_vector_db():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
result = orchestrator.get_vector_db()
|
||||
assert isinstance(result, chromadb.Collection)
|
||||
|
||||
def test_retrieve_results(orchestrator, mock_vector_db):
|
||||
mock_vector_db.query.return_value = "expected_result"
|
||||
assert orchestrator.retrieve_results(0) == "expected_result"
|
||||
|
||||
def test_orchestrator_append_to_db():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
orchestrator.append_to_db("Hello, world!")
|
||||
assert orchestrator.collection.count() == 1
|
||||
|
||||
def test_update_vector_db(orchestrator, mock_vector_db):
|
||||
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
|
||||
orchestrator.update_vector_db(data)
|
||||
mock_vector_db.add_documents.assert_called_once_with(
|
||||
[data["vector"]], [str(data["task_id"])]
|
||||
)
|
||||
|
||||
def test_orchestrator_run():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
result = orchestrator.run("Write a short story.")
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_vector_db(orchestrator, mock_vector_db):
|
||||
assert orchestrator.get_vector_db() == mock_vector_db
|
||||
|
||||
def test_orchestrator_chat():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
orchestrator.chat(1, 2, "Hello, Agent 2!")
|
||||
assert orchestrator.collection.count() == 1
|
||||
|
||||
def test_append_to_db(orchestrator, mock_vector_db):
|
||||
collection = "test_collection"
|
||||
result = "test_result"
|
||||
orchestrator.append_to_db(collection, result)
|
||||
mock_vector_db.append_document.assert_called_once_with(
|
||||
collection, result, id=str(id(result))
|
||||
)
|
||||
|
||||
def test_orchestrator_add_agents():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
orchestrator.add_agents(5)
|
||||
assert orchestrator.agents.qsize() == 10
|
||||
|
||||
def test_run(orchestrator, mock_agent, mock_vector_db):
|
||||
objective = "test_objective"
|
||||
collection = "test_collection"
|
||||
orchestrator.run(objective, collection)
|
||||
|
||||
def test_orchestrator_remove_agents():
|
||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
||||
orchestrator.remove_agents(3)
|
||||
assert orchestrator.agents.qsize() == 2
|
||||
mock_agent.process_task.assert_called()
|
||||
mock_vector_db.append_document.assert_called()
|
||||
|
Loading…
Reference in new issue