Enhance tool execution handling for streaming responses, adding support for partial JSON and tool call detection

pull/938/head
harshalmore31 1 week ago
parent cebcd454c2
commit 5e35951e4c

@ -1116,7 +1116,7 @@ class Agent:
) )
# Check and execute callable tools # Check and execute callable tools
if exists(self.tools): if exists(self.tools) and self._response_contains_tool_calls(response):
self.tool_execution_retry( self.tool_execution_retry(
response, loop_count response, loop_count
) )
@ -3001,20 +3001,185 @@ class Agent:
) )
return return
# Check if this is a streaming response
if self.streaming_on and hasattr(response, '__iter__') and not isinstance(response, (str, dict)):
self._execute_tools_streaming(response, loop_count)
else:
self._execute_tools_non_streaming(response, loop_count)
def _execute_tools_streaming(self, streaming_response, loop_count: int):
"""Handle tool execution for streaming responses with real-time parsing"""
tool_call_accumulators = {} # Dictionary to track multiple tool calls by ID
executed_tools = set() # Track which tools have been executed
try:
if self.print_on:
logger.info(f"🔧 Starting streaming tool execution for agent {self.agent_name}")
for chunk in streaming_response:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if hasattr(delta, 'tool_calls') and delta.tool_calls:
for tool_call in delta.tool_calls:
# Get tool call index to handle multiple parallel tool calls
tool_index = getattr(tool_call, 'index', 0)
tool_id = getattr(tool_call, 'id', f"tool_{tool_index}")
# Initialize accumulator for new tool call
if tool_id not in tool_call_accumulators:
tool_call_accumulators[tool_id] = {
'name': '',
'arguments': '',
'id': tool_id,
'index': tool_index,
'complete': False
}
# Accumulate tool name
if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'name'):
if tool_call.function.name:
tool_call_accumulators[tool_id]['name'] = tool_call.function.name
if self.print_on and self.verbose:
logger.info(f"🛠️ Tool call detected: {tool_call.function.name}")
# Accumulate tool arguments
if hasattr(tool_call, 'function') and hasattr(tool_call.function, 'arguments'):
if tool_call.function.arguments:
tool_call_accumulators[tool_id]['arguments'] += tool_call.function.arguments
# Try to parse arguments to see if they're complete valid JSON
try:
parsed_args = json.loads(tool_call_accumulators[tool_id]['arguments'])
# If parsing succeeds and tool hasn't been executed yet, execute it
if (not tool_call_accumulators[tool_id]['complete'] and
tool_id not in executed_tools and
tool_call_accumulators[tool_id]['name']):
tool_call_accumulators[tool_id]['complete'] = True
executed_tools.add(tool_id)
# Execute tool immediately
self._execute_single_tool_streaming(
tool_call_accumulators[tool_id],
loop_count
)
except json.JSONDecodeError:
# Arguments not complete yet, continue accumulating
if self.verbose:
logger.debug(f"Accumulating args for {tool_call_accumulators[tool_id]['name']}: {tool_call_accumulators[tool_id]['arguments'][:50]}...")
continue
# Handle any remaining tools that might not have been executed
for tool_id, tool_data in tool_call_accumulators.items():
if not tool_data['complete'] and tool_data['arguments'] and tool_id not in executed_tools:
try:
json.loads(tool_data['arguments'])
self._execute_single_tool_streaming(tool_data, loop_count)
executed_tools.add(tool_id)
except json.JSONDecodeError:
logger.warning(f"Tool {tool_data['name']} had incomplete arguments: {tool_data['arguments'][:100]}...")
except Exception as e:
logger.error(f"Error during streaming tool execution: {e}")
# Fallback to non-streaming execution if something goes wrong
logger.info("Falling back to non-streaming tool execution")
self._execute_tools_non_streaming(streaming_response, loop_count)
def _execute_single_tool_streaming(self, tool_data: dict, loop_count: int):
"""Execute a single tool with its accumulated data during streaming"""
try: try:
if self.print_on:
formatter.print_panel(
f"🚀 Executing tool: {tool_data['name']}\nArguments: {tool_data['arguments']}",
f"Real-time Tool Execution [{time.strftime('%H:%M:%S')}]",
style="cyan"
)
# Create a mock response object that the existing tool_struct can handle
mock_response = {
'choices': [{
'message': {
'tool_calls': [{
'id': tool_data['id'],
'type': 'function',
'function': {
'name': tool_data['name'],
'arguments': tool_data['arguments']
}
}]
}
}]
}
# Execute the tool with streaming mode enabled
output = self.tool_struct.execute_function_calls_from_api_response( output = self.tool_struct.execute_function_calls_from_api_response(
response mock_response,
is_streaming=True
)
if output:
# Add tool output to memory immediately
tool_output_content = f"Tool '{tool_data['name']}' executed successfully:\n{format_data_structure(output)}"
self.short_memory.add(
role="Tool Executor",
content=tool_output_content,
)
if self.print_on:
formatter.print_panel(
format_data_structure(output),
f"✅ Tool '{tool_data['name']}' Output [{time.strftime('%H:%M:%S')}]",
style="green"
)
# Generate tool summary if enabled
if self.tool_call_summary:
self._generate_tool_summary(output, loop_count, tool_data['name'])
else:
logger.warning(f"Tool {tool_data['name']} returned no output")
except Exception as e:
error_msg = f"Error executing streaming tool {tool_data['name']}: {str(e)}"
logger.error(error_msg)
# Add error to memory
self.short_memory.add(
role="Tool Executor",
content=f"Tool execution failed: {error_msg}",
)
if self.print_on:
formatter.print_panel(
error_msg,
f"❌ Tool Execution Error [{time.strftime('%H:%M:%S')}]",
style="red"
)
def _execute_tools_non_streaming(self, response: any, loop_count: int):
"""Handle tool execution for non-streaming responses (existing logic)"""
try:
output = self.tool_struct.execute_function_calls_from_api_response(
response,
is_streaming=False
) )
except Exception as e: except Exception as e:
# Retry the tool call # Retry the tool call
try:
output = self.tool_struct.execute_function_calls_from_api_response( output = self.tool_struct.execute_function_calls_from_api_response(
response response,
is_streaming=False
) )
except Exception as retry_error:
logger.error(f"Error executing tools after retry: {retry_error}")
raise retry_error
if output is None: if output is None:
logger.error(f"Error executing tools: {e}") logger.error(f"Error executing tools: {e}")
raise e raise e
if output:
self.short_memory.add( self.short_memory.add(
role="Tool Executor", role="Tool Executor",
content=format_data_structure(output), content=format_data_structure(output),
@ -3026,10 +3191,12 @@ class Agent:
loop_count, loop_count,
) )
# Now run the LLM again without tools - create a temporary LLM instance
# instead of modifying the cached one
# Create a temporary LLM instance without tools for the follow-up call
if self.tool_call_summary is True: if self.tool_call_summary is True:
self._generate_tool_summary(output, loop_count)
def _generate_tool_summary(self, output, loop_count: int, tool_name: str = ""):
"""Generate tool execution summary"""
try:
temp_llm = self.temp_llm_instance_for_tool_summary() temp_llm = self.temp_llm_instance_for_tool_summary()
tool_response = temp_llm.run( tool_response = temp_llm.run(
@ -3038,6 +3205,7 @@ class Agent:
Focus on the key information and insights that would be most relevant to the user's original request. Focus on the key information and insights that would be most relevant to the user's original request.
If there are any errors or issues, highlight them prominently. If there are any errors or issues, highlight them prominently.
Tool Name: {tool_name}
Tool Output: Tool Output:
{output} {output}
""" """
@ -3053,6 +3221,8 @@ class Agent:
tool_response, tool_response,
loop_count, loop_count,
) )
except Exception as e:
logger.error(f"Error generating tool summary: {e}")
def list_output_types(self): def list_output_types(self):
return OutputType return OutputType
@ -3188,6 +3358,86 @@ class Agent:
f"Failed to find correct answer '{correct_answer}' after {max_attempts} attempts" f"Failed to find correct answer '{correct_answer}' after {max_attempts} attempts"
) )
def _response_contains_tool_calls(self, response: any) -> bool:
"""
Check if a response contains tool calls that should be executed.
Args:
response: The response from the LLM
Returns:
bool: True if response contains tool calls, False otherwise
"""
if response is None:
return False
try:
# Handle string responses
if isinstance(response, str):
# Check for empty or whitespace-only strings
if not response.strip():
return False
# Try to parse as JSON
try:
response_dict = json.loads(response)
except json.JSONDecodeError:
# If it's not JSON, it's likely just text without tool calls
return False
response = response_dict
# Handle BaseModel objects
if isinstance(response, BaseModel):
response = response.model_dump()
# Check if it's a dictionary with tool call indicators
if isinstance(response, dict):
# Check for OpenAI format
if "choices" in response:
choices = response.get("choices", [])
for choice in choices:
message = choice.get("message", {})
if "tool_calls" in message and message["tool_calls"]:
return True
# Check for direct tool_calls
if "tool_calls" in response and response["tool_calls"]:
return True
# Check for Anthropic format
if "content" in response:
content = response.get("content", [])
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "tool_use":
return True
# Check for single tool call format
if (response.get("type") == "function" and "function" in response) or \
(response.get("type") == "tool_use" and "name" in response):
return True
# Handle list of tool calls
if isinstance(response, list):
for item in response:
if isinstance(item, dict):
if (item.get("type") == "function" and "function" in item) or \
(item.get("type") == "tool_use" and "name" in item):
return True
elif isinstance(item, BaseModel):
# Convert BaseModel to dict and check
item_dict = item.model_dump()
if (item_dict.get("type") == "function" and "function" in item_dict) or \
(item_dict.get("type") == "tool_use" and "name" in item_dict):
return True
return False
except Exception as e:
logger.debug(f"Error checking for tool calls in response: {e}")
return False
def tool_execution_retry(self, response: any, loop_count: int): def tool_execution_retry(self, response: any, loop_count: int):
""" """
Execute tools with retry logic for handling failures. Execute tools with retry logic for handling failures.

