from unittest.mock import Mock, patch from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent def test_tool_agent_init(): model = Mock(spec=AutoModelForCausalLM) tokenizer = Mock(spec=AutoTokenizer) json_schema = { "type": "object", "properties": { "name": { "type": "string" }, "age": { "type": "number" }, "is_student": { "type": "boolean" }, "courses": { "type": "array", "items": { "type": "string" } }, }, } name = "Test Agent" description = "This is a test agent" agent = ToolAgent(name, description, model, tokenizer, json_schema) assert agent.name == name assert agent.description == description assert agent.model == model assert agent.tokenizer == tokenizer assert agent.json_schema == json_schema @patch.object(ToolAgent, "run") def test_tool_agent_run(mock_run): model = Mock(spec=AutoModelForCausalLM) tokenizer = Mock(spec=AutoTokenizer) json_schema = { "type": "object", "properties": { "name": { "type": "string" }, "age": { "type": "number" }, "is_student": { "type": "boolean" }, "courses": { "type": "array", "items": { "type": "string" } }, }, } name = "Test Agent" description = "This is a test agent" task = ("Generate a person's information based on the following" " schema:") agent = ToolAgent(name, description, model, tokenizer, json_schema) agent.run(task) mock_run.assert_called_once_with(task) def test_tool_agent_init_with_kwargs(): model = Mock(spec=AutoModelForCausalLM) tokenizer = Mock(spec=AutoTokenizer) json_schema = { "type": "object", "properties": { "name": { "type": "string" }, "age": { "type": "number" }, "is_student": { "type": "boolean" }, "courses": { "type": "array", "items": { "type": "string" } }, }, } name = "Test Agent" description = "This is a test agent" kwargs = { "debug": True, "max_array_length": 20, "max_number_tokens": 12, "temperature": 0.5, "max_string_token_length": 20, } agent = ToolAgent(name, description, model, tokenizer, json_schema, **kwargs) assert agent.name == name assert agent.description == description assert agent.model == model assert agent.tokenizer == tokenizer assert agent.json_schema == json_schema assert agent.debug == kwargs["debug"] assert agent.max_array_length == kwargs["max_array_length"] assert agent.max_number_tokens == kwargs["max_number_tokens"] assert agent.temperature == kwargs["temperature"] assert (agent.max_string_token_length == kwargs["max_string_token_length"])