You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/agents/agents.py

145 lines
5.3 KiB

import pytest
from unittest.mock import Mock, patch
from swarms.agents.agents import (
AgentNodeInitializer,
AgentNode,
agent,
) # replace with actual import
# For initializing AgentNodeInitializer in multiple tests
@pytest.fixture
def mock_agent_node_initializer():
with patch("swarms.agents.agents.ChatOpenAI") as mock_llm, patch(
"swarms.agents.agents.AutoGPT"
) as mock_agent:
initializer = AgentNodeInitializer(
model_type="openai",
model_id="test",
openai_api_key="test_key",
temperature=0.5,
)
initializer.llm = mock_llm
initializer.tools = [Mock(spec=BaseTool)]
initializer.vectorstore = Mock()
initializer.agent = mock_agent
return initializer
# Test initialize_llm method of AgentNodeInitializer class
@pytest.mark.parametrize("model_type", ["openai", "huggingface", "invalid"])
def test_agent_node_initializer_initialize_llm(model_type, mock_agent_node_initializer):
with patch("swarms.agents.agents.ChatOpenAI") as mock_openai, patch(
"swarms.agents.agents.HuggingFaceLLM"
) as mock_huggingface:
if model_type == "invalid":
with pytest.raises(ValueError):
mock_agent_node_initializer.initialize_llm(
model_type, "model_id", "openai_api_key", 0.5
)
else:
mock_agent_node_initializer.initialize_llm(
model_type, "model_id", "openai_api_key", 0.5
)
if model_type == "openai":
mock_openai.assert_called_once()
elif model_type == "huggingface":
mock_huggingface.assert_called_once()
# Test add_tool method of AgentNodeInitializer class
def test_agent_node_initializer_add_tool(mock_agent_node_initializer):
with patch("swarms.agents.agents.BaseTool") as mock_base_tool:
mock_agent_node_initializer.add_tool(mock_base_tool)
assert mock_base_tool in mock_agent_node_initializer.tools
# Test run method of AgentNodeInitializer class
@pytest.mark.parametrize("prompt", ["valid prompt", ""])
def test_agent_node_initializer_run(prompt, mock_agent_node_initializer):
if prompt == "":
with pytest.raises(ValueError):
mock_agent_node_initializer.run(prompt)
else:
assert mock_agent_node_initializer.run(prompt) == "Task completed by AgentNode"
# For initializing AgentNode in multiple tests
@pytest.fixture
def mock_agent_node():
with patch("swarms.agents.agents.ChatOpenAI") as mock_llm, patch(
"swarms.agents.agents.AgentNodeInitializer"
) as mock_agent_node_initializer:
mock_agent_node = AgentNode("test_key")
mock_agent_node.llm_class = mock_llm
mock_agent_node.vectorstore = Mock()
mock_agent_node_initializer.llm = mock_llm
return mock_agent_node
# Test initialize_llm method of AgentNode class
@pytest.mark.parametrize("llm_class", ["openai", "huggingface"])
def test_agent_node_initialize_llm(llm_class, mock_agent_node):
with patch("swarms.agents.agents.ChatOpenAI") as mock_openai, patch(
"swarms.agents.agents.HuggingFaceLLM"
) as mock_huggingface:
mock_agent_node.initialize_llm(llm_class)
if llm_class == "openai":
mock_openai.assert_called_once()
elif llm_class == "huggingface":
mock_huggingface.assert_called_once()
# Test initialize_tools method of AgentNode class
def test_agent_node_initialize_tools(mock_agent_node):
with patch("swarms.agents.agents.DuckDuckGoSearchRun") as mock_ddg, patch(
"swarms.agents.agents.WriteFileTool"
) as mock_write_file, patch(
"swarms.agents.agents.ReadFileTool"
) as mock_read_file, patch(
"swarms.agents.agents.process_csv"
) as mock_process_csv, patch(
"swarms.agents.agents.WebpageQATool"
) as mock_webpage_qa:
mock_agent_node.initialize_tools("openai")
assert mock_ddg.called
assert mock_write_file.called
assert mock_read_file.called
assert mock_process_csv.called
assert mock_webpage_qa.called
# Test create_agent method of AgentNode class
def test_agent_node_create_agent(mock_agent_node):
with patch.object(mock_agent_node, "initialize_llm"), patch.object(
mock_agent_node, "initialize_tools"
), patch.object(mock_agent_node, "initialize_vectorstore"), patch(
"swarms.agents.agents.AgentNodeInitializer"
) as mock_agent_node_initializer:
mock_agent_node.create_agent()
mock_agent_node_initializer.assert_called_once()
mock_agent_node_initializer.return_value.create_agent.assert_called_once()
# Test agent function
@pytest.mark.parametrize(
"openai_api_key,objective",
[("valid_key", "valid_objective"), ("", "valid_objective"), ("valid_key", "")],
)
def test_agent(openai_api_key, objective):
if openai_api_key == "" or objective == "":
with pytest.raises(ValueError):
agent(openai_api_key, objective)
else:
with patch(
"swarms.agents.agents.AgentNodeInitializer"
) as mock_agent_node_initializer:
mock_agent_node = (
mock_agent_node_initializer.return_value.create_agent.return_value
)
mock_agent_node.run.return_value = "Agent output"
result = agent(openai_api_key, objective)
assert result == "Agent output"