Merge e397b8b19e into 221e9419ec
	
		
	
				
					
				
			
						commit
						422b02b5b8
					
				| @ -0,0 +1,22 @@ | |||||||
|  | [flake8] | ||||||
|  | max-line-length = 88 | ||||||
|  | extend-ignore = E203, W503 | ||||||
|  | exclude = | ||||||
|  |     .git, | ||||||
|  |     __pycache__, | ||||||
|  |     build, | ||||||
|  |     dist, | ||||||
|  |     *.egg-info, | ||||||
|  |     .eggs, | ||||||
|  |     .tox, | ||||||
|  |     .venv, | ||||||
|  |     venv, | ||||||
|  |     .env, | ||||||
|  |     .pytest_cache, | ||||||
|  |     .coverage, | ||||||
|  |     htmlcov, | ||||||
|  |     .mypy_cache, | ||||||
|  |     .ruff_cache | ||||||
|  | per-file-ignores = | ||||||
|  |     __init__.py: F401 | ||||||
|  | max-complexity = 10  | ||||||
| @ -0,0 +1,106 @@ | |||||||
|  | from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     # Initialize TypeDB wrapper with custom configuration | ||||||
|  |     config = TypeDBConfig( | ||||||
|  |         uri="localhost:1729", | ||||||
|  |         database="swarms_example", | ||||||
|  |         username="admin", | ||||||
|  |         password="password" | ||||||
|  |     ) | ||||||
|  |      | ||||||
|  |     # Define schema for a simple knowledge graph | ||||||
|  |     schema = """ | ||||||
|  |     define | ||||||
|  |     person sub entity, | ||||||
|  |         owns name: string, | ||||||
|  |         owns age: long, | ||||||
|  |         plays role; | ||||||
|  |      | ||||||
|  |     role sub entity, | ||||||
|  |         owns title: string, | ||||||
|  |         owns department: string; | ||||||
|  |      | ||||||
|  |     works_at sub relation, | ||||||
|  |         relates person, | ||||||
|  |         relates role; | ||||||
|  |     """ | ||||||
|  |      | ||||||
|  |     # Example data insertion | ||||||
|  |     insert_queries = [ | ||||||
|  |         """ | ||||||
|  |         insert | ||||||
|  |         $p isa person, has name "John Doe", has age 30; | ||||||
|  |         $r isa role, has title "Software Engineer", has department "Engineering"; | ||||||
|  |         (person: $p, role: $r) isa works_at; | ||||||
|  |         """, | ||||||
|  |         """ | ||||||
|  |         insert | ||||||
|  |         $p isa person, has name "Jane Smith", has age 28; | ||||||
|  |         $r isa role, has title "Data Scientist", has department "Data Science"; | ||||||
|  |         (person: $p, role: $r) isa works_at; | ||||||
|  |         """ | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     # Example queries | ||||||
|  |     query_queries = [ | ||||||
|  |         # Get all people | ||||||
|  |         "match $p isa person; get;", | ||||||
|  |          | ||||||
|  |         # Get people in Engineering department | ||||||
|  |         """ | ||||||
|  |         match | ||||||
|  |         $p isa person; | ||||||
|  |         $r isa role, has department "Engineering"; | ||||||
|  |         (person: $p, role: $r) isa works_at; | ||||||
|  |         get $p; | ||||||
|  |         """, | ||||||
|  |          | ||||||
|  |         # Get people with their roles | ||||||
|  |         """ | ||||||
|  |         match | ||||||
|  |         $p isa person, has name $n; | ||||||
|  |         $r isa role, has title $t; | ||||||
|  |         (person: $p, role: $r) isa works_at; | ||||||
|  |         get $n, $t; | ||||||
|  |         """ | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     try: | ||||||
|  |         with TypeDBWrapper(config) as db: | ||||||
|  |             # Define schema | ||||||
|  |             print("Defining schema...") | ||||||
|  |             db.define_schema(schema) | ||||||
|  |              | ||||||
|  |             # Insert data | ||||||
|  |             print("\nInserting data...") | ||||||
|  |             for query in insert_queries: | ||||||
|  |                 db.insert_data(query) | ||||||
|  |              | ||||||
|  |             # Query data | ||||||
|  |             print("\nQuerying data...") | ||||||
|  |             for i, query in enumerate(query_queries, 1): | ||||||
|  |                 print(f"\nQuery {i}:") | ||||||
|  |                 results = db.query_data(query) | ||||||
|  |                 print(f"Results: {results}") | ||||||
|  |              | ||||||
|  |             # Example of deleting data | ||||||
|  |             print("\nDeleting data...") | ||||||
|  |             delete_query = """ | ||||||
|  |             match | ||||||
|  |             $p isa person, has name "John Doe"; | ||||||
|  |             delete $p; | ||||||
|  |             """ | ||||||
|  |             db.delete_data(delete_query) | ||||||
|  |              | ||||||
|  |             # Verify deletion | ||||||
|  |             print("\nVerifying deletion...") | ||||||
|  |             verify_query = "match $p isa person, has name $n; get $n;" | ||||||
|  |             results = db.query_data(verify_query) | ||||||
|  |             print(f"Remaining people: {results}") | ||||||
|  |              | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"Error: {e}") | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main()  | ||||||
| @ -0,0 +1,44 @@ | |||||||
|  | from swarms.utils.vllm_wrapper import VLLMWrapper | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     # Initialize the vLLM wrapper with a model | ||||||
|  |     # Note: You'll need to have the model downloaded or specify a HuggingFace model ID | ||||||
|  |     llm = VLLMWrapper( | ||||||
|  |         model_name="meta-llama/Llama-2-7b-chat-hf",  # Replace with your model path or HF model ID | ||||||
|  |         temperature=0.7, | ||||||
|  |         max_tokens=1000, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Example task | ||||||
|  |     task = "What are the benefits of using vLLM for inference?" | ||||||
|  | 
 | ||||||
|  |     # Run inference | ||||||
|  |     response = llm.run(task) | ||||||
|  |     print("Response:", response) | ||||||
|  | 
 | ||||||
|  |     # Example with system prompt | ||||||
|  |     llm_with_system = VLLMWrapper( | ||||||
|  |         model_name="meta-llama/Llama-2-7b-chat-hf",  # Replace with your model path or HF model ID | ||||||
|  |         system_prompt="You are a helpful AI assistant that provides concise answers.", | ||||||
|  |         temperature=0.7, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Run inference with system prompt | ||||||
|  |     response = llm_with_system.run(task) | ||||||
|  |     print("\nResponse with system prompt:", response) | ||||||
|  | 
 | ||||||
|  |     # Example with batched inference | ||||||
|  |     tasks = [ | ||||||
|  |         "What is vLLM?", | ||||||
|  |         "How does vLLM improve inference speed?", | ||||||
|  |         "What are the main features of vLLM?" | ||||||
|  |     ] | ||||||
|  |      | ||||||
|  |     responses = llm.batched_run(tasks, batch_size=2) | ||||||
|  |     print("\nBatched responses:") | ||||||
|  |     for task, response in zip(tasks, responses): | ||||||
|  |         print(f"\nTask: {task}") | ||||||
|  |         print(f"Response: {response}") | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main()  | ||||||
| @ -0,0 +1,4 @@ | |||||||
|  | [pyupgrade] | ||||||
|  | py3-plus = True | ||||||
|  | py39-plus = True | ||||||
|  | keep-runtime-typing = True  | ||||||
| @ -0,0 +1,55 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  | import subprocess | ||||||
|  | import sys | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | def run_command(command: list[str], cwd: Path) -> bool: | ||||||
|  |     """Run a command and return True if successful.""" | ||||||
|  |     try: | ||||||
|  |         result = subprocess.run( | ||||||
|  |             command, | ||||||
|  |             cwd=cwd, | ||||||
|  |             capture_output=True, | ||||||
|  |             text=True, | ||||||
|  |             check=True | ||||||
|  |         ) | ||||||
|  |         return True | ||||||
|  |     except subprocess.CalledProcessError as e: | ||||||
|  |         print(f"Error running {' '.join(command)}:") | ||||||
|  |         print(e.stdout) | ||||||
|  |         print(e.stderr, file=sys.stderr) | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     """Run all code quality checks.""" | ||||||
|  |     root_dir = Path(__file__).parent.parent | ||||||
|  |     success = True | ||||||
|  | 
 | ||||||
|  |     # Run flake8 | ||||||
|  |     print("\nRunning flake8...") | ||||||
|  |     if not run_command(["flake8", "swarms", "tests"], root_dir): | ||||||
|  |         success = False | ||||||
|  | 
 | ||||||
|  |     # Run pyupgrade | ||||||
|  |     print("\nRunning pyupgrade...") | ||||||
|  |     if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir): | ||||||
|  |         success = False | ||||||
|  | 
 | ||||||
|  |     # Run black | ||||||
|  |     print("\nRunning black...") | ||||||
|  |     if not run_command(["black", "--check", "swarms", "tests"], root_dir): | ||||||
|  |         success = False | ||||||
|  | 
 | ||||||
|  |     # Run ruff | ||||||
|  |     print("\nRunning ruff...") | ||||||
|  |     if not run_command(["ruff", "check", "swarms", "tests"], root_dir): | ||||||
|  |         success = False | ||||||
|  | 
 | ||||||
|  |     if not success: | ||||||
|  |         print("\nCode quality checks failed. Please fix the issues and try again.") | ||||||
|  |         sys.exit(1) | ||||||
|  |     else: | ||||||
|  |         print("\nAll code quality checks passed!") | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main()  | ||||||
| @ -0,0 +1,32 @@ | |||||||
|  | from typing import Any, Dict, Optional | ||||||
|  | 
 | ||||||
|  | class ToolAgentError(Exception): | ||||||
|  |     """Base exception for all tool agent errors.""" | ||||||
|  |     def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): | ||||||
|  |         self.message = message | ||||||
|  |         self.details = details or {} | ||||||
|  |         super().__init__(self.message) | ||||||
|  | 
 | ||||||
|  | class ToolExecutionError(ToolAgentError): | ||||||
|  |     """Raised when a tool fails to execute.""" | ||||||
|  |     def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None): | ||||||
|  |         message = f"Failed to execute tool '{tool_name}': {str(error)}" | ||||||
|  |         super().__init__(message, details) | ||||||
|  | 
 | ||||||
