|
|
|
from unittest.mock import Mock, patch
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
from swarms import ToolAgent
|
|
|
|
|
|
|
|
|
|
|
|
def test_tool_agent_init():
|
|
|
|
model = Mock(spec=AutoModelForCausalLM)
|
|
|
|
tokenizer = Mock(spec=AutoTokenizer)
|
|
|
|
json_schema = {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"name": {"type": "string"},
|
|
|
|
"age": {"type": "number"},
|
|
|
|
"is_student": {"type": "boolean"},
|
|
|
|
"courses": {"type": "array", "items": {"type": "string"}},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
name = "Test Agent"
|
|
|
|
description = "This is a test agent"
|
|
|
|
|
|
|
|
agent = ToolAgent(
|
|
|
|
name, description, model, tokenizer, json_schema
|
|
|
|
)
|
|
|
|
|
|
|
|
assert agent.name == name
|
|
|
|
assert agent.description == description
|
|
|
|
assert agent.model == model
|
|
|
|
assert agent.tokenizer == tokenizer
|
|
|
|
assert agent.json_schema == json_schema
|
|
|
|
|
|
|
|
|
|
|
|
@patch.object(ToolAgent, "run")
|
|
|
|
def test_tool_agent_run(mock_run):
|
|
|
|
model = Mock(spec=AutoModelForCausalLM)
|
|
|
|
tokenizer = Mock(spec=AutoTokenizer)
|
|
|
|
json_schema = {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"name": {"type": "string"},
|
|
|
|
"age": {"type": "number"},
|
|
|
|
"is_student": {"type": "boolean"},
|
|
|
|
"courses": {"type": "array", "items": {"type": "string"}},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
name = "Test Agent"
|
|
|
|
description = "This is a test agent"
|
|
|
|
task = (
|
|
|
|
"Generate a person's information based on the following"
|
|
|
|
" schema:"
|
|
|
|
)
|
|
|
|
|
|
|
|
agent = ToolAgent(
|
|
|
|
name, description, model, tokenizer, json_schema
|
|
|
|
)
|
|
|
|
agent.run(task)
|
|
|
|
|
|
|
|
mock_run.assert_called_once_with(task)
|
|
|
|
|
|
|
|
|
|
|
|
def test_tool_agent_init_with_kwargs():
|
|
|
|
model = Mock(spec=AutoModelForCausalLM)
|
|
|
|
tokenizer = Mock(spec=AutoTokenizer)
|
|
|
|
json_schema = {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"name": {"type": "string"},
|
|
|
|
"age": {"type": "number"},
|
|
|
|
"is_student": {"type": "boolean"},
|
|
|
|
"courses": {"type": "array", "items": {"type": "string"}},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
name = "Test Agent"
|
|
|
|
description = "This is a test agent"
|
|
|
|
|
|
|
|
kwargs = {
|
|
|
|
"debug": True,
|
|
|
|
"max_array_length": 20,
|
|
|
|
"max_number_tokens": 12,
|
|
|
|
"temperature": 0.5,
|
|
|
|
"max_string_token_length": 20,
|
|
|
|
}
|
|
|
|
|
|
|
|
agent = ToolAgent(
|
|
|
|
name, description, model, tokenizer, json_schema, **kwargs
|
|
|
|
)
|
|
|
|
|
|
|
|
assert agent.name == name
|
|
|
|
assert agent.description == description
|
|
|
|
assert agent.model == model
|
|
|
|
assert agent.tokenizer == tokenizer
|
|
|
|
assert agent.json_schema == json_schema
|
|
|
|
assert agent.debug == kwargs["debug"]
|
|
|
|
assert agent.max_array_length == kwargs["max_array_length"]
|
|
|
|
assert agent.max_number_tokens == kwargs["max_number_tokens"]
|
|
|
|
assert agent.temperature == kwargs["temperature"]
|
|
|
|
assert (
|
|
|
|
agent.max_string_token_length
|
|
|
|
== kwargs["max_string_token_length"]
|
|
|
|
)
|