You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/tools/test_base_tool.py

127 lines
3.3 KiB

from pydantic import BaseModel
from typing import Optional
import json
from swarms.tools.base_tool import BaseTool
class TestModel(BaseModel):
name: str
age: int
email: Optional[str] = None
def sample_function(x: int, y: int) -> int:
"""Test function for addition."""
return x + y
def test_func_to_dict():
print("Testing func_to_dict")
tool = BaseTool()
result = tool.func_to_dict(
function=sample_function,
name="sample_function",
description="Test function",
)
assert result["type"] == "function"
assert result["function"]["name"] == "sample_function"
assert "parameters" in result["function"]
print("func_to_dict test passed")
def test_base_model_to_dict():
print("Testing base_model_to_dict")
tool = BaseTool()
result = tool.base_model_to_dict(TestModel)
assert "type" in result
assert "properties" in result["properties"]
assert "name" in result["properties"]["properties"]
print("base_model_to_dict test passed")
def test_detect_tool_input_type():
print("Testing detect_tool_input_type")
tool = BaseTool()
model = TestModel(name="Test", age=25)
assert tool.detect_tool_input_type(model) == "Pydantic"
dict_input = {"key": "value"}
assert tool.detect_tool_input_type(dict_input) == "Dictionary"
assert tool.detect_tool_input_type(sample_function) == "Function"
print("detect_tool_input_type test passed")
def test_execute_tool_by_name():
print("Testing execute_tool_by_name")
tool = BaseTool(
function_map={"sample_function": sample_function},
verbose=True,
)
response = json.dumps(
{"name": "sample_function", "parameters": {"x": 1, "y": 2}}
)
result = tool.execute_tool_by_name("sample_function", response)
assert result == 3
print("execute_tool_by_name test passed")
def test_check_str_for_functions_valid():
print("Testing check_str_for_functions_valid")
tool = BaseTool(function_map={"test_func": lambda x: x})
valid_json = json.dumps(
{"type": "function", "function": {"name": "test_func"}}
)
assert tool.check_str_for_functions_valid(valid_json) is True
invalid_json = json.dumps({"type": "invalid"})
assert tool.check_str_for_functions_valid(invalid_json) is False
print("check_str_for_functions_valid test passed")
def test_convert_funcs_into_tools():
print("Testing convert_funcs_into_tools")
tool = BaseTool(tools=[sample_function])
tool.convert_funcs_into_tools()
assert "sample_function" in tool.function_map
assert callable(tool.function_map["sample_function"])
print("convert_funcs_into_tools test passed")
def run_all_tests():
print("Starting all tests")
tests = [
test_func_to_dict,
test_base_model_to_dict,
test_detect_tool_input_type,
test_execute_tool_by_name,
test_check_str_for_functions_valid,
test_convert_funcs_into_tools,
]
for test in tests:
try:
test()
except AssertionError as e:
print(f"Test {test.__name__} failed: {str(e)}")
except Exception as e:
print(f"Unexpected error in {test.__name__}: {str(e)}")
print("All tests completed")
if __name__ == "__main__":
run_all_tests()