|  | class ToolValidationError(ToolAgentError): | ||||||
|  |     """Raised when tool parameters fail validation.""" | ||||||
|  |     def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None): | ||||||
|  |         message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}" | ||||||
|  |         super().__init__(message, details) | ||||||
|  | 
 | ||||||
|  | class ToolNotFoundError(ToolAgentError): | ||||||
|  |     """Raised when a requested tool is not found.""" | ||||||
|  |     def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None): | ||||||
|  |         message = f"Tool '{tool_name}' not found" | ||||||
|  |         super().__init__(message, details) | ||||||
|  | 
 | ||||||
|  | class ToolParameterError(ToolAgentError): | ||||||
|  |     """Raised when tool parameters are invalid.""" | ||||||
|  |     def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None): | ||||||
|  |         message = f"Invalid parameters for tool '{tool_name}': {error}" | ||||||
|  |         super().__init__(message, details)  | ||||||
| @ -0,0 +1,168 @@ | |||||||
|  | from typing import Dict, List, Optional, Any, Union | ||||||
|  | from loguru import logger | ||||||
|  | from typedb.client import TypeDB, SessionType, TransactionType | ||||||
|  | from typedb.api.connection.transaction import Transaction | ||||||
|  | from dataclasses import dataclass | ||||||
|  | import json | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class TypeDBConfig: | ||||||
|  |     """Configuration for TypeDB connection.""" | ||||||
|  |     uri: str = "localhost:1729" | ||||||
|  |     database: str = "swarms" | ||||||
|  |     username: Optional[str] = None | ||||||
|  |     password: Optional[str] = None | ||||||
|  |     timeout: int = 30 | ||||||
|  | 
 | ||||||
