parent
069b2aed45
commit
9c0c6c06cc
@ -1,68 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock
|
|
||||||
from swarms.swarms.orchestrate import Orchestrator
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_agent():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_task():
|
|
||||||
return {"task_id": 1, "task_data": "data"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_vector_db():
|
|
||||||
return Mock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def orchestrator(mock_agent, mock_vector_db):
|
|
||||||
agent_list = [mock_agent for _ in range(5)]
|
|
||||||
task_queue = []
|
|
||||||
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
|
|
||||||
|
|
||||||
|
|
||||||
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
|
|
||||||
orchestrator.task_queue.append(mock_task)
|
|
||||||
orchestrator.assign_task(0, mock_task)
|
|
||||||
|
|
||||||
mock_agent.process_task.assert_called_once()
|
|
||||||
mock_vector_db.add_documents.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_results(orchestrator, mock_vector_db):
|
|
||||||
mock_vector_db.query.return_value = "expected_result"
|
|
||||||
assert orchestrator.retrieve_results(0) == "expected_result"
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_vector_db(orchestrator, mock_vector_db):
|
|
||||||
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
|
|
||||||
orchestrator.update_vector_db(data)
|
|
||||||
mock_vector_db.add_documents.assert_called_once_with(
|
|
||||||
[data["vector"]], [str(data["task_id"])]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_vector_db(orchestrator, mock_vector_db):
|
|
||||||
assert orchestrator.get_vector_db() == mock_vector_db
|
|
||||||
|
|
||||||
|
|
||||||
def test_append_to_db(orchestrator, mock_vector_db):
|
|
||||||
collection = "test_collection"
|
|
||||||
result = "test_result"
|
|
||||||
orchestrator.append_to_db(collection, result)
|
|
||||||
mock_vector_db.append_document.assert_called_once_with(
|
|
||||||
collection, result, id=str(id(result))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run(orchestrator, mock_agent, mock_vector_db):
|
|
||||||
objective = "test_objective"
|
|
||||||
collection = "test_collection"
|
|
||||||
orchestrator.run(objective, collection)
|
|
||||||
|
|
||||||
mock_agent.process_task.assert_called()
|
|
||||||
mock_vector_db.append_document.assert_called()
|
|
@ -1,71 +1,68 @@
|
|||||||
import numpy as np
|
import pytest
|
||||||
from swarms.swarms.orchestrate import Orchestrator, Worker
|
from unittest.mock import Mock
|
||||||
import chromadb
|
from swarms.swarms.orchestrate import Orchestrator
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_initialization():
|
@pytest.fixture
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
def mock_agent():
|
||||||
assert isinstance(orchestrator, Orchestrator)
|
return Mock()
|
||||||
assert orchestrator.agents.qsize() == 5
|
|
||||||
assert orchestrator.task_queue.qsize() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_assign_task():
|
@pytest.fixture
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
def mock_task():
|
||||||
orchestrator.assign_task(1, {"content": "task1"})
|
return {"task_id": 1, "task_data": "data"}
|
||||||
assert orchestrator.task_queue.qsize() == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_embed():
|
@pytest.fixture
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
def mock_vector_db():
|
||||||
result = orchestrator.embed("Hello, world!", "api_key", "model_name")
|
return Mock()
|
||||||
assert isinstance(result, np.ndarray)
|
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_retrieve_results():
|
@pytest.fixture
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
def orchestrator(mock_agent, mock_vector_db):
|
||||||
result = orchestrator.retrieve_results(1)
|
agent_list = [mock_agent for _ in range(5)]
|
||||||
assert isinstance(result, list)
|
task_queue = []
|
||||||
|
return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db)
|
||||||
|
|
||||||
|
|
||||||
def test_orchestrator_update_vector_db():
|
def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db):
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
orchestrator.task_queue.append(mock_task)
|
||||||
data = {"vector": np.array([1, 2, 3]), "task_id": 1}
|
orchestrator.assign_task(0, mock_task)
|
||||||
orchestrator.update_vector_db(data)
|
|
||||||
assert orchestrator.collection.count() == 1
|
|
||||||
|
|
||||||
|
mock_agent.process_task.assert_called_once()
|
||||||
|
mock_vector_db.add_documents.assert_called_once()
|
||||||
|
|
||||||
def test_orchestrator_get_vector_db():
|
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
|
||||||
result = orchestrator.get_vector_db()
|
|
||||||
assert isinstance(result, chromadb.Collection)
|
|
||||||
|
|
||||||
|
def test_retrieve_results(orchestrator, mock_vector_db):
|
||||||
|
mock_vector_db.query.return_value = "expected_result"
|
||||||
|
assert orchestrator.retrieve_results(0) == "expected_result"
|
||||||
|
|
||||||
def test_orchestrator_append_to_db():
|
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
|
||||||
orchestrator.append_to_db("Hello, world!")
|
|
||||||
assert orchestrator.collection.count() == 1
|
|
||||||
|
|
||||||
|
def test_update_vector_db(orchestrator, mock_vector_db):
|
||||||
|
data = {"vector": [0.1, 0.2, 0.3], "task_id": 1}
|
||||||
|
orchestrator.update_vector_db(data)
|
||||||
|
mock_vector_db.add_documents.assert_called_once_with(
|
||||||
|
[data["vector"]], [str(data["task_id"])]
|
||||||
|
)
|
||||||
|
|
||||||
def test_orchestrator_run():
|
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
|
||||||
result = orchestrator.run("Write a short story.")
|
|
||||||
assert isinstance(result, list)
|
|
||||||
|
|
||||||
|
def test_get_vector_db(orchestrator, mock_vector_db):
|
||||||
|
assert orchestrator.get_vector_db() == mock_vector_db
|
||||||
|
|
||||||
def test_orchestrator_chat():
|
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
|
||||||
orchestrator.chat(1, 2, "Hello, Agent 2!")
|
|
||||||
assert orchestrator.collection.count() == 1
|
|
||||||
|
|
||||||
|
def test_append_to_db(orchestrator, mock_vector_db):
|
||||||
|
collection = "test_collection"
|
||||||
|
result = "test_result"
|
||||||
|
orchestrator.append_to_db(collection, result)
|
||||||
|
mock_vector_db.append_document.assert_called_once_with(
|
||||||
|
collection, result, id=str(id(result))
|
||||||
|
)
|
||||||
|
|
||||||
def test_orchestrator_add_agents():
|
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
|
||||||
orchestrator.add_agents(5)
|
|
||||||
assert orchestrator.agents.qsize() == 10
|
|
||||||
|
|
||||||
|
def test_run(orchestrator, mock_agent, mock_vector_db):
|
||||||
|
objective = "test_objective"
|
||||||
|
collection = "test_collection"
|
||||||
|
orchestrator.run(objective, collection)
|
||||||
|
|
||||||
def test_orchestrator_remove_agents():
|
mock_agent.process_task.assert_called()
|
||||||
orchestrator = Orchestrator(agent=Worker, agent_list=[Worker] * 5, task_queue=[])
|
mock_vector_db.append_document.assert_called()
|
||||||
orchestrator.remove_agents(3)
|
|
||||||
assert orchestrator.agents.qsize() == 2
|
|
||||||
|
Loading…
Reference in new issue