pull/801/merge
Pavan Kumar 4 days ago committed by GitHub
commit be5a744319
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
class ToolAgentError(Exception):
"""Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}"
super().__init__(message, details)
class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details)
class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
message = f"Tool '{tool_name}' not found"
super().__init__(message, details)
class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}"
super().__init__(message, details)

@ -1,156 +1,243 @@
from typing import Any, Optional, Callable from typing import List, Optional, Dict, Any, Callable
from swarms.tools.json_former import Jsonformer from loguru import logger
from swarms.utils.loguru_logger import initialize_logger from swarms.agents.exceptions import (
ToolAgentError,
logger = initialize_logger(log_folder="tool_agent") ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
class ToolAgent: class ToolAgent:
""" """
Represents a tool agent that performs a specific task using a model and tokenizer. A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
"""
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
system_prompt: Optional[str] = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
max_completion_tokens: int = 4000,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
retry_attempts: int = 3,
retry_interval: float = 1.0,
*args,
**kwargs,
):
"""
Initialize the vLLM wrapper with the given parameters.
Args: Args:
name (str): The name of the tool agent. model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
description (str): A description of the tool agent. system_prompt (str, optional): The system prompt to use. Defaults to None.
model (Any): The model used by the tool agent. stream (bool): Whether to stream the output. Defaults to False.
tokenizer (Any): The tokenizer used by the tool agent. temperature (float): The temperature for sampling. Defaults to 0.5.
json_schema (Any): The JSON schema used by the tool agent. max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
*args: Variable length arguments. max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
**kwargs: Keyword arguments. tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
Attributes: parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
name (str): The name of the tool agent. retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
description (str): A description of the tool agent. retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
model (Any): The model used by the tool agent. """
tokenizer (Any): The tokenizer used by the tool agent. self.model_name = model_name
json_schema (Any): The JSON schema used by the tool agent. self.system_prompt = system_prompt
self.stream = stream
Methods: self.temperature = temperature
run: Runs the tool agent for a specific task. self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
self.retry_attempts = retry_attempts
self.retry_interval = retry_interval
# Initialize vLLM
try:
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
except Exception as e:
raise ToolExecutionError(
"model_initialization",
e,
{"model_name": model_name, "kwargs": kwargs}
)
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
"""
Validate tool parameters before execution.
Args:
tool_name (str): Name of the tool to validate
parameters (Dict[str, Any]): Parameters to validate
Raises: Raises:
Exception: If an error occurs while running the tool agent. ToolValidationError: If validation fails
"""
if not self.tools_list_dictionary:
Example: raise ToolValidationError(
from transformers import AutoModelForCausalLM, AutoTokenizer tool_name,
from swarms import ToolAgent "parameters",
"No tools available for validation"
)
tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
None
)
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b") if not tool_spec:
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") raise ToolNotFoundError(tool_name)
json_schema = { required_params = {
"type": "object", param["name"] for param in tool_spec["parameters"]
"properties": { if param.get("required", True)
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {
"type": "array",
"items": {"type": "string"}
}
}
} }
task = "Generate a person's information based on the following schema:" missing_params = required_params - set(parameters.keys())
agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema) if missing_params:
generated_data = agent.run(task) raise ToolParameterError(
tool_name,
f"Missing required parameters: {', '.join(missing_params)}"
)
print(generated_data) def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
""" """
Execute a function with retry logic.
Args:
func (Callable): Function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Any: Result of the function execution
Raises:
ToolExecutionError: If all retry attempts fail
"""
last_error = None
for attempt in range(self.retry_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
logger.warning(
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
)
if attempt < self.retry_attempts - 1:
time.sleep(self.retry_interval)
def __init__( raise ToolExecutionError(
self, func.__name__,
name: str = "Function Calling Agent", last_error,
description: str = "Generates a function based on the input json schema and the task", {"attempts": self.retry_attempts}
model: Any = None,
tokenizer: Any = None,
json_schema: Any = None,
max_number_tokens: int = 500,
parsing_function: Optional[Callable] = None,
llm: Any = None,
*args,
**kwargs,
):
super().__init__(
agent_name=name,
agent_description=description,
llm=llm,
**kwargs,
) )
self.name = name
self.description = description
self.model = model
self.tokenizer = tokenizer
self.json_schema = json_schema
self.max_number_tokens = max_number_tokens
self.parsing_function = parsing_function
def run(self, task: str, *args, **kwargs): def run(self, task: str, *args, **kwargs) -> str:
""" """
Run the tool agent for the specified task. Run the tool agent for the specified task.
Args: Args:
task (str): The task to be performed by the tool agent. task (str): The task to be performed by the tool agent.
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Returns: Returns:
The output of the tool agent. The output of the tool agent.
Raises: Raises:
Exception: If an error occurs during the execution of the tool agent. ToolExecutionError: If an error occurs during execution.
""" """
try: try:
if self.model: if not self.llm:
logger.info(f"Running {self.name} for task: {task}") raise ToolExecutionError(
self.toolagent = Jsonformer( "run",
model=self.model, Exception("LLM not initialized"),
tokenizer=self.tokenizer, {"task": task}
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
) )
if self.parsing_function: logger.info(f"Running task: {task}")
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
elif self.llm:
logger.info(f"Running {self.name} for task: {task}")
self.toolagent = Jsonformer(
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
if self.parsing_function: # Prepare the prompt
out = self.parsing_function(self.toolagent()) prompt = self._prepare_prompt(task)
else:
out = self.toolagent()
return out # Execute with retry logic
outputs = self._execute_with_retry(
else: self.llm.generate,
raise Exception( prompt,
"Either model or llm should be provided to the" self.sampling_params
" ToolAgent"
) )
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error: except Exception as error:
logger.error( logger.error(f"Error running task: {error}")
f"Error running {self.name} for task: {task}" raise ToolExecutionError(
"run",
error,
{"task": task, "args": args, "kwargs": kwargs}
) )
raise error
def __call__(self, task: str, *args, **kwargs): def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
Raises:
ToolExecutionError: If an error occurs during batch execution.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
try:
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
try:
result = self.run(task)
results.append(result)
except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}")
results.append(f"Error: {str(e)}")
continue
logger.info("Completed all tasks.")
return results
except Exception as error:
logger.error(f"Error in batch execution: {error}")
raise ToolExecutionError(
"batched_run",
error,
{"tasks": tasks, "batch_size": batch_size}
)

@ -1,8 +1,15 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent from swarms import ToolAgent
from swarms.agents.exceptions import (
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
def test_tool_agent_init(): def test_tool_agent_init():
@ -99,3 +106,123 @@ def test_tool_agent_init_with_kwargs():
agent.max_string_token_length agent.max_string_token_length
== kwargs["max_string_token_length"] == kwargs["max_string_token_length"]
) )
def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters."""
agent = ToolAgent(
model_name="test-model",
temperature=0.7,
max_tokens=1000
)
assert agent.model_name == "test-model"
assert agent.temperature == 0.7
assert agent.max_tokens == 1000
assert agent.retry_attempts == 3
assert agent.retry_interval == 1.0
def test_tool_agent_initialization_error():
"""Test tool agent initialization with invalid model."""
with pytest.raises(ToolExecutionError) as exc_info:
ToolAgent(model_name="invalid-model")
assert "model_initialization" in str(exc_info.value)
def test_tool_validation():
"""Test tool parameter validation."""
tools_list = [{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False}
]
}]
agent = ToolAgent(tools_list_dictionary=tools_list)
# Test missing required parameter
with pytest.raises(ToolParameterError) as exc_info:
agent._validate_tool("test_tool", {})
assert "Missing required parameters" in str(exc_info.value)
# Test valid parameters
agent._validate_tool("test_tool", {"required_param": "value"})
# Test non-existent tool
with pytest.raises(ToolNotFoundError) as exc_info:
agent._validate_tool("non_existent_tool", {})
assert "Tool 'non_existent_tool' not found" in str(exc_info.value)
def test_retry_mechanism():
"""Test retry mechanism for failed operations."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Exception("First attempt failed"),
Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
# Test successful retry
result = agent.run("test task")
assert result == "Success"
assert mock_llm.generate.call_count == 3
# Test all retries failing
mock_llm.generate.side_effect = Exception("All attempts failed")
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "All attempts failed" in str(exc_info.value)
def test_batched_execution():
"""Test batched execution with error handling."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
tasks = ["Task 1", "Task 2", "Task 3"]
results = agent.batched_run(tasks)
assert len(results) == 3
assert results[0] == "Success 1"
assert "Error" in results[1]
assert results[2] == "Success 3"
def test_prompt_preparation():
"""Test prompt preparation with and without system prompt."""
# Test without system prompt
agent = ToolAgent()
prompt = agent._prepare_prompt("test task")
assert prompt == "User: test task\nAssistant:"
# Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:"
def test_tool_execution_error_handling():
"""Test error handling during tool execution."""
agent = ToolAgent(model_name="test-model")
agent.llm = None # Simulate uninitialized LLM
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "LLM not initialized" in str(exc_info.value)
# Test with invalid parameters
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task", invalid_param="value")
assert "Error running task" in str(exc_info.value)

Loading…
Cancel
Save