Improve tool agent exceptions

pull/744/head
ChethanUK 3 months ago
parent 515394e762
commit 8e741a1074

@ -1,3 +1,8 @@
"""
This module initializes the agents package by importing necessary components
from the swarms framework, including stopping conditions and the ToolAgent.
"""
from swarms.structs.stopping_conditions import (
check_cancelled,
check_complete,
@ -11,8 +16,13 @@ from swarms.structs.stopping_conditions import (
check_success,
)
from swarms.agents.tool_agent import ToolAgent
from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
from swarms.agents.create_agents_from_yaml import create_agents_from_yaml
from swarms.agents.exceptions import (
ErrorSeverity,
ToolAgentError,
ValidationError,
ModelNotProvidedError,
SecurityError
)
__all__ = [
@ -28,4 +38,9 @@ __all__ = [
"check_exit",
"check_end",
"create_agents_from_yaml",
"ErrorSeverity",
"ToolAgentError",
"ValidationError",
"ModelNotProvidedError",
"SecurityError"
]

@ -0,0 +1,41 @@
from enum import Enum
from typing import Dict, Optional
class ErrorSeverity(Enum):
"""Enum for error severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ToolAgentError(Exception):
"""Base exception class for ToolAgent errors."""
def __init__(
self,
message: str,
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
details: Optional[Dict] = None
):
self.severity = severity
self.details = details or {}
super().__init__(message)
class ValidationError(ToolAgentError):
"""Raised when validation fails."""
pass
class ModelNotProvidedError(ToolAgentError):
"""Raised when neither model nor llm is provided."""
pass
class SecurityError(ToolAgentError):
"""Raised when security checks fail."""
pass
class SchemaValidationError(ToolAgentError):
"""Raised when JSON schema validation fails."""
pass
class ConfigurationError(ToolAgentError):
"""Raised when there's an error in configuration."""
pass

@ -2,7 +2,13 @@ from typing import Any, Optional, Callable
from swarms.tools.json_former import Jsonformer
from swarms.utils.loguru_logger import initialize_logger
from swarms.utils.lazy_loader import lazy_import_decorator
from swarms.agents.exceptions import (
ToolAgentError,
ValidationError,
ModelNotProvidedError,
ConfigurationError,
ErrorSeverity
)
logger = initialize_logger(log_folder="tool_agent")
@ -102,57 +108,116 @@ class ToolAgent:
The output of the tool agent.
Raises:
Exception: If an error occurs during the execution of the tool agent.
ValidationError: If input validation fails
ModelNotProvidedError: If neither model nor llm is provided
ToolAgentError: For general execution errors
SchemaValidationError: If JSON schema validation fails
"""
try:
if self.model:
logger.info(f"Running {self.name} for task: {task}")
self.toolagent = Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
# Input validation
if not isinstance(task, str):
raise ValidationError(
"Task must be a string",
severity=ErrorSeverity.HIGH
)
if not task.strip():
raise ValidationError(
"Task cannot be empty",
severity=ErrorSeverity.HIGH
)
if self.parsing_function:
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
if self.model:
logger.info(f"Running {self.name} for task: {task}")
try:
self.toolagent = Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
except Exception as e:
raise ConfigurationError(
"Failed to initialize Jsonformer",
severity=ErrorSeverity.HIGH,
details={"original_error": str(e)}
)
try:
if self.parsing_function:
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
except Exception as e:
raise ToolAgentError(
"Error during task execution",
severity=ErrorSeverity.HIGH,
details={"original_error": str(e)}
)
return out
elif self.llm:
logger.info(f"Running {self.name} for task: {task}")
self.toolagent = Jsonformer(
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
if self.parsing_function:
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
try:
self.toolagent = Jsonformer(
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
except Exception as e:
raise ConfigurationError(
"Failed to initialize Jsonformer with LLM",
severity=ErrorSeverity.HIGH,
details={"original_error": str(e)}
)
try:
if self.parsing_function:
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
except Exception as e:
raise ToolAgentError(
"Error during LLM task execution",
severity=ErrorSeverity.HIGH,
details={"original_error": str(e)}
)
else:
raise Exception(
"Either model or llm should be provided to the"
" ToolAgent"
raise ModelNotProvidedError(
"Either model or llm should be provided to the ToolAgent",
severity=ErrorSeverity.CRITICAL
)
except (ValidationError, ModelNotProvidedError, ConfigurationError) as e:
# Re-raise these specific exceptions without wrapping
logger.error(
f"Error running {self.name} for task: {task}",
error_type=type(e).__name__,
severity=e.severity,
details=e.details
)
raise
except Exception as error:
# Wrap unexpected exceptions
logger.error(
f"Error running {self.name} for task: {task}"
f"Unexpected error running {self.name} for task: {task}",
error=str(error)
)
raise ToolAgentError(
f"Unexpected error in ToolAgent: {str(error)}",
severity=ErrorSeverity.CRITICAL,
details={"original_error": str(error)}
)
raise error
def __call__(self, task: str, *args, **kwargs):
return self.run(task, *args, **kwargs)

@ -0,0 +1,32 @@
from swarms.agents.exceptions import (
ErrorSeverity,
ToolAgentError,
ValidationError,
ModelNotProvidedError
)
def test_error_severity():
assert ErrorSeverity.LOW.value == "low"
assert ErrorSeverity.MEDIUM.value == "medium"
assert ErrorSeverity.HIGH.value == "high"
assert ErrorSeverity.CRITICAL.value == "critical"
def test_tool_agent_error():
error = ToolAgentError(
"Test error",
severity=ErrorSeverity.HIGH,
details={"test": "value"}
)
assert str(error) == "Test error"
assert error.severity == ErrorSeverity.HIGH
assert error.details == {"test": "value"}
def test_validation_error():
error = ValidationError("Validation failed")
assert isinstance(error, ToolAgentError)
assert str(error) == "Validation failed"
def test_model_not_provided_error():
error = ModelNotProvidedError("Model missing")
assert isinstance(error, ToolAgentError)
assert str(error) == "Model missing"
Loading…
Cancel
Save