Merge e397b8b19e
into 221e9419ec
commit
422b02b5b8
@ -0,0 +1,22 @@
|
|||||||
|
[flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
extend-ignore = E203, W503
|
||||||
|
exclude =
|
||||||
|
.git,
|
||||||
|
__pycache__,
|
||||||
|
build,
|
||||||
|
dist,
|
||||||
|
*.egg-info,
|
||||||
|
.eggs,
|
||||||
|
.tox,
|
||||||
|
.venv,
|
||||||
|
venv,
|
||||||
|
.env,
|
||||||
|
.pytest_cache,
|
||||||
|
.coverage,
|
||||||
|
htmlcov,
|
||||||
|
.mypy_cache,
|
||||||
|
.ruff_cache
|
||||||
|
per-file-ignores =
|
||||||
|
__init__.py: F401
|
||||||
|
max-complexity = 10
|
@ -0,0 +1,106 @@
|
|||||||
|
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize TypeDB wrapper with custom configuration
|
||||||
|
config = TypeDBConfig(
|
||||||
|
uri="localhost:1729",
|
||||||
|
database="swarms_example",
|
||||||
|
username="admin",
|
||||||
|
password="password"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define schema for a simple knowledge graph
|
||||||
|
schema = """
|
||||||
|
define
|
||||||
|
person sub entity,
|
||||||
|
owns name: string,
|
||||||
|
owns age: long,
|
||||||
|
plays role;
|
||||||
|
|
||||||
|
role sub entity,
|
||||||
|
owns title: string,
|
||||||
|
owns department: string;
|
||||||
|
|
||||||
|
works_at sub relation,
|
||||||
|
relates person,
|
||||||
|
relates role;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Example data insertion
|
||||||
|
insert_queries = [
|
||||||
|
"""
|
||||||
|
insert
|
||||||
|
$p isa person, has name "John Doe", has age 30;
|
||||||
|
$r isa role, has title "Software Engineer", has department "Engineering";
|
||||||
|
(person: $p, role: $r) isa works_at;
|
||||||
|
""",
|
||||||
|
"""
|
||||||
|
insert
|
||||||
|
$p isa person, has name "Jane Smith", has age 28;
|
||||||
|
$r isa role, has title "Data Scientist", has department "Data Science";
|
||||||
|
(person: $p, role: $r) isa works_at;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Example queries
|
||||||
|
query_queries = [
|
||||||
|
# Get all people
|
||||||
|
"match $p isa person; get;",
|
||||||
|
|
||||||
|
# Get people in Engineering department
|
||||||
|
"""
|
||||||
|
match
|
||||||
|
$p isa person;
|
||||||
|
$r isa role, has department "Engineering";
|
||||||
|
(person: $p, role: $r) isa works_at;
|
||||||
|
get $p;
|
||||||
|
""",
|
||||||
|
|
||||||
|
# Get people with their roles
|
||||||
|
"""
|
||||||
|
match
|
||||||
|
$p isa person, has name $n;
|
||||||
|
$r isa role, has title $t;
|
||||||
|
(person: $p, role: $r) isa works_at;
|
||||||
|
get $n, $t;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with TypeDBWrapper(config) as db:
|
||||||
|
# Define schema
|
||||||
|
print("Defining schema...")
|
||||||
|
db.define_schema(schema)
|
||||||
|
|
||||||
|
# Insert data
|
||||||
|
print("\nInserting data...")
|
||||||
|
for query in insert_queries:
|
||||||
|
db.insert_data(query)
|
||||||
|
|
||||||
|
# Query data
|
||||||
|
print("\nQuerying data...")
|
||||||
|
for i, query in enumerate(query_queries, 1):
|
||||||
|
print(f"\nQuery {i}:")
|
||||||
|
results = db.query_data(query)
|
||||||
|
print(f"Results: {results}")
|
||||||
|
|
||||||
|
# Example of deleting data
|
||||||
|
print("\nDeleting data...")
|
||||||
|
delete_query = """
|
||||||
|
match
|
||||||
|
$p isa person, has name "John Doe";
|
||||||
|
delete $p;
|
||||||
|
"""
|
||||||
|
db.delete_data(delete_query)
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
print("\nVerifying deletion...")
|
||||||
|
verify_query = "match $p isa person, has name $n; get $n;"
|
||||||
|
results = db.query_data(verify_query)
|
||||||
|
print(f"Remaining people: {results}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,44 @@
|
|||||||
|
from swarms.utils.vllm_wrapper import VLLMWrapper
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Initialize the vLLM wrapper with a model
|
||||||
|
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
|
||||||
|
llm = VLLMWrapper(
|
||||||
|
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example task
|
||||||
|
task = "What are the benefits of using vLLM for inference?"
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
response = llm.run(task)
|
||||||
|
print("Response:", response)
|
||||||
|
|
||||||
|
# Example with system prompt
|
||||||
|
llm_with_system = VLLMWrapper(
|
||||||
|
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
|
||||||
|
system_prompt="You are a helpful AI assistant that provides concise answers.",
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run inference with system prompt
|
||||||
|
response = llm_with_system.run(task)
|
||||||
|
print("\nResponse with system prompt:", response)
|
||||||
|
|
||||||
|
# Example with batched inference
|
||||||
|
tasks = [
|
||||||
|
"What is vLLM?",
|
||||||
|
"How does vLLM improve inference speed?",
|
||||||
|
"What are the main features of vLLM?"
|
||||||
|
]
|
||||||
|
|
||||||
|
responses = llm.batched_run(tasks, batch_size=2)
|
||||||
|
print("\nBatched responses:")
|
||||||
|
for task, response in zip(tasks, responses):
|
||||||
|
print(f"\nTask: {task}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,4 @@
|
|||||||
|
[pyupgrade]
|
||||||
|
py3-plus = True
|
||||||
|
py39-plus = True
|
||||||
|
keep-runtime-typing = True
|
@ -0,0 +1,55 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def run_command(command: list[str], cwd: Path) -> bool:
|
||||||
|
"""Run a command and return True if successful."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
cwd=cwd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Error running {' '.join(command)}:")
|
||||||
|
print(e.stdout)
|
||||||
|
print(e.stderr, file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all code quality checks."""
|
||||||
|
root_dir = Path(__file__).parent.parent
|
||||||
|
success = True
|
||||||
|
|
||||||
|
# Run flake8
|
||||||
|
print("\nRunning flake8...")
|
||||||
|
if not run_command(["flake8", "swarms", "tests"], root_dir):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# Run pyupgrade
|
||||||
|
print("\nRunning pyupgrade...")
|
||||||
|
if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# Run black
|
||||||
|
print("\nRunning black...")
|
||||||
|
if not run_command(["black", "--check", "swarms", "tests"], root_dir):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# Run ruff
|
||||||
|
print("\nRunning ruff...")
|
||||||
|
if not run_command(["ruff", "check", "swarms", "tests"], root_dir):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
print("\nCode quality checks failed. Please fix the issues and try again.")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print("\nAll code quality checks passed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,32 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
class ToolAgentError(Exception):
|
||||||
|
"""Base exception for all tool agent errors."""
|
||||||
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
class ToolExecutionError(ToolAgentError):
|
||||||
|
"""Raised when a tool fails to execute."""
|
||||||
|
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
|
||||||
|
message = f"Failed to execute tool '{tool_name}': {str(error)}"
|
||||||
|
super().__init__(message, details)
|
||||||
|
|
||||||
|
class ToolValidationError(ToolAgentError):
|
||||||
|
"""Raised when tool parameters fail validation."""
|
||||||
|
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
|
||||||
|
super().__init__(message, details)
|
||||||
|
|
||||||
|
class ToolNotFoundError(ToolAgentError):
|
||||||
|
"""Raised when a requested tool is not found."""
|
||||||
|
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
message = f"Tool '{tool_name}' not found"
|
||||||
|
super().__init__(message, details)
|
||||||
|
|
||||||
|
class ToolParameterError(ToolAgentError):
|
||||||
|
"""Raised when tool parameters are invalid."""
|
||||||
|
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
message = f"Invalid parameters for tool '{tool_name}': {error}"
|
||||||
|
super().__init__(message, details)
|
@ -1,156 +1,243 @@
|
|||||||
from typing import Any, Optional, Callable
|
from typing import List, Optional, Dict, Any, Callable
|
||||||
from swarms.tools.json_former import Jsonformer
|
from loguru import logger
|
||||||
from swarms.utils.loguru_logger import initialize_logger
|
from swarms.agents.exceptions import (
|
||||||
|
ToolAgentError,
|
||||||
logger = initialize_logger(log_folder="tool_agent")
|
ToolExecutionError,
|
||||||
|
ToolValidationError,
|
||||||
|
ToolNotFoundError,
|
||||||
|
ToolParameterError
|
||||||
|
)
|
||||||
|
|
||||||
class ToolAgent:
|
class ToolAgent:
|
||||||
"""
|
"""
|
||||||
Represents a tool agent that performs a specific task using a model and tokenizer.
|
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
||||||
|
This class handles model initialization and inference using vLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
max_completion_tokens: int = 4000,
|
||||||
|
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
parallel_tool_calls: bool = False,
|
||||||
|
retry_attempts: int = 3,
|
||||||
|
retry_interval: float = 1.0,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the vLLM wrapper with the given parameters.
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the tool agent.
|
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
||||||
description (str): A description of the tool agent.
|
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||||
model (Any): The model used by the tool agent.
|
stream (bool): Whether to stream the output. Defaults to False.
|
||||||
tokenizer (Any): The tokenizer used by the tool agent.
|
temperature (float): The temperature for sampling. Defaults to 0.5.
|
||||||
json_schema (Any): The JSON schema used by the tool agent.
|
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
||||||
*args: Variable length arguments.
|
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
||||||
**kwargs: Keyword arguments.
|
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
||||||
|
tool_choice (str): How to choose tools. Defaults to "auto".
|
||||||
Attributes:
|
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
||||||
name (str): The name of the tool agent.
|
retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
|
||||||
description (str): A description of the tool agent.
|
retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
|
||||||
model (Any): The model used by the tool agent.
|
"""
|
||||||
tokenizer (Any): The tokenizer used by the tool agent.
|
self.model_name = model_name
|
||||||
json_schema (Any): The JSON schema used by the tool agent.
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
Methods:
|
self.temperature = temperature
|
||||||
run: Runs the tool agent for a specific task.
|
self.max_tokens = max_tokens
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.tools_list_dictionary = tools_list_dictionary
|
||||||
|
self.tool_choice = tool_choice
|
||||||
|
self.parallel_tool_calls = parallel_tool_calls
|
||||||
|
self.retry_attempts = retry_attempts
|
||||||
|
self.retry_interval = retry_interval
|
||||||
|
|
||||||
|
# Initialize vLLM
|
||||||
|
try:
|
||||||
|
self.llm = LLM(model=model_name, **kwargs)
|
||||||
|
self.sampling_params = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolExecutionError(
|
||||||
|
"model_initialization",
|
||||||
|
e,
|
||||||
|
{"model_name": model_name, "kwargs": kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate tool parameters before execution.
|
||||||
|
Args:
|
||||||
|
tool_name (str): Name of the tool to validate
|
||||||
|
parameters (Dict[str, Any]): Parameters to validate
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs while running the tool agent.
|
ToolValidationError: If validation fails
|
||||||
|
"""
|
||||||
|
if not self.tools_list_dictionary:
|
||||||
Example:
|
raise ToolValidationError(
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
tool_name,
|
||||||
from swarms import ToolAgent
|
"parameters",
|
||||||
|
"No tools available for validation"
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_spec = next(
|
||||||
|
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b")
|
if not tool_spec:
|
||||||
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")
|
raise ToolNotFoundError(tool_name)
|
||||||
|
|
||||||
json_schema = {
|
required_params = {
|
||||||
"type": "object",
|
param["name"] for param in tool_spec["parameters"]
|
||||||
"properties": {
|
if param.get("required", True)
|
||||||
"name": {"type": "string"},
|
|
||||||
"age": {"type": "number"},
|
|
||||||
"is_student": {"type": "boolean"},
|
|
||||||
"courses": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
task = "Generate a person's information based on the following schema:"
|
missing_params = required_params - set(parameters.keys())
|
||||||
agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema)
|
if missing_params:
|
||||||
generated_data = agent.run(task)
|
raise ToolParameterError(
|
||||||
|
tool_name,
|
||||||
|
f"Missing required parameters: {', '.join(missing_params)}"
|
||||||
|
)
|
||||||
|
|
||||||
print(generated_data)
|
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
|
Execute a function with retry logic.
|
||||||
|
Args:
|
||||||
|
func (Callable): Function to execute
|
||||||
|
*args: Positional arguments for the function
|
||||||
|
**kwargs: Keyword arguments for the function
|
||||||
|
Returns:
|
||||||
|
Any: Result of the function execution
|
||||||
|
Raises:
|
||||||
|
ToolExecutionError: If all retry attempts fail
|
||||||
|
"""
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(self.retry_attempts):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
logger.warning(
|
||||||
|
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
|
||||||
|
)
|
||||||
|
if attempt < self.retry_attempts - 1:
|
||||||
|
time.sleep(self.retry_interval)
|
||||||
|
|
||||||
def __init__(
|
raise ToolExecutionError(
|
||||||
self,
|
func.__name__,
|
||||||
name: str = "Function Calling Agent",
|
last_error,
|
||||||
description: str = "Generates a function based on the input json schema and the task",
|
{"attempts": self.retry_attempts}
|
||||||
model: Any = None,
|
|
||||||
tokenizer: Any = None,
|
|
||||||
json_schema: Any = None,
|
|
||||||
max_number_tokens: int = 500,
|
|
||||||
parsing_function: Optional[Callable] = None,
|
|
||||||
llm: Any = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
agent_name=name,
|
|
||||||
agent_description=description,
|
|
||||||
llm=llm,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.json_schema = json_schema
|
|
||||||
self.max_number_tokens = max_number_tokens
|
|
||||||
self.parsing_function = parsing_function
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs):
|
def run(self, task: str, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Run the tool agent for the specified task.
|
Run the tool agent for the specified task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (str): The task to be performed by the tool agent.
|
task (str): The task to be performed by the tool agent.
|
||||||
*args: Variable length argument list.
|
*args: Variable length argument list.
|
||||||
**kwargs: Arbitrary keyword arguments.
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output of the tool agent.
|
The output of the tool agent.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs during the execution of the tool agent.
|
ToolExecutionError: If an error occurs during execution.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.model:
|
if not self.llm:
|
||||||
logger.info(f"Running {self.name} for task: {task}")
|
raise ToolExecutionError(
|
||||||
self.toolagent = Jsonformer(
|
"run",
|
||||||
model=self.model,
|
Exception("LLM not initialized"),
|
||||||
tokenizer=self.tokenizer,
|
{"task": task}
|
||||||
json_schema=self.json_schema,
|
|
||||||
llm=self.llm,
|
|
||||||
prompt=task,
|
|
||||||
max_number_tokens=self.max_number_tokens,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.parsing_function:
|
logger.info(f"Running task: {task}")
|
||||||
out = self.parsing_function(self.toolagent())
|
|
||||||
else:
|
|
||||||
out = self.toolagent()
|
|
||||||
|
|
||||||
return out
|
|
||||||
elif self.llm:
|
|
||||||
logger.info(f"Running {self.name} for task: {task}")
|
|
||||||
self.toolagent = Jsonformer(
|
|
||||||
json_schema=self.json_schema,
|
|
||||||
llm=self.llm,
|
|
||||||
prompt=task,
|
|
||||||
max_number_tokens=self.max_number_tokens,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.parsing_function:
|
# Prepare the prompt
|
||||||
out = self.parsing_function(self.toolagent())
|
prompt = self._prepare_prompt(task)
|
||||||
else:
|
|
||||||
out = self.toolagent()
|
|
||||||
|
|
||||||
return out
|
# Execute with retry logic
|
||||||
|
outputs = self._execute_with_retry(
|
||||||
else:
|
self.llm.generate,
|
||||||
raise Exception(
|
prompt,
|
||||||
"Either model or llm should be provided to the"
|
self.sampling_params
|
||||||
" ToolAgent"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
return response
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error(
|
logger.error(f"Error running task: {error}")
|
||||||
f"Error running {self.name} for task: {task}"
|
raise ToolExecutionError(
|
||||||
|
"run",
|
||||||
|
error,
|
||||||
|
{"task": task, "args": args, "kwargs": kwargs}
|
||||||
)
|
)
|
||||||
raise error
|
|
||||||
|
|
||||||
def __call__(self, task: str, *args, **kwargs):
|
def _prepare_prompt(self, task: str) -> str:
|
||||||
|
"""
|
||||||
|
Prepare the prompt for the given task.
|
||||||
|
Args:
|
||||||
|
task (str): The task to prepare the prompt for.
|
||||||
|
Returns:
|
||||||
|
str: The prepared prompt.
|
||||||
|
"""
|
||||||
|
if self.system_prompt:
|
||||||
|
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
||||||
|
return f"User: {task}\nAssistant:"
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Call the model for the given task.
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
Returns:
|
||||||
|
str: The model's response.
|
||||||
|
"""
|
||||||
return self.run(task, *args, **kwargs)
|
return self.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
Run the model for multiple tasks in batches.
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): List of tasks to run.
|
||||||
|
batch_size (int): Size of each batch. Defaults to 10.
|
||||||
|
Returns:
|
||||||
|
List[str]: List of model responses.
|
||||||
|
Raises:
|
||||||
|
ToolExecutionError: If an error occurs during batch execution.
|
||||||
|
"""
|
||||||
|
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in range(0, len(tasks), batch_size):
|
||||||
|
batch = tasks[i:i + batch_size]
|
||||||
|
for task in batch:
|
||||||
|
logger.info(f"Running task: {task}")
|
||||||
|
try:
|
||||||
|
result = self.run(task)
|
||||||
|
results.append(result)
|
||||||
|
except ToolExecutionError as e:
|
||||||
|
logger.error(f"Failed to execute task '{task}': {e}")
|
||||||
|
results.append(f"Error: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Completed all tasks.")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"Error in batch execution: {error}")
|
||||||
|
raise ToolExecutionError(
|
||||||
|
"batched_run",
|
||||||
|
error,
|
||||||
|
{"tasks": tasks, "batch_size": batch_size}
|
||||||
|
)
|
||||||
|
@ -0,0 +1,168 @@
|
|||||||
|
from typing import Dict, List, Optional, Any, Union
|
||||||
|
from loguru import logger
|
||||||
|
from typedb.client import TypeDB, SessionType, TransactionType
|
||||||
|
from typedb.api.connection.transaction import Transaction
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypeDBConfig:
|
||||||
|
"""Configuration for TypeDB connection."""
|
||||||
|
uri: str = "localhost:1729"
|
||||||
|
database: str = "swarms"
|
||||||
|
username: Optional[str] = None
|
||||||
|
password: Optional[str] = None
|
||||||
|
timeout: int = 30
|
||||||
|
|
||||||
|
class TypeDBWrapper:
|
||||||
|
"""
|
||||||
|
A wrapper class for TypeDB that provides graph database operations for Swarms.
|
||||||
|
This class handles connection, schema management, and data operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[TypeDBConfig] = None):
|
||||||
|
"""
|
||||||
|
Initialize the TypeDB wrapper with the given configuration.
|
||||||
|
Args:
|
||||||
|
config (Optional[TypeDBConfig]): Configuration for TypeDB connection.
|
||||||
|
"""
|
||||||
|
self.config = config or TypeDBConfig()
|
||||||
|
self.client = None
|
||||||
|
self.session = None
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
def _connect(self) -> None:
|
||||||
|
"""Establish connection to TypeDB."""
|
||||||
|
try:
|
||||||
|
self.client = TypeDB.core_client(self.config.uri)
|
||||||
|
if self.config.username and self.config.password:
|
||||||
|
self.session = self.client.session(
|
||||||
|
self.config.database,
|
||||||
|
SessionType.DATA,
|
||||||
|
self.config.username,
|
||||||
|
self.config.password
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.session = self.client.session(
|
||||||
|
self.config.database,
|
||||||
|
SessionType.DATA
|
||||||
|
)
|
||||||
|
logger.info(f"Connected to TypeDB at {self.config.uri}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to TypeDB: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _ensure_connection(self) -> None:
|
||||||
|
"""Ensure connection is active, reconnect if necessary."""
|
||||||
|
if not self.session or not self.session.is_open():
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
def define_schema(self, schema: str) -> None:
|
||||||
|
"""
|
||||||
|
Define the database schema.
|
||||||
|
Args:
|
||||||
|
schema (str): TypeQL schema definition.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||||
|
transaction.query.define(schema)
|
||||||
|
transaction.commit()
|
||||||
|
logger.info("Schema defined successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to define schema: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def insert_data(self, query: str) -> None:
|
||||||
|
"""
|
||||||
|
Insert data using TypeQL query.
|
||||||
|
Args:
|
||||||
|
query (str): TypeQL insert query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||||
|
transaction.query.insert(query)
|
||||||
|
transaction.commit()
|
||||||
|
logger.info("Data inserted successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to insert data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def query_data(self, query: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Query data using TypeQL query.
|
||||||
|
Args:
|
||||||
|
query (str): TypeQL query.
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: Query results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.session.transaction(TransactionType.READ) as transaction:
|
||||||
|
result = transaction.query.get(query)
|
||||||
|
return [self._convert_concept_to_dict(concept) for concept in result]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to query data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _convert_concept_to_dict(self, concept: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a TypeDB concept to a dictionary.
|
||||||
|
Args:
|
||||||
|
concept: TypeDB concept.
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary representation of the concept.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if hasattr(concept, "get_type"):
|
||||||
|
concept_type = concept.get_type()
|
||||||
|
if hasattr(concept, "get_value"):
|
||||||
|
return {
|
||||||
|
"type": concept_type.get_label_name(),
|
||||||
|
"value": concept.get_value()
|
||||||
|
}
|
||||||
|
elif hasattr(concept, "get_attributes"):
|
||||||
|
return {
|
||||||
|
"type": concept_type.get_label_name(),
|
||||||
|
"attributes": {
|
||||||
|
attr.get_type().get_label_name(): attr.get_value()
|
||||||
|
for attr in concept.get_attributes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {"type": "unknown", "value": str(concept)}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert concept to dict: {e}")
|
||||||
|
return {"type": "error", "value": str(e)}
|
||||||
|
|
||||||
|
def delete_data(self, query: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete data using TypeQL query.
|
||||||
|
Args:
|
||||||
|
query (str): TypeQL delete query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||||
|
transaction.query.delete(query)
|
||||||
|
transaction.commit()
|
||||||
|
logger.info("Data deleted successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the TypeDB connection."""
|
||||||
|
try:
|
||||||
|
if self.session:
|
||||||
|
self.session.close()
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
logger.info("TypeDB connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to close TypeDB connection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.close()
|
@ -0,0 +1,138 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
except ImportError:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
print("Installing vllm")
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"])
|
||||||
|
print("vllm installed")
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
class VLLMWrapper:
|
||||||
|
"""
|
||||||
|
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
||||||
|
This class handles model initialization and inference using vLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
max_completion_tokens: int = 4000,
|
||||||
|
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
parallel_tool_calls: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the vLLM wrapper with the given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
||||||
|
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||||
|
stream (bool): Whether to stream the output. Defaults to False.
|
||||||
|
temperature (float): The temperature for sampling. Defaults to 0.5.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
||||||
|
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
||||||
|
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
||||||
|
tool_choice (str): How to choose tools. Defaults to "auto".
|
||||||
|
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.tools_list_dictionary = tools_list_dictionary
|
||||||
|
self.tool_choice = tool_choice
|
||||||
|
self.parallel_tool_calls = parallel_tool_calls
|
||||||
|
|
||||||
|
# Initialize vLLM
|
||||||
|
self.llm = LLM(model=model_name, **kwargs)
|
||||||
|
self.sampling_params = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_prompt(self, task: str) -> str:
|
||||||
|
"""
|
||||||
|
Prepare the prompt for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to prepare the prompt for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The prepared prompt.
|
||||||
|
"""
|
||||||
|
if self.system_prompt:
|
||||||
|
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
||||||
|
return f"User: {task}\nAssistant:"
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Run the model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The model's response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt = self._prepare_prompt(task)
|
||||||
|
|
||||||
|
outputs = self.llm.generate(prompt, self.sampling_params)
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"Error in VLLMWrapper: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Call the model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The model's response.
|
||||||
|
"""
|
||||||
|
return self.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
Run the model for multiple tasks in batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): List of tasks to run.
|
||||||
|
batch_size (int): Size of each batch. Defaults to 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of model responses.
|
||||||
|
"""
|
||||||
|
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in range(0, len(tasks), batch_size):
|
||||||
|
batch = tasks[i:i + batch_size]
|
||||||
|
for task in batch:
|
||||||
|
logger.info(f"Running task: {task}")
|
||||||
|
results.append(self.run(task))
|
||||||
|
|
||||||
|
logger.info("Completed all tasks.")
|
||||||
|
return results
|
@ -0,0 +1,129 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_typedb():
|
||||||
|
"""Mock TypeDB client and session."""
|
||||||
|
with patch('swarms.utils.typedb_wrapper.TypeDB') as mock_typedb:
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_typedb.core_client.return_value = mock_client
|
||||||
|
mock_client.session.return_value = mock_session
|
||||||
|
yield mock_typedb, mock_client, mock_session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def typedb_wrapper(mock_typedb):
|
||||||
|
"""Create a TypeDBWrapper instance with mocked dependencies."""
|
||||||
|
config = TypeDBConfig(
|
||||||
|
uri="test:1729",
|
||||||
|
database="test_db",
|
||||||
|
username="test_user",
|
||||||
|
password="test_pass"
|
||||||
|
)
|
||||||
|
return TypeDBWrapper(config)
|
||||||
|
|
||||||
|
def test_initialization(typedb_wrapper):
|
||||||
|
"""Test TypeDBWrapper initialization."""
|
||||||
|
assert typedb_wrapper.config.uri == "test:1729"
|
||||||
|
assert typedb_wrapper.config.database == "test_db"
|
||||||
|
assert typedb_wrapper.config.username == "test_user"
|
||||||
|
assert typedb_wrapper.config.password == "test_pass"
|
||||||
|
|
||||||
|
def test_connect(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test connection to TypeDB."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
typedb_wrapper._connect()
|
||||||
|
|
||||||
|
mock_typedb.core_client.assert_called_once_with("test:1729")
|
||||||
|
mock_client.session.assert_called_once_with(
|
||||||
|
"test_db",
|
||||||
|
"DATA",
|
||||||
|
"test_user",
|
||||||
|
"test_pass"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_define_schema(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test schema definition."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
schema = "define person sub entity;"
|
||||||
|
|
||||||
|
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.define.return_value = None
|
||||||
|
typedb_wrapper.define_schema(schema)
|
||||||
|
|
||||||
|
mock_transaction.assert_called_once_with("WRITE")
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.define.assert_called_once_with(schema)
|
||||||
|
|
||||||
|
def test_insert_data(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test data insertion."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
query = "insert $p isa person;"
|
||||||
|
|
||||||
|
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.insert.return_value = None
|
||||||
|
typedb_wrapper.insert_data(query)
|
||||||
|
|
||||||
|
mock_transaction.assert_called_once_with("WRITE")
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.insert.assert_called_once_with(query)
|
||||||
|
|
||||||
|
def test_query_data(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test data querying."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
query = "match $p isa person; get;"
|
||||||
|
mock_result = [Mock()]
|
||||||
|
|
||||||
|
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.get.return_value = mock_result
|
||||||
|
result = typedb_wrapper.query_data(query)
|
||||||
|
|
||||||
|
mock_transaction.assert_called_once_with("READ")
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.get.assert_called_once_with(query)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_delete_data(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test data deletion."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
query = "match $p isa person; delete $p;"
|
||||||
|
|
||||||
|
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.delete.return_value = None
|
||||||
|
typedb_wrapper.delete_data(query)
|
||||||
|
|
||||||
|
mock_transaction.assert_called_once_with("WRITE")
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.delete.assert_called_once_with(query)
|
||||||
|
|
||||||
|
def test_close(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test connection closing."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
typedb_wrapper.close()
|
||||||
|
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_context_manager(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test context manager functionality."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
|
||||||
|
with typedb_wrapper as db:
|
||||||
|
assert db == typedb_wrapper
|
||||||
|
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_error_handling(typedb_wrapper, mock_typedb):
|
||||||
|
"""Test error handling."""
|
||||||
|
mock_typedb, mock_client, mock_session = mock_typedb
|
||||||
|
|
||||||
|
# Test connection error
|
||||||
|
mock_typedb.core_client.side_effect = Exception("Connection failed")
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
typedb_wrapper._connect()
|
||||||
|
assert "Connection failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Test query error
|
||||||
|
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||||
|
mock_transaction.return_value.__enter__.return_value.query.get.side_effect = Exception("Query failed")
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
typedb_wrapper.query_data("test query")
|
||||||
|
assert "Query failed" in str(exc_info.value)
|
Loading…
Reference in new issue