You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/agents/test_tool_agent.py

125 lines
3.3 KiB

from unittest.mock import Mock, patch
from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent
def test_tool_agent_init():
model = Mock(spec=AutoModelForCausalLM)
tokenizer = Mock(spec=AutoTokenizer)
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
},
}
name = "Test Agent"
description = "This is a test agent"
agent = ToolAgent(name, description, model, tokenizer, json_schema)
assert agent.name == name
assert agent.description == description
assert agent.model == model
assert agent.tokenizer == tokenizer
assert agent.json_schema == json_schema
@patch.object(ToolAgent, "run")
def test_tool_agent_run(mock_run):
model = Mock(spec=AutoModelForCausalLM)
tokenizer = Mock(spec=AutoTokenizer)
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
},
}
name = "Test Agent"
description = "This is a test agent"
task = ("Generate a person's information based on the following"
" schema:")
agent = ToolAgent(name, description, model, tokenizer, json_schema)
agent.run(task)
mock_run.assert_called_once_with(task)
def test_tool_agent_init_with_kwargs():
model = Mock(spec=AutoModelForCausalLM)
tokenizer = Mock(spec=AutoTokenizer)
json_schema = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "number"
},
"is_student": {
"type": "boolean"
},
"courses": {
"type": "array",
"items": {
"type": "string"
}
},
},
}
name = "Test Agent"
description = "This is a test agent"
kwargs = {
"debug": True,
"max_array_length": 20,
"max_number_tokens": 12,
"temperature": 0.5,
"max_string_token_length": 20,
}
agent = ToolAgent(name, description, model, tokenizer, json_schema,
**kwargs)
assert agent.name == name
assert agent.description == description
assert agent.model == model
assert agent.tokenizer == tokenizer
assert agent.json_schema == json_schema
assert agent.debug == kwargs["debug"]
assert agent.max_array_length == kwargs["max_array_length"]
assert agent.max_number_tokens == kwargs["max_number_tokens"]
assert agent.temperature == kwargs["temperature"]
assert (agent.max_string_token_length == kwargs["max_string_token_length"])