From 8e741a1074a77e067d83f720a5cc77d4fea2720a Mon Sep 17 00:00:00 2001
From: ChethanUK <chethanuk@outlook.com>
Date: Sat, 11 Jan 2025 02:14:55 +0530
Subject: [PATCH] Improve tool agent exceptions

---
 swarms/agents/__init__.py       |  19 ++++-
 swarms/agents/exceptions.py     |  41 ++++++++++
 swarms/agents/tool_agent.py     | 141 +++++++++++++++++++++++---------
 tests/agents/test_exceptions.py |  32 ++++++++
 4 files changed, 193 insertions(+), 40 deletions(-)
 create mode 100644 swarms/agents/exceptions.py
 create mode 100644 tests/agents/test_exceptions.py

diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py
index 68f75f99..769c8cc5 100644
--- a/swarms/agents/__init__.py
+++ b/swarms/agents/__init__.py
@@ -1,3 +1,8 @@
+"""
+This module initializes the agents package by importing necessary components
+from the swarms framework, including stopping conditions and the ToolAgent.
+"""
+
 from swarms.structs.stopping_conditions import (
     check_cancelled,
     check_complete,
@@ -11,8 +16,13 @@ from swarms.structs.stopping_conditions import (
     check_success,
 )
 from swarms.agents.tool_agent import ToolAgent
-from swarms.agents.create_agents_from_yaml import (
-    create_agents_from_yaml,
+from swarms.agents.create_agents_from_yaml import create_agents_from_yaml
+from swarms.agents.exceptions import (
+    ErrorSeverity,
+    ToolAgentError,
+    ValidationError,
+    ModelNotProvidedError,
+    SecurityError
 )
 
 __all__ = [
@@ -28,4 +38,9 @@ __all__ = [
     "check_exit",
     "check_end",
     "create_agents_from_yaml",
+    "ErrorSeverity",
+    "ToolAgentError",
+    "ValidationError",
+    "ModelNotProvidedError",
+    "SecurityError"
 ]
diff --git a/swarms/agents/exceptions.py b/swarms/agents/exceptions.py
new file mode 100644
index 00000000..5d3c544e
--- /dev/null
+++ b/swarms/agents/exceptions.py
@@ -0,0 +1,41 @@
+from enum import Enum
+from typing import Dict, Optional
+
+class ErrorSeverity(Enum):
+    """Enum for error severity levels."""
+    LOW = "low"
+    MEDIUM = "medium"
+    HIGH = "high"
+    CRITICAL = "critical"
+
+class ToolAgentError(Exception):
+    """Base exception class for ToolAgent errors."""
+    def __init__(
+        self, 
+        message: str, 
+        severity: ErrorSeverity = ErrorSeverity.MEDIUM,
+        details: Optional[Dict] = None
+    ):
+        self.severity = severity
+        self.details = details or {}
+        super().__init__(message)
+
+class ValidationError(ToolAgentError):
+    """Raised when validation fails."""
+    pass
+
+class ModelNotProvidedError(ToolAgentError):
+    """Raised when neither model nor llm is provided."""
+    pass
+
+class SecurityError(ToolAgentError):
+    """Raised when security checks fail."""
+    pass
+
+class SchemaValidationError(ToolAgentError):
+    """Raised when JSON schema validation fails."""
+    pass
+
+class ConfigurationError(ToolAgentError):
+    """Raised when there's an error in configuration."""
+    pass
\ No newline at end of file
diff --git a/swarms/agents/tool_agent.py b/swarms/agents/tool_agent.py
index b686f3b0..54839675 100644
--- a/swarms/agents/tool_agent.py
+++ b/swarms/agents/tool_agent.py
@@ -2,7 +2,13 @@ from typing import Any, Optional, Callable
 from swarms.tools.json_former import Jsonformer
 from swarms.utils.loguru_logger import initialize_logger
 from swarms.utils.lazy_loader import lazy_import_decorator
-
+from swarms.agents.exceptions import (
+    ToolAgentError,
+    ValidationError,
+    ModelNotProvidedError,
+    ConfigurationError,
+    ErrorSeverity
+)
 logger = initialize_logger(log_folder="tool_agent")
 
 
@@ -102,57 +108,116 @@ class ToolAgent:
             The output of the tool agent.
 
         Raises:
-            Exception: If an error occurs during the execution of the tool agent.
+            ValidationError: If input validation fails
+            ModelNotProvidedError: If neither model nor llm is provided
+            ToolAgentError: For general execution errors
+            SchemaValidationError: If JSON schema validation fails
         """
         try:
-            if self.model:
-                logger.info(f"Running {self.name} for task: {task}")
-                self.toolagent = Jsonformer(
-                    model=self.model,
-                    tokenizer=self.tokenizer,
-                    json_schema=self.json_schema,
-                    llm=self.llm,
-                    prompt=task,
-                    max_number_tokens=self.max_number_tokens,
-                    *args,
-                    **kwargs,
+            # Input validation
+            if not isinstance(task, str):
+                raise ValidationError(
+                    "Task must be a string",
+                    severity=ErrorSeverity.HIGH
+                )
+            
+            if not task.strip():
+                raise ValidationError(
+                    "Task cannot be empty",
+                    severity=ErrorSeverity.HIGH
                 )
 
-                if self.parsing_function:
-                    out = self.parsing_function(self.toolagent())
-                else:
-                    out = self.toolagent()
+            if self.model:
+                logger.info(f"Running {self.name} for task: {task}")
+                try:
+                    self.toolagent = Jsonformer(
+                        model=self.model,
+                        tokenizer=self.tokenizer,
+                        json_schema=self.json_schema,
+                        llm=self.llm,
+                        prompt=task,
+                        max_number_tokens=self.max_number_tokens,
+                        *args,
+                        **kwargs,
+                    )
+                except Exception as e:
+                    raise ConfigurationError(
+                        "Failed to initialize Jsonformer",
+                        severity=ErrorSeverity.HIGH,
+                        details={"original_error": str(e)}
+                    )
+
+                try:
+                    if self.parsing_function:
+                        out = self.parsing_function(self.toolagent())
+                    else:
+                        out = self.toolagent()
+                    return out
+                except Exception as e:
+                    raise ToolAgentError(
+                        "Error during task execution",
+                        severity=ErrorSeverity.HIGH,
+                        details={"original_error": str(e)}
+                    )
 
-                return out
             elif self.llm:
                 logger.info(f"Running {self.name} for task: {task}")
-                self.toolagent = Jsonformer(
-                    json_schema=self.json_schema,
-                    llm=self.llm,
-                    prompt=task,
-                    max_number_tokens=self.max_number_tokens,
-                    *args,
-                    **kwargs,
-                )
-
-                if self.parsing_function:
-                    out = self.parsing_function(self.toolagent())
-                else:
-                    out = self.toolagent()
-
-                return out
+                try:
+                    self.toolagent = Jsonformer(
+                        json_schema=self.json_schema,
+                        llm=self.llm,
+                        prompt=task,
+                        max_number_tokens=self.max_number_tokens,
+                        *args,
+                        **kwargs,
+                    )
+                except Exception as e:
+                    raise ConfigurationError(
+                        "Failed to initialize Jsonformer with LLM",
+                        severity=ErrorSeverity.HIGH,
+                        details={"original_error": str(e)}
+                    )
+
+                try:
+                    if self.parsing_function:
+                        out = self.parsing_function(self.toolagent())
+                    else:
+                        out = self.toolagent()
+                    return out
+                except Exception as e:
+                    raise ToolAgentError(
+                        "Error during LLM task execution",
+                        severity=ErrorSeverity.HIGH,
+                        details={"original_error": str(e)}
+                    )
 
             else:
-                raise Exception(
-                    "Either model or llm should be provided to the"
-                    " ToolAgent"
+                raise ModelNotProvidedError(
+                    "Either model or llm should be provided to the ToolAgent",
+                    severity=ErrorSeverity.CRITICAL
                 )
 
+        except (ValidationError, ModelNotProvidedError, ConfigurationError) as e:
+            # Re-raise these specific exceptions without wrapping
+            logger.error(
+                f"Error running {self.name} for task: {task}",
+                error_type=type(e).__name__,
+                severity=e.severity,
+                details=e.details
+            )
+            raise
+
         except Exception as error:
+            # Wrap unexpected exceptions
             logger.error(
-                f"Error running {self.name} for task: {task}"
+                f"Unexpected error running {self.name} for task: {task}",
+                error=str(error)
+            )
+            raise ToolAgentError(
+                f"Unexpected error in ToolAgent: {str(error)}",
+                severity=ErrorSeverity.CRITICAL,
+                details={"original_error": str(error)}
             )
-            raise error
 
     def __call__(self, task: str, *args, **kwargs):
         return self.run(task, *args, **kwargs)
diff --git a/tests/agents/test_exceptions.py b/tests/agents/test_exceptions.py
new file mode 100644
index 00000000..6099c255
--- /dev/null
+++ b/tests/agents/test_exceptions.py
@@ -0,0 +1,32 @@
+from swarms.agents.exceptions import (
+    ErrorSeverity,
+    ToolAgentError,
+    ValidationError,
+    ModelNotProvidedError
+)
+
+def test_error_severity():
+    assert ErrorSeverity.LOW.value == "low"
+    assert ErrorSeverity.MEDIUM.value == "medium"
+    assert ErrorSeverity.HIGH.value == "high"
+    assert ErrorSeverity.CRITICAL.value == "critical"
+
+def test_tool_agent_error():
+    error = ToolAgentError(
+        "Test error",
+        severity=ErrorSeverity.HIGH,
+        details={"test": "value"}
+    )
+    assert str(error) == "Test error"
+    assert error.severity == ErrorSeverity.HIGH
+    assert error.details == {"test": "value"}
+
+def test_validation_error():
+    error = ValidationError("Validation failed")
+    assert isinstance(error, ToolAgentError)
+    assert str(error) == "Validation failed"
+
+def test_model_not_provided_error():
+    error = ModelNotProvidedError("Model missing")
+    assert isinstance(error, ToolAgentError)
+    assert str(error) == "Model missing"