|  | class TypeDBWrapper: | ||||||
|  |     """ | ||||||
|  |     A wrapper class for TypeDB that provides graph database operations for Swarms. | ||||||
|  |     This class handles connection, schema management, and data operations. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, config: Optional[TypeDBConfig] = None): | ||||||
|  |         """ | ||||||
|  |         Initialize the TypeDB wrapper with the given configuration. | ||||||
|  |         Args: | ||||||
|  |             config (Optional[TypeDBConfig]): Configuration for TypeDB connection. | ||||||
|  |         """ | ||||||
|  |         self.config = config or TypeDBConfig() | ||||||
|  |         self.client = None | ||||||
|  |         self.session = None | ||||||
|  |         self._connect() | ||||||
|  | 
 | ||||||
|  |     def _connect(self) -> None: | ||||||
|  |         """Establish connection to TypeDB.""" | ||||||
|  |         try: | ||||||
|  |             self.client = TypeDB.core_client(self.config.uri) | ||||||
|  |             if self.config.username and self.config.password: | ||||||
|  |                 self.session = self.client.session( | ||||||
|  |                     self.config.database, | ||||||
|  |                     SessionType.DATA, | ||||||
|  |                     self.config.username, | ||||||
|  |                     self.config.password | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 self.session = self.client.session( | ||||||
|  |                     self.config.database, | ||||||
|  |                     SessionType.DATA | ||||||
|  |                 ) | ||||||
|  |             logger.info(f"Connected to TypeDB at {self.config.uri}") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to connect to TypeDB: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def _ensure_connection(self) -> None: | ||||||
|  |         """Ensure connection is active, reconnect if necessary.""" | ||||||
|  |         if not self.session or not self.session.is_open(): | ||||||
|  |             self._connect() | ||||||
|  | 
 | ||||||
|  |     def define_schema(self, schema: str) -> None: | ||||||
|  |         """ | ||||||
|  |         Define the database schema. | ||||||
|  |         Args: | ||||||
|  |             schema (str): TypeQL schema definition. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             with self.session.transaction(TransactionType.WRITE) as transaction: | ||||||
|  |                 transaction.query.define(schema) | ||||||
|  |                 transaction.commit() | ||||||
|  |             logger.info("Schema defined successfully") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to define schema: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def insert_data(self, query: str) -> None: | ||||||
|  |         """ | ||||||
|  |         Insert data using TypeQL query. | ||||||
|  |         Args: | ||||||
|  |             query (str): TypeQL insert query. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             with self.session.transaction(TransactionType.WRITE) as transaction: | ||||||
|  |                 transaction.query.insert(query) | ||||||
|  |                 transaction.commit() | ||||||
|  |             logger.info("Data inserted successfully") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to insert data: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def query_data(self, query: str) -> List[Dict[str, Any]]: | ||||||
|  |         """ | ||||||
|  |         Query data using TypeQL query. | ||||||
|  |         Args: | ||||||
|  |             query (str): TypeQL query. | ||||||
|  |         Returns: | ||||||
|  |             List[Dict[str, Any]]: Query results. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             with self.session.transaction(TransactionType.READ) as transaction: | ||||||
|  |                 result = transaction.query.get(query) | ||||||
|  |                 return [self._convert_concept_to_dict(concept) for concept in result] | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to query data: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def _convert_concept_to_dict(self, concept: Any) -> Dict[str, Any]: | ||||||
|  |         """ | ||||||
|  |         Convert a TypeDB concept to a dictionary. | ||||||
|  |         Args: | ||||||
|  |             concept: TypeDB concept. | ||||||
|  |         Returns: | ||||||
|  |             Dict[str, Any]: Dictionary representation of the concept. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             if hasattr(concept, "get_type"): | ||||||
|  |                 concept_type = concept.get_type() | ||||||
|  |                 if hasattr(concept, "get_value"): | ||||||
|  |                     return { | ||||||
|  |                         "type": concept_type.get_label_name(), | ||||||
|  |                         "value": concept.get_value() | ||||||
|  |                     } | ||||||
|  |                 elif hasattr(concept, "get_attributes"): | ||||||
|  |                     return { | ||||||
|  |                         "type": concept_type.get_label_name(), | ||||||
|  |                         "attributes": { | ||||||
|  |                             attr.get_type().get_label_name(): attr.get_value() | ||||||
|  |                             for attr in concept.get_attributes() | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |             return {"type": "unknown", "value": str(concept)} | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to convert concept to dict: {e}") | ||||||
|  |             return {"type": "error", "value": str(e)} | ||||||
|  | 
 | ||||||
|  |     def delete_data(self, query: str) -> None: | ||||||
|  |         """ | ||||||
|  |         Delete data using TypeQL query. | ||||||
|  |         Args: | ||||||
|  |             query (str): TypeQL delete query. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             with self.session.transaction(TransactionType.WRITE) as transaction: | ||||||
|  |                 transaction.query.delete(query) | ||||||
|  |                 transaction.commit() | ||||||
|  |             logger.info("Data deleted successfully") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to delete data: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def close(self) -> None: | ||||||
|  |         """Close the TypeDB connection.""" | ||||||
|  |         try: | ||||||
|  |             if self.session: | ||||||
|  |                 self.session.close() | ||||||
|  |             if self.client: | ||||||
|  |                 self.client.close() | ||||||
|  |             logger.info("TypeDB connection closed") | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.error(f"Failed to close TypeDB connection: {e}") | ||||||
|  |             raise | ||||||
|  | 
 | ||||||
