|
|
|
@ -5,8 +5,8 @@ from swarms.agents.agents import AgentNodeInitializer, AgentNode, agent # repla
|
|
|
|
|
# For initializing AgentNodeInitializer in multiple tests
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_agent_node_initializer():
|
|
|
|
|
with patch('your_module.ChatOpenAI') as mock_llm, \
|
|
|
|
|
patch('your_module.AutoGPT') as mock_agent:
|
|
|
|
|
with patch('swarms.agents.agents.ChatOpenAI') as mock_llm, \
|
|
|
|
|
patch('swarms.agents.agents.AutoGPT') as mock_agent:
|
|
|
|
|
|
|
|
|
|
initializer = AgentNodeInitializer(model_type='openai', model_id='test', openai_api_key='test_key', temperature=0.5)
|
|
|
|
|
initializer.llm = mock_llm
|
|
|
|
@ -20,8 +20,8 @@ def mock_agent_node_initializer():
|
|
|
|
|
# Test initialize_llm method of AgentNodeInitializer class
|
|
|
|
|
@pytest.mark.parametrize("model_type", ['openai', 'huggingface', 'invalid'])
|
|
|
|
|
def test_agent_node_initializer_initialize_llm(model_type, mock_agent_node_initializer):
|
|
|
|
|
with patch('your_module.ChatOpenAI') as mock_openai, \
|
|
|
|
|
patch('your_module.HuggingFaceLLM') as mock_huggingface:
|
|
|
|
|
with patch('swarms.agents.agents.ChatOpenAI') as mock_openai, \
|
|
|
|
|
patch('swarms.agents.agents.HuggingFaceLLM') as mock_huggingface:
|
|
|
|
|
|
|
|
|
|
if model_type == 'invalid':
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
@ -36,7 +36,7 @@ def test_agent_node_initializer_initialize_llm(model_type, mock_agent_node_initi
|
|
|
|
|
|
|
|
|
|
# Test add_tool method of AgentNodeInitializer class
|
|
|
|
|
def test_agent_node_initializer_add_tool(mock_agent_node_initializer):
|
|
|
|
|
with patch('your_module.BaseTool') as mock_base_tool:
|
|
|
|
|
with patch('swarms.agents.agents.BaseTool') as mock_base_tool:
|
|
|
|
|
mock_agent_node_initializer.add_tool(mock_base_tool)
|
|
|
|
|
assert mock_base_tool in mock_agent_node_initializer.tools
|
|
|
|
|
|
|
|
|
@ -50,12 +50,11 @@ def test_agent_node_initializer_run(prompt, mock_agent_node_initializer):
|
|
|
|
|
else:
|
|
|
|
|
assert mock_agent_node_initializer.run(prompt) == "Task completed by AgentNode"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# For initializing AgentNode in multiple tests
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def mock_agent_node():
|
|
|
|
|
with patch('your_module.ChatOpenAI') as mock_llm, \
|
|
|
|
|
patch('your_module.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
with patch('swarms.agents.agents.ChatOpenAI') as mock_llm, \
|
|
|
|
|
patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
|
|
|
|
|
mock_agent_node = AgentNode('test_key')
|
|
|
|
|
mock_agent_node.llm_class = mock_llm
|
|
|
|
@ -64,12 +63,11 @@ def mock_agent_node():
|
|
|
|
|
|
|
|
|
|
return mock_agent_node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test initialize_llm method of AgentNode class
|
|
|
|
|
@pytest.mark.parametrize("llm_class", ['openai', 'huggingface'])
|
|
|
|
|
def test_agent_node_initialize_llm(llm_class, mock_agent_node):
|
|
|
|
|
with patch('your_module.ChatOpenAI') as mock_openai, \
|
|
|
|
|
patch('your_module.HuggingFaceLLM') as mock_huggingface:
|
|
|
|
|
with patch('swarms.agents.agents.ChatOpenAI') as mock_openai, \
|
|
|
|
|
patch('swarms.agents.agents.HuggingFaceLLM') as mock_huggingface:
|
|
|
|
|
|
|
|
|
|
mock_agent_node.initialize_llm(llm_class)
|
|
|
|
|
if llm_class == 'openai':
|
|
|
|
@ -77,14 +75,13 @@ def test_agent_node_initialize_llm(llm_class, mock_agent_node):
|
|
|
|
|
elif llm_class == 'huggingface':
|
|
|
|
|
mock_huggingface.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test initialize_tools method of AgentNode class
|
|
|
|
|
def test_agent_node_initialize_tools(mock_agent_node):
|
|
|
|
|
with patch('your_module.DuckDuckGoSearchRun') as mock_ddg, \
|
|
|
|
|
patch('your_module.WriteFileTool') as mock_write_file, \
|
|
|
|
|
patch('your_module.ReadFileTool') as mock_read_file, \
|
|
|
|
|
patch('your_module.process_csv') as mock_process_csv, \
|
|
|
|
|
patch('your_module.WebpageQATool') as mock_webpage_qa:
|
|
|
|
|
with patch('swarms.agents.agents.DuckDuckGoSearchRun') as mock_ddg, \
|
|
|
|
|
patch('swarms.agents.agents.WriteFileTool') as mock_write_file, \
|
|
|
|
|
patch('swarms.agents.agents.ReadFileTool') as mock_read_file, \
|
|
|
|
|
patch('swarms.agents.agents.process_csv') as mock_process_csv, \
|
|
|
|
|
patch('swarms.agents.agents.WebpageQATool') as mock_webpage_qa:
|
|
|
|
|
|
|
|
|
|
mock_agent_node.initialize_tools('openai')
|
|
|
|
|
assert mock_ddg.called
|
|
|
|
@ -99,7 +96,7 @@ def test_agent_node_create_agent(mock_agent_node):
|
|
|
|
|
with patch.object(mock_agent_node, 'initialize_llm'), \
|
|
|
|
|
patch.object(mock_agent_node, 'initialize_tools'), \
|
|
|
|
|
patch.object(mock_agent_node, 'initialize_vectorstore'), \
|
|
|
|
|
patch('your_module.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
|
|
|
|
|
mock_agent_node.create_agent()
|
|
|
|
|
mock_agent_node_initializer.assert_called_once()
|
|
|
|
@ -113,7 +110,7 @@ def test_agent(openai_api_key, objective):
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
agent(openai_api_key, objective)
|
|
|
|
|
else:
|
|
|
|
|
with patch('your_module.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
with patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer:
|
|
|
|
|
mock_agent_node = mock_agent_node_initializer.return_value.create_agent.return_value
|
|
|
|
|
mock_agent_node.run.return_value = 'Agent output'
|
|
|
|
|
result = agent(openai_api_key, objective)
|
|
|
|
|