fix examples and code formatting

pull/1018/head
Kye Gomez 3 weeks ago
parent 63a998bb1f
commit 5be0ab609e

@ -19,4 +19,4 @@ print(
conversation.export_and_count_categories( conversation.export_and_count_categories(
tokenizer_model_name="claude-3-5-sonnet-20240620" tokenizer_model_name="claude-3-5-sonnet-20240620"
) )
) )

@ -5,81 +5,113 @@ from swarms.utils.litellm_tokenizer import count_tokens
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
def demonstrate_truncation(): def demonstrate_truncation():
# Using a smaller context length to clearly see the truncation effect # Using a smaller context length to clearly see the truncation effect
context_length = 25 context_length = 25
print(f"Creating a conversation instance with context length {context_length}") print(
f"Creating a conversation instance with context length {context_length}"
)
# Using Claude model as the tokenizer model # Using Claude model as the tokenizer model
conversation = Conversation( conversation = Conversation(
context_length=context_length, context_length=context_length,
tokenizer_model_name="claude-3-7-sonnet-20250219" tokenizer_model_name="claude-3-7-sonnet-20250219",
) )
# Adding first message - short message # Adding first message - short message
short_message = "Hello, I am a user." short_message = "Hello, I am a user."
print(f"\nAdding short message: '{short_message}'") print(f"\nAdding short message: '{short_message}'")
conversation.add("user", short_message) conversation.add("user", short_message)
# Display token count # Display token count
tokens = count_tokens(short_message, conversation.tokenizer_model_name) tokens = count_tokens(
short_message, conversation.tokenizer_model_name
)
print(f"Short message token count: {tokens}") print(f"Short message token count: {tokens}")
# Adding second message - long message, should be truncated # Adding second message - long message, should be truncated
long_message = "I have a question about artificial intelligence. I want to understand how large language models handle long texts, especially under token constraints. This issue is important because it relates to the model's practicality and effectiveness. I hope to get a detailed answer that helps me understand this complex technical problem." long_message = "I have a question about artificial intelligence. I want to understand how large language models handle long texts, especially under token constraints. This issue is important because it relates to the model's practicality and effectiveness. I hope to get a detailed answer that helps me understand this complex technical problem."
print(f"\nAdding long message:\n'{long_message}'") print(f"\nAdding long message:\n'{long_message}'")
conversation.add("assistant", long_message) conversation.add("assistant", long_message)
# Display long message token count # Display long message token count
tokens = count_tokens(long_message, conversation.tokenizer_model_name) tokens = count_tokens(
long_message, conversation.tokenizer_model_name
)
print(f"Long message token count: {tokens}") print(f"Long message token count: {tokens}")
# Display current conversation total token count # Display current conversation total token count
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) total_tokens = sum(
for msg in conversation.conversation_history) count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"Total token count before truncation: {total_tokens}") print(f"Total token count before truncation: {total_tokens}")
# Print the complete conversation history before truncation # Print the complete conversation history before truncation
print("\nConversation history before truncation:") print("\nConversation history before truncation:")
for i, msg in enumerate(conversation.conversation_history): for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}") print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Execute truncation # Execute truncation
print("\nExecuting truncation...") print("\nExecuting truncation...")
conversation.truncate_memory_with_tokenizer() conversation.truncate_memory_with_tokenizer()
# Print conversation history after truncation # Print conversation history after truncation
print("\nConversation history after truncation:") print("\nConversation history after truncation:")
for i, msg in enumerate(conversation.conversation_history): for i, msg in enumerate(conversation.conversation_history):
print(f"[{i}] {msg['role']}: {msg['content']}") print(f"[{i}] {msg['role']}: {msg['content']}")
print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") print(
f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}"
)
# Display total token count after truncation # Display total token count after truncation
total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) total_tokens = sum(
for msg in conversation.conversation_history) count_tokens(
msg["content"], conversation.tokenizer_model_name
)
for msg in conversation.conversation_history
)
print(f"\nTotal token count after truncation: {total_tokens}") print(f"\nTotal token count after truncation: {total_tokens}")
print(f"Context length limit: {context_length}") print(f"Context length limit: {context_length}")
# Verify if successfully truncated below the limit # Verify if successfully truncated below the limit
if total_tokens <= context_length: if total_tokens <= context_length:
print("✅ Success: Total token count is now less than or equal to context length limit") print(
"✅ Success: Total token count is now less than or equal to context length limit"
)
else: else:
print("❌ Failure: Total token count still exceeds context length limit") print(
"❌ Failure: Total token count still exceeds context length limit"
)
# Test sentence boundary truncation # Test sentence boundary truncation
print("\n\nTesting sentence boundary truncation:") print("\n\nTesting sentence boundary truncation:")
sentence_test = Conversation(context_length=15, tokenizer_model_name="claude-3-opus-20240229") sentence_test = Conversation(
context_length=15,
tokenizer_model_name="claude-3-opus-20240229",
)
test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence." test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence."
print(f"Original text: '{test_text}'") print(f"Original text: '{test_text}'")
print(f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}") print(
f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}"
)
# Using binary search for truncation # Using binary search for truncation
truncated = sentence_test._binary_search_truncate(test_text, 10, sentence_test.tokenizer_model_name) truncated = sentence_test._binary_search_truncate(
test_text, 10, sentence_test.tokenizer_model_name
)
print(f"Truncated text: '{truncated}'") print(f"Truncated text: '{truncated}'")
print(f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}") print(
f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}"
)
# Check if truncated at period # Check if truncated at period
if truncated.endswith("."): if truncated.endswith("."):
print("✅ Success: Text was truncated at sentence boundary") print("✅ Success: Text was truncated at sentence boundary")
@ -88,4 +120,4 @@ def demonstrate_truncation():
if __name__ == "__main__": if __name__ == "__main__":
demonstrate_truncation() demonstrate_truncation()

