You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
3.3 KiB
127 lines
3.3 KiB
from pydantic import BaseModel
|
|
from typing import Optional
|
|
import json
|
|
|
|
from swarms.tools.base_tool import BaseTool
|
|
|
|
|
|
class TestModel(BaseModel):
|
|
name: str
|
|
age: int
|
|
email: Optional[str] = None
|
|
|
|
|
|
def sample_function(x: int, y: int) -> int:
|
|
"""Test function for addition."""
|
|
return x + y
|
|
|
|
|
|
def test_func_to_dict():
|
|
print("Testing func_to_dict")
|
|
tool = BaseTool()
|
|
|
|
result = tool.func_to_dict(
|
|
function=sample_function,
|
|
name="sample_function",
|
|
description="Test function",
|
|
)
|
|
|
|
assert result["type"] == "function"
|
|
assert result["function"]["name"] == "sample_function"
|
|
assert "parameters" in result["function"]
|
|
print("func_to_dict test passed")
|
|
|
|
|
|
def test_base_model_to_dict():
|
|
print("Testing base_model_to_dict")
|
|
tool = BaseTool()
|
|
|
|
result = tool.base_model_to_dict(TestModel)
|
|
|
|
assert "type" in result
|
|
assert "properties" in result["properties"]
|
|
assert "name" in result["properties"]["properties"]
|
|
print("base_model_to_dict test passed")
|
|
|
|
|
|
def test_detect_tool_input_type():
|
|
print("Testing detect_tool_input_type")
|
|
tool = BaseTool()
|
|
|
|
model = TestModel(name="Test", age=25)
|
|
assert tool.detect_tool_input_type(model) == "Pydantic"
|
|
|
|
dict_input = {"key": "value"}
|
|
assert tool.detect_tool_input_type(dict_input) == "Dictionary"
|
|
|
|
assert tool.detect_tool_input_type(sample_function) == "Function"
|
|
print("detect_tool_input_type test passed")
|
|
|
|
|
|
def test_execute_tool_by_name():
|
|
print("Testing execute_tool_by_name")
|
|
tool = BaseTool(
|
|
function_map={"sample_function": sample_function},
|
|
verbose=True,
|
|
)
|
|
|
|
response = json.dumps(
|
|
{"name": "sample_function", "parameters": {"x": 1, "y": 2}}
|
|
)
|
|
|
|
result = tool.execute_tool_by_name("sample_function", response)
|
|
assert result == 3
|
|
print("execute_tool_by_name test passed")
|
|
|
|
|
|
def test_check_str_for_functions_valid():
|
|
print("Testing check_str_for_functions_valid")
|
|
tool = BaseTool(function_map={"test_func": lambda x: x})
|
|
|
|
valid_json = json.dumps(
|
|
{"type": "function", "function": {"name": "test_func"}}
|
|
)
|
|
|
|
assert tool.check_str_for_functions_valid(valid_json) is True
|
|
|
|
invalid_json = json.dumps({"type": "invalid"})
|
|
assert tool.check_str_for_functions_valid(invalid_json) is False
|
|
print("check_str_for_functions_valid test passed")
|
|
|
|
|
|
def test_convert_funcs_into_tools():
|
|
print("Testing convert_funcs_into_tools")
|
|
tool = BaseTool(tools=[sample_function])
|
|
|
|
tool.convert_funcs_into_tools()
|
|
assert "sample_function" in tool.function_map
|
|
assert callable(tool.function_map["sample_function"])
|
|
print("convert_funcs_into_tools test passed")
|
|
|
|
|
|
def run_all_tests():
|
|
print("Starting all tests")
|
|
|
|
tests = [
|
|
test_func_to_dict,
|
|
test_base_model_to_dict,
|
|
test_detect_tool_input_type,
|
|
test_execute_tool_by_name,
|
|
test_check_str_for_functions_valid,
|
|
test_convert_funcs_into_tools,
|
|
]
|
|
|
|
for test in tests:
|
|
try:
|
|
test()
|
|
except AssertionError as e:
|
|
print(f"Test {test.__name__} failed: {str(e)}")
|
|
except Exception as e:
|
|
print(f"Unexpected error in {test.__name__}: {str(e)}")
|
|
|
|
print("All tests completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_all_tests()
|