|  |     def __enter__(self): | ||||||
|  |         """Context manager entry.""" | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     def __exit__(self, exc_type, exc_val, exc_tb): | ||||||
|  |         """Context manager exit.""" | ||||||
|  |         self.close()  | ||||||
| @ -0,0 +1,138 @@ | |||||||
|  | from typing import List, Optional, Dict, Any | ||||||
|  | from loguru import logger | ||||||
|  | 
 | ||||||
|  | try: | ||||||
|  |     from vllm import LLM, SamplingParams | ||||||
|  | except ImportError: | ||||||
|  |     import subprocess | ||||||
|  |     import sys | ||||||
|  |     print("Installing vllm") | ||||||
|  |     subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"]) | ||||||
|  |     print("vllm installed") | ||||||
|  |     from vllm import LLM, SamplingParams | ||||||
|  | 
 | ||||||
|  | class VLLMWrapper: | ||||||
|  |     """ | ||||||
|  |     A wrapper class for vLLM that provides a similar interface to LiteLLM. | ||||||
|  |     This class handles model initialization and inference using vLLM. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_name: str = "meta-llama/Llama-2-7b-chat-hf", | ||||||
|  |         system_prompt: Optional[str] = None, | ||||||
|  |         stream: bool = False, | ||||||
|  |         temperature: float = 0.5, | ||||||
|  |         max_tokens: int = 4000, | ||||||
|  |         max_completion_tokens: int = 4000, | ||||||
|  |         tools_list_dictionary: Optional[List[Dict[str, Any]]] = None, | ||||||
|  |         tool_choice: str = "auto", | ||||||
|  |         parallel_tool_calls: bool = False, | ||||||
|  |         *args, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         Initialize the vLLM wrapper with the given parameters. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf". | ||||||
|  |             system_prompt (str, optional): The system prompt to use. Defaults to None. | ||||||
|  |             stream (bool): Whether to stream the output. Defaults to False. | ||||||
|  |             temperature (float): The temperature for sampling. Defaults to 0.5. | ||||||
|  |             max_tokens (int): The maximum number of tokens to generate. Defaults to 4000. | ||||||
|  |             max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000. | ||||||
|  |             tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None. | ||||||
|  |             tool_choice (str): How to choose tools. Defaults to "auto". | ||||||
|  |             parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False. | ||||||
|  |         """ | ||||||
|  |         self.model_name = model_name | ||||||
|  |         self.system_prompt = system_prompt | ||||||
|  |         self.stream = stream | ||||||
|  |         self.temperature = temperature | ||||||
|  |         self.max_tokens = max_tokens | ||||||
|  |         self.max_completion_tokens = max_completion_tokens | ||||||
|  |         self.tools_list_dictionary = tools_list_dictionary | ||||||
|  |         self.tool_choice = tool_choice | ||||||
|  |         self.parallel_tool_calls = parallel_tool_calls | ||||||
|  |          | ||||||
|  |         # Initialize vLLM | ||||||
|  |         self.llm = LLM(model=model_name, **kwargs) | ||||||
|  |         self.sampling_params = SamplingParams( | ||||||
|  |             temperature=temperature, | ||||||
|  |             max_tokens=max_tokens, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def _prepare_prompt(self, task: str) -> str: | ||||||
|  |         """ | ||||||
|  |         Prepare the prompt for the given task. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task (str): The task to prepare the prompt for. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             str: The prepared prompt. | ||||||
|  |         """ | ||||||
|  |         if self.system_prompt: | ||||||
|  |             return f"{self.system_prompt}\n\nUser: {task}\nAssistant:" | ||||||
|  |         return f"User: {task}\nAssistant:" | ||||||
|  | 
 | ||||||
|  |     def run(self, task: str, *args, **kwargs) -> str: | ||||||
|  |         """ | ||||||
|  |         Run the model for the given task. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task (str): The task to run the model for. | ||||||
|  |             *args: Additional positional arguments. | ||||||
|  |             **kwargs: Additional keyword arguments. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             str: The model's response. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             prompt = self._prepare_prompt(task) | ||||||
|  |              | ||||||
|  |             outputs = self.llm.generate(prompt, self.sampling_params) | ||||||
|  |             response = outputs[0].outputs[0].text.strip() | ||||||
|  |              | ||||||
|  |             return response | ||||||
|  |              | ||||||
|  |         except Exception as error: | ||||||
|  |             logger.error(f"Error in VLLMWrapper: {error}") | ||||||
|  |             raise error | ||||||
|  | 
 | ||||||
|  |     def __call__(self, task: str, *args, **kwargs) -> str: | ||||||
|  |         """ | ||||||
|  |         Call the model for the given task. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             task (str): The task to run the model for. | ||||||
|  |             *args: Additional positional arguments. | ||||||
|  |             **kwargs: Additional keyword arguments. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             str: The model's response. | ||||||
|  |         """ | ||||||
|  |         return self.run(task, *args, **kwargs) | ||||||
|  | 
 | ||||||
|  |     def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: | ||||||
|  |         """ | ||||||
|  |         Run the model for multiple tasks in batches. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             tasks (List[str]): List of tasks to run. | ||||||
|  |             batch_size (int): Size of each batch. Defaults to 10. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             List[str]: List of model responses. | ||||||
|  |         """ | ||||||
|  |         logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") | ||||||
|  |         results = [] | ||||||
|  |          | ||||||
|  |         for i in range(0, len(tasks), batch_size): | ||||||
|  |             batch = tasks[i:i + batch_size] | ||||||
|  |             for task in batch: | ||||||
|  |                 logger.info(f"Running task: {task}") | ||||||
|  |                 results.append(self.run(task)) | ||||||
|  |                  | ||||||
|  |         logger.info("Completed all tasks.") | ||||||
|  |         return results  | ||||||
| @ -0,0 +1,129 @@ | |||||||
|  | import pytest | ||||||
|  | from unittest.mock import Mock, patch | ||||||
|  | from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def mock_typedb(): | ||||||
|  |     """Mock TypeDB client and session.""" | ||||||
|  |     with patch('swarms.utils.typedb_wrapper.TypeDB') as mock_typedb: | ||||||
|  |         mock_client = Mock() | ||||||
|  |         mock_session = Mock() | ||||||
|  |         mock_typedb.core_client.return_value = mock_client | ||||||
|  |         mock_client.session.return_value = mock_session | ||||||
|  |         yield mock_typedb, mock_client, mock_session | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | def typedb_wrapper(mock_typedb): | ||||||
|  |     """Create a TypeDBWrapper instance with mocked dependencies.""" | ||||||
|  |     config = TypeDBConfig( | ||||||
|  |         uri="test:1729", | ||||||
|  |         database="test_db", | ||||||
|  |         username="test_user", | ||||||
|  |         password="test_pass" | ||||||
|  |     ) | ||||||
|  |     return TypeDBWrapper(config) | ||||||
|  | 
 | ||||||
|  | def test_initialization(typedb_wrapper): | ||||||
|  |     """Test TypeDBWrapper initialization.""" | ||||||
|  |     assert typedb_wrapper.config.uri == "test:1729" | ||||||
|  |     assert typedb_wrapper.config.database == "test_db" | ||||||
|  |     assert typedb_wrapper.config.username == "test_user" | ||||||
|  |     assert typedb_wrapper.config.password == "test_pass" | ||||||
|  | 
 | ||||||
|  | def test_connect(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test connection to TypeDB.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     typedb_wrapper._connect() | ||||||
|  |      | ||||||
|  |     mock_typedb.core_client.assert_called_once_with("test:1729") | ||||||
|  |     mock_client.session.assert_called_once_with( | ||||||
|  |         "test_db", | ||||||
|  |         "DATA", | ||||||
|  |         "test_user", | ||||||
|  |         "test_pass" | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | def test_define_schema(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test schema definition.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     schema = "define person sub entity;" | ||||||
|  |      | ||||||
|  |     with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.define.return_value = None | ||||||
|  |         typedb_wrapper.define_schema(schema) | ||||||
|  |          | ||||||
|  |         mock_transaction.assert_called_once_with("WRITE") | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.define.assert_called_once_with(schema) | ||||||
|  | 
 | ||||||
|  | def test_insert_data(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test data insertion.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     query = "insert $p isa person;" | ||||||
|  |      | ||||||
|  |     with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.insert.return_value = None | ||||||
|  |         typedb_wrapper.insert_data(query) | ||||||
|  |          | ||||||
|  |         mock_transaction.assert_called_once_with("WRITE") | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.insert.assert_called_once_with(query) | ||||||
|  | 
 | ||||||
|  | def test_query_data(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test data querying.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     query = "match $p isa person; get;" | ||||||
|  |     mock_result = [Mock()] | ||||||
|  |      | ||||||
|  |     with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.get.return_value = mock_result | ||||||
|  |         result = typedb_wrapper.query_data(query) | ||||||
|  |          | ||||||
|  |         mock_transaction.assert_called_once_with("READ") | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.get.assert_called_once_with(query) | ||||||
|  |         assert len(result) == 1 | ||||||
|  | 
 | ||||||
|  | def test_delete_data(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test data deletion.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     query = "match $p isa person; delete $p;" | ||||||
|  |      | ||||||
|  |     with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.delete.return_value = None | ||||||
|  |         typedb_wrapper.delete_data(query) | ||||||
|  |          | ||||||
|  |         mock_transaction.assert_called_once_with("WRITE") | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.delete.assert_called_once_with(query) | ||||||
|  | 
 | ||||||
|  | def test_close(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test connection closing.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |     typedb_wrapper.close() | ||||||
|  |      | ||||||
|  |     mock_session.close.assert_called_once() | ||||||
|  |     mock_client.close.assert_called_once() | ||||||
|  | 
 | ||||||
|  | def test_context_manager(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test context manager functionality.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |      | ||||||
|  |     with typedb_wrapper as db: | ||||||
|  |         assert db == typedb_wrapper | ||||||
|  |      | ||||||
|  |     mock_session.close.assert_called_once() | ||||||
|  |     mock_client.close.assert_called_once() | ||||||
|  | 
 | ||||||
|  | def test_error_handling(typedb_wrapper, mock_typedb): | ||||||
|  |     """Test error handling.""" | ||||||
|  |     mock_typedb, mock_client, mock_session = mock_typedb | ||||||
|  |      | ||||||
|  |     # Test connection error | ||||||
|  |     mock_typedb.core_client.side_effect = Exception("Connection failed") | ||||||
|  |     with pytest.raises(Exception) as exc_info: | ||||||
|  |         typedb_wrapper._connect() | ||||||
|  |     assert "Connection failed" in str(exc_info.value) | ||||||
|  |      | ||||||
|  |     # Test query error | ||||||
|  |     with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction: | ||||||
|  |         mock_transaction.return_value.__enter__.return_value.query.get.side_effect = Exception("Query failed") | ||||||
|  |         with pytest.raises(Exception) as exc_info: | ||||||
|  |             typedb_wrapper.query_data("test query") | ||||||
|  |         assert "Query failed" in str(exc_info.value)  | ||||||
					Loading…
					
					
				
		Reference in new issue