@ -9,10 +9,10 @@ def main():
on the specified bill. on the specified bill.
""" """
senator_simulation = SenatorAssembly() senator_simulation = SenatorAssembly()
# senator_simulation.simulate_vote_concurrent( senator_simulation.simulate_vote_concurrent(
# "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.", "A bill proposing to deregulate the IPO (Initial Public Offering) market in the United States as extensively as possible. The bill seeks to remove or significantly reduce existing regulatory requirements and oversight for companies seeking to go public, with the aim of increasing market efficiency and access to capital. Senators must consider the potential economic, legal, and ethical consequences of such broad deregulation, and cast their votes accordingly.",
# batch_size=10, batch_size=10,
# ) )
if __name__ == "__main__": if __name__ == "__main__":

@ -1,32 +1,66 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
class ToolAgentError(Exception): class ToolAgentError(Exception):
"""Base exception for all tool agent errors.""" """Base exception for all tool agent errors."""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, message: str, details: Optional[Dict[str, Any]] = None
):
self.message = message self.message = message
self.details = details or {} self.details = details or {}
super().__init__(self.message) super().__init__(self.message)
class ToolExecutionError(ToolAgentError): class ToolExecutionError(ToolAgentError):
"""Raised when a tool fails to execute.""" """Raised when a tool fails to execute."""
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
message = f"Failed to execute tool '{tool_name}': {str(error)}" def __init__(
self,
tool_name: str,
error: Exception,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Failed to execute tool '{tool_name}': {str(error)}"
)
super().__init__(message, details) super().__init__(message, details)
class ToolValidationError(ToolAgentError): class ToolValidationError(ToolAgentError):
"""Raised when tool parameters fail validation.""" """Raised when tool parameters fail validation."""
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self,
tool_name: str,
param_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}" message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
super().__init__(message, details) super().__init__(message, details)
class ToolNotFoundError(ToolAgentError): class ToolNotFoundError(ToolAgentError):
"""Raised when a requested tool is not found.""" """Raised when a requested tool is not found."""
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
def __init__(
self, tool_name: str, details: Optional[Dict[str, Any]] = None
):
message = f"Tool '{tool_name}' not found" message = f"Tool '{tool_name}' not found"
super().__init__(message, details) super().__init__(message, details)
class ToolParameterError(ToolAgentError): class ToolParameterError(ToolAgentError):
"""Raised when tool parameters are invalid.""" """Raised when tool parameters are invalid."""
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
message = f"Invalid parameters for tool '{tool_name}': {error}" def __init__(
super().__init__(message, details) self,
tool_name: str,
error: str,
details: Optional[Dict[str, Any]] = None,
):
message = (
f"Invalid parameters for tool '{tool_name}': {error}"
)
super().__init__(message, details)

