diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..4d18c26e --- /dev/null +++ b/.flake8 @@ -0,0 +1,22 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203, W503 +exclude = + .git, + __pycache__, + build, + dist, + *.egg-info, + .eggs, + .tox, + .venv, + venv, + .env, + .pytest_cache, + .coverage, + htmlcov, + .mypy_cache, + .ruff_cache +per-file-ignores = + __init__.py: F401 +max-complexity = 10 \ No newline at end of file diff --git a/examples/typedb_example.py b/examples/typedb_example.py new file mode 100644 index 00000000..8eacdc94 --- /dev/null +++ b/examples/typedb_example.py @@ -0,0 +1,106 @@ +from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig + +def main(): + # Initialize TypeDB wrapper with custom configuration + config = TypeDBConfig( + uri="localhost:1729", + database="swarms_example", + username="admin", + password="password" + ) + + # Define schema for a simple knowledge graph + schema = """ + define + person sub entity, + owns name: string, + owns age: long, + plays role; + + role sub entity, + owns title: string, + owns department: string; + + works_at sub relation, + relates person, + relates role; + """ + + # Example data insertion + insert_queries = [ + """ + insert + $p isa person, has name "John Doe", has age 30; + $r isa role, has title "Software Engineer", has department "Engineering"; + (person: $p, role: $r) isa works_at; + """, + """ + insert + $p isa person, has name "Jane Smith", has age 28; + $r isa role, has title "Data Scientist", has department "Data Science"; + (person: $p, role: $r) isa works_at; + """ + ] + + # Example queries + query_queries = [ + # Get all people + "match $p isa person; get;", + + # Get people in Engineering department + """ + match + $p isa person; + $r isa role, has department "Engineering"; + (person: $p, role: $r) isa works_at; + get $p; + """, + + # Get people with their roles + """ + match + $p isa person, has name $n; + $r isa role, has title $t; + (person: $p, role: $r) isa works_at; + get $n, $t; + """ + ] + + try: + with TypeDBWrapper(config) as db: + # Define schema + print("Defining schema...") + db.define_schema(schema) + + # Insert data + print("\nInserting data...") + for query in insert_queries: + db.insert_data(query) + + # Query data + print("\nQuerying data...") + for i, query in enumerate(query_queries, 1): + print(f"\nQuery {i}:") + results = db.query_data(query) + print(f"Results: {results}") + + # Example of deleting data + print("\nDeleting data...") + delete_query = """ + match + $p isa person, has name "John Doe"; + delete $p; + """ + db.delete_data(delete_query) + + # Verify deletion + print("\nVerifying deletion...") + verify_query = "match $p isa person, has name $n; get $n;" + results = db.query_data(verify_query) + print(f"Remaining people: {results}") + + except Exception as e: + print(f"Error: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/vllm_example.py b/examples/vllm_example.py new file mode 100644 index 00000000..231a68fc --- /dev/null +++ b/examples/vllm_example.py @@ -0,0 +1,44 @@ +from swarms.utils.vllm_wrapper import VLLMWrapper + +def main(): + # Initialize the vLLM wrapper with a model + # Note: You'll need to have the model downloaded or specify a HuggingFace model ID + llm = VLLMWrapper( + model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID + temperature=0.7, + max_tokens=1000, + ) + + # Example task + task = "What are the benefits of using vLLM for inference?" + + # Run inference + response = llm.run(task) + print("Response:", response) + + # Example with system prompt + llm_with_system = VLLMWrapper( + model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID + system_prompt="You are a helpful AI assistant that provides concise answers.", + temperature=0.7, + ) + + # Run inference with system prompt + response = llm_with_system.run(task) + print("\nResponse with system prompt:", response) + + # Example with batched inference + tasks = [ + "What is vLLM?", + "How does vLLM improve inference speed?", + "What are the main features of vLLM?" + ] + + responses = llm.batched_run(tasks, batch_size=2) + print("\nBatched responses:") + for task, response in zip(tasks, responses): + print(f"\nTask: {task}") + print(f"Response: {response}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyupgrade.ini b/pyupgrade.ini new file mode 100644 index 00000000..3f064eff --- /dev/null +++ b/pyupgrade.ini @@ -0,0 +1,4 @@ +[pyupgrade] +py3-plus = True +py39-plus = True +keep-runtime-typing = True \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 913e77de..6b50d912 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ - torch>=2.1.1,<3.0 transformers>=4.39.0,<4.49.0 asyncio>=3.4.3,<4.0 @@ -23,3 +22,13 @@ pytest>=8.1.1 networkx aiofiles httpx +vllm>=0.2.0 +flake8>=6.1.0 +flake8-bugbear>=23.3.12 +flake8-comprehensions>=3.12.0 +flake8-simplify>=0.19.3 +flake8-unused-arguments>=0.0.4 +pyupgrade>=3.15.0 +typedb-client>=2.25.0 +typedb-protocol>=2.25.0 +typedb-driver>=2.25.0 diff --git a/scripts/check_code_quality.py b/scripts/check_code_quality.py new file mode 100755 index 00000000..fdfc3ccc --- /dev/null +++ b/scripts/check_code_quality.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +import subprocess +import sys +from pathlib import Path + +def run_command(command: list[str], cwd: Path) -> bool: + """Run a command and return True if successful.""" + try: + result = subprocess.run( + command, + cwd=cwd, + capture_output=True, + text=True, + check=True + ) + return True + except subprocess.CalledProcessError as e: + print(f"Error running {' '.join(command)}:") + print(e.stdout) + print(e.stderr, file=sys.stderr) + return False + +def main(): + """Run all code quality checks.""" + root_dir = Path(__file__).parent.parent + success = True + + # Run flake8 + print("\nRunning flake8...") + if not run_command(["flake8", "swarms", "tests"], root_dir): + success = False + + # Run pyupgrade + print("\nRunning pyupgrade...") + if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir): + success = False + + # Run black + print("\nRunning black...") + if not run_command(["black", "--check", "swarms", "tests"], root_dir): + success = False + + # Run ruff + print("\nRunning ruff...") + if not run_command(["ruff", "check", "swarms", "tests"], root_dir): + success = False + + if not success: + print("\nCode quality checks failed. Please fix the issues and try again.") + sys.exit(1) + else: + print("\nAll code quality checks passed!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/swarms/agents/exceptions.py b/swarms/agents/exceptions.py new file mode 100644 index 00000000..a07fa88f --- /dev/null +++ b/swarms/agents/exceptions.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Optional + +class ToolAgentError(Exception): + """Base exception for all tool agent errors.""" + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + self.message = message + self.details = details or {} + super().__init__(self.message) + +class ToolExecutionError(ToolAgentError): + """Raised when a tool fails to execute.""" + def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None): + message = f"Failed to execute tool '{tool_name}': {str(error)}" + super().__init__(message, details) + +class ToolValidationError(ToolAgentError): + """Raised when tool parameters fail validation.""" + def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None): + message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}" + super().__init__(message, details) + +class ToolNotFoundError(ToolAgentError): + """Raised when a requested tool is not found.""" + def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None): + message = f"Tool '{tool_name}' not found" + super().__init__(message, details) + +class ToolParameterError(ToolAgentError): + """Raised when tool parameters are invalid.""" + def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None): + message = f"Invalid parameters for tool '{tool_name}': {error}" + super().__init__(message, details) \ No newline at end of file diff --git a/swarms/agents/tool_agent.py b/swarms/agents/tool_agent.py index 2d19ec26..d69d6c2f 100644 --- a/swarms/agents/tool_agent.py +++ b/swarms/agents/tool_agent.py @@ -1,156 +1,243 @@ -from typing import Any, Optional, Callable -from swarms.tools.json_former import Jsonformer -from swarms.utils.loguru_logger import initialize_logger - -logger = initialize_logger(log_folder="tool_agent") - +from typing import List, Optional, Dict, Any, Callable +from loguru import logger +from swarms.agents.exceptions import ( + ToolAgentError, + ToolExecutionError, + ToolValidationError, + ToolNotFoundError, + ToolParameterError +) class ToolAgent: """ - Represents a tool agent that performs a specific task using a model and tokenizer. - - Args: - name (str): The name of the tool agent. - description (str): A description of the tool agent. - model (Any): The model used by the tool agent. - tokenizer (Any): The tokenizer used by the tool agent. - json_schema (Any): The JSON schema used by the tool agent. - *args: Variable length arguments. - **kwargs: Keyword arguments. - - Attributes: - name (str): The name of the tool agent. - description (str): A description of the tool agent. - model (Any): The model used by the tool agent. - tokenizer (Any): The tokenizer used by the tool agent. - json_schema (Any): The JSON schema used by the tool agent. - - Methods: - run: Runs the tool agent for a specific task. - - Raises: - Exception: If an error occurs while running the tool agent. - - - Example: - from transformers import AutoModelForCausalLM, AutoTokenizer - from swarms import ToolAgent - - - model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b") - tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - "is_student": {"type": "boolean"}, - "courses": { - "type": "array", - "items": {"type": "string"} - } - } - } - - task = "Generate a person's information based on the following schema:" - agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema) - generated_data = agent.run(task) - - print(generated_data) + A wrapper class for vLLM that provides a similar interface to LiteLLM. + This class handles model initialization and inference using vLLM. """ def __init__( self, - name: str = "Function Calling Agent", - description: str = "Generates a function based on the input json schema and the task", - model: Any = None, - tokenizer: Any = None, - json_schema: Any = None, - max_number_tokens: int = 500, - parsing_function: Optional[Callable] = None, - llm: Any = None, + model_name: str = "meta-llama/Llama-2-7b-chat-hf", + system_prompt: Optional[str] = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + max_completion_tokens: int = 4000, + tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, + tool_choice: str = "auto", + parallel_tool_calls: bool = False, + retry_attempts: int = 3, + retry_interval: float = 1.0, *args, **kwargs, ): - super().__init__( - agent_name=name, - agent_description=description, - llm=llm, - **kwargs, + """ + Initialize the vLLM wrapper with the given parameters. + Args: + model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool): Whether to stream the output. Defaults to False. + temperature (float): The temperature for sampling. Defaults to 0.5. + max_tokens (int): The maximum number of tokens to generate. Defaults to 4000. + max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000. + tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None. + tool_choice (str): How to choose tools. Defaults to "auto". + parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False. + retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3. + retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens + self.tools_list_dictionary = tools_list_dictionary + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + self.retry_attempts = retry_attempts + self.retry_interval = retry_interval + + # Initialize vLLM + try: + self.llm = LLM(model=model_name, **kwargs) + self.sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + ) + except Exception as e: + raise ToolExecutionError( + "model_initialization", + e, + {"model_name": model_name, "kwargs": kwargs} + ) + + def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None: + """ + Validate tool parameters before execution. + Args: + tool_name (str): Name of the tool to validate + parameters (Dict[str, Any]): Parameters to validate + Raises: + ToolValidationError: If validation fails + """ + if not self.tools_list_dictionary: + raise ToolValidationError( + tool_name, + "parameters", + "No tools available for validation" + ) + + tool_spec = next( + (tool for tool in self.tools_list_dictionary if tool["name"] == tool_name), + None ) - self.name = name - self.description = description - self.model = model - self.tokenizer = tokenizer - self.json_schema = json_schema - self.max_number_tokens = max_number_tokens - self.parsing_function = parsing_function - - def run(self, task: str, *args, **kwargs): + + if not tool_spec: + raise ToolNotFoundError(tool_name) + + required_params = { + param["name"] for param in tool_spec["parameters"] + if param.get("required", True) + } + + missing_params = required_params - set(parameters.keys()) + if missing_params: + raise ToolParameterError( + tool_name, + f"Missing required parameters: {', '.join(missing_params)}" + ) + + def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any: """ - Run the tool agent for the specified task. + Execute a function with retry logic. + Args: + func (Callable): Function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + Returns: + Any: Result of the function execution + Raises: + ToolExecutionError: If all retry attempts fail + """ + last_error = None + for attempt in range(self.retry_attempts): + try: + return func(*args, **kwargs) + except Exception as e: + last_error = e + logger.warning( + f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}" + ) + if attempt < self.retry_attempts - 1: + time.sleep(self.retry_interval) + raise ToolExecutionError( + func.__name__, + last_error, + {"attempts": self.retry_attempts} + ) + + def run(self, task: str, *args, **kwargs) -> str: + """ + Run the tool agent for the specified task. Args: task (str): The task to be performed by the tool agent. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. - Returns: The output of the tool agent. - Raises: - Exception: If an error occurs during the execution of the tool agent. + ToolExecutionError: If an error occurs during execution. """ try: - if self.model: - logger.info(f"Running {self.name} for task: {task}") - self.toolagent = Jsonformer( - model=self.model, - tokenizer=self.tokenizer, - json_schema=self.json_schema, - llm=self.llm, - prompt=task, - max_number_tokens=self.max_number_tokens, - *args, - **kwargs, + if not self.llm: + raise ToolExecutionError( + "run", + Exception("LLM not initialized"), + {"task": task} ) - if self.parsing_function: - out = self.parsing_function(self.toolagent()) - else: - out = self.toolagent() - - return out - elif self.llm: - logger.info(f"Running {self.name} for task: {task}") - self.toolagent = Jsonformer( - json_schema=self.json_schema, - llm=self.llm, - prompt=task, - max_number_tokens=self.max_number_tokens, - *args, - **kwargs, - ) + logger.info(f"Running task: {task}") + + # Prepare the prompt + prompt = self._prepare_prompt(task) + + # Execute with retry logic + outputs = self._execute_with_retry( + self.llm.generate, + prompt, + self.sampling_params + ) + + response = outputs[0].outputs[0].text.strip() + return response - if self.parsing_function: - out = self.parsing_function(self.toolagent()) - else: - out = self.toolagent() + except Exception as error: + logger.error(f"Error running task: {error}") + raise ToolExecutionError( + "run", + error, + {"task": task, "args": args, "kwargs": kwargs} + ) - return out + def _prepare_prompt(self, task: str) -> str: + """ + Prepare the prompt for the given task. + Args: + task (str): The task to prepare the prompt for. + Returns: + str: The prepared prompt. + """ + if self.system_prompt: + return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" + return f"User: {task}\nAssistant:" - else: - raise Exception( - "Either model or llm should be provided to the" - " ToolAgent" - ) + def __call__(self, task: str, *args, **kwargs) -> str: + """ + Call the model for the given task. + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + Returns: + str: The model's response. + """ + return self.run(task, *args, **kwargs) + + def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: + """ + Run the model for multiple tasks in batches. + Args: + tasks (List[str]): List of tasks to run. + batch_size (int): Size of each batch. Defaults to 10. + Returns: + List[str]: List of model responses. + Raises: + ToolExecutionError: If an error occurs during batch execution. + """ + logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") + results = [] + + try: + for i in range(0, len(tasks), batch_size): + batch = tasks[i:i + batch_size] + for task in batch: + logger.info(f"Running task: {task}") + try: + result = self.run(task) + results.append(result) + except ToolExecutionError as e: + logger.error(f"Failed to execute task '{task}': {e}") + results.append(f"Error: {str(e)}") + continue + + logger.info("Completed all tasks.") + return results except Exception as error: - logger.error( - f"Error running {self.name} for task: {task}" + logger.error(f"Error in batch execution: {error}") + raise ToolExecutionError( + "batched_run", + error, + {"tasks": tasks, "batch_size": batch_size} ) - raise error - - def __call__(self, task: str, *args, **kwargs): - return self.run(task, *args, **kwargs) diff --git a/swarms/utils/typedb_wrapper.py b/swarms/utils/typedb_wrapper.py new file mode 100644 index 00000000..8a0b5396 --- /dev/null +++ b/swarms/utils/typedb_wrapper.py @@ -0,0 +1,168 @@ +from typing import Dict, List, Optional, Any, Union +from loguru import logger +from typedb.client import TypeDB, SessionType, TransactionType +from typedb.api.connection.transaction import Transaction +from dataclasses import dataclass +import json + +@dataclass +class TypeDBConfig: + """Configuration for TypeDB connection.""" + uri: str = "localhost:1729" + database: str = "swarms" + username: Optional[str] = None + password: Optional[str] = None + timeout: int = 30 + +class TypeDBWrapper: + """ + A wrapper class for TypeDB that provides graph database operations for Swarms. + This class handles connection, schema management, and data operations. + """ + + def __init__(self, config: Optional[TypeDBConfig] = None): + """ + Initialize the TypeDB wrapper with the given configuration. + Args: + config (Optional[TypeDBConfig]): Configuration for TypeDB connection. + """ + self.config = config or TypeDBConfig() + self.client = None + self.session = None + self._connect() + + def _connect(self) -> None: + """Establish connection to TypeDB.""" + try: + self.client = TypeDB.core_client(self.config.uri) + if self.config.username and self.config.password: + self.session = self.client.session( + self.config.database, + SessionType.DATA, + self.config.username, + self.config.password + ) + else: + self.session = self.client.session( + self.config.database, + SessionType.DATA + ) + logger.info(f"Connected to TypeDB at {self.config.uri}") + except Exception as e: + logger.error(f"Failed to connect to TypeDB: {e}") + raise + + def _ensure_connection(self) -> None: + """Ensure connection is active, reconnect if necessary.""" + if not self.session or not self.session.is_open(): + self._connect() + + def define_schema(self, schema: str) -> None: + """ + Define the database schema. + Args: + schema (str): TypeQL schema definition. + """ + try: + with self.session.transaction(TransactionType.WRITE) as transaction: + transaction.query.define(schema) + transaction.commit() + logger.info("Schema defined successfully") + except Exception as e: + logger.error(f"Failed to define schema: {e}") + raise + + def insert_data(self, query: str) -> None: + """ + Insert data using TypeQL query. + Args: + query (str): TypeQL insert query. + """ + try: + with self.session.transaction(TransactionType.WRITE) as transaction: + transaction.query.insert(query) + transaction.commit() + logger.info("Data inserted successfully") + except Exception as e: + logger.error(f"Failed to insert data: {e}") + raise + + def query_data(self, query: str) -> List[Dict[str, Any]]: + """ + Query data using TypeQL query. + Args: + query (str): TypeQL query. + Returns: + List[Dict[str, Any]]: Query results. + """ + try: + with self.session.transaction(TransactionType.READ) as transaction: + result = transaction.query.get(query) + return [self._convert_concept_to_dict(concept) for concept in result] + except Exception as e: + logger.error(f"Failed to query data: {e}") + raise + + def _convert_concept_to_dict(self, concept: Any) -> Dict[str, Any]: + """ + Convert a TypeDB concept to a dictionary. + Args: + concept: TypeDB concept. + Returns: + Dict[str, Any]: Dictionary representation of the concept. + """ + try: + if hasattr(concept, "get_type"): + concept_type = concept.get_type() + if hasattr(concept, "get_value"): + return { + "type": concept_type.get_label_name(), + "value": concept.get_value() + } + elif hasattr(concept, "get_attributes"): + return { + "type": concept_type.get_label_name(), + "attributes": { + attr.get_type().get_label_name(): attr.get_value() + for attr in concept.get_attributes() + } + } + return {"type": "unknown", "value": str(concept)} + except Exception as e: + logger.error(f"Failed to convert concept to dict: {e}") + return {"type": "error", "value": str(e)} + + def delete_data(self, query: str) -> None: + """ + Delete data using TypeQL query. + Args: + query (str): TypeQL delete query. + """ + try: + with self.session.transaction(TransactionType.WRITE) as transaction: + transaction.query.delete(query) + transaction.commit() + logger.info("Data deleted successfully") + except Exception as e: + logger.error(f"Failed to delete data: {e}") + raise + + def close(self) -> None: + """Close the TypeDB connection.""" + try: + if self.session: + self.session.close() + if self.client: + self.client.close() + logger.info("TypeDB connection closed") + except Exception as e: + logger.error(f"Failed to close TypeDB connection: {e}") + raise + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() \ No newline at end of file diff --git a/swarms/utils/vllm_wrapper.py b/swarms/utils/vllm_wrapper.py new file mode 100644 index 00000000..322ce1ad --- /dev/null +++ b/swarms/utils/vllm_wrapper.py @@ -0,0 +1,138 @@ +from typing import List, Optional, Dict, Any +from loguru import logger + +try: + from vllm import LLM, SamplingParams +except ImportError: + import subprocess + import sys + print("Installing vllm") + subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"]) + print("vllm installed") + from vllm import LLM, SamplingParams + +class VLLMWrapper: + """ + A wrapper class for vLLM that provides a similar interface to LiteLLM. + This class handles model initialization and inference using vLLM. + """ + + def __init__( + self, + model_name: str = "meta-llama/Llama-2-7b-chat-hf", + system_prompt: Optional[str] = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + max_completion_tokens: int = 4000, + tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, + tool_choice: str = "auto", + parallel_tool_calls: bool = False, + *args, + **kwargs, + ): + """ + Initialize the vLLM wrapper with the given parameters. + + Args: + model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool): Whether to stream the output. Defaults to False. + temperature (float): The temperature for sampling. Defaults to 0.5. + max_tokens (int): The maximum number of tokens to generate. Defaults to 4000. + max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000. + tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None. + tool_choice (str): How to choose tools. Defaults to "auto". + parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + self.max_completion_tokens = max_completion_tokens + self.tools_list_dictionary = tools_list_dictionary + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + + # Initialize vLLM + self.llm = LLM(model=model_name, **kwargs) + self.sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + ) + + def _prepare_prompt(self, task: str) -> str: + """ + Prepare the prompt for the given task. + + Args: + task (str): The task to prepare the prompt for. + + Returns: + str: The prepared prompt. + """ + if self.system_prompt: + return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" + return f"User: {task}\nAssistant:" + + def run(self, task: str, *args, **kwargs) -> str: + """ + Run the model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The model's response. + """ + try: + prompt = self._prepare_prompt(task) + + outputs = self.llm.generate(prompt, self.sampling_params) + response = outputs[0].outputs[0].text.strip() + + return response + + except Exception as error: + logger.error(f"Error in VLLMWrapper: {error}") + raise error + + def __call__(self, task: str, *args, **kwargs) -> str: + """ + Call the model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The model's response. + """ + return self.run(task, *args, **kwargs) + + def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: + """ + Run the model for multiple tasks in batches. + + Args: + tasks (List[str]): List of tasks to run. + batch_size (int): Size of each batch. Defaults to 10. + + Returns: + List[str]: List of model responses. + """ + logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") + results = [] + + for i in range(0, len(tasks), batch_size): + batch = tasks[i:i + batch_size] + for task in batch: + logger.info(f"Running task: {task}") + results.append(self.run(task)) + + logger.info("Completed all tasks.") + return results \ No newline at end of file diff --git a/tests/agents/test_tool_agent.py b/tests/agents/test_tool_agent.py index 691489c0..9f8344d0 100644 --- a/tests/agents/test_tool_agent.py +++ b/tests/agents/test_tool_agent.py @@ -1,8 +1,15 @@ from unittest.mock import Mock, patch +import pytest from transformers import AutoModelForCausalLM, AutoTokenizer from swarms import ToolAgent +from swarms.agents.exceptions import ( + ToolExecutionError, + ToolValidationError, + ToolNotFoundError, + ToolParameterError +) def test_tool_agent_init(): @@ -99,3 +106,123 @@ def test_tool_agent_init_with_kwargs(): agent.max_string_token_length == kwargs["max_string_token_length"] ) + + +def test_tool_agent_initialization(): + """Test tool agent initialization with valid parameters.""" + agent = ToolAgent( + model_name="test-model", + temperature=0.7, + max_tokens=1000 + ) + assert agent.model_name == "test-model" + assert agent.temperature == 0.7 + assert agent.max_tokens == 1000 + assert agent.retry_attempts == 3 + assert agent.retry_interval == 1.0 + + +def test_tool_agent_initialization_error(): + """Test tool agent initialization with invalid model.""" + with pytest.raises(ToolExecutionError) as exc_info: + ToolAgent(model_name="invalid-model") + assert "model_initialization" in str(exc_info.value) + + +def test_tool_validation(): + """Test tool parameter validation.""" + tools_list = [{ + "name": "test_tool", + "parameters": [ + {"name": "required_param", "required": True}, + {"name": "optional_param", "required": False} + ] + }] + + agent = ToolAgent(tools_list_dictionary=tools_list) + + # Test missing required parameter + with pytest.raises(ToolParameterError) as exc_info: + agent._validate_tool("test_tool", {}) + assert "Missing required parameters" in str(exc_info.value) + + # Test valid parameters + agent._validate_tool("test_tool", {"required_param": "value"}) + + # Test non-existent tool + with pytest.raises(ToolNotFoundError) as exc_info: + agent._validate_tool("non_existent_tool", {}) + assert "Tool 'non_existent_tool' not found" in str(exc_info.value) + + +def test_retry_mechanism(): + """Test retry mechanism for failed operations.""" + mock_llm = Mock() + mock_llm.generate.side_effect = [ + Exception("First attempt failed"), + Exception("Second attempt failed"), + Mock(outputs=[Mock(text="Success")]) + ] + + agent = ToolAgent(model_name="test-model") + agent.llm = mock_llm + + # Test successful retry + result = agent.run("test task") + assert result == "Success" + assert mock_llm.generate.call_count == 3 + + # Test all retries failing + mock_llm.generate.side_effect = Exception("All attempts failed") + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task") + assert "All attempts failed" in str(exc_info.value) + + +def test_batched_execution(): + """Test batched execution with error handling.""" + mock_llm = Mock() + mock_llm.generate.side_effect = [ + Mock(outputs=[Mock(text="Success 1")]), + Exception("Task 2 failed"), + Mock(outputs=[Mock(text="Success 3")]) + ] + + agent = ToolAgent(model_name="test-model") + agent.llm = mock_llm + + tasks = ["Task 1", "Task 2", "Task 3"] + results = agent.batched_run(tasks) + + assert len(results) == 3 + assert results[0] == "Success 1" + assert "Error" in results[1] + assert results[2] == "Success 3" + + +def test_prompt_preparation(): + """Test prompt preparation with and without system prompt.""" + # Test without system prompt + agent = ToolAgent() + prompt = agent._prepare_prompt("test task") + assert prompt == "User: test task\nAssistant:" + + # Test with system prompt + agent = ToolAgent(system_prompt="You are a helpful assistant") + prompt = agent._prepare_prompt("test task") + assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:" + + +def test_tool_execution_error_handling(): + """Test error handling during tool execution.""" + agent = ToolAgent(model_name="test-model") + agent.llm = None # Simulate uninitialized LLM + + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task") + assert "LLM not initialized" in str(exc_info.value) + + # Test with invalid parameters + with pytest.raises(ToolExecutionError) as exc_info: + agent.run("test task", invalid_param="value") + assert "Error running task" in str(exc_info.value) diff --git a/tests/utils/test_typedb_wrapper.py b/tests/utils/test_typedb_wrapper.py new file mode 100644 index 00000000..076f5461 --- /dev/null +++ b/tests/utils/test_typedb_wrapper.py @@ -0,0 +1,129 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig + +@pytest.fixture +def mock_typedb(): + """Mock TypeDB client and session.""" + with patch('swarms.utils.typedb_wrapper.TypeDB') as mock_typedb: + mock_client = Mock() + mock_session = Mock() + mock_typedb.core_client.return_value = mock_client + mock_client.session.return_value = mock_session + yield mock_typedb, mock_client, mock_session + +@pytest.fixture +def typedb_wrapper(mock_typedb): + """Create a TypeDBWrapper instance with mocked dependencies.""" + config = TypeDBConfig( + uri="test:1729", + database="test_db", + username="test_user", + password="test_pass" + ) + return TypeDBWrapper(config) + +def test_initialization(typedb_wrapper): + """Test TypeDBWrapper initialization.""" + assert typedb_wrapper.config.uri == "test:1729" + assert typedb_wrapper.config.database == "test_db" + assert typedb_wrapper.config.username == "test_user" + assert typedb_wrapper.config.password == "test_pass" + +def test_connect(typedb_wrapper, mock_typedb): + """Test connection to TypeDB.""" + mock_typedb, mock_client, mock_session = mock_typedb + typedb_wrapper._connect() + + mock_typedb.core_client.assert_called_once_with("test:1729") + mock_client.session.assert_called_once_with( + "test_db", + "DATA", + "test_user", + "test_pass" + ) + +def test_define_schema(typedb_wrapper, mock_typedb): + """Test schema definition.""" + mock_typedb, mock_client, mock_session = mock_typedb + schema = "define person sub entity;" + + with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: + mock_transaction.return_value.__enter__.return_value.query.define.return_value = None + typedb_wrapper.define_schema(schema) + + mock_transaction.assert_called_once_with("WRITE") + mock_transaction.return_value.__enter__.return_value.query.define.assert_called_once_with(schema) + +def test_insert_data(typedb_wrapper, mock_typedb): + """Test data insertion.""" + mock_typedb, mock_client, mock_session = mock_typedb + query = "insert $p isa person;" + + with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: + mock_transaction.return_value.__enter__.return_value.query.insert.return_value = None + typedb_wrapper.insert_data(query) + + mock_transaction.assert_called_once_with("WRITE") + mock_transaction.return_value.__enter__.return_value.query.insert.assert_called_once_with(query) + +def test_query_data(typedb_wrapper, mock_typedb): + """Test data querying.""" + mock_typedb, mock_client, mock_session = mock_typedb + query = "match $p isa person; get;" + mock_result = [Mock()] + + with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: + mock_transaction.return_value.__enter__.return_value.query.get.return_value = mock_result + result = typedb_wrapper.query_data(query) + + mock_transaction.assert_called_once_with("READ") + mock_transaction.return_value.__enter__.return_value.query.get.assert_called_once_with(query) + assert len(result) == 1 + +def test_delete_data(typedb_wrapper, mock_typedb): + """Test data deletion.""" + mock_typedb, mock_client, mock_session = mock_typedb + query = "match $p isa person; delete $p;" + + with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: + mock_transaction.return_value.__enter__.return_value.query.delete.return_value = None + typedb_wrapper.delete_data(query) + + mock_transaction.assert_called_once_with("WRITE") + mock_transaction.return_value.__enter__.return_value.query.delete.assert_called_once_with(query) + +def test_close(typedb_wrapper, mock_typedb): + """Test connection closing.""" + mock_typedb, mock_client, mock_session = mock_typedb + typedb_wrapper.close() + + mock_session.close.assert_called_once() + mock_client.close.assert_called_once() + +def test_context_manager(typedb_wrapper, mock_typedb): + """Test context manager functionality.""" + mock_typedb, mock_client, mock_session = mock_typedb + + with typedb_wrapper as db: + assert db == typedb_wrapper + + mock_session.close.assert_called_once() + mock_client.close.assert_called_once() + +def test_error_handling(typedb_wrapper, mock_typedb): + """Test error handling.""" + mock_typedb, mock_client, mock_session = mock_typedb + + # Test connection error + mock_typedb.core_client.side_effect = Exception("Connection failed") + with pytest.raises(Exception) as exc_info: + typedb_wrapper._connect() + assert "Connection failed" in str(exc_info.value) + + # Test query error + with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: + mock_transaction.return_value.__enter__.return_value.query.get.side_effect = Exception("Query failed") + with pytest.raises(Exception) as exc_info: + typedb_wrapper.query_data("test query") + assert "Query failed" in str(exc_info.value) \ No newline at end of file