parent
c5ea99aece
commit
aa6f6f3636
@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from pydantic import ValidationError
|
||||
from swarms.agents.tools.agent_tools import *
|
||||
from swarms.boss.boss_node import BossNodeInitializer, BossNode
|
||||
# For initializing BossNodeInitializer in multiple tests
|
||||
@pytest.fixture
|
||||
def mock_boss_node_initializer():
|
||||
llm = Mock()
|
||||
vectorstore = Mock()
|
||||
agent_executor = Mock()
|
||||
max_iterations = 5
|
||||
|
||||
boss_node_initializer = BossNodeInitializer(llm, vectorstore, agent_executor, max_iterations)
|
||||
|
||||
return boss_node_initializer
|
||||
|
||||
|
||||
# Test BossNodeInitializer class __init__ method
|
||||
def test_boss_node_initializer_init(mock_boss_node_initializer):
|
||||
with patch('swarms.agents.tools.agent_tools.BabyAGI.from_llm') as mock_from_llm:
|
||||
assert isinstance(mock_boss_node_initializer, BossNodeInitializer)
|
||||
mock_from_llm.assert_called_once()
|
||||
|
||||
|
||||
# Test initialize_vectorstore method of BossNodeInitializer class
|
||||
def test_boss_node_initializer_initialize_vectorstore(mock_boss_node_initializer):
|
||||
with patch('swarms.agents.tools.agent_tools.OpenAIEmbeddings') as mock_embeddings, \
|
||||
patch('swarms.agents.tools.agent_tools.FAISS') as mock_faiss:
|
||||
|
||||
result = mock_boss_node_initializer.initialize_vectorstore()
|
||||
mock_embeddings.assert_called_once()
|
||||
mock_faiss.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
|
||||
# Test initialize_llm method of BossNodeInitializer class
|
||||
def test_boss_node_initializer_initialize_llm(mock_boss_node_initializer):
|
||||
with patch('swarms.agents.tools.agent_tools.OpenAI') as mock_llm:
|
||||
result = mock_boss_node_initializer.initialize_llm(mock_llm)
|
||||
mock_llm.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
|
||||
# Test create_task method of BossNodeInitializer class
|
||||
@pytest.mark.parametrize("objective", ['valid objective', ''])
|
||||
def test_boss_node_initializer_create_task(objective, mock_boss_node_initializer):
|
||||
if objective == '':
|
||||
with pytest.raises(ValueError):
|
||||
mock_boss_node_initializer.create_task(objective)
|
||||
else:
|
||||
assert mock_boss_node_initializer.create_task(objective) == {"objective": objective}
|
||||
|
||||
|
||||
# Test run method of BossNodeInitializer class
|
||||
@pytest.mark.parametrize("task", ['valid task', ''])
|
||||
def test_boss_node_initializer_run(task, mock_boss_node_initializer):
|
||||
with patch.object(mock_boss_node_initializer, 'baby_agi'):
|
||||
if task == '':
|
||||
with pytest.raises(ValueError):
|
||||
mock_boss_node_initializer.run(task)
|
||||
else:
|
||||
try:
|
||||
mock_boss_node_initializer.run(task)
|
||||
mock_boss_node_initializer.baby_agi.assert_called_once_with(task)
|
||||
except Exception:
|
||||
pytest.fail("Unexpected Error!")
|
||||
|
||||
|
||||
# Test BossNode function
|
||||
@pytest.mark.parametrize("api_key, objective, llm_class, max_iterations",
|
||||
[('valid_key', 'valid_objective', OpenAI, 5),
|
||||
('', 'valid_objective', OpenAI, 5),
|
||||
('valid_key', '', OpenAI, 5),
|
||||
('valid_key', 'valid_objective', '', 5),
|
||||
('valid_key', 'valid_objective', OpenAI, 0)])
|
||||
def test_boss_node(api_key, objective, llm_class, max_iterations):
|
||||
with patch('os.getenv') as mock_getenv, \
|
||||
patch('swarms.agents.tools.agent_tools.PromptTemplate.from_template') as mock_from_template, \
|
||||
patch('swarms.agents.tools.agent_tools.LLMChain') as mock_llm_chain, \
|
||||
patch('swarms.agents.tools.agent_tools.ZeroShotAgent.create_prompt') as mock_create_prompt, \
|
||||
patch('swarms.agents.tools.agent_tools.ZeroShotAgent') as mock_zero_shot_agent, \
|
||||
patch('swarms.agents.tools.agent_tools.AgentExecutor.from_agent_and_tools') as mock_from_agent_and_tools, \
|
||||
patch('swarms.agents.tools.agent_tools.BossNodeInitializer') as mock_boss_node_initializer, \
|
||||
patch.object(mock_boss_node_initializer, 'create_task') as mock_create_task, \
|
||||
patch.object(mock_boss_node_initializer, 'run') as mock_run:
|
||||
|
||||
if api_key == '' or objective == '' or llm_class == '' or max_iterations <= 0:
|
||||
with pytest.raises(ValueError):
|
||||
BossNode(objective, api_key, vectorstore=None, worker_node=None, llm_class=llm_class, max_iterations=max_iterations, verbose=False)
|
||||
else:
|
||||
mock_getenv.return_value = 'valid_key'
|
||||
BossNode(objective, api_key, vectorstore=None, worker_node=None, llm_class=llm_class, max_iterations=max_iterations, verbose=False)
|
||||
mock_from_template.assert_called_once()
|
||||
mock_llm_chain.assert_called_once()
|
||||
mock_create_prompt.assert_called_once()
|
||||
mock_zero_shot_agent.assert_called_once()
|
||||
mock_from_agent_and_tools.assert_called_once()
|
||||
mock_boss_node_initializer.assert_called_once()
|
||||
mock_create_task.assert_called_once()
|
||||
mock_run.assert_called_once()
|
Loading…
Reference in new issue