From 3fe621e4a7ffba246bec36e6f4539bca19f73ddf Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 27 Jul 2023 17:46:37 -0400 Subject: [PATCH] tests Former-commit-id: ec6152845c28e9da663d6f5994e57fd35df4266b --- tests/agents/memory/main.py | 52 +++++++++++++ tests/{ => agents/models}/LLM.py | 0 tests/agents/models/hf.py | 69 ++++++++++++++++++ tests/orchestrate.py | 61 ++++++++++++++++ tests/swarms.py | 121 +++++++++++++++++-------------- 5 files changed, 247 insertions(+), 56 deletions(-) create mode 100644 tests/agents/memory/main.py rename tests/{ => agents/models}/LLM.py (100%) create mode 100644 tests/agents/models/hf.py create mode 100644 tests/orchestrate.py diff --git a/tests/agents/memory/main.py b/tests/agents/memory/main.py new file mode 100644 index 00000000..a6ea0706 --- /dev/null +++ b/tests/agents/memory/main.py @@ -0,0 +1,52 @@ +import pytest +from unittest.mock import Mock, MagicMock +from swarms.agents.memory.oceandb import OceanDB + + +@pytest.fixture +def mock_ocean_client(): + return Mock() + + +@pytest.fixture +def mock_collection(): + return Mock() + + +@pytest.fixture +def ocean_db(mock_ocean_client): + OceanDB.client = mock_ocean_client + return OceanDB() + + +def test_init(ocean_db, mock_ocean_client): + mock_ocean_client.heartbeat.return_value = "OK" + assert ocean_db.client.heartbeat() == "OK" + + +def test_create_collection(ocean_db, mock_ocean_client, mock_collection): + mock_ocean_client.create_collection.return_value = mock_collection + collection = ocean_db.create_collection("test", "text") + assert collection == mock_collection + + +def test_append_document(ocean_db, mock_collection): + document = "test_document" + id = "test_id" + ocean_db.append_document(mock_collection, document, id) + mock_collection.add.assert_called_once_with(documents=[document], ids=[id]) + + +def test_add_documents(ocean_db, mock_collection): + documents = ["test_document1", "test_document2"] + ids = ["test_id1", "test_id2"] + ocean_db.add_documents(mock_collection, documents, ids) + mock_collection.add.assert_called_once_with(documents=documents, ids=ids) + + +def test_query(ocean_db, mock_collection): + query_texts = ["test_query"] + n_results = 10 + mock_collection.query.return_value = "query_result" + result = ocean_db.query(mock_collection, query_texts, n_results) + assert result == "query_result" diff --git a/tests/LLM.py b/tests/agents/models/LLM.py similarity index 100% rename from tests/LLM.py rename to tests/agents/models/LLM.py diff --git a/tests/agents/models/hf.py b/tests/agents/models/hf.py new file mode 100644 index 00000000..68648904 --- /dev/null +++ b/tests/agents/models/hf.py @@ -0,0 +1,69 @@ +import pytest +import torch +from unittest.mock import Mock, MagicMock +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from your_module import HuggingFaceLLM # replace with actual import + + +@pytest.fixture +def mock_torch(): + return Mock() + + +@pytest.fixture +def mock_autotokenizer(): + return Mock() + + +@pytest.fixture +def mock_automodelforcausallm(): + return Mock() + + +@pytest.fixture +def mock_bitsandbytesconfig(): + return Mock() + + +@pytest.fixture +def hugging_face_llm(mock_torch, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig): + HuggingFaceLLM.torch = mock_torch + HuggingFaceLLM.AutoTokenizer = mock_autotokenizer + HuggingFaceLLM.AutoModelForCausalLM = mock_automodelforcausallm + HuggingFaceLLM.BitsAndBytesConfig = mock_bitsandbytesconfig + + return HuggingFaceLLM(model_id='test') + + +def test_init(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm): + assert hugging_face_llm.model_id == 'test' + mock_autotokenizer.from_pretrained.assert_called_once_with('test') + mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=None) + + +def test_init_with_quantize(hugging_face_llm, mock_autotokenizer, mock_automodelforcausallm, mock_bitsandbytesconfig): + quantization_config = { + 'load_in_4bit': True, + 'bnb_4bit_use_double_quant': True, + 'bnb_4bit_quant_type': "nf4", + 'bnb_4bit_compute_dtype': torch.bfloat16 + } + mock_bitsandbytesconfig.return_value = quantization_config + + hugging_face_llm = HuggingFaceLLM(model_id='test', quantize=True) + + mock_bitsandbytesconfig.assert_called_once_with(**quantization_config) + mock_autotokenizer.from_pretrained.assert_called_once_with('test') + mock_automodelforcausallm.from_pretrained.assert_called_once_with('test', quantization_config=quantization_config) + + +def test_generate_text(hugging_face_llm): + prompt_text = 'test prompt' + expected_output = 'test output' + hugging_face_llm.tokenizer.encode.return_value = torch.tensor([0]) # Mock tensor + hugging_face_llm.model.generate.return_value = torch.tensor([0]) # Mock tensor + hugging_face_llm.tokenizer.decode.return_value = expected_output + + output = hugging_face_llm.generate_text(prompt_text) + + assert output == expected_output diff --git a/tests/orchestrate.py b/tests/orchestrate.py new file mode 100644 index 00000000..86395a57 --- /dev/null +++ b/tests/orchestrate.py @@ -0,0 +1,61 @@ +import pytest +from unittest.mock import Mock +from swarms.orchestrate import Orchestrator + + +@pytest.fixture +def mock_agent(): + return Mock() + +@pytest.fixture +def mock_task(): + return {"task_id": 1, "task_data": "data"} + +@pytest.fixture +def mock_vector_db(): + return Mock() + +@pytest.fixture +def orchestrator(mock_agent, mock_vector_db): + agent_list = [mock_agent for _ in range(5)] + task_queue = [] + return Orchestrator(mock_agent, agent_list, task_queue, mock_vector_db) + + +def test_assign_task(orchestrator, mock_agent, mock_task, mock_vector_db): + orchestrator.task_queue.append(mock_task) + orchestrator.assign_task(0, mock_task) + + mock_agent.process_task.assert_called_once() + mock_vector_db.add_documents.assert_called_once() + + +def test_retrieve_results(orchestrator, mock_vector_db): + mock_vector_db.query.return_value = "expected_result" + assert orchestrator.retrieve_results(0) == "expected_result" + + +def test_update_vector_db(orchestrator, mock_vector_db): + data = {"vector": [0.1, 0.2, 0.3], "task_id": 1} + orchestrator.update_vector_db(data) + mock_vector_db.add_documents.assert_called_once_with([data['vector']], [str(data['task_id'])]) + + +def test_get_vector_db(orchestrator, mock_vector_db): + assert orchestrator.get_vector_db() == mock_vector_db + + +def test_append_to_db(orchestrator, mock_vector_db): + collection = "test_collection" + result = "test_result" + orchestrator.append_to_db(collection, result) + mock_vector_db.append_document.assert_called_once_with(collection, result, id=str(id(result))) + + +def test_run(orchestrator, mock_agent, mock_vector_db): + objective = "test_objective" + collection = "test_collection" + orchestrator.run(objective, collection) + + mock_agent.process_task.assert_called() + mock_vector_db.append_document.assert_called() diff --git a/tests/swarms.py b/tests/swarms.py index f525b835..55a0e851 100644 --- a/tests/swarms.py +++ b/tests/swarms.py @@ -1,64 +1,73 @@ -import unittest -import swarms -from swarms.workers.worker_node import WorkerNode -from swarms.boss.BossNode import BossNode +import pytest +import logging +from unittest.mock import Mock, patch +from swarms.swarms import HierarchicalSwarm # replace with your actual module name -class TestSwarms(unittest.TestCase): - def setUp(self): - self.swarm = swarms.Swarms('fake_api_key') +@pytest.fixture +def swarm(): + return HierarchicalSwarm( + model_id='gpt-4', + openai_api_key='some_api_key', + use_vectorstore=True, + embedding_size=1024, + use_async=False, + human_in_the_loop=True, + model_type='openai', + boss_prompt='boss', + worker_prompt='worker', + temperature=0.5, + max_iterations=100, + logging_enabled=True + ) - def test_initialize_llm(self): - llm = self.swarm.initialize_llm(swarms.ChatOpenAI) - self.assertIsNotNone(llm) +@pytest.fixture +def swarm_no_logging(): + return HierarchicalSwarm(logging_enabled=False) - def test_initialize_tools(self): - tools = self.swarm.initialize_tools(swarms.ChatOpenAI) - self.assertIsNotNone(tools) +def test_swarm_init(swarm): + assert swarm.model_id == 'gpt-4' + assert swarm.openai_api_key == 'some_api_key' + assert swarm.use_vectorstore + assert swarm.embedding_size == 1024 + assert not swarm.use_async + assert swarm.human_in_the_loop + assert swarm.model_type == 'openai' + assert swarm.boss_prompt == 'boss' + assert swarm.worker_prompt == 'worker' + assert swarm.temperature == 0.5 + assert swarm.max_iterations == 100 + assert swarm.logging_enabled + assert isinstance(swarm.logger, logging.Logger) - def test_initialize_vectorstore(self): - vectorstore = self.swarm.initialize_vectorstore() - self.assertIsNotNone(vectorstore) +def test_swarm_no_logging_init(swarm_no_logging): + assert not swarm_no_logging.logging_enabled + assert swarm_no_logging.logger.disabled - def test_run(self): - objective = "Do a web search for 'OpenAI'" - result = self.swarm.run(objective) - self.assertIsNotNone(result) +@patch('your_module.OpenAI') +@patch('your_module.HuggingFaceLLM') +def test_initialize_llm(mock_huggingface, mock_openai, swarm): + swarm.initialize_llm('openai') + mock_openai.assert_called_once_with(openai_api_key='some_api_key', temperature=0.5) + + swarm.initialize_llm('huggingface') + mock_huggingface.assert_called_once_with(model_id='gpt-4', temperature=0.5) +@patch('your_module.HierarchicalSwarm.initialize_llm') +def test_initialize_tools(mock_llm, swarm): + mock_llm.return_value = 'mock_llm_class' + tools = swarm.initialize_tools('openai') + assert 'mock_llm_class' in tools -class TestWorkerNode(unittest.TestCase): - def setUp(self): - swarm = swarms.Swarms('fake_api_key') - worker_tools = swarm.initialize_tools(swarms.ChatOpenAI) - vectorstore = swarm.initialize_vectorstore() - self.worker_node = swarm.initialize_worker_node(worker_tools, vectorstore) +@patch('your_module.HierarchicalSwarm.initialize_llm') +def test_initialize_tools_with_extra_tools(mock_llm, swarm): + mock_llm.return_value = 'mock_llm_class' + tools = swarm.initialize_tools('openai', extra_tools=['tool1', 'tool2']) + assert 'tool1' in tools + assert 'tool2' in tools - def test_create_agent(self): - self.worker_node.create_agent("Worker 1", "Assistant", False, {}) - self.assertIsNotNone(self.worker_node.agent) - - def test_run(self): - tool_input = {'prompt': "Search the web for 'OpenAI'"} - result = self.worker_node.run(tool_input) - self.assertIsNotNone(result) - - -class TestBossNode(unittest.TestCase): - def setUp(self): - swarm = swarms.Swarms('fake_api_key') - worker_tools = swarm.initialize_tools(swarms.ChatOpenAI) - vectorstore = swarm.initialize_vectorstore() - worker_node = swarm.initialize_worker_node(worker_tools, vectorstore) - self.boss_node = swarm.initialize_boss_node(vectorstore, worker_node) - - def test_create_task(self): - task = self.boss_node.create_task("Do a web search for 'OpenAI'") - self.assertIsNotNone(task) - - def test_execute_task(self): - task = self.boss_node.create_task("Do a web search for 'OpenAI'") - result = self.boss_node.execute_task(task) - self.assertIsNotNone(result) - - -if __name__ == '__main__': - unittest.main() +@patch('your_module.OpenAIEmbeddings') +@patch('your_module.FAISS') +def test_initialize_vectorstore(mock_faiss, mock_openai_embeddings, swarm): + mock_openai_embeddings.return_value.embed_query = 'embed_query' + vectorstore = swarm.initialize_vectorstore() + mock_faiss.assert_called_once_with('embed_query', instance_of(faiss.IndexFlatL2), instance_of(InMemoryDocstore), {})