pull/803/merge
Pavan Kumar 3 weeks ago committed by GitHub
commit 422b02b5b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,22 @@
[flake8]
max-line-length = 88
extend-ignore = E203, W503
exclude =
.git,
__pycache__,
build,
dist,
*.egg-info,
.eggs,
.tox,
.venv,
venv,
.env,
.pytest_cache,
.coverage,
htmlcov,
.mypy_cache,
.ruff_cache
per-file-ignores =
__init__.py: F401
max-complexity = 10

@ -0,0 +1,106 @@
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
def main():
# Initialize TypeDB wrapper with custom configuration
config = TypeDBConfig(
uri="localhost:1729",
database="swarms_example",
username="admin",
password="password"
)
# Define schema for a simple knowledge graph
schema = """
define
person sub entity,
owns name: string,
owns age: long,
plays role;
role sub entity,
owns title: string,
owns department: string;
works_at sub relation,
relates person,
relates role;
"""
# Example data insertion
insert_queries = [
"""
insert
$p isa person, has name "John Doe", has age 30;
$r isa role, has title "Software Engineer", has department "Engineering";
(person: $p, role: $r) isa works_at;
""",
"""
insert
$p isa person, has name "Jane Smith", has age 28;
$r isa role, has title "Data Scientist", has department "Data Science";
(person: $p, role: $r) isa works_at;
"""
]
# Example queries
query_queries = [
# Get all people
"match $p isa person; get;",
# Get people in Engineering department
"""
match
$p isa person;
$r isa role, has department "Engineering";
(person: $p, role: $r) isa works_at;
get $p;
""",
# Get people with their roles
"""
match
$p isa person, has name $n;
$r isa role, has title $t;
(person: $p, role: $r) isa works_at;
get $n, $t;
"""
]
try:
with TypeDBWrapper(config) as db:
# Define schema
print("Defining schema...")
db.define_schema(schema)
# Insert data
print("\nInserting data...")
for query in insert_queries:
db.insert_data(query)
# Query data
print("\nQuerying data...")
for i, query in enumerate(query_queries, 1):
print(f"\nQuery {i}:")
results = db.query_data(query)
print(f"Results: {results}")
# Example of deleting data
print("\nDeleting data...")
delete_query = """
match
$p isa person, has name "John Doe";
delete $p;
"""
db.delete_data(delete_query)
# Verify deletion
print("\nVerifying deletion...")
verify_query = "match $p isa person, has name $n; get $n;"
results = db.query_data(verify_query)
print(f"Remaining people: {results}")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()

@ -0,0 +1,44 @@
from swarms.utils.vllm_wrapper import VLLMWrapper
def main():
# Initialize the vLLM wrapper with a model
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
llm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
temperature=0.7,
max_tokens=1000,
)
# Example task
task = "What are the benefits of using vLLM for inference?"
# Run inference
response = llm.run(task)
print("Response:", response)
# Example with system prompt
llm_with_system = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
system_prompt="You are a helpful AI assistant that provides concise answers.",
temperature=0.7,
)
# Run inference with system prompt
response = llm_with_system.run(task)
print("\nResponse with system prompt:", response)
# Example with batched inference
tasks = [
"What is vLLM?",
"How does vLLM improve inference speed?",
"What are the main features of vLLM?"
]
responses = llm.batched_run(tasks, batch_size=2)
print("\nBatched responses:")
for task, response in zip(tasks, responses):
print(f"\nTask: {task}")
print(f"Response: {response}")
if __name__ == "__main__":
main()

@ -0,0 +1,4 @@
[pyupgrade]
py3-plus = True
py39-plus = True
keep-runtime-typing = True

@ -1,4 +1,3 @@
torch>=2.1.1,<3.0 torch>=2.1.1,<3.0
transformers>=4.39.0,<4.49.0 transformers>=4.39.0,<4.49.0
asyncio>=3.4.3,<4.0 asyncio>=3.4.3,<4.0
@ -23,3 +22,13 @@ pytest>=8.1.1
networkx networkx
aiofiles aiofiles
httpx httpx
vllm>=0.2.0
flake8>=6.1.0
flake8-bugbear>=23.3.12
flake8-comprehensions>=3.12.0
flake8-simplify>=0.19.3
flake8-unused-arguments>=0.0.4
pyupgrade>=3.15.0
typedb-client>=2.25.0
typedb-protocol>=2.25.0
typedb-driver>=2.25.0

