fix examples and code formatting

pull/1018/head
Kye Gomez 3 weeks ago
parent 63a998bb1f
commit 5be0ab609e

@ -19,4 +19,4 @@ print(
conversation.export_and_count_categories(
tokenizer_model_name="claude-3-5-sonnet-20240620"
)
)
)

@ -5,81 +5,113 @@ from swarms.utils.litellm_tokenizer import count_tokens
# Load environment variables from .env file
load_dotenv()
def demonstrate_truncation():
# Using a smaller context length to clearly see the truncation effect
context_length = 25
print(f"Creating a conversation instance with context length {context_length}")
print(
f"Creating a conversation instance with context length {context_length}"
)
# Using Claude model as the tokenizer model
conversation = Conversation(
context_length=context_length,
tokenizer_model_name="claude-3-7-sonnet-20250219"
tokenizer_model_name="claude-3-7-sonnet-20250219",
)
# Adding first message - short message
short_message = "Hello, I am a user."
print(f"\nAdding short message: '{short_message}'")
conversation.add("user", short_message)
# Display token count
tokens = count_tokens(short_message, conversation.tokenizer_model_name)
tokens = count_tokens(
short_message, conversation.tokenizer_model_name
)
print(f"Short message token count: {tokens}")
# Adding second message - long message, should be truncated
long_message = "I have a question about artificial intelligence. I want to understand how large language models handle long texts, especially under token constraints. This issue is important because it relates to the model's practicality and effectiveness. I hope to get a detailed answer that helps me understand this complex technical problem."
print(f"\nAdding long message:\n'{long_message}'")
conversation.add("assistant", long_message)
# Display long message token count
tokens = count_tokens(long_message, conversation.tokenizer_model_name)
tokens = count_tokens(
long_message, conversation.tokenizer_model_name
)
print(f"Long message token count: {tokens}")
# Display current conversation total token count
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name)
for msg in conversation.conversation_history)
total_tokens = sum(
count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"Total token count before truncation: {total_tokens}")
# Print the complete conversation history before truncation
print("\nConversation history before truncation:")
for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}")
print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Execute truncation
print("\nExecuting truncation...")
conversation.truncate_memory_with_tokenizer()
# Print conversation history after truncation
print("\nConversation history after truncation:")
for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}")
print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Display total token count after truncation
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name)
for msg in conversation.conversation_history)
total_tokens = sum(
count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"\nTotal token count after truncation: {total_tokens}")
print(f"Context length limit: {context_length}")
# Verify if successfully truncated below the limit
if total_tokens <= context_length:
print("✅ Success: Total token count is now less than or equal to context length limit")
print(
"✅ Success: Total token count is now less than or equal to context length limit"
)
else:
print("❌ Failure: Total token count still exceeds context length limit")
print(
"❌ Failure: Total token count still exceeds context length limit"
)
# Test sentence boundary truncation
print("\n\nTesting sentence boundary truncation:")
sentence_test = Conversation(context_length=15, tokenizer_model_name="claude-3-opus-20240229")
sentence_test = Conversation(
context_length=15,
tokenizer_model_name="claude-3-opus-20240229",
)
test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence."
print(f"Original text: '{test_text}'")
print(f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}")
print(
f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}"
)
# Using binary search for truncation
truncated = sentence_test._binary_search_truncate(test_text, 10, sentence_test.tokenizer_model_name)
truncated = sentence_test._binary_search_truncate(
test_text, 10, sentence_test.tokenizer_model_name
)
print(f"Truncated text: '{truncated}'")
print(f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}")
print(
f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}"
)
# Check if truncated at period
if truncated.endswith("."):
print("✅ Success: Text was truncated at sentence boundary")
@ -88,4 +120,4 @@ def demonstrate_truncation():
if __name__ == "__main__":
demonstrate_truncation()
demonstrate_truncation()

@ -9,10 +9,10 @@ def main():
on the specified bill.
"""
senator_simulation = SenatorAssembly()
# senator_simulation.simulate_vote_concurrent(
# "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
# batch_size=10,
# )
senator_simulation.simulate_vote_concurrent(
"A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
batch_size=10,
)
if __name__ == "__main__":

@ -1,32 +1,66 @@
from typing import Any, Dict, Optional
class ToolAgentError(Exception):
"""Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, message: str, details: Optional[Dict[str, Any]] = None
):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}"
def __init__(
self,
tool_name: str,
error: Exception,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Failed to execute tool '{tool_name}': {str(error)}"
)
super().__init__(message, details)
class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self,
tool_name: str,
param_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details)
class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, tool_name: str, details: Optional[Dict[str, Any]] = None
):
message = f"Tool '{tool_name}' not found"
super().__init__(message, details)
class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}"
super().__init__(message, details)
def __init__(
self,
tool_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Invalid parameters for tool '{tool_name}': {error}"
)
super().__init__(message, details)