@ -1,13 +1,13 @@
from typing import List, Optional, Dict, Any, Callable from typing import List, Optional, Dict, Any, Callable
from loguru import logger from loguru import logger
from swarms.agents.exceptions import ( from swarms.agents.exceptions import (
ToolAgentError,
ToolExecutionError, ToolExecutionError,
ToolValidationError, ToolValidationError,
ToolNotFoundError, ToolNotFoundError,
ToolParameterError ToolParameterError,
) )
class ToolAgent: class ToolAgent:
""" """
A wrapper class for vLLM that provides a similar interface to LiteLLM. A wrapper class for vLLM that provides a similar interface to LiteLLM.
@ -68,10 +68,12 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"model_initialization", "model_initialization",
e, e,
{"model_name": model_name, "kwargs": kwargs} {"model_name": model_name, "kwargs": kwargs},
) )
def _validate_tool(self, tool_name: str, parameters: Dict[str, Any]) -> None: def _validate_tool(
self, tool_name: str, parameters: Dict[str, Any]
) -> None:
""" """
Validate tool parameters before execution. Validate tool parameters before execution.
Args: Args:
@ -84,19 +86,24 @@ class ToolAgent:
raise ToolValidationError( raise ToolValidationError(
tool_name, tool_name,
"parameters", "parameters",
"No tools available for validation" "No tools available for validation",
) )
tool_spec = next( tool_spec = next(
(tool for tool in self.tools_list_dictionary if tool["name"] == tool_name), (
None tool
for tool in self.tools_list_dictionary
if tool["name"] == tool_name
),
None,
) )
if not tool_spec: if not tool_spec:
raise ToolNotFoundError(tool_name) raise ToolNotFoundError(tool_name)
required_params = { required_params = {
param["name"] for param in tool_spec["parameters"] param["name"]
for param in tool_spec["parameters"]
if param.get("required", True) if param.get("required", True)
} }
@ -104,10 +111,12 @@ class ToolAgent:
if missing_params: if missing_params:
raise ToolParameterError( raise ToolParameterError(
tool_name, tool_name,
f"Missing required parameters: {', '.join(missing_params)}" f"Missing required parameters: {', '.join(missing_params)}",
) )
def _execute_with_retry(self, func: Callable, *args, **kwargs) -> Any: def _execute_with_retry(
self, func: Callable, *args, **kwargs
) -> Any:
""" """
Execute a function with retry logic. Execute a function with retry logic.
Args: Args:
@ -134,7 +143,7 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
func.__name__, func.__name__,
last_error, last_error,
{"attempts": self.retry_attempts} {"attempts": self.retry_attempts},
) )
def run(self, task: str, *args, **kwargs) -> str: def run(self, task: str, *args, **kwargs) -> str:
@ -154,21 +163,19 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"run", "run",
Exception("LLM not initialized"), Exception("LLM not initialized"),
{"task": task} {"task": task},
) )
logger.info(f"Running task: {task}") logger.info(f"Running task: {task}")
# Prepare the prompt # Prepare the prompt
prompt = self._prepare_prompt(task) prompt = self._prepare_prompt(task)
# Execute with retry logic # Execute with retry logic
outputs = self._execute_with_retry( outputs = self._execute_with_retry(
self.llm.generate, self.llm.generate, prompt, self.sampling_params
prompt,
self.sampling_params
) )
response = outputs[0].outputs[0].text.strip() response = outputs[0].outputs[0].text.strip()
return response return response
@ -177,7 +184,7 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"run", "run",
error, error,
{"task": task, "args": args, "kwargs": kwargs} {"task": task, "args": args, "kwargs": kwargs},
) )
def _prepare_prompt(self, task: str) -> str: def _prepare_prompt(self, task: str) -> str:
@ -204,7 +211,9 @@ class ToolAgent:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10) -> List[str]: def batched_run(
self, tasks: List[str], batch_size: int = 10
) -> List[str]:
""" """
Run the model for multiple tasks in batches. Run the model for multiple tasks in batches.
Args: Args:
@ -215,19 +224,23 @@ class ToolAgent:
Raises: Raises:
ToolExecutionError: If an error occurs during batch execution. ToolExecutionError: If an error occurs during batch execution.
""" """
logger.info(f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}") logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
)
results = [] results = []
try: try:
for i in range(0, len(tasks), batch_size): for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size] batch = tasks[i : i + batch_size]
for task in batch: for task in batch:
logger.info(f"Running task: {task}") logger.info(f"Running task: {task}")
try: try:
result = self.run(task) result = self.run(task)
results.append(result) results.append(result)
except ToolExecutionError as e: except ToolExecutionError as e:
logger.error(f"Failed to execute task '{task}': {e}") logger.error(
f"Failed to execute task '{task}': {e}"
)
results.append(f"Error: {str(e)}") results.append(f"Error: {str(e)}")
continue continue
@ -239,5 +252,5 @@ class ToolAgent:
raise ToolExecutionError( raise ToolExecutionError(
"batched_run", "batched_run",
error, error,
{"tasks": tasks, "batch_size": batch_size} {"tasks": tasks, "batch_size": batch_size},
) )

