pull/802/merge
Pavan Kumar 3 weeks ago committed by GitHub
commit f29a1e2180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,22 @@
[flake8]
max-line-length = 88
extend-ignore = E203, W503
exclude =
.git,
__pycache__,
build,
dist,
*.egg-info,
.eggs,
.tox,
.venv,
venv,
.env,
.pytest_cache,
.coverage,
htmlcov,
.mypy_cache,
.ruff_cache
per-file-ignores =
__init__.py: F401
max-complexity = 10

@ -0,0 +1,44 @@
from swarms.utils.vllm_wrapper import VLLMWrapper
def main():
# Initialize the vLLM wrapper with a model
# Note: You'll need to have the model downloaded or specify a HuggingFace model ID
llm = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
temperature=0.7,
max_tokens=1000,
)
# Example task
task = "What are the benefits of using vLLM for inference?"
# Run inference
response = llm.run(task)
print("Response:", response)
# Example with system prompt
llm_with_system = VLLMWrapper(
model_name="meta-llama/Llama-2-7b-chat-hf", # Replace with your model path or HF model ID
system_prompt="You are a helpful AI assistant that provides concise answers.",
temperature=0.7,
)
# Run inference with system prompt
response = llm_with_system.run(task)
print("\nResponse with system prompt:", response)
# Example with batched inference
tasks = [
"What is vLLM?",
"How does vLLM improve inference speed?",
"What are the main features of vLLM?"
]
responses = llm.batched_run(tasks, batch_size=2)
print("\nBatched responses:")
for task, response in zip(tasks, responses):
print(f"\nTask: {task}")
print(f"Response: {response}")
if __name__ == "__main__":
main()

@ -0,0 +1,4 @@
[pyupgrade]
py3-plus = True
py39-plus = True
keep-runtime-typing = True

@ -1,4 +1,3 @@
torch>=2.1.1,<3.0 torch>=2.1.1,<3.0
transformers>=4.39.0,<4.51.0 transformers>=4.39.0,<4.51.0
asyncio>=3.4.3,<4.0 asyncio>=3.4.3,<4.0
@ -23,3 +22,10 @@ pytest>=8.1.1
networkx networkx
aiofiles aiofiles
httpx httpx
vllm>=0.2.0
flake8>=6.1.0
flake8-bugbear>=23.3.12
flake8-comprehensions>=3.12.0
flake8-simplify>=0.19.3
flake8-unused-arguments>=0.0.4
pyupgrade>=3.15.0

@ -0,0 +1,55 @@
#!/usr/bin/env python3
import subprocess
import sys
from pathlib import Path
def run_command(command: list[str], cwd: Path) -> bool:
"""Run a command and return True if successful."""
try:
result = subprocess.run(
command,
cwd=cwd,
capture_output=True,
text=True,
check=True
)
return True
except subprocess.CalledProcessError as e:
print(f"Error running {' '.join(command)}:")
print(e.stdout)
print(e.stderr, file=sys.stderr)
return False
def main():
"""Run all code quality checks."""
root_dir = Path(__file__).parent.parent
success = True
# Run flake8
print("\nRunning flake8...")
if not run_command(["flake8", "swarms", "tests"], root_dir):
success = False
# Run pyupgrade
print("\nRunning pyupgrade...")
if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir):
success = False
# Run black
print("\nRunning black...")
if not run_command(["black", "--check", "swarms", "tests"], root_dir):
success = False
# Run ruff
print("\nRunning ruff...")
if not run_command(["ruff", "check", "swarms", "tests"], root_dir):
success = False
if not success:
print("\nCode quality checks failed. Please fix the issues and try again.")
sys.exit(1)
else:
print("\nAll code quality checks passed!")
if __name__ == "__main__":
main()