@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any, Callable
from loguru import logger
from swarms.agents.exceptions import (
ToolAgentError,
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
ToolParameterError,
)
class ToolAgent:
"""
A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -68,10 +68,12 @@ class ToolAgent:
raise ToolExecutionError(
"model_initialization",
e,
{"model_name": model_name, "kwargs": kwargs}
{"model_name": model_name, "kwargs": kwargs},
)
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None:
def _validate_tool(
self, tool_name: str, parameters: Dict[str, Any]
) -> None:
"""
Validate tool parameters before execution.
Args:
@ -84,19 +86,24 @@ class ToolAgent:
raise ToolValidationError(
tool_name,
"parameters",
"No tools available for validation"
"No tools available for validation",
)
tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name),
None
(
tool
for tool in self.tools_list_dictionary
if tool["name"] == tool_name
),
None,
)
if not tool_spec:
raise ToolNotFoundError(tool_name)
required_params = {
param["name"] for param in tool_spec["parameters"]
param["name"]
for param in tool_spec["parameters"]
if param.get("required", True)
}
@ -104,10 +111,12 @@ class ToolAgent:
if missing_params:
raise ToolParameterError(
tool_name,
f"Missing required parameters: {', '.join(missing_params)}"
f"Missing required parameters: {', '.join(missing_params)}",
)
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any:
def _execute_with_retry(
self, func: Callable, *args, **kwargs
) -> Any:
"""
Execute a function with retry logic.
Args:
@ -134,7 +143,7 @@ class ToolAgent:
raise ToolExecutionError(
func.__name__,
last_error,
{"attempts": self.retry_attempts}
{"attempts": self.retry_attempts},
)
def run(self, task: str, *args, **kwargs) -> str:
@ -154,21 +163,19 @@ class ToolAgent:
raise ToolExecutionError(
"run",
Exception("LLM not initialized"),
{"task": task}
{"task": task},
)
logger.info(f"Running task: {task}")
# Prepare the prompt
prompt = self._prepare_prompt(task)
# Execute with retry logic
outputs = self._execute_with_retry(
self.llm.generate,
prompt,
self.sampling_params
self.llm.generate, prompt, self.sampling_params
)
response = outputs[0].outputs[0].text.strip()
return response
@ -177,7 +184,7 @@ class ToolAgent:
raise ToolExecutionError(
"run",
error,
{"task": task, "args": args, "kwargs": kwargs}
{"task": task, "args": args, "kwargs": kwargs},
)
def _prepare_prompt(self, task: str) -> str:
@ -204,7 +211,9 @@ class ToolAgent:
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]:
def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
"""
Run the model for multiple tasks in batches.
Args:
@ -215,19 +224,23 @@ class ToolAgent:
Raises:
ToolExecutionError: If an error occurs during batch execution.
"""
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}")
logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = []
try:
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
batch = tasks[i : i + batch_size]
for task in batch:
logger.info(f"Running task: {task}")
try:
result = self.run(task)
results.append(result)
except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}")
logger.error(
f"Failed to execute task '{task}': {e}"
)
results.append(f"Error: {str(e)}")
continue
@ -239,5 +252,5 @@ class ToolAgent:
raise ToolExecutionError(
"batched_run",
error,
{"tasks": tasks, "batch_size": batch_size}
{"tasks": tasks, "batch_size": batch_size},
)

@ -1268,29 +1268,30 @@ class Conversation:
def truncate_memory_with_tokenizer(self):
"""
Truncate conversation history based on the total token count using tokenizer.
This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter.
Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content.
Returns:
None
"""
total_tokens = 0
truncated_history = []
for message in self.conversation_history:
role = message.get("role")
content = message.get("content")
# Convert content to string if it's not already a string
if not isinstance(content, str):
content = str(content)
# Calculate token count for this message
token_count = count_tokens(content, self.tokenizer_model_name)
token_count = count_tokens(
content, self.tokenizer_model_name
)
# Check if adding this message would exceed the limit
if total_tokens + token_count <= self.context_length:
# If not exceeding limit, add the full message
@ -1299,69 +1300,70 @@ class Conversation:
else:
# Calculate remaining tokens we can include
remaining_tokens = self.context_length - total_tokens
# If no token space left, break the loop
if remaining_tokens <= 0:
break
# If we have space left, we need to truncate this message
# Use binary search to find content length that fits remaining token space
truncated_content = self._binary_search_truncate(
content,
content,
remaining_tokens,
self.tokenizer_model_name
self.tokenizer_model_name,
)
# Create the truncated message
truncated_message = {
"role": role,
"content": truncated_content,
}
# Add any other fields from the original message
for key, value in message.items():
if key not in ["role", "content"]:
truncated_message[key] = value
truncated_history.append(truncated_message)
break
# Update conversation history
self.conversation_history = truncated_history
def _binary_search_truncate(self, text, target_tokens, model_name):
def _binary_search_truncate(
self, text, target_tokens, model_name
):
"""
Use binary search to find the maximum text substring that fits the target token count.
Parameters:
text (str): Original text to truncate
target_tokens (int): Target token count
model_name (str): Model name for token counting
Returns:
str: Truncated text with token count not exceeding target_tokens
"""
# If text is empty or target tokens is 0, return empty string
if not text or target_tokens <= 0:
return ""
# If original text token count is already less than or equal to target, return as is
original_tokens = count_tokens(text, model_name)
if original_tokens <= target_tokens:
return text
# Binary search
left, right = 0, len(text)
best_length = 0
best_text = ""
while left <= right:
mid = (left + right) // 2
truncated = text[:mid]
tokens = count_tokens(truncated, model_name)
if tokens <= target_tokens:
# If current truncated text token count is less than or equal to target, try longer text
best_length = mid
@ -1370,18 +1372,23 @@ class Conversation:
else:
# Otherwise try shorter text
right = mid - 1
# Try to truncate at sentence boundaries if possible
sentence_delimiters = ['.', '!', '?', '\n']
sentence_delimiters = [".", "!", "?", "\n"]
for delimiter in sentence_delimiters:
last_pos = best_text.rfind(delimiter)
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[:last_pos+1]
if count_tokens(truncated_at_sentence, model_name) <= target_tokens:
if (
last_pos > len(best_text) * 0.75
): # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[: last_pos + 1]
if (
count_tokens(truncated_at_sentence, model_name)
<= target_tokens
):
return truncated_at_sentence
return best_text
def clear(self):
"""Clear the conversation history."""
if self.backend_instance:
@ -1787,4 +1794,4 @@ class Conversation:
# # # conversation.add("assistant", "I am doing well, thanks.")
# # # # print(conversation.to_json())
# # print(type(conversation.to_dict()))
# # print(conversation.to_yaml())
# # print(conversation.to_yaml())

@ -6,9 +6,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent
from swarms.agents.exceptions import (
ToolExecutionError,
ToolValidationError,
ToolNotFoundError,
ToolParameterError
ToolParameterError,
)
@ -111,9 +110,7 @@ def test_tool_agent_init_with_kwargs():
def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters."""
agent = ToolAgent(
model_name="test-model",
temperature=0.7,
max_tokens=1000
model_name="test-model", temperature=0.7, max_tokens=1000
)
assert agent.model_name == "test-model"
assert agent.temperature == 0.7
@ -131,24 +128,26 @@ def test_tool_agent_initialization_error():
def test_tool_validation():
"""Test tool parameter validation."""
tools_list = [{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False}
]
}]
tools_list = [
{
"name": "test_tool",
"parameters": [
{"name": "required_param", "required": True},
{"name": "optional_param", "required": False},
],
}
]
agent = ToolAgent(tools_list_dictionary=tools_list)
# Test missing required parameter
with pytest.raises(ToolParameterError) as exc_info:
agent._validate_tool("test_tool", {})
assert "Missing required parameters" in str(exc_info.value)
# Test valid parameters
agent._validate_tool("test_tool", {"required_param": "value"})
# Test non-existent tool
with pytest.raises(ToolNotFoundError) as exc_info:
agent._validate_tool("non_existent_tool", {})
@ -161,17 +160,17 @@ def test_retry_mechanism():
mock_llm.generate.side_effect = [
Exception("First attempt failed"),
Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")])
Mock(outputs=[Mock(text="Success")]),
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
# Test successful retry
result = agent.run("test task")
assert result == "Success"
assert mock_llm.generate.call_count == 3
# Test all retries failing
mock_llm.generate.side_effect = Exception("All attempts failed")
with pytest.raises(ToolExecutionError) as exc_info:
@ -185,15 +184,15 @@ def test_batched_execution():
mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")])
Mock(outputs=[Mock(text="Success 3")]),
]
agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm
tasks = ["Task 1", "Task 2", "Task 3"]
results = agent.batched_run(tasks)
assert len(results) == 3
assert results[0] == "Success 1"
assert "Error" in results[1]
@ -206,22 +205,25 @@ def test_prompt_preparation():
agent = ToolAgent()
prompt = agent._prepare_prompt("test task")
assert prompt == "User: test task\nAssistant:"
# Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:"
assert (
prompt
== "You are a helpful assistant\n\nUser: test task\nAssistant:"
)
def test_tool_execution_error_handling():
"""Test error handling during tool execution."""
agent = ToolAgent(model_name="test-model")
agent.llm = None # Simulate uninitialized LLM
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task")
assert "LLM not initialized" in str(exc_info.value)
# Test with invalid parameters
with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task", invalid_param="value")

Loading…
Cancel
Save