Former-commit-id: ef9d4b40a3ac7abe923a21b34ff263f11b1b143dclean-history
parent
2d05a09e1c
commit
daf3c9e6a6
@ -1,61 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
from swarms.swarms.scalable_groupchat import ScalableGroupChat
|
||||
|
||||
|
||||
def test_scalablegroupchat_initialization():
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
assert isinstance(scalablegroupchat, ScalableGroupChat)
|
||||
assert len(scalablegroupchat.workers) == 5
|
||||
assert scalablegroupchat.collection_name == "swarm"
|
||||
assert scalablegroupchat.api_key == "api_key"
|
||||
|
||||
|
||||
@patch("chromadb.utils.embedding_functions.OpenAIEmbeddingFunction")
|
||||
def test_scalablegroupchat_embed(mock_openaiembeddingfunction):
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
scalablegroupchat.embed("input", "model_name")
|
||||
assert mock_openaiembeddingfunction.call_count == 1
|
||||
|
||||
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.query")
|
||||
def test_scalablegroupchat_retrieve_results(mock_query):
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
scalablegroupchat.retrieve_results(1)
|
||||
assert mock_query.call_count == 1
|
||||
|
||||
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add")
|
||||
def test_scalablegroupchat_update_vector_db(mock_add):
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
scalablegroupchat.update_vector_db({"vector": "vector", "task_id": "task_id"})
|
||||
assert mock_add.call_count == 1
|
||||
|
||||
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add")
|
||||
def test_scalablegroupchat_append_to_db(mock_add):
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
scalablegroupchat.append_to_db("result")
|
||||
assert mock_add.call_count == 1
|
||||
|
||||
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.collection.add")
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.embed")
|
||||
@patch("swarms.swarms.scalable_groupchat.ScalableGroupChat.run")
|
||||
def test_scalablegroupchat_chat(mock_run, mock_embed, mock_add):
|
||||
scalablegroupchat = ScalableGroupChat(
|
||||
worker_count=5, collection_name="swarm", api_key="api_key"
|
||||
)
|
||||
scalablegroupchat.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!")
|
||||
assert mock_embed.call_count == 1
|
||||
assert mock_add.call_count == 1
|
||||
assert mock_run.call_count == 1
|
@ -1,85 +0,0 @@
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
from swarms.swarms.swarms import (
|
||||
HierarchicalSwarm,
|
||||
) # replace with your actual module name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def swarm():
|
||||
return HierarchicalSwarm(
|
||||
model_id="gpt-4",
|
||||
openai_api_key="some_api_key",
|
||||
use_vectorstore=True,
|
||||
embedding_size=1024,
|
||||
use_async=False,
|
||||
human_in_the_loop=True,
|
||||
model_type="openai",
|
||||
boss_prompt="boss",
|
||||
worker_prompt="worker",
|
||||
temperature=0.5,
|
||||
max_iterations=100,
|
||||
logging_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def swarm_no_logging():
|
||||
return HierarchicalSwarm(logging_enabled=False)
|
||||
|
||||
|
||||
def test_swarm_init(swarm):
|
||||
assert swarm.model_id == "gpt-4"
|
||||
assert swarm.openai_api_key == "some_api_key"
|
||||
assert swarm.use_vectorstore
|
||||
assert swarm.embedding_size == 1024
|
||||
assert not swarm.use_async
|
||||
assert swarm.human_in_the_loop
|
||||
assert swarm.model_type == "openai"
|
||||
assert swarm.boss_prompt == "boss"
|
||||
assert swarm.worker_prompt == "worker"
|
||||
assert swarm.temperature == 0.5
|
||||
assert swarm.max_iterations == 100
|
||||
assert swarm.logging_enabled
|
||||
assert isinstance(swarm.logger, logging.Logger)
|
||||
|
||||
|
||||
def test_swarm_no_logging_init(swarm_no_logging):
|
||||
assert not swarm_no_logging.logging_enabled
|
||||
assert swarm_no_logging.logger.disabled
|
||||
|
||||
|
||||
@patch("your_module.OpenAI")
|
||||
@patch("your_module.HuggingFaceLLM")
|
||||
def test_initialize_llm(mock_huggingface, mock_openai, swarm):
|
||||
swarm.initialize_llm("openai")
|
||||
mock_openai.assert_called_once_with(openai_api_key="some_api_key", temperature=0.5)
|
||||
|
||||
swarm.initialize_llm("huggingface")
|
||||
mock_huggingface.assert_called_once_with(model_id="gpt-4", temperature=0.5)
|
||||
|
||||
|
||||
@patch("your_module.HierarchicalSwarm.initialize_llm")
|
||||
def test_initialize_tools(mock_llm, swarm):
|
||||
mock_llm.return_value = "mock_llm_class"
|
||||
tools = swarm.initialize_tools("openai")
|
||||
assert "mock_llm_class" in tools
|
||||
|
||||
|
||||
@patch("your_module.HierarchicalSwarm.initialize_llm")
|
||||
def test_initialize_tools_with_extra_tools(mock_llm, swarm):
|
||||
mock_llm.return_value = "mock_llm_class"
|
||||
tools = swarm.initialize_tools("openai", extra_tools=["tool1", "tool2"])
|
||||
assert "tool1" in tools
|
||||
assert "tool2" in tools
|
||||
|
||||
|
||||
@patch("your_module.OpenAIEmbeddings")
|
||||
@patch("your_module.FAISS")
|
||||
def test_initialize_vectorstore(mock_faiss, mock_openai_embeddings, swarm):
|
||||
mock_openai_embeddings.return_value.embed_query = "embed_query"
|
||||
swarm.initialize_vectorstore()
|
||||
mock_faiss.assert_called_once_with(
|
||||
"embed_query", instance_of(faiss.IndexFlatL2), instance_of(InMemoryDocstore), {}
|
||||
)
|
Loading…
Reference in new issue