@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
class ToolAgentError(Exception):
"""Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}"
super().__init__(message, details)
class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details)
class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
message = f"Tool '{tool_name}' not found"
super().__init__(message, details)
class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}"
super().__init__(message, details)

@ -1,156 +1,243 @@
from typing import Any, Optional, Callable from typing import List, Optional, Dict, Any, Callable
from swarms.tools.json_former import Jsonformer from loguru import logger
from swarms.utils.loguru_logger import initialize_logger from swarms.agents.exceptions import (
ToolAgentError,
logger = initialize_logger(log_folder="tool_agent") ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
class ToolAgent: class ToolAgent:
""" """
Represents a tool agent that performs a specific task using a model and tokenizer. A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
Args:
name (str): The name of the tool agent.
description (str): A description of the tool agent.
model (Any): The model used by the tool agent.
tokenizer (Any): The tokenizer used by the tool agent.
json_schema (Any): The JSON schema used by the tool agent.
*args: Variable length arguments.
**kwargs: Keyword arguments.
Attributes:
name (str): The name of the tool agent.
description (str): A description of the tool agent.
model (Any): The model used by the tool agent.
tokenizer (Any): The tokenizer used by the tool agent.
json_schema (Any): The JSON schema used by the tool agent.
Methods:
run: Runs the tool agent for a specific task.
Raises:
Exception: If an error occurs while running the tool agent.
Example:
from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b")
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")
json_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
"is_student": {"type": "boolean"},
"courses": {
"type": "array",
"items": {"type": "string"}
}
}
}
task = "Generate a person's information based on the following schema:"
agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema)
generated_data = agent.run(task)
print(generated_data)
""" """
def __init__( def __init__(
self, self,
name: str = "Function Calling Agent", model_name: str = "meta-llama/Llama-2-7b-chat-hf",
description: str = "Generates a function based on the input json schema and the task", system_prompt: Optional[str] = None,
model: Any = None, stream: bool = False,
tokenizer: Any = None, temperature: float = 0.5,
json_schema: Any = None, max_tokens: int = 4000,
max_number_tokens: int = 500, max_completion_tokens: int = 4000,
parsing_function: Optional[Callable] = None, tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
llm: Any = None, tool_choice: str = "auto",
parallel_tool_calls: bool = False,
retry_attempts: int = 3,
retry_interval: float = 1.0,
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__( """
agent_name=name, Initialize the vLLM wrapper with the given parameters.
agent_description=description, Args:
llm=llm, model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
**kwargs, system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool): Whether to stream the output. Defaults to False.
temperature (float): The temperature for sampling. Defaults to 0.5.
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
self.retry_attempts = retry_attempts
self.retry_interval = retry_interval
# Initialize vLLM
try:
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
except Exception as e:
raise ToolExecutionError(
"model_initialization",
e,
{"model_name": model_name, "kwargs": kwargs}
)
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
"""
Validate tool parameters before execution.
Args:
tool_name (str): Name of the tool to validate
parameters (Dict[str, Any]): Parameters to validate
Raises:
ToolValidationError: If validation fails
"""
if not self.tools_list_dictionary:
raise ToolValidationError(
tool_name,
"parameters",
"No tools available for validation"
)
tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
None
) )
self.name = name
self.description = description if not tool_spec:
self.model = model raise ToolNotFoundError(tool_name)
self.tokenizer = tokenizer
self.json_schema = json_schema required_params = {
self.max_number_tokens = max_number_tokens param["name"] for param in tool_spec["parameters"]
self.parsing_function = parsing_function if param.get("required", True)
}
def run(self, task: str, *args, **kwargs):
missing_params = required_params - set(parameters.keys())
if missing_params:
raise ToolParameterError(
tool_name,
f"Missing required parameters: {', '.join(missing_params)}"
)
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
""" """
Run the tool agent for the specified task. Execute a function with retry logic.
Args:
func (Callable): Function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Any: Result of the function execution
Raises:
ToolExecutionError: If all retry attempts fail
"""
last_error = None
for attempt in range(self.retry_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
logger.warning(
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
)
if attempt < self.retry_attempts - 1:
time.sleep(self.retry_interval)
raise ToolExecutionError(
func.__name__,
last_error,
{"attempts": self.retry_attempts}
)
def run(self, task: str, *args, **kwargs) -> str:
"""
Run the tool agent for the specified task.
Args: Args:
task (str): The task to be performed by the tool agent. task (str): The task to be performed by the tool agent.
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Returns: Returns:
The output of the tool agent. The output of the tool agent.
Raises: Raises:
Exception: If an error occurs during the execution of the tool agent. ToolExecutionError: If an error occurs during execution.
""" """
try: try:
if self.model: if not self.llm:
logger.info(f"Running {self.name} for task: {task}") raise ToolExecutionError(
self.toolagent = Jsonformer( "run",
model=self.model, Exception("LLM not initialized"),
tokenizer=self.tokenizer, {"task": task}
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
) )
if self.parsing_function: logger.info(f"Running task: {task}")
out = self.parsing_function(self.toolagent())
else:
out = self.toolagent()
return out
elif self.llm:
logger.info(f"Running {self.name} for task: {task}")
self.toolagent = Jsonformer(
json_schema=self.json_schema,
llm=self.llm,
prompt=task,
max_number_tokens=self.max_number_tokens,
*args,
**kwargs,
)
if self.parsing_function: # Prepare the prompt
out = self.parsing_function(self.toolagent()) prompt = self._prepare_prompt(task)
else:
out = self.toolagent()
return out # Execute with retry logic
outputs = self._execute_with_retry(
self.llm.generate,
prompt,
self.sampling_params
)
else: response = outputs[0].outputs[0].text.strip()
raise Exception( return response
"Either model or llm should be provided to the"
" ToolAgent"
)
except Exception as error: except Exception as error:
logger.error( logger.error(f"Error running task: {error}")
f"Error running {self.name} for task: {task}" raise ToolExecutionError(
"run",
error,
{"task": task, "args": args, "kwargs": kwargs}
) )
raise error
def __call__(self, task: str, *args, **kwargs): def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
Raises:
ToolExecutionError: If an error occurs during batch execution.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
try:
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
try:
result = self.run(task)
results.append(result)
except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}")
results.append(f"Error: {str(e)}")
continue
logger.info("Completed all tasks.")
return results
except Exception as error:
logger.error(f"Error in batch execution: {error}")
raise ToolExecutionError(
"batched_run",
error,
{"tasks": tasks, "batch_size": batch_size}
)

