fix examples and code formatting

pull/1018/head
Kye Gomez 4 weeks ago
parent 63a998bb1f
commit 5be0ab609e

@ -5,15 +5,18 @@ from swarms.utils.litellm_tokenizer import count_tokens
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
def demonstrate_truncation(): def demonstrate_truncation():
# Using a smaller context length to clearly see the truncation effect # Using a smaller context length to clearly see the truncation effect
context_length = 25 context_length = 25
print(f"Creating a conversation instance with context length {context_length}") print(
f"Creating a conversation instance with context length {context_length}"
)
# Using Claude model as the tokenizer model # Using Claude model as the tokenizer model
conversation = Conversation( conversation = Conversation(
context_length=context_length, context_length=context_length,
tokenizer_model_name="claude-3-7-sonnet-20250219" tokenizer_model_name="claude-3-7-sonnet-20250219",
) )
# Adding first message - short message # Adding first message - short message
@ -23,7 +26,9 @@ def demonstrate_truncation():
# Display token count # Display token count
tokens = count_tokens(short_message, conversation.tokenizer_model_name) tokens = count_tokens(
short_message, conversation.tokenizer_model_name
)
print(f"Short message token count: {tokens}") print(f"Short message token count: {tokens}")
# Adding second message - long message, should be truncated # Adding second message - long message, should be truncated
@ -32,19 +37,27 @@ def demonstrate_truncation():
conversation.add("assistant", long_message) conversation.add("assistant", long_message)
# Display long message token count # Display long message token count
tokens = count_tokens(long_message, conversation.tokenizer_model_name) tokens = count_tokens(
long_message, conversation.tokenizer_model_name
)
print(f"Long message token count: {tokens}") print(f"Long message token count: {tokens}")
# Display current conversation total token count # Display current conversation total token count
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) total_tokens = sum(
for msg in conversation.conversation_history) count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"Total token count before truncation: {total_tokens}") print(f"Total token count before truncation: {total_tokens}")
# Print the complete conversation history before truncation # Print the complete conversation history before truncation
print("\nConversation history before truncation:") print("\nConversation history before truncation:")
for i, msg in enumerate(conversation.conversation_history): for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}") print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Execute truncation # Execute truncation
print("\nExecuting truncation...") print("\nExecuting truncation...")
@ -54,31 +67,50 @@ def demonstrate_truncation():
print("\nConversation history after truncation:") print("\nConversation history after truncation:")
for i, msg in enumerate(conversation.conversation_history): for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}") print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Display total token count after truncation # Display total token count after truncation
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) total_tokens = sum(
for msg in conversation.conversation_history) count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"\nTotal token count after truncation: {total_tokens}") print(f"\nTotal token count after truncation: {total_tokens}")
print(f"Context length limit: {context_length}") print(f"Context length limit: {context_length}")
# Verify if successfully truncated below the limit # Verify if successfully truncated below the limit
if total_tokens <= context_length: if total_tokens <= context_length:
print("✅ Success: Total token count is now less than or equal to context length limit") print(
"✅ Success: Total token count is now less than or equal to context length limit"
)
else: else:
print("❌ Failure: Total token count still exceeds context length limit") print(
"❌ Failure: Total token count still exceeds context length limit"
)
# Test sentence boundary truncation # Test sentence boundary truncation
print("\n\nTesting sentence boundary truncation:") print("\n\nTesting sentence boundary truncation:")
sentence_test = Conversation(context_length=15, tokenizer_model_name="claude-3-opus-20240229") sentence_test = Conversation(
context_length=15,
tokenizer_model_name="claude-3-opus-20240229",
)
test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence." test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence."
print(f"Original text: '{test_text}'") print(f"Original text: '{test_text}'")
print(f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}") print(
f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}"
)
# Using binary search for truncation # Using binary search for truncation
truncated = sentence_test._binary_search_truncate(test_text, 10, sentence_test.tokenizer_model_name) truncated = sentence_test._binary_search_truncate(
test_text, 10, sentence_test.tokenizer_model_name
)
print(f"Truncated text: '{truncated}'") print(f"Truncated text: '{truncated}'")
print(f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}") print(
f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}"
)
# Check if truncated at period # Check if truncated at period
if truncated.endswith("."): if truncated.endswith("."):

@ -9,10 +9,10 @@ def main():
on the specified bill. on the specified bill.
""" """
senator_simulation = SenatorAssembly() senator_simulation = SenatorAssembly()
# senator_simulation.simulate_vote_concurrent( senator_simulation.simulate_vote_concurrent(
# "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.", "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
# batch_size=10, batch_size=10,
# ) )
if __name__ == "__main__": if __name__ == "__main__":