@ -0,0 +1,55 @@
#!/usr/bin/env python3
import subprocess
import sys
from pathlib import Path
def run_command(command: list[str], cwd: Path) -> bool:
"""Run a command and return True if successful."""
try:
result = subprocess.run(
command,
cwd=cwd,
capture_output=True,
text=True,
check=True
)
return True
except subprocess.CalledProcessError as e:
print(f"Error running {' '.join(command)}:")
print(e.stdout)
print(e.stderr, file=sys.stderr)
return False
def main():
"""Run all code quality checks."""
root_dir = Path(__file__).parent.parent
success = True
# Run flake8
print("\nRunning flake8...")
if not run_command(["flake8", "swarms", "tests"], root_dir):
success = False
# Run pyupgrade
print("\nRunning pyupgrade...")
if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir):
success = False
# Run black
print("\nRunning black...")
if not run_command(["black", "--check", "swarms", "tests"], root_dir):
success = False
# Run ruff
print("\nRunning ruff...")
if not run_command(["ruff", "check", "swarms", "tests"], root_dir):
success = False
if not success:
print("\nCode quality checks failed. Please fix the issues and try again.")
sys.exit(1)
else:
print("\nAll code quality checks passed!")
if __name__ == "__main__":
main()

@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
class ToolAgentError(Exception):
"""Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}"
super().__init__(message, details)
class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details)
class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
message = f"Tool '{tool_name}' not found"
super().__init__(message, details)
class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}"
super().__init__(message, details)

@ -1,156 +1,243 @@
from typing import Any, Optional, Callable from typing import List, Optional, Dict, Any, Callable
from swarms.tools.json_former import Jsonformer from loguru import logger
from swarms.utils.loguru_logger import initialize_logger from swarms.agents.exceptions import (
ToolAgentError,
logger = initialize_logger(log_folder="tool_agent") ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
class ToolAgent: class ToolAgent:
""" """
Represents a tool agent that performs a specific task using a model and tokenizer. A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
"""
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
system_prompt: Optional[str] = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
max_completion_tokens: int = 4000,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
retry_attempts: int = 3,
retry_interval: float = 1.0,
*args,
**kwargs,
):
"""
Initialize the vLLM wrapper with the given parameters.
Args: Args:
name (str): The name of the tool agent. model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
description (str): A description of the tool agent. system_prompt (str, optional): The system prompt to use. Defaults to None.
model (Any): The model used by the tool agent. stream (bool): Whether to stream the output. Defaults to False.
tokenizer (Any): The tokenizer used by the tool agent. temperature (float): The temperature for sampling. Defaults to 0.5.
json_schema (Any): The JSON schema used by the tool agent. max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
*args: Variable length arguments. max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
**kwargs: Keyword arguments. tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
Attributes: parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
name (str): The name of the tool agent. retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
description (str): A description of the tool agent. retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
model (Any): The model used by the tool agent. """
tokenizer (Any): The tokenizer used by the tool agent. self.model_name = model_name
json_schema (Any): The JSON schema used by the tool agent. self.system_prompt = system_prompt
self.stream = stream
Methods: self.temperature = temperature
run: Runs the tool agent for a specific task. self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
self.retry_attempts = retry_attempts
self.retry_interval = retry_interval
# Initialize vLLM
try:
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
except Exception as e:
raise ToolExecutionError(
"model_initialization",
e,
{"model_name": model_name, "kwargs": kwargs}
)
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
"""
Validate tool parameters before execution.
Args:
tool_name (str): Name of the tool to validate
parameters (Dict[str, Any]): Parameters to validate
Raises: Raises:
Exception: If an error occurs while running the tool agent. ToolValidationError: If validation fails
"""
if not self.tools_list_dictionary:
Example: raise ToolValidationError(
from transformers import AutoModelForCausalLM, AutoTokenizer tool_name,
from swarms import ToolAgent "parameters",
"No tools available for validation"
)
tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
None
)
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b") if not tool_spec:
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b") raise ToolNotFoundError(tool_name)
json_schema = { required_params = {
"type": "object", param["name"] for param in tool_spec["parameters"]
"properties": { if param.get("required", True)
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {
"type": "array",
"items": {"type": "string"}
}
}
} }
task = "Generate a person's information based on the following schema:" missing_params = required_params - set(parameters.keys())
agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema) if missing_params:
generated_data = agent.run(task) raise ToolParameterError(
tool_name,
f"Missing required parameters: {', '.join(missing_params)}"
)
print(generated_data) def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
""" """
Execute a function with retry logic.
Args:
func (Callable): Function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Any: Result of the function execution
Raises:
ToolExecutionError: If all retry attempts fail
"""
last_error = None
for attempt in range(self.retry_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
logger.warning(
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
)
if attempt < self.retry_attempts - 1:
time.sleep(self.retry_interval)
def __init__( raise ToolExecutionError(
self, func.__name__,
name: str = "Function Calling Agent", last_error,
description: str = "Generates a function based on the input json schema and the task", {"attempts": self.retry_attempts}
model: Any = None,
tokenizer: Any = None,
json_schema: Any = None,
max_number_tokens: int = 500,
parsing_function: Optional[Callable] = None,
llm: Any = None,
*args,
**kwargs,
):
super().__init__(
agent_name=name,
agent_description=description,
llm=llm,
**kwargs,
) )
self.name = name
self.description = description
self.model = model
self.tokenizer = tokenizer
self.json_schema = json_schema
self.max_number_tokens = max_number_tokens
self.parsing_function = parsing_function
def run(self, task: str, *args, **kwargs): def run(self, task: str, *args, **kwargs) -> str:
""" """
Run the tool agent for the specified task. Run the tool agent for the specified task.
Args: Args:
task (str): The task to be performed by the tool agent. task (str): The task to be performed by the tool agent.
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Returns: Returns:
The output of the tool agent. The output of the tool agent.
Raises: Raises:
Exception: If an error occurs during the execution of the tool agent. ToolExecutionError: If an error occurs during execution.
""" """
try: try:
if self.model: if not self.llm:
logger.info(f"Running {self.name} for task: {task}") raise ToolExecutionError(
self.toolagent = Jsonformer( "run",
model=self.model, Exception("LLM not initialized"),
tokenizer=self.tokenizer, {"task": task}
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
) )
if self.parsing_function: logger.info(f"Running task: {task}")
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
elif self.llm:
logger.info(f"Running {self.name} for task: {task}")
self.toolagent = Jsonformer(
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
if self.parsing_function: # Prepare the prompt
out = self.parsing_function(self.toolagent()) prompt = self._prepare_prompt(task)
else:
out = self.toolagent()
return out # Execute with retry logic
outputs = self._execute_with_retry(
else: self.llm.generate,
raise Exception( prompt,
"Either model or llm should be provided to the" self.sampling_params
" ToolAgent"
) )
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error: except Exception as error:
logger.error( logger.error(f"Error running task: {error}")
f"Error running {self.name} for task: {task}" raise ToolExecutionError(
"run",
error,
{"task": task, "args": args, "kwargs": kwargs}
) )
raise error
def __call__(self, task: str, *args, **kwargs): def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
Raises:
ToolExecutionError: If an error occurs during batch execution.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
try:
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
try:
result = self.run(task)
results.append(result)
except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}")
results.append(f"Error: {str(e)}")
continue
logger.info("Completed all tasks.")
return results
except Exception as error:
logger.error(f"Error in batch execution: {error}")
raise ToolExecutionError(
"batched_run",
error,
{"tasks": tasks, "batch_size": batch_size}
)

@ -0,0 +1,168 @@
from typing import Dict, List, Optional, Any, Union
from loguru import logger
from typedb.client import TypeDB, SessionType, TransactionType
from typedb.api.connection.transaction import Transaction
from dataclasses import dataclass
import json
@dataclass
class TypeDBConfig:
"""Configuration for TypeDB connection."""
uri: str = "localhost:1729"
database: str = "swarms"
username: Optional[str] = None
password: Optional[str] = None
timeout: int = 30
class TypeDBWrapper:
"""
A wrapper class for TypeDB that provides graph database operations for Swarms.
This class handles connection, schema management, and data operations.
"""
def __init__(self, config: Optional[TypeDBConfig] = None):
"""
Initialize the TypeDB wrapper with the given configuration.
Args:
config (Optional[TypeDBConfig]): Configuration for TypeDB connection.
"""
self.config = config or TypeDBConfig()
self.client = None
self.session = None
self._connect()
def _connect(self) -> None:
"""Establish connection to TypeDB."""
try:
self.client = TypeDB.core_client(self.config.uri)
if self.config.username and self.config.password:
self.session = self.client.session(
self.config.database,
SessionType.DATA,
self.config.username,
self.config.password
)
else:
self.session = self.client.session(
self.config.database,
SessionType.DATA
)
logger.info(f"Connected to TypeDB at {self.config.uri}")
except Exception as e:
logger.error(f"Failed to connect to TypeDB: {e}")
raise
def _ensure_connection(self) -> None:
"""Ensure connection is active, reconnect if necessary."""
if not self.session or not self.session.is_open():
self._connect()
def define_schema(self, schema: str) -> None:
"""
Define the database schema.
Args:
schema (str): TypeQL schema definition.
"""
try:
with self.session.transaction(TransactionType.WRITE) as transaction:
transaction.query.define(schema)
transaction.commit()
logger.info("Schema defined successfully")
except Exception as e:
logger.error(f"Failed to define schema: {e}")
raise
def insert_data(self, query: str) -> None:
"""
Insert data using TypeQL query.
Args:
query (str): TypeQL insert query.
"""
try:
with self.session.transaction(TransactionType.WRITE) as transaction:
transaction.query.insert(query)
transaction.commit()
logger.info("Data inserted successfully")
except Exception as e:
logger.error(f"Failed to insert data: {e}")
raise
def query_data(self, query: str) -> List[Dict[str, Any]]:
"""
Query data using TypeQL query.
Args:
query (str): TypeQL query.
Returns:
List[Dict[str, Any]]: Query results.
"""
try:
with self.session.transaction(TransactionType.READ) as transaction:
result = transaction.query.get(query)
return [self._convert_concept_to_dict(concept) for concept in result]
except Exception as e:
logger.error(f"Failed to query data: {e}")
raise
def _convert_concept_to_dict(self, concept: Any) -> Dict[str, Any]:
"""
Convert a TypeDB concept to a dictionary.
Args:
concept: TypeDB concept.
Returns:
Dict[str, Any]: Dictionary representation of the concept.
"""
try:
if hasattr(concept, "get_type"):
concept_type = concept.get_type()
if hasattr(concept, "get_value"):
return {
"type": concept_type.get_label_name(),
"value": concept.get_value()
}
elif hasattr(concept, "get_attributes"):
return {
"type": concept_type.get_label_name(),
"attributes": {
attr.get_type().get_label_name(): attr.get_value()
for attr in concept.get_attributes()
}
}
return {"type": "unknown", "value": str(concept)}
except Exception as e:
logger.error(f"Failed to convert concept to dict: {e}")
return {"type": "error", "value": str(e)}
def delete_data(self, query: str) -> None:
"""
Delete data using TypeQL query.
Args:
query (str): TypeQL delete query.
"""
try:
with self.session.transaction(TransactionType.WRITE) as transaction:
transaction.query.delete(query)
transaction.commit()
logger.info("Data deleted successfully")
except Exception as e:
logger.error(f"Failed to delete data: {e}")
raise
def close(self) -> None:
"""Close the TypeDB connection."""
try:
if self.session:
self.session.close()
if self.client:
self.client.close()
logger.info("TypeDB connection closed")
except Exception as e:
logger.error(f"Failed to close TypeDB connection: {e}")
raise
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()

@ -0,0 +1,138 @@
from typing import List, Optional, Dict, Any
from loguru import logger
try:
from vllm import LLM, SamplingParams
except ImportError:
import subprocess
import sys
print("Installing vllm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"])
print("vllm installed")
from vllm import LLM, SamplingParams
class VLLMWrapper:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
"""
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
system_prompt: Optional[str] = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
max_completion_tokens: int = 4000,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
*args,
**kwargs,
):
"""
Initialize the vLLM wrapper with the given parameters.
Args:
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool): Whether to stream the output. Defaults to False.
temperature (float): The temperature for sampling. Defaults to 0.5.
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def run(self, task: str, *args, **kwargs) -> str:
"""
Run the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
try:
prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error:
logger.error(f"Error in VLLMWrapper: {error}")
raise error
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
results.append(self.run(task))
logger.info("Completed all tasks.")
return results