@ -0,0 +1,138 @@
from typing import List, Optional, Dict, Any
from loguru import logger
try:
from vllm import LLM, SamplingParams
except ImportError:
import subprocess
import sys
print("Installing vllm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "vllm"])
print("vllm installed")
from vllm import LLM, SamplingParams
class VLLMWrapper:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
This class handles model initialization and inference using vLLM.
"""
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
system_prompt: Optional[str] = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
max_completion_tokens: int = 4000,
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
*args,
**kwargs,
):
"""
Initialize the vLLM wrapper with the given parameters.
Args:
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool): Whether to stream the output. Defaults to False.
temperature (float): The temperature for sampling. Defaults to 0.5.
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
tool_choice (str): How to choose tools. Defaults to "auto".
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
# Initialize vLLM
self.llm = LLM(model=model_name, **kwargs)
self.sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
def _prepare_prompt(self, task: str) -> str:
"""
Prepare the prompt for the given task.
Args:
task (str): The task to prepare the prompt for.
Returns:
str: The prepared prompt.
"""
if self.system_prompt:
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
return f"User: {task}\nAssistant:"
def run(self, task: str, *args, **kwargs) -> str:
"""
Run the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
try:
prompt = self._prepare_prompt(task)
outputs = self.llm.generate(prompt, self.sampling_params)
response = outputs[0].outputs[0].text.strip()
return response
except Exception as error:
logger.error(f"Error in VLLMWrapper: {error}")
raise error
def __call__(self, task: str, *args, **kwargs) -> str:
"""
Call the model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The model's response.
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
tasks (List[str]): List of tasks to run.
batch_size (int): Size of each batch. Defaults to 10.
Returns:
List[str]: List of model responses.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
results.append(self.run(task))
logger.info("Completed all tasks.")
return results

@ -1,8 +1,15 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent from swarms import ToolAgent
from swarms.agents.exceptions import (
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
)
def test_tool_agent_init(): def test_tool_agent_init():
@ -99,3 +106,123 @@ def test_tool_agent_init_with_kwargs():
agent.max_string_token_length agent.max_string_token_length
== kwargs["max_string_token_length"] == kwargs["max_string_token_length"]
) )
def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters."""
agent = ToolAgent(
model_name="test-model",
temperature=0.7,
max_tokens=1000
)
assert agent.model_name == "test-model"
assert agent.temperature == 0.7
assert agent.max_tokens == 1000
assert agent.retry_attempts == 3
assert agent.retry_interval == 1.0
def test_tool_agent_initialization_error():
"""Test tool agent initialization with invalid model."""
with pytest.raises(ToolExecutionError) as exc_info:
ToolAgent(model_name="invalid-model")
assert "model_initialization" in str(exc_info.value)
def test_tool_validation():
"""Test tool parameter validation."""
tools_list = [{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False}
]
}]
agent = ToolAgent(tools_list_dictionary=tools_list)
# Test missing required parameter
with pytest.raises(ToolParameterError) as exc_info:
agent._validate_tool("test_tool", {})
assert "Missing required parameters" in str(exc_info.value)
# Test valid parameters
agent._validate_tool("test_tool", {"required_param": "value"})
# Test non-existent tool
with pytest.raises(ToolNotFoundError) as exc_info:
agent._validate_tool("non_existent_tool", {})
assert "Tool 'non_existent_tool' not found" in str(exc_info.value)
def test_retry_mechanism():
"""Test retry mechanism for failed operations."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Exception("First attempt failed"),
Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
# Test successful retry
result = agent.run("test task")
assert result == "Success"
assert mock_llm.generate.call_count == 3
# Test all retries failing
mock_llm.generate.side_effect = Exception("All attempts failed")
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "All attempts failed" in str(exc_info.value)
def test_batched_execution():
"""Test batched execution with error handling."""
mock_llm = Mock()
mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")])
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
tasks = ["Task 1", "Task 2", "Task 3"]
results = agent.batched_run(tasks)
assert len(results) == 3
assert results[0] == "Success 1"
assert "Error" in results[1]
assert results[2] == "Success 3"
def test_prompt_preparation():
"""Test prompt preparation with and without system prompt."""
# Test without system prompt
agent = ToolAgent()
prompt = agent._prepare_prompt("test task")
assert prompt == "User: test task\nAssistant:"
# Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:"
def test_tool_execution_error_handling():
"""Test error handling during tool execution."""
agent = ToolAgent(model_name="test-model")
agent.llm = None # Simulate uninitialized LLM
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "LLM not initialized" in str(exc_info.value)
# Test with invalid parameters
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task", invalid_param="value")
assert "Error running task" in str(exc_info.value)

Loading…
Cancel
Save