@ -1,32 +1,66 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
class ToolAgentError(Exception): class ToolAgentError(Exception):
"""Base exception for all tool agent errors.""" """Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, message: str, details: Optional[Dict[str, Any]] = None
):
self.message = message self.message = message
self.details = details or {} self.details = details or {}
super().__init__(self.message) super().__init__(self.message)
class ToolExecutionError(ToolAgentError): class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute.""" """Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}" def __init__(
self,
tool_name: str,
error: Exception,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Failed to execute tool '{tool_name}': {str(error)}"
)
super().__init__(message, details) super().__init__(message, details)
class ToolValidationError(ToolAgentError): class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation.""" """Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self,
tool_name: str,
param_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}" message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details) super().__init__(message, details)
class ToolNotFoundError(ToolAgentError): class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found.""" """Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, tool_name: str, details: Optional[Dict[str, Any]] = None
):
message = f"Tool '{tool_name}' not found" message = f"Tool '{tool_name}' not found"
super().__init__(message, details) super().__init__(message, details)
class ToolParameterError(ToolAgentError): class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid.""" """Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}" def __init__(
self,
tool_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Invalid parameters for tool '{tool_name}': {error}"
)
super().__init__(message, details) super().__init__(message, details)

@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any, Callable from typing import List, Optional, Dict, Any, Callable
from loguru import logger from loguru import logger
from swarms.agents.exceptions import ( from swarms.agents.exceptions import (
ToolAgentError,
ToolExecutionError, ToolExecutionError,
ToolValidationError, ToolValidationError,
ToolNotFoundError, ToolNotFoundError,
ToolParameterError ToolParameterError,
) )
class ToolAgent: class ToolAgent:
""" """
A wrapper class for vLLM that provides a similar interface to LiteLLM. A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -68,10 +68,12 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"model_initialization", "model_initialization",
e, e,
{"model_name": model_name, "kwargs": kwargs} {"model_name": model_name, "kwargs": kwargs},
) )
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None: def _validate_tool(
self, tool_name: str, parameters: Dict[str, Any]
) -> None:
""" """
Validate tool parameters before execution. Validate tool parameters before execution.
Args: Args:
@ -84,19 +86,24 @@ class ToolAgent:
raise ToolValidationError( raise ToolValidationError(
tool_name, tool_name,
"parameters", "parameters",
"No tools available for validation" "No tools available for validation",
) )
tool_spec = next( tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name), (
None tool
for tool in self.tools_list_dictionary
if tool["name"] == tool_name
),
None,
) )
if not tool_spec: if not tool_spec:
raise ToolNotFoundError(tool_name) raise ToolNotFoundError(tool_name)
required_params = { required_params = {
param["name"] for param in tool_spec["parameters"] param["name"]
for param in tool_spec["parameters"]
if param.get("required", True) if param.get("required", True)
} }
@ -104,10 +111,12 @@ class ToolAgent:
if missing_params: if missing_params:
raise ToolParameterError( raise ToolParameterError(
tool_name, tool_name,
f"Missing required parameters: {', '.join(missing_params)}" f"Missing required parameters: {', '.join(missing_params)}",
) )
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any: def _execute_with_retry(
self, func: Callable, *args, **kwargs
) -> Any:
""" """
Execute a function with retry logic. Execute a function with retry logic.
Args: Args:
@ -134,7 +143,7 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
func.__name__, func.__name__,
last_error, last_error,
{"attempts": self.retry_attempts} {"attempts": self.retry_attempts},
) )
def run(self, task: str, *args, **kwargs) -> str: def run(self, task: str, *args, **kwargs) -> str:
@ -154,7 +163,7 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"run", "run",
Exception("LLM not initialized"), Exception("LLM not initialized"),
{"task": task} {"task": task},
) )
logger.info(f"Running task: {task}") logger.info(f"Running task: {task}")
@ -164,9 +173,7 @@ class ToolAgent:
# Execute with retry logic # Execute with retry logic
outputs = self._execute_with_retry( outputs = self._execute_with_retry(
self.llm.generate, self.llm.generate, prompt, self.sampling_params
prompt,
self.sampling_params
) )
response = outputs[0].outputs[0].text.strip() response = outputs[0].outputs[0].text.strip()
@ -177,7 +184,7 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"run", "run",
error, error,
{"task": task, "args": args, "kwargs": kwargs} {"task": task, "args": args, "kwargs": kwargs},
) )
def _prepare_prompt(self, task: str) -> str: def _prepare_prompt(self, task: str) -> str:
@ -204,7 +211,9 @@ class ToolAgent:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
""" """
Run the model for multiple tasks in batches. Run the model for multiple tasks in batches.
Args: Args:
@ -215,19 +224,23 @@ class ToolAgent:
Raises: Raises:
ToolExecutionError: If an error occurs during batch execution. ToolExecutionError: If an error occurs during batch execution.
""" """
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = [] results = []
try: try:
for i in range(0, len(tasks), batch_size): for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size] batch = tasks[i : i + batch_size]
for task in batch: for task in batch:
logger.info(f"Running task: {task}") logger.info(f"Running task: {task}")
try: try:
result = self.run(task) result = self.run(task)
results.append(result) results.append(result)
except ToolExecutionError as e: except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}") logger.error(
f"Failed to execute task '{task}': {e}"
)
results.append(f"Error: {str(e)}") results.append(f"Error: {str(e)}")
continue continue
@ -239,5 +252,5 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"batched_run", "batched_run",
error, error,
{"tasks": tasks, "batch_size": batch_size} {"tasks": tasks, "batch_size": batch_size},
) )

@ -1276,7 +1276,6 @@ class Conversation:
None None
""" """
total_tokens = 0 total_tokens = 0
truncated_history = [] truncated_history = []
@ -1289,7 +1288,9 @@ class Conversation:
content = str(content) content = str(content)
# Calculate token count for this message # Calculate token count for this message
token_count = count_tokens(content, self.tokenizer_model_name) token_count = count_tokens(
content, self.tokenizer_model_name
)
# Check if adding this message would exceed the limit # Check if adding this message would exceed the limit
if total_tokens + token_count <= self.context_length: if total_tokens + token_count <= self.context_length:
@ -1309,7 +1310,7 @@ class Conversation:
truncated_content = self._binary_search_truncate( truncated_content = self._binary_search_truncate(
content, content,
remaining_tokens, remaining_tokens,
self.tokenizer_model_name self.tokenizer_model_name,
) )
# Create the truncated message # Create the truncated message
@ -1329,7 +1330,9 @@ class Conversation:
# Update conversation history # Update conversation history
self.conversation_history = truncated_history self.conversation_history = truncated_history
def _binary_search_truncate(self, text, target_tokens, model_name): def _binary_search_truncate(
self, text, target_tokens, model_name
):
""" """
Use binary search to find the maximum text substring that fits the target token count. Use binary search to find the maximum text substring that fits the target token count.
@ -1342,7 +1345,6 @@ class Conversation:
str: Truncated text with token count not exceeding target_tokens str: Truncated text with token count not exceeding target_tokens
""" """
# If text is empty or target tokens is 0, return empty string # If text is empty or target tokens is 0, return empty string
if not text or target_tokens <= 0: if not text or target_tokens <= 0:
return "" return ""
@ -1372,12 +1374,17 @@ class Conversation:
right = mid - 1 right = mid - 1
# Try to truncate at sentence boundaries if possible # Try to truncate at sentence boundaries if possible
sentence_delimiters = ['.', '!', '?', '\n'] sentence_delimiters = [".", "!", "?", "\n"]
for delimiter in sentence_delimiters: for delimiter in sentence_delimiters:
last_pos = best_text.rfind(delimiter) last_pos = best_text.rfind(delimiter)
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content if (
truncated_at_sentence = best_text[:last_pos+1] last_pos > len(best_text) * 0.75
if count_tokens(truncated_at_sentence, model_name) <= target_tokens: ): # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[: last_pos + 1]
if (
count_tokens(truncated_at_sentence, model_name)
<= target_tokens
):
return truncated_at_sentence return truncated_at_sentence
return best_text return best_text

@ -6,9 +6,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent from swarms import ToolAgent
from swarms.agents.exceptions import ( from swarms.agents.exceptions import (
ToolExecutionError, ToolExecutionError,
ToolValidationError,
ToolNotFoundError, ToolNotFoundError,
ToolParameterError ToolParameterError,
) )
@ -111,9 +110,7 @@ def test_tool_agent_init_with_kwargs():
def test_tool_agent_initialization(): def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters.""" """Test tool agent initialization with valid parameters."""
agent = ToolAgent( agent = ToolAgent(
model_name="test-model", model_name="test-model", temperature=0.7, max_tokens=1000
temperature=0.7,
max_tokens=1000
) )
assert agent.model_name == "test-model" assert agent.model_name == "test-model"
assert agent.temperature == 0.7 assert agent.temperature == 0.7
@ -131,13 +128,15 @@ def test_tool_agent_initialization_error():
def test_tool_validation(): def test_tool_validation():
"""Test tool parameter validation.""" """Test tool parameter validation."""
tools_list = [{ tools_list = [
{
"name": "test_tool", "name": "test_tool",
"parameters": [ "parameters": [
{"name": "required_param", "required": True}, {"name": "required_param", "required": True},
{"name": "optional_param", "required": False} {"name": "optional_param", "required": False},
],
}
] ]
}]
agent = ToolAgent(tools_list_dictionary=tools_list) agent = ToolAgent(tools_list_dictionary=tools_list)
@ -161,7 +160,7 @@ def test_retry_mechanism():
mock_llm.generate.side_effect = [ mock_llm.generate.side_effect = [
Exception("First attempt failed"), Exception("First attempt failed"),
Exception("Second attempt failed"), Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")]) Mock(outputs=[Mock(text="Success")]),
] ]
agent = ToolAgent(model_name="test-model") agent = ToolAgent(model_name="test-model")
@ -185,7 +184,7 @@ def test_batched_execution():
mock_llm.generate.side_effect = [ mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]), Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"), Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")]) Mock(outputs=[Mock(text="Success 3")]),
] ]
agent = ToolAgent(model_name="test-model") agent = ToolAgent(model_name="test-model")
@ -210,7 +209,10 @@ def test_prompt_preparation():
# Test with system prompt # Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant") agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task") prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:" assert (
prompt
== "You are a helpful assistant\n\nUser: test task\nAssistant:"
)
def test_tool_execution_error_handling(): def test_tool_execution_error_handling():

Loading…
Cancel
Save