@ -1268,29 +1268,30 @@ class Conversation:
def truncate_memory_with_tokenizer(self): def truncate_memory_with_tokenizer(self):
""" """
Truncate conversation history based on the total token count using tokenizer. Truncate conversation history based on the total token count using tokenizer.
This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter. This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter.
Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content. Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content.
Returns: Returns:
None None
""" """
total_tokens = 0 total_tokens = 0
truncated_history = [] truncated_history = []
for message in self.conversation_history: for message in self.conversation_history:
role = message.get("role") role = message.get("role")
content = message.get("content") content = message.get("content")
# Convert content to string if it's not already a string # Convert content to string if it's not already a string
if not isinstance(content, str): if not isinstance(content, str):
content = str(content) content = str(content)
# Calculate token count for this message # Calculate token count for this message
token_count = count_tokens(content, self.tokenizer_model_name) token_count = count_tokens(
content, self.tokenizer_model_name
)
# Check if adding this message would exceed the limit # Check if adding this message would exceed the limit
if total_tokens + token_count <= self.context_length: if total_tokens + token_count <= self.context_length:
# If not exceeding limit, add the full message # If not exceeding limit, add the full message
@ -1299,69 +1300,70 @@ class Conversation:
else: else:
# Calculate remaining tokens we can include # Calculate remaining tokens we can include
remaining_tokens = self.context_length - total_tokens remaining_tokens = self.context_length - total_tokens
# If no token space left, break the loop # If no token space left, break the loop
if remaining_tokens <= 0: if remaining_tokens <= 0:
break break
# If we have space left, we need to truncate this message # If we have space left, we need to truncate this message
# Use binary search to find content length that fits remaining token space # Use binary search to find content length that fits remaining token space
truncated_content = self._binary_search_truncate( truncated_content = self._binary_search_truncate(
content, content,
remaining_tokens, remaining_tokens,
self.tokenizer_model_name self.tokenizer_model_name,
) )
# Create the truncated message # Create the truncated message
truncated_message = { truncated_message = {
"role": role, "role": role,
"content": truncated_content, "content": truncated_content,
} }
# Add any other fields from the original message # Add any other fields from the original message
for key, value in message.items(): for key, value in message.items():
if key not in ["role", "content"]: if key not in ["role", "content"]:
truncated_message[key] = value truncated_message[key] = value
truncated_history.append(truncated_message) truncated_history.append(truncated_message)
break break
# Update conversation history # Update conversation history
self.conversation_history = truncated_history self.conversation_history = truncated_history
def _binary_search_truncate(self, text, target_tokens, model_name): def _binary_search_truncate(
self, text, target_tokens, model_name
):
""" """
Use binary search to find the maximum text substring that fits the target token count. Use binary search to find the maximum text substring that fits the target token count.
Parameters: Parameters:
text (str): Original text to truncate text (str): Original text to truncate
target_tokens (int): Target token count target_tokens (int): Target token count
model_name (str): Model name for token counting model_name (str): Model name for token counting
Returns: Returns:
str: Truncated text with token count not exceeding target_tokens str: Truncated text with token count not exceeding target_tokens
""" """
# If text is empty or target tokens is 0, return empty string # If text is empty or target tokens is 0, return empty string
if not text or target_tokens <= 0: if not text or target_tokens <= 0:
return "" return ""
# If original text token count is already less than or equal to target, return as is # If original text token count is already less than or equal to target, return as is
original_tokens = count_tokens(text, model_name) original_tokens = count_tokens(text, model_name)
if original_tokens <= target_tokens: if original_tokens <= target_tokens:
return text return text
# Binary search # Binary search
left, right = 0, len(text) left, right = 0, len(text)
best_length = 0 best_length = 0
best_text = "" best_text = ""
while left <= right: while left <= right:
mid = (left + right) // 2 mid = (left + right) // 2
truncated = text[:mid] truncated = text[:mid]
tokens = count_tokens(truncated, model_name) tokens = count_tokens(truncated, model_name)
if tokens <= target_tokens: if tokens <= target_tokens:
# If current truncated text token count is less than or equal to target, try longer text # If current truncated text token count is less than or equal to target, try longer text
best_length = mid best_length = mid
@ -1370,18 +1372,23 @@ class Conversation:
else: else:
# Otherwise try shorter text # Otherwise try shorter text
right = mid - 1 right = mid - 1
# Try to truncate at sentence boundaries if possible # Try to truncate at sentence boundaries if possible
sentence_delimiters = ['.', '!', '?', '\n'] sentence_delimiters = [".", "!", "?", "\n"]
for delimiter in sentence_delimiters: for delimiter in sentence_delimiters:
last_pos = best_text.rfind(delimiter) last_pos = best_text.rfind(delimiter)
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content if (
truncated_at_sentence = best_text[:last_pos+1] last_pos > len(best_text) * 0.75
if count_tokens(truncated_at_sentence, model_name) <= target_tokens: ): # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[: last_pos + 1]
if (
count_tokens(truncated_at_sentence, model_name)
<= target_tokens
):
return truncated_at_sentence return truncated_at_sentence
return best_text return best_text
def clear(self): def clear(self):
"""Clear the conversation history.""" """Clear the conversation history."""
if self.backend_instance: if self.backend_instance:
@ -1787,4 +1794,4 @@ class Conversation:
# # # conversation.add("assistant", "I am doing well, thanks.") # # # conversation.add("assistant", "I am doing well, thanks.")
# # # # print(conversation.to_json()) # # # # print(conversation.to_json())
# # print(type(conversation.to_dict())) # # print(type(conversation.to_dict()))
# # print(conversation.to_yaml()) # # print(conversation.to_yaml())

