diff --git a/swarms/agents/exceptions.py b/swarms/agents/exceptions.py new file mode 100644 index 00000000..a07fa88f --- /dev/null +++ b/swarms/agents/exceptions.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Optional + +class ToolAgentError(Exception): + """Base exception for all tool agent errors.""" + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + self.message = message + self.details = details or {} + super().__init__(self.message) + +class ToolExecutionError(ToolAgentError): + """Raised when a tool fails to execute.""" + def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None): + message = f"Failed to execute tool '{tool_name}': {str(error)}" + super().__init__(message, details) + +class ToolValidationError(ToolAgentError): + """Raised when tool parameters fail validation.""" + def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None): + message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}" + super().__init__(message, details) + +class ToolNotFoundError(ToolAgentError): + """Raised when a requested tool is not found.""" + def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None): + message = f"Tool '{tool_name}' not found" + super().__init__(message, details) + +class ToolParameterError(ToolAgentError): + """Raised when tool parameters are invalid.""" + def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None): + message = f"Invalid parameters for tool '{tool_name}': {error}" + super().__init__(message, details) \ No newline at end of file diff --git a/swarms/agents/tool_agent.py b/swarms/agents/tool_agent.py index 2d19ec26..d69d6c2f 100644 --- a/swarms/agents/tool_agent.py +++ b/swarms/agents/tool_agent.py @@ -1,156 +1,243 @@ -from typing import Any, Optional, Callable -from swarms.tools.json_former import Jsonformer -from swarms.utils.loguru_logger import initialize_logger - -logger = initialize_logger(log_folder="tool_agent") - +from typing import List, Optional, Dict, Any, Callable +from loguru import logger +from swarms.agents.exceptions import ( + ToolAgentError, + ToolExecutionError, + ToolValidationError, + ToolNotFoundError, + ToolParameterError +) class ToolAgent: """ - Represents a tool agent that performs a specific task using a model and tokenizer. - - Args: - name (str): The name of the tool agent. - description (str): A description of the tool agent. - model (Any): The model used by the tool agent. - tokenizer (Any): The tokenizer used by the tool agent. - json_schema (Any): The JSON schema used by the tool agent. - *args: Variable length arguments. - **kwargs: Keyword arguments. - - Attributes: - name (str): The name of the tool agent. - description (str): A description of the tool agent. - model (Any): The model used by the tool agent. - tokenizer (Any): The tokenizer used by the tool agent. - json_schema (Any): The JSON schema used by the tool agent. - - Methods: - run: Runs the tool agent for a specific task. - - Raises: - Exception: If an error occurs while running the tool agent. - - - Example: - from transformers import AutoModelForCausalLM, AutoTokenizer - from swarms import ToolAgent - - - model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b") - tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - "is_student": {"type": "boolean"}, - "courses": { - "type": "array", - "items": {"type": "string"} - } - } - } - - task = "Generate a person's information based on the following schema:" - agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema) - generated_data = agent.run(task) - - print(generated_data) + A wrapper class for vLLM that provides a similar interface to LiteLLM. + This class handles model initialization and inference using vLLM. """ def __init__( self, - name: str = "Function Calling Agent", - description: str = "Generates a function based on the input json schema and the task", - model: Any = None, - tokenizer: Any = None, - json_schema: Any = None, - max_number_tokens: int = 500, - parsing_function: Optional[Callable] = None, - llm: Any = None, + model_name: str = "meta-llama/Llama-2-7b-chat-hf", + system_prompt: Optional[str] = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + max_completion_tokens: int = 4000, + tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, + tool_choice: str = "auto", + parallel_tool_calls: bool = False, + retry_attempts: int = 3, + retry_interval: float = 1.0, *args, **kwargs, ): - super().__init__( - agent_name=name, - agent_description=description, - llm=llm, - **kwargs, + """ + Initialize the vLLM wrapper with the given parameters. + Args: + model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool): Whether to stream the output. Defaults to False. + temperature (float): The temperature for sampling. Defaults to 0.5. + max_tokens (int): The maximum number of tokens to generate. Defaults to 4000. + max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000. + tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None. + tool_choice (str): How to choose tools. Defaults to "auto". + parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False. + retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3. + retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens + self.tools_list_dictionary = tools_list_dictionary + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + self.retry_attempts = retry_attempts + self.retry_interval = retry_interval + + # Initialize vLLM + try: + self.llm = LLM(model=model_name, **kwargs) + self.sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + ) + except Exception as e: + raise ToolExecutionError( + "model_initialization", + e, + {"model_name": model_name, "kwargs": kwargs} + ) + + def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None: + """ + Validate tool parameters before execution. + Args: + tool_name (str): Name of the tool to validate + parameters (Dict[str, Any]): Parameters to validate + Raises: + ToolValidationError: If validation fails + """ + if not self.tools_list_dictionary: + raise ToolValidationError( + tool_name, + "parameters", + "No tools available for validation" + ) + + tool_spec = next( + (tool for tool in self.tools_list_dictionary if tool["name"] == tool_name), + None ) - self.name = name - self.description = description - self.model = model - self.tokenizer = tokenizer - self.json_schema = json_schema - self.max_number_tokens = max_number_tokens - self.parsing_function = parsing_function - - def run(self, task: str, *args, **kwargs): + + if not tool_spec: + raise ToolNotFoundError(tool_name) + + required_params = { + param["name"] for param in tool_spec["parameters"] + if param.get("required", True) + } + + missing_params = required_params - set(parameters.keys()) + if missing_params: + raise ToolParameterError( + tool_name, + f"Missing required parameters: {', '.join(missing_params)}" + ) + + def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any: """ - Run the tool agent for the specified task. + Execute a function with retry logic. + Args: + func (Callable): Function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + Returns: + Any: Result of the function execution + Raises: + ToolExecutionError: If all retry attempts fail + """ + last_error = None + for attempt in range(self.retry_attempts): + try: + return func(*args, **kwargs) + except Exception as e: + last_error = e + logger.warning( + f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}" + ) + if attempt < self.retry_attempts - 1: + time.sleep(self.retry_interval) + raise ToolExecutionError( + func.__name__, + last_error, + {"attempts": self.retry_attempts} + ) + + def run(self, task: str, *args, **kwargs) -> str: + """ + Run the tool agent for the specified task. Args: task (str): The task to be performed by the tool agent. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. - Returns: The output of the tool agent. - Raises: - Exception: If an error occurs during the execution of the tool agent. + ToolExecutionError: If an error occurs during execution. """ try: - if self.model: - logger.info(f"Running {self.name} for task: {task}") - self.toolagent = Jsonformer( - model=self.model, - tokenizer=self.tokenizer, - json_schema=self.json_schema, - llm=self.llm, - prompt=task, - max_number_tokens=self.max_number_tokens, - *args, - **kwargs, + if not self.llm: + raise ToolExecutionError( + "run", + Exception("LLM not initialized"), + {"task": task} ) - if self.parsing_function: - out = self.parsing_function(self.toolagent()) - else: - out = self.toolagent() - - return out - elif self.llm: - logger.info(f"Running {self.name} for task: {task}") - self.toolagent = Jsonformer( - json_schema=self.json_schema, - llm=self.llm, - prompt=task, - max_number_tokens=self.max_number_tokens, - *args, - **kwargs, - ) + logger.info(f"Running task: {task}") + + # Prepare the prompt + prompt = self._prepare_prompt(task) + + # Execute with retry logic + outputs = self._execute_with_retry( + self.llm.generate, + prompt, + self.sampling_params + ) + + response = outputs[0].outputs[0].text.strip() + return response - if self.parsing_function: - out = self.parsing_function(self.toolagent()) - else: - out = self.toolagent() + except Exception as error: + logger.error(f"Error running task: {error}") + raise ToolExecutionError( + "run", + error, + {"task": task, "args": args, "kwargs": kwargs} + ) - return out + def _prepare_prompt(self, task: str) -> str: + """ + Prepare the prompt for the given task. + Args: + task (str): The task to prepare the prompt for. + Returns: + str: The prepared prompt. + """ + if self.system_prompt: + return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" + return f"User: {task}\nAssistant:" - else: - raise Exception( - "Either model or llm should be provided to the" - " ToolAgent" - ) + def __call__(self, task: str, *args, **kwargs) -> str: + """ + Call the model for the given task. + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + Returns: + str: The model's response. + """ + return self.run(task, *args, **kwargs) + + def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: + """ + Run the model for multiple tasks in batches. + Args: + tasks (List[str]): List of tasks to run. + batch_size (int): Size of each batch. Defaults to 10. + Returns: + List[str]: List of model responses. + Raises: + ToolExecutionError: If an error occurs during batch execution. + """ + logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") + results = [] + + try: + for i in range(0, len(tasks), batch_size): + batch = tasks[i:i + batch_size] + for task in batch: + logger.info(f"Running task: {task}") + try: + result = self.run(task) + results.append(result) + except ToolExecutionError as e: + logger.error(f"Failed to execute task '{task}': {e}") + results.append(f"Error: {str(e)}") + continue + + logger.info("Completed all tasks.") + return results except Exception as error: - logger.error( - f"Error running {self.name} for task: {task}" + logger.error(f"Error in batch execution: {error}") + raise ToolExecutionError( + "batched_run", + error, + {"tasks": tasks, "batch_size": batch_size} ) - raise error - - def __call__(self, task: str, *args, **kwargs): - return self.run(task, *args, **kwargs) diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 691489c0..9f8344d0 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -1,8 +1,15 @@ from unittest.mock import Mock, patch +import pytest from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent +from swarms.agents.exceptions import ( + ToolExecutionError, + ToolValidationError, + ToolNotFoundError, + ToolParameterError +) def test_tool_agent_init(): @@ -99,3 +106,123 @@ def test_tool_agent_init_with_kwargs(): agent.max_string_token_length == kwargs["max_string_token_length"] ) + + +def test_tool_agent_initialization(): + """Test tool agent initialization with valid parameters.""" + agent = ToolAgent( + model_name="test-model", + temperature=0.7, + max_tokens=1000 + ) + assert agent.model_name == "test-model" + assert agent.temperature == 0.7 + assert agent.max_tokens == 1000 + assert agent.retry_attempts == 3 + assert agent.retry_interval == 1.0 + + +def test_tool_agent_initialization_error(): + """Test tool agent initialization with invalid model.""" + with pytest.raises(ToolExecutionError) as exc_info: + ToolAgent(model_name="invalid-model") + assert "model_initialization" in str(exc_info.value) + + +def test_tool_validation(): + """Test tool parameter validation.""" + tools_list = [{ + "name": "test_tool", + "parameters": [ + {"name": "required_param", "required": True}, + {"name": "optional_param", "required": False} + ] + }] + + agent = ToolAgent(tools_list_dictionary=tools_list) + + # Test missing required parameter + with pytest.raises(ToolParameterError) as exc_info: + agent._validate_tool("test_tool", {}) + assert "Missing required parameters" in str(exc_info.value) + + # Test valid parameters + agent._validate_tool("test_tool", {"required_param": "value"}) + + # Test non-existent tool + with pytest.raises(ToolNotFoundError) as exc_info: + agent._validate_tool("non_existent_tool", {}) + assert "Tool 'non_existent_tool' not found" in str(exc_info.value) + + +def test_retry_mechanism(): + """Test retry mechanism for failed operations.""" + mock_llm = Mock() + mock_llm.generate.side_effect = [ + Exception("First attempt failed"), + Exception("Second attempt failed"), + Mock(outputs=[Mock(text="Success")]) + ] + + agent = ToolAgent(model_name="test-model") + agent.llm = mock_llm + + # Test successful retry + result = agent.run("test task") + assert result == "Success" + assert mock_llm.generate.call_count == 3 + + # Test all retries failing + mock_llm.generate.side_effect = Exception("All attempts failed") + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task") + assert "All attempts failed" in str(exc_info.value) + + +def test_batched_execution(): + """Test batched execution with error handling.""" + mock_llm = Mock() + mock_llm.generate.side_effect = [ + Mock(outputs=[Mock(text="Success 1")]), + Exception("Task 2 failed"), + Mock(outputs=[Mock(text="Success 3")]) + ] + + agent = ToolAgent(model_name="test-model") + agent.llm = mock_llm + + tasks = ["Task 1", "Task 2", "Task 3"] + results = agent.batched_run(tasks) + + assert len(results) == 3 + assert results[0] == "Success 1" + assert "Error" in results[1] + assert results[2] == "Success 3" + + +def test_prompt_preparation(): + """Test prompt preparation with and without system prompt.""" + # Test without system prompt + agent = ToolAgent() + prompt = agent._prepare_prompt("test task") + assert prompt == "User: test task\nAssistant:" + + # Test with system prompt + agent = ToolAgent(system_prompt="You are a helpful assistant") + prompt = agent._prepare_prompt("test task") + assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:" + + +def test_tool_execution_error_handling(): + """Test error handling during tool execution.""" + agent = ToolAgent(model_name="test-model") + agent.llm = None # Simulate uninitialized LLM + + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task") + assert "LLM not initialized" in str(exc_info.value) + + # Test with invalid parameters + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task", invalid_param="value") + assert "Error running task" in str(exc_info.value)