diff --git a/tests/swarms/dialogue_simulator.py b/tests/swarms/dialogue_simulator.py index f6bb8ad3..634c29d1 100644 --- a/tests/swarms/dialogue_simulator.py +++ b/tests/swarms/dialogue_simulator.py @@ -7,13 +7,13 @@ def test_dialoguesimulator_initialization(): assert isinstance(dialoguesimulator, DialogueSimulator) assert len(dialoguesimulator.agents) == 5 -@patch('swarms.workers.Worker.run') +@patch('swarms.workers.worker.Worker.run') def test_dialoguesimulator_run(mock_run): dialoguesimulator = DialogueSimulator(agents=[Worker]*5) dialoguesimulator.run(max_iters=5, name="Agent 1", message="Hello, world!") assert mock_run.call_count == 30 -@patch('swarms.workers.Worker.run') +@patch('swarms.workers.worker.Worker.run') def test_dialoguesimulator_run_without_name_and_message(mock_run): dialoguesimulator = DialogueSimulator(agents=[Worker]*5) dialoguesimulator.run(max_iters=5) diff --git a/tests/swarms/scalable_groupchat.py b/tests/swarms/scalable_groupchat.py index e69de29b..583514a7 100644 --- a/tests/swarms/scalable_groupchat.py +++ b/tests/swarms/scalable_groupchat.py @@ -0,0 +1,44 @@ +import pytest +from unittest.mock import patch, MagicMock +from swarms.swarms.scalable_groupchat import ScalableGroupChat, Worker + +def test_scalablegroupchat_initialization(): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + assert isinstance(scalablegroupchat, ScalableGroupChat) + assert len(scalablegroupchat.workers) == 5 + assert scalablegroupchat.collection_name == "swarm" + assert scalablegroupchat.api_key == "api_key" + +@patch('chromadb.utils.embedding_functions.OpenAIEmbeddingFunction') +def test_scalablegroupchat_embed(mock_openaiembeddingfunction): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + result = scalablegroupchat.embed("input", "model_name") + assert mock_openaiembeddingfunction.call_count == 1 + +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.query') +def test_scalablegroupchat_retrieve_results(mock_query): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + result = scalablegroupchat.retrieve_results(1) + assert mock_query.call_count == 1 + +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add') +def test_scalablegroupchat_update_vector_db(mock_add): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + scalablegroupchat.update_vector_db({"vector": "vector", "task_id": "task_id"}) + assert mock_add.call_count == 1 + +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add') +def test_scalablegroupchat_append_to_db(mock_add): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + scalablegroupchat.append_to_db("result") + assert mock_add.call_count == 1 + +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add') +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.embed') +@patch('swarms.swarms.scalable_groupchat.ScalableGroupChat.run') +def test_scalablegroupchat_chat(mock_run, mock_embed, mock_add): + scalablegroupchat = ScalableGroupChat(worker_count=5, collection_name="swarm", api_key="api_key") + scalablegroupchat.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!") + assert mock_embed.call_count == 1 + assert mock_add.call_count == 1 + assert mock_run.call_count == 1 \ No newline at end of file diff --git a/tests/swarms/simple_swarm.py b/tests/swarms/simple_swarm.py index 4753a99b..c168dc7f 100644 --- a/tests/swarms/simple_swarm.py +++ b/tests/swarms/simple_swarm.py @@ -16,7 +16,7 @@ def test_simpleswarm_distribute(): simpleswarm.distribute("task2", priority=1) assert simpleswarm.priority_queue.qsize() == 1 -@patch('swarms.workers.Worker.run') +@patch('swarms.workers.worker.Worker.run') def test_simpleswarm_process_task(mock_run): simpleswarm = SimpleSwarm(num_workers=5, openai_api_key="api_key", ai_name="ai_name") result = simpleswarm._process_task("task1")