@ -6,9 +6,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent from swarms import ToolAgent
from swarms.agents.exceptions import ( from swarms.agents.exceptions import (
ToolExecutionError, ToolExecutionError,
ToolValidationError,
ToolNotFoundError, ToolNotFoundError,
ToolParameterError ToolParameterError,
) )
@ -111,9 +110,7 @@ def test_tool_agent_init_with_kwargs():
def test_tool_agent_initialization(): def test_tool_agent_initialization():
"""Test tool agent initialization with valid parameters.""" """Test tool agent initialization with valid parameters."""
agent = ToolAgent( agent = ToolAgent(
model_name="test-model", model_name="test-model", temperature=0.7, max_tokens=1000
temperature=0.7,
max_tokens=1000
) )
assert agent.model_name == "test-model" assert agent.model_name == "test-model"
assert agent.temperature == 0.7 assert agent.temperature == 0.7
@ -131,24 +128,26 @@ def test_tool_agent_initialization_error():
def test_tool_validation(): def test_tool_validation():
"""Test tool parameter validation.""" """Test tool parameter validation."""
tools_list = [{ tools_list = [
"name": "test_tool", {
"parameters": [ "name": "test_tool",
{"name": "required_param", "required": True}, "parameters": [
{"name": "optional_param", "required": False} {"name": "required_param", "required": True},
] {"name": "optional_param", "required": False},
}] ],
}
]
agent = ToolAgent(tools_list_dictionary=tools_list) agent = ToolAgent(tools_list_dictionary=tools_list)
# Test missing required parameter # Test missing required parameter
with pytest.raises(ToolParameterError) as exc_info: with pytest.raises(ToolParameterError) as exc_info:
agent._validate_tool("test_tool", {}) agent._validate_tool("test_tool", {})
assert "Missing required parameters" in str(exc_info.value) assert "Missing required parameters" in str(exc_info.value)
# Test valid parameters # Test valid parameters
agent._validate_tool("test_tool", {"required_param": "value"}) agent._validate_tool("test_tool", {"required_param": "value"})
# Test non-existent tool # Test non-existent tool
with pytest.raises(ToolNotFoundError) as exc_info: with pytest.raises(ToolNotFoundError) as exc_info:
agent._validate_tool("non_existent_tool", {}) agent._validate_tool("non_existent_tool", {})
@ -161,17 +160,17 @@ def test_retry_mechanism():
mock_llm.generate.side_effect = [ mock_llm.generate.side_effect = [
Exception("First attempt failed"), Exception("First attempt failed"),
Exception("Second attempt failed"), Exception("Second attempt failed"),
Mock(outputs=[Mock(text="Success")]) Mock(outputs=[Mock(text="Success")]),
] ]
agent = ToolAgent(model_name="test-model") agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm agent.llm = mock_llm
# Test successful retry # Test successful retry
result = agent.run("test task") result = agent.run("test task")
assert result == "Success" assert result == "Success"
assert mock_llm.generate.call_count == 3 assert mock_llm.generate.call_count == 3
# Test all retries failing # Test all retries failing
mock_llm.generate.side_effect = Exception("All attempts failed") mock_llm.generate.side_effect = Exception("All attempts failed")
with pytest.raises(ToolExecutionError) as exc_info: with pytest.raises(ToolExecutionError) as exc_info:
@ -185,15 +184,15 @@ def test_batched_execution():
mock_llm.generate.side_effect = [ mock_llm.generate.side_effect = [
Mock(outputs=[Mock(text="Success 1")]), Mock(outputs=[Mock(text="Success 1")]),
Exception("Task 2 failed"), Exception("Task 2 failed"),
Mock(outputs=[Mock(text="Success 3")]) Mock(outputs=[Mock(text="Success 3")]),
] ]
agent = ToolAgent(model_name="test-model") agent = ToolAgent(model_name="test-model")
agent.llm = mock_llm agent.llm = mock_llm
tasks = ["Task 1", "Task 2", "Task 3"] tasks = ["Task 1", "Task 2", "Task 3"]
results = agent.batched_run(tasks) results = agent.batched_run(tasks)
assert len(results) == 3 assert len(results) == 3
assert results[0] == "Success 1" assert results[0] == "Success 1"
assert "Error" in results[1] assert "Error" in results[1]
@ -206,22 +205,25 @@ def test_prompt_preparation():
agent = ToolAgent() agent = ToolAgent()
prompt = agent._prepare_prompt("test task") prompt = agent._prepare_prompt("test task")
assert prompt == "User: test task\nAssistant:" assert prompt == "User: test task\nAssistant:"
# Test with system prompt # Test with system prompt
agent = ToolAgent(system_prompt="You are a helpful assistant") agent = ToolAgent(system_prompt="You are a helpful assistant")
prompt = agent._prepare_prompt("test task") prompt = agent._prepare_prompt("test task")
assert prompt == "You are a helpful assistant\n\nUser: test task\nAssistant:" assert (
prompt
== "You are a helpful assistant\n\nUser: test task\nAssistant:"
)
def test_tool_execution_error_handling(): def test_tool_execution_error_handling():
"""Test error handling during tool execution.""" """Test error handling during tool execution."""
agent = ToolAgent(model_name="test-model") agent = ToolAgent(model_name="test-model")
agent.llm = None # Simulate uninitialized LLM agent.llm = None # Simulate uninitialized LLM
with pytest.raises(ToolExecutionError) as exc_info: with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task") agent.run("test task")
assert "LLM not initialized" in str(exc_info.value) assert "LLM not initialized" in str(exc_info.value)
# Test with invalid parameters # Test with invalid parameters
with pytest.raises(ToolExecutionError) as exc_info: with pytest.raises(ToolExecutionError) as exc_info:
agent.run("test task", invalid_param="value") agent.run("test task", invalid_param="value")

Loading…
Cancel
Save