@ -1,8 +1,15 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent from swarms import ToolAgent
from swarms.agents.exceptions import (
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
def test_tool_agent_init(): def test_tool_agent_init():
@ -99,3 +106,123 @@ def test_tool_agent_init_with_kwargs():
agent.max_string_token_length agent.max_string_token_length
== kwargs["max_string_token_length"] == kwargs["max_string_token_length"]
) )
def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters."""
agent = ToolAgent(
model_name="test-model",
temperature=0.7,
max_tokens=1000
)
assert agent.model_name == "test-model"
assert agent.temperature == 0.7
assert agent.max_tokens == 1000
assert agent.retry_attempts == 3
assert agent.retry_interval == 1.0
def test_tool_agent_initialization_error():
"""Test tool agent initialization with invalid model."""
with pytest.raises(ToolExecutionError) as exc_info:
ToolAgent(model_name="invalid-model")
assert "model_initialization" in str(exc_info.value)
def test_tool_validation():
"""Test tool parameter validation."""
tools_list = [{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False}
]
}]
agent = ToolAgent(tools_list_dictionary=tools_list)
# Test missing required parameter
with pytest.raises(ToolParameterError) as exc_info:
agent._validate_tool("test_tool", {})
assert "Missing required parameters" in str(exc_info.value)
# Test valid parameters
agent._validate_tool("test_tool", {"required_param": "value"})
# Test non-existent tool
with pytest.raises(ToolNotFoundError) as exc_info:
agent._validate_tool("non_existent_tool", {})
assert "Tool 'non_existent_tool' not found" in str(exc_info.value)
def test_retry_mechanism():
"""Test retry mechanism for failed operations."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Exception("First attempt failed"),
Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
# Test successful retry
result = agent.run("test task")
assert result == "Success"
assert mock_llm.generate.call_count == 3
# Test all retries failing
mock_llm.generate.side_effect = Exception("All attempts failed")
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "All attempts failed" in str(exc_info.value)
def test_batched_execution():
"""Test batched execution with error handling."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
tasks = ["Task 1", "Task 2", "Task 3"]
results = agent.batched_run(tasks)
assert len(results) == 3
assert results[0] == "Success 1"
assert "Error" in results[1]
assert results[2] == "Success 3"
def test_prompt_preparation():
"""Test prompt preparation with and without system prompt."""
# Test without system prompt
agent = ToolAgent()
prompt = agent._prepare_prompt("test task")
assert prompt == "User: test task\nAssistant:"
# Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:"
def test_tool_execution_error_handling():
"""Test error handling during tool execution."""
agent = ToolAgent(model_name="test-model")
agent.llm = None # Simulate uninitialized LLM
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "LLM not initialized" in str(exc_info.value)
# Test with invalid parameters
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task", invalid_param="value")
assert "Error running task" in str(exc_info.value)

@ -0,0 +1,129 @@
import pytest
from unittest.mock import Mock, patch
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
@pytest.fixture
def mock_typedb():
"""Mock TypeDB client and session."""
with patch('swarms.utils.typedb_wrapper.TypeDB') as mock_typedb:
mock_client = Mock()
mock_session = Mock()
mock_typedb.core_client.return_value = mock_client
mock_client.session.return_value = mock_session
yield mock_typedb, mock_client, mock_session
@pytest.fixture
def typedb_wrapper(mock_typedb):
"""Create a TypeDBWrapper instance with mocked dependencies."""
config = TypeDBConfig(
uri="test:1729",
database="test_db",
username="test_user",
password="test_pass"
)
return TypeDBWrapper(config)
def test_initialization(typedb_wrapper):
"""Test TypeDBWrapper initialization."""
assert typedb_wrapper.config.uri == "test:1729"
assert typedb_wrapper.config.database == "test_db"
assert typedb_wrapper.config.username == "test_user"
assert typedb_wrapper.config.password == "test_pass"
def test_connect(typedb_wrapper, mock_typedb):
"""Test connection to TypeDB."""
mock_typedb, mock_client, mock_session = mock_typedb
typedb_wrapper._connect()
mock_typedb.core_client.assert_called_once_with("test:1729")
mock_client.session.assert_called_once_with(
"test_db",
"DATA",
"test_user",
"test_pass"
)
def test_define_schema(typedb_wrapper, mock_typedb):
"""Test schema definition."""
mock_typedb, mock_client, mock_session = mock_typedb
schema = "define person sub entity;"
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
mock_transaction.return_value.__enter__.return_value.query.define.return_value = None
typedb_wrapper.define_schema(schema)
mock_transaction.assert_called_once_with("WRITE")
mock_transaction.return_value.__enter__.return_value.query.define.assert_called_once_with(schema)
def test_insert_data(typedb_wrapper, mock_typedb):
"""Test data insertion."""
mock_typedb, mock_client, mock_session = mock_typedb
query = "insert $p isa person;"
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
mock_transaction.return_value.__enter__.return_value.query.insert.return_value = None
typedb_wrapper.insert_data(query)
mock_transaction.assert_called_once_with("WRITE")
mock_transaction.return_value.__enter__.return_value.query.insert.assert_called_once_with(query)
def test_query_data(typedb_wrapper, mock_typedb):
"""Test data querying."""
mock_typedb, mock_client, mock_session = mock_typedb
query = "match $p isa person; get;"
mock_result = [Mock()]
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
mock_transaction.return_value.__enter__.return_value.query.get.return_value = mock_result
result = typedb_wrapper.query_data(query)
mock_transaction.assert_called_once_with("READ")
mock_transaction.return_value.__enter__.return_value.query.get.assert_called_once_with(query)
assert len(result) == 1
def test_delete_data(typedb_wrapper, mock_typedb):
"""Test data deletion."""
mock_typedb, mock_client, mock_session = mock_typedb
query = "match $p isa person; delete $p;"
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
mock_transaction.return_value.__enter__.return_value.query.delete.return_value = None
typedb_wrapper.delete_data(query)
mock_transaction.assert_called_once_with("WRITE")
mock_transaction.return_value.__enter__.return_value.query.delete.assert_called_once_with(query)
def test_close(typedb_wrapper, mock_typedb):
"""Test connection closing."""
mock_typedb, mock_client, mock_session = mock_typedb
typedb_wrapper.close()
mock_session.close.assert_called_once()
mock_client.close.assert_called_once()
def test_context_manager(typedb_wrapper, mock_typedb):
"""Test context manager functionality."""
mock_typedb, mock_client, mock_session = mock_typedb
with typedb_wrapper as db:
assert db == typedb_wrapper
mock_session.close.assert_called_once()
mock_client.close.assert_called_once()
def test_error_handling(typedb_wrapper, mock_typedb):
"""Test error handling."""
mock_typedb, mock_client, mock_session = mock_typedb
# Test connection error
mock_typedb.core_client.side_effect = Exception("Connection failed")
with pytest.raises(Exception) as exc_info:
typedb_wrapper._connect()
assert "Connection failed" in str(exc_info.value)
# Test query error
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
mock_transaction.return_value.__enter__.return_value.query.get.side_effect = Exception("Query failed")
with pytest.raises(Exception) as exc_info:
typedb_wrapper.query_data("test query")
assert "Query failed" in str(exc_info.value)
Loading…
Cancel
Save