@ -2185,6 +2185,7 @@ class BaseTool(BaseModel):
sequential: bool = False, sequential: bool = False,
max_workers: int = 4, max_workers: int = 4,
return_as_string: bool = True, return_as_string: bool = True,
is_streaming: bool = False,
) -> Union[List[Any], List[str]]: ) -> Union[List[Any], List[str]]:
""" """
Automatically detect and execute function calls from OpenAI or Anthropic API responses. Automatically detect and execute function calls from OpenAI or Anthropic API responses.
@ -2196,12 +2197,14 @@ class BaseTool(BaseModel):
- Pydantic BaseModel objects from Anthropic responses - Pydantic BaseModel objects from Anthropic responses
- Parallel function execution with concurrent.futures or sequential execution - Parallel function execution with concurrent.futures or sequential execution
- Multiple function calls in a single response - Multiple function calls in a single response
- Streaming responses with partial JSON chunks
Args: Args:
api_response (Union[Dict[str, Any], str, List[Any]]): The API response containing function calls api_response (Union[Dict[str, Any], str, List[Any]]): The API response containing function calls
sequential (bool): If True, execute functions sequentially. If False, execute in parallel (default) sequential (bool): If True, execute functions sequentially. If False, execute in parallel (default)
max_workers (int): Maximum number of worker threads for parallel execution (default: 4) max_workers (int): Maximum number of worker threads for parallel execution (default: 4)
return_as_string (bool): If True, return results as formatted strings (default: True) return_as_string (bool): If True, return results as formatted strings (default: True)
is_streaming (bool): If True, handle partial/incomplete streaming responses gracefully (default: False)
Returns: Returns:
Union[List[Any], List[str]]: List of results from executed functions Union[List[Any], List[str]]: List of results from executed functions
@ -2222,6 +2225,9 @@ class BaseTool(BaseModel):
>>> # Direct tool calls list (including BaseModel objects) >>> # Direct tool calls list (including BaseModel objects)
>>> tool_calls = [ChatCompletionMessageToolCall(...), ...] >>> tool_calls = [ChatCompletionMessageToolCall(...), ...]
>>> results = tool.execute_function_calls_from_api_response(tool_calls) >>> results = tool.execute_function_calls_from_api_response(tool_calls)
>>> # Streaming response handling
>>> results = tool.execute_function_calls_from_api_response(partial_response, is_streaming=True)
""" """
# Handle None API response gracefully by returning empty results # Handle None API response gracefully by returning empty results
if api_response is None: if api_response is None:
@ -2231,6 +2237,26 @@ class BaseTool(BaseModel):
) )
return [] if not return_as_string else [] return [] if not return_as_string else []
# Handle streaming mode with empty or partial responses
if is_streaming:
# For streaming, we may get empty strings or partial JSON - handle gracefully
if isinstance(api_response, str) and api_response.strip() == "":
self._log_if_verbose(
"debug",
"Empty streaming response, returning empty results",
)
return [] if not return_as_string else []
# If it's a string that looks like incomplete JSON, return empty results
if isinstance(api_response, str):
stripped_response = api_response.strip()
if stripped_response and not self._is_likely_complete_json(stripped_response):
self._log_if_verbose(
"debug",
f"Incomplete streaming JSON detected: '{stripped_response[:50]}...', returning empty results",
)
return [] if not return_as_string else []
# Handle direct list of tool call objects (e.g., from OpenAI ChatCompletionMessageToolCall or Anthropic BaseModels) # Handle direct list of tool call objects (e.g., from OpenAI ChatCompletionMessageToolCall or Anthropic BaseModels)
if isinstance(api_response, list): if isinstance(api_response, list):
self._log_if_verbose( self._log_if_verbose(
@ -2261,6 +2287,14 @@ class BaseTool(BaseModel):
try: try:
api_response = json.loads(api_response) api_response = json.loads(api_response)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# In streaming mode, this is expected for partial responses
if is_streaming:
self._log_if_verbose(
"debug",
f"JSON parsing failed in streaming mode (expected for partial responses): {e}. Response: '{api_response[:100]}...'",
)
return [] if not return_as_string else []
else:
self._log_if_verbose( self._log_if_verbose(
"error", "error",
f"Failed to parse JSON from API response: {e}. Response: '{api_response[:100]}...'", f"Failed to parse JSON from API response: {e}. Response: '{api_response[:100]}...'",
@ -2966,6 +3000,65 @@ class BaseTool(BaseModel):
return function_calls return function_calls
def _is_likely_complete_json(self, json_str: str) -> bool:
"""
Check if a JSON string appears to be complete by examining its structure.
This is a heuristic method for streaming responses to avoid parsing incomplete JSON.
Args:
json_str (str): JSON string to check
Returns:
bool: True if the JSON appears complete, False otherwise
"""
if not json_str or not isinstance(json_str, str):
return False
json_str = json_str.strip()
if not json_str:
return False
try:
# Try to parse - if it succeeds, it's complete
json.loads(json_str)
return True
except json.JSONDecodeError:
# If parsing fails, use heuristics to check if it might be incomplete
# Check for basic structural completeness
if json_str.startswith('{'):
# Count braces to see if they're balanced
open_braces = json_str.count('{')
close_braces = json_str.count('}')
if open_braces > close_braces:
return False # Likely incomplete
elif json_str.startswith('['):
# Count brackets to see if they're balanced
open_brackets = json_str.count('[')
close_brackets = json_str.count(']')
if open_brackets > close_brackets:
return False # Likely incomplete
# Check for incomplete strings (odd number of unescaped quotes)
quote_count = 0
escaped = False
for char in json_str:
if char == '\\' and not escaped:
escaped = True
elif char == '"' and not escaped:
quote_count += 1
else:
escaped = False
if quote_count % 2 != 0:
return False # Incomplete string
# If we get here, the JSON might be malformed but not necessarily incomplete
# Return False to be safe
return False
def _format_results_as_strings( def _format_results_as_strings(
self, results: List[Any], function_calls: List[Dict[str, Any]] self, results: List[Any], function_calls: List[Dict[str, Any]]
) -> List[str]: ) -> List[str]:

Loading…
Cancel
Save