You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.6 KiB
86 lines
2.6 KiB
import pytest
|
|
import logging
|
|
from unittest.mock import patch
|
|
from swarms.swarms.swarms import (
|
|
HierarchicalSwarm,
|
|
) # replace with your actual module name
|
|
|
|
|
|
@pytest.fixture
|
|
def swarm():
|
|
return HierarchicalSwarm(
|
|
model_id="gpt-4",
|
|
openai_api_key="some_api_key",
|
|
use_vectorstore=True,
|
|
embedding_size=1024,
|
|
use_async=False,
|
|
human_in_the_loop=True,
|
|
model_type="openai",
|
|
boss_prompt="boss",
|
|
worker_prompt="worker",
|
|
temperature=0.5,
|
|
max_iterations=100,
|
|
logging_enabled=True,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def swarm_no_logging():
|
|
return HierarchicalSwarm(logging_enabled=False)
|
|
|
|
|
|
def test_swarm_init(swarm):
|
|
assert swarm.model_id == "gpt-4"
|
|
assert swarm.openai_api_key == "some_api_key"
|
|
assert swarm.use_vectorstore
|
|
assert swarm.embedding_size == 1024
|
|
assert not swarm.use_async
|
|
assert swarm.human_in_the_loop
|
|
assert swarm.model_type == "openai"
|
|
assert swarm.boss_prompt == "boss"
|
|
assert swarm.worker_prompt == "worker"
|
|
assert swarm.temperature == 0.5
|
|
assert swarm.max_iterations == 100
|
|
assert swarm.logging_enabled
|
|
assert isinstance(swarm.logger, logging.Logger)
|
|
|
|
|
|
def test_swarm_no_logging_init(swarm_no_logging):
|
|
assert not swarm_no_logging.logging_enabled
|
|
assert swarm_no_logging.logger.disabled
|
|
|
|
|
|
@patch("your_module.OpenAI")
|
|
@patch("your_module.HuggingFaceLLM")
|
|
def test_initialize_llm(mock_huggingface, mock_openai, swarm):
|
|
swarm.initialize_llm("openai")
|
|
mock_openai.assert_called_once_with(openai_api_key="some_api_key", temperature=0.5)
|
|
|
|
swarm.initialize_llm("huggingface")
|
|
mock_huggingface.assert_called_once_with(model_id="gpt-4", temperature=0.5)
|
|
|
|
|
|
@patch("your_module.HierarchicalSwarm.initialize_llm")
|
|
def test_initialize_tools(mock_llm, swarm):
|
|
mock_llm.return_value = "mock_llm_class"
|
|
tools = swarm.initialize_tools("openai")
|
|
assert "mock_llm_class" in tools
|
|
|
|
|
|
@patch("your_module.HierarchicalSwarm.initialize_llm")
|
|
def test_initialize_tools_with_extra_tools(mock_llm, swarm):
|
|
mock_llm.return_value = "mock_llm_class"
|
|
tools = swarm.initialize_tools("openai", extra_tools=["tool1", "tool2"])
|
|
assert "tool1" in tools
|
|
assert "tool2" in tools
|
|
|
|
|
|
@patch("your_module.OpenAIEmbeddings")
|
|
@patch("your_module.FAISS")
|
|
def test_initialize_vectorstore(mock_faiss, mock_openai_embeddings, swarm):
|
|
mock_openai_embeddings.return_value.embed_query = "embed_query"
|
|
swarm.initialize_vectorstore()
|
|
mock_faiss.assert_called_once_with(
|
|
"embed_query", instance_of(faiss.IndexFlatL2), instance_of(InMemoryDocstore), {}
|
|
)
|