fix examples and code formatting

pull/1018/head
Kye Gomez 3 weeks ago
parent 63a998bb1f
commit 5be0ab609e

@ -5,15 +5,18 @@ from swarms.utils.litellm_tokenizer import count_tokens
# Load environment variables from .env file
load_dotenv()
def demonstrate_truncation():
# Using a smaller context length to clearly see the truncation effect
context_length = 25
print(f"Creating a conversation instance with context length {context_length}")
print(
f"Creating a conversation instance with context length {context_length}"
)
# Using Claude model as the tokenizer model
conversation = Conversation(
context_length=context_length,
tokenizer_model_name="claude-3-7-sonnet-20250219"
tokenizer_model_name="claude-3-7-sonnet-20250219",
)
# Adding first message - short message
@ -23,7 +26,9 @@ def demonstrate_truncation():
# Display token count
tokens = count_tokens(short_message, conversation.tokenizer_model_name)
tokens = count_tokens(
short_message, conversation.tokenizer_model_name
)
print(f"Short message token count: {tokens}")
# Adding second message - long message, should be truncated
@ -32,19 +37,27 @@ def demonstrate_truncation():
conversation.add("assistant", long_message)
# Display long message token count
tokens = count_tokens(long_message, conversation.tokenizer_model_name)
tokens = count_tokens(
long_message, conversation.tokenizer_model_name
)
print(f"Long message token count: {tokens}")
# Display current conversation total token count
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name)
for msg in conversation.conversation_history)
total_tokens = sum(
count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"Total token count before truncation: {total_tokens}")
# Print the complete conversation history before truncation
print("\nConversation history before truncation:")
for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}")
print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Execute truncation
print("\nExecuting truncation...")
@ -54,31 +67,50 @@ def demonstrate_truncation():
print("\nConversation history after truncation:")
for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}")
print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Display total token count after truncation
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name)
for msg in conversation.conversation_history)
total_tokens = sum(
count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"\nTotal token count after truncation: {total_tokens}")
print(f"Context length limit: {context_length}")
# Verify if successfully truncated below the limit
if total_tokens <= context_length:
print("✅ Success: Total token count is now less than or equal to context length limit")
print(
"✅ Success: Total token count is now less than or equal to context length limit"
)
else:
print("❌ Failure: Total token count still exceeds context length limit")
print(
"❌ Failure: Total token count still exceeds context length limit"
)
# Test sentence boundary truncation
print("\n\nTesting sentence boundary truncation:")
sentence_test = Conversation(context_length=15, tokenizer_model_name="claude-3-opus-20240229")
sentence_test = Conversation(
context_length=15,
tokenizer_model_name="claude-3-opus-20240229",
)
test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence."
print(f"Original text: '{test_text}'")
print(f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}")
print(
f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}"
)
# Using binary search for truncation
truncated = sentence_test._binary_search_truncate(test_text, 10, sentence_test.tokenizer_model_name)
truncated = sentence_test._binary_search_truncate(
test_text, 10, sentence_test.tokenizer_model_name
)
print(f"Truncated text: '{truncated}'")
print(f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}")
print(
f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}"
)
# Check if truncated at period
if truncated.endswith("."):

@ -9,10 +9,10 @@ def main():
on the specified bill.
"""
senator_simulation = SenatorAssembly()
# senator_simulation.simulate_vote_concurrent(
# "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
# batch_size=10,
# )
senator_simulation.simulate_vote_concurrent(
"A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
batch_size=10,
)
if __name__ == "__main__":

@ -1,32 +1,66 @@
from typing import Any, Dict, Optional
class ToolAgentError(Exception):
"""Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, message: str, details: Optional[Dict[str, Any]] = None
):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}"
def __init__(
self,
tool_name: str,
error: Exception,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Failed to execute tool '{tool_name}': {str(error)}"
)
super().__init__(message, details)
class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self,
tool_name: str,
param_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details)
class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, tool_name: str, details: Optional[Dict[str, Any]] = None
):
message = f"Tool '{tool_name}' not found"
super().__init__(message, details)
class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}"
def __init__(
self,
tool_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Invalid parameters for tool '{tool_name}': {error}"
)
super().__init__(message, details)

@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any, Callable
from loguru import logger
from swarms.agents.exceptions import (
ToolAgentError,
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
ToolParameterError,
)
class ToolAgent:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -68,10 +68,12 @@ class ToolAgent:
raise ToolExecutionError(
"model_initialization",
e,
{"model_name": model_name, "kwargs": kwargs}
{"model_name": model_name, "kwargs": kwargs},
)
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
def _validate_tool(
self, tool_name: str, parameters: Dict[str, Any]
) -> None:
"""
Validate tool parameters before execution.
Args:
@ -84,19 +86,24 @@ class ToolAgent:
raise ToolValidationError(
tool_name,
"parameters",
"No tools available for validation"
"No tools available for validation",
)
tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
None
(
tool
for tool in self.tools_list_dictionary
if tool["name"] == tool_name
),
None,
)
if not tool_spec:
raise ToolNotFoundError(tool_name)
required_params = {
param["name"] for param in tool_spec["parameters"]
param["name"]
for param in tool_spec["parameters"]
if param.get("required", True)
}
@ -104,10 +111,12 @@ class ToolAgent:
if missing_params:
raise ToolParameterError(
tool_name,
f"Missing required parameters: {', '.join(missing_params)}"
f"Missing required parameters: {', '.join(missing_params)}",
)
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
def _execute_with_retry(
self, func: Callable, *args, **kwargs
) -> Any:
"""
Execute a function with retry logic.
Args:
@ -134,7 +143,7 @@ class ToolAgent:
raise ToolExecutionError(
func.__name__,
last_error,
{"attempts": self.retry_attempts}
{"attempts": self.retry_attempts},
)
def run(self, task: str, *args, **kwargs) -> str:
@ -154,7 +163,7 @@ class ToolAgent:
raise ToolExecutionError(
"run",
Exception("LLM not initialized"),
{"task": task}
{"task": task},
)
logger.info(f"Running task: {task}")
@ -164,9 +173,7 @@ class ToolAgent:
# Execute with retry logic
outputs = self._execute_with_retry(
self.llm.generate,
prompt,
self.sampling_params
self.llm.generate, prompt, self.sampling_params
)
response = outputs[0].outputs[0].text.strip()
@ -177,7 +184,7 @@ class ToolAgent:
raise ToolExecutionError(
"run",
error,
{"task": task, "args": args, "kwargs": kwargs}
{"task": task, "args": args, "kwargs": kwargs},
)
def _prepare_prompt(self, task: str) -> str:
@ -204,7 +211,9 @@ class ToolAgent:
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
@ -215,19 +224,23 @@ class ToolAgent:
Raises:
ToolExecutionError: If an error occurs during batch execution.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = []
try:
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
batch = tasks[i : i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
try:
result = self.run(task)
results.append(result)
except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}")
logger.error(
f"Failed to execute task '{task}': {e}"
)
results.append(f"Error: {str(e)}")
continue
@ -239,5 +252,5 @@ class ToolAgent:
raise ToolExecutionError(
"batched_run",
error,
{"tasks": tasks, "batch_size": batch_size}
{"tasks": tasks, "batch_size": batch_size},
)

@ -1276,7 +1276,6 @@ class Conversation:
None
"""
total_tokens = 0
truncated_history = []
@ -1289,7 +1288,9 @@ class Conversation:
content = str(content)
# Calculate token count for this message
token_count = count_tokens(content, self.tokenizer_model_name)
token_count = count_tokens(
content, self.tokenizer_model_name
)
# Check if adding this message would exceed the limit
if total_tokens + token_count <= self.context_length:
@ -1309,7 +1310,7 @@ class Conversation:
truncated_content = self._binary_search_truncate(
content,
remaining_tokens,
self.tokenizer_model_name
self.tokenizer_model_name,
)
# Create the truncated message
@ -1329,7 +1330,9 @@ class Conversation:
# Update conversation history
self.conversation_history = truncated_history
def _binary_search_truncate(self, text, target_tokens, model_name):
def _binary_search_truncate(
self, text, target_tokens, model_name
):
"""
Use binary search to find the maximum text substring that fits the target token count.
@ -1342,7 +1345,6 @@ class Conversation:
str: Truncated text with token count not exceeding target_tokens
"""
# If text is empty or target tokens is 0, return empty string
if not text or target_tokens <= 0:
return ""
@ -1372,12 +1374,17 @@ class Conversation:
right = mid - 1
# Try to truncate at sentence boundaries if possible
sentence_delimiters = ['.', '!', '?', '\n']
sentence_delimiters = [".", "!", "?", "\n"]
for delimiter in sentence_delimiters:
last_pos = best_text.rfind(delimiter)
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[:last_pos+1]
if count_tokens(truncated_at_sentence, model_name) <= target_tokens:
if (
last_pos > len(best_text) * 0.75
): # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[: last_pos + 1]
if (
count_tokens(truncated_at_sentence, model_name)
<= target_tokens
):
return truncated_at_sentence
return best_text

@ -6,9 +6,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent
from swarms.agents.exceptions import (
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
ToolParameterError,
)
@ -111,9 +110,7 @@ def test_tool_agent_init_with_kwargs():
def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters."""
agent = ToolAgent(
model_name="test-model",
temperature=0.7,
max_tokens=1000
model_name="test-model", temperature=0.7, max_tokens=1000
)
assert agent.model_name == "test-model"
assert agent.temperature == 0.7
@ -131,13 +128,15 @@ def test_tool_agent_initialization_error():
def test_tool_validation():
"""Test tool parameter validation."""
tools_list = [{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False}
]
}]
tools_list = [
{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False},
],
}
]
agent = ToolAgent(tools_list_dictionary=tools_list)
@ -161,7 +160,7 @@ def test_retry_mechanism():
mock_llm.generate.side_effect = [
Exception("First attempt failed"),
Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")])
Mock(outputs=[Mock(text="Success")]),
]
agent = ToolAgent(model_name="test-model")
@ -185,7 +184,7 @@ def test_batched_execution():
mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")])
Mock(outputs=[Mock(text="Success 3")]),
]
agent = ToolAgent(model_name="test-model")
@ -210,7 +209,10 @@ def test_prompt_preparation():
# Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:"
assert (
prompt
== "You are a helpful assistant\n\nUser: test task\nAssistant:"
)
def test_tool_execution_error_handling():

Loading…
Cancel
Save