|
|
|
import pytest
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
from swarms.agents.agents import (
|
|
|
|
AgentNodeInitializer,
|
|
|
|
AgentNode,
|
|
|
|
agent,
|
|
|
|
) # replace with actual import
|
|
|
|
|
|
|
|
|
|
|
|
# For initializing AgentNodeInitializer in multiple tests
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_agent_node_initializer():
|
|
|
|
with patch("swarms.agents.agents.ChatOpenAI") as mock_llm, patch(
|
|
|
|
"swarms.agents.agents.AutoGPT"
|
|
|
|
) as mock_agent:
|
|
|
|
initializer = AgentNodeInitializer(
|
|
|
|
model_type="openai",
|
|
|
|
model_id="test",
|
|
|
|
openai_api_key="test_key",
|
|
|
|
temperature=0.5,
|
|
|
|
)
|
|
|
|
initializer.llm = mock_llm
|
|
|
|
initializer.tools = [Mock(spec=BaseTool)]
|
|
|
|
initializer.vectorstore = Mock()
|
|
|
|
initializer.agent = mock_agent
|
|
|
|
|
|
|
|
return initializer
|
|
|
|
|
|
|
|
|
|
|
|
# Test initialize_llm method of AgentNodeInitializer class
|
|
|
|
@pytest.mark.parametrize("model_type", ["openai", "huggingface", "invalid"])
|
|
|
|
def test_agent_node_initializer_initialize_llm(model_type, mock_agent_node_initializer):
|
|
|
|
with patch("swarms.agents.agents.ChatOpenAI") as mock_openai, patch(
|
|
|
|
"swarms.agents.agents.HuggingFaceLLM"
|
|
|
|
) as mock_huggingface:
|
|
|
|
if model_type == "invalid":
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
mock_agent_node_initializer.initialize_llm(
|
|
|
|
model_type, "model_id", "openai_api_key", 0.5
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
mock_agent_node_initializer.initialize_llm(
|
|
|
|
model_type, "model_id", "openai_api_key", 0.5
|
|
|
|
)
|
|
|
|
if model_type == "openai":
|
|
|
|
mock_openai.assert_called_once()
|
|
|
|
elif model_type == "huggingface":
|
|
|
|
mock_huggingface.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
# Test add_tool method of AgentNodeInitializer class
|
|
|
|
def test_agent_node_initializer_add_tool(mock_agent_node_initializer):
|
|
|
|
with patch("swarms.agents.agents.BaseTool") as mock_base_tool:
|
|
|
|
mock_agent_node_initializer.add_tool(mock_base_tool)
|
|
|
|
assert mock_base_tool in mock_agent_node_initializer.tools
|
|
|
|
|
|
|
|
|
|
|
|
# Test run method of AgentNodeInitializer class
|
|
|
|
@pytest.mark.parametrize("prompt", ["valid prompt", ""])
|
|
|
|
def test_agent_node_initializer_run(prompt, mock_agent_node_initializer):
|
|
|
|
if prompt == "":
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
mock_agent_node_initializer.run(prompt)
|
|
|
|
else:
|
|
|
|
assert mock_agent_node_initializer.run(prompt) == "Task completed by AgentNode"
|
|
|
|
|
|
|
|
|
|
|
|
# For initializing AgentNode in multiple tests
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_agent_node():
|
|
|
|
with patch("swarms.agents.agents.ChatOpenAI") as mock_llm, patch(
|
|
|
|
"swarms.agents.agents.AgentNodeInitializer"
|
|
|
|
) as mock_agent_node_initializer:
|
|
|
|
mock_agent_node = AgentNode("test_key")
|
|
|
|
mock_agent_node.llm_class = mock_llm
|
|
|
|
mock_agent_node.vectorstore = Mock()
|
|
|
|
mock_agent_node_initializer.llm = mock_llm
|
|
|
|
|
|
|
|
return mock_agent_node
|
|
|
|
|
|
|
|
|
|
|
|
# Test initialize_llm method of AgentNode class
|
|
|
|
@pytest.mark.parametrize("llm_class", ["openai", "huggingface"])
|
|
|
|
def test_agent_node_initialize_llm(llm_class, mock_agent_node):
|
|
|
|
with patch("swarms.agents.agents.ChatOpenAI") as mock_openai, patch(
|
|
|
|
"swarms.agents.agents.HuggingFaceLLM"
|
|
|
|
) as mock_huggingface:
|
|
|
|
mock_agent_node.initialize_llm(llm_class)
|
|
|
|
if llm_class == "openai":
|
|
|
|
mock_openai.assert_called_once()
|
|
|
|
elif llm_class == "huggingface":
|
|
|
|
mock_huggingface.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
# Test initialize_tools method of AgentNode class
|
|
|
|
def test_agent_node_initialize_tools(mock_agent_node):
|
|
|
|
with patch("swarms.agents.agents.DuckDuckGoSearchRun") as mock_ddg, patch(
|
|
|
|
"swarms.agents.agents.WriteFileTool"
|
|
|
|
) as mock_write_file, patch(
|
|
|
|
"swarms.agents.agents.ReadFileTool"
|
|
|
|
) as mock_read_file, patch(
|
|
|
|
"swarms.agents.agents.process_csv"
|
|
|
|
) as mock_process_csv, patch(
|
|
|
|
"swarms.agents.agents.WebpageQATool"
|
|
|
|
) as mock_webpage_qa:
|
|
|
|
mock_agent_node.initialize_tools("openai")
|
|
|
|
assert mock_ddg.called
|
|
|
|
assert mock_write_file.called
|
|
|
|
assert mock_read_file.called
|
|
|
|
assert mock_process_csv.called
|
|
|
|
assert mock_webpage_qa.called
|
|
|
|
|
|
|
|
|
|
|
|
# Test create_agent method of AgentNode class
|
|
|
|
def test_agent_node_create_agent(mock_agent_node):
|
|
|
|
with patch.object(mock_agent_node, "initialize_llm"), patch.object(
|
|
|
|
mock_agent_node, "initialize_tools"
|
|
|
|
), patch.object(mock_agent_node, "initialize_vectorstore"), patch(
|
|
|
|
"swarms.agents.agents.AgentNodeInitializer"
|
|
|
|
) as mock_agent_node_initializer:
|
|
|
|
mock_agent_node.create_agent()
|
|
|
|
mock_agent_node_initializer.assert_called_once()
|
|
|
|
mock_agent_node_initializer.return_value.create_agent.assert_called_once()
|
|
|
|
|
|
|
|
|
|
|
|
# Test agent function
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"openai_api_key,objective",
|
|
|
|
[("valid_key", "valid_objective"), ("", "valid_objective"), ("valid_key", "")],
|
|
|
|
)
|
|
|
|
def test_agent(openai_api_key, objective):
|
|
|
|
if openai_api_key == "" or objective == "":
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
agent(openai_api_key, objective)
|
|
|
|
else:
|
|
|
|
with patch(
|
|
|
|
"swarms.agents.agents.AgentNodeInitializer"
|
|
|
|
) as mock_agent_node_initializer:
|
|
|
|
mock_agent_node = (
|
|
|
|
mock_agent_node_initializer.return_value.create_agent.return_value
|
|
|
|
)
|
|
|
|
mock_agent_node.run.return_value = "Agent output"
|
|
|
|
result = agent(openai_api_key, objective)
|
|
|
|
assert result == "Agent output"
|