minor refactor

- minor refactor
pull/615/head
Sambhav Dixit 2 months ago committed by GitHub
parent 2a8cf252a7
commit b40c76ee5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -61,6 +61,8 @@ from dataclasses import asdict
# Utils # Utils
# Custom stopping condition # Custom stopping condition
def stop_when_repeats(response: str) -> bool: def stop_when_repeats(response: str) -> bool:
@ -812,6 +814,9 @@ class Agent:
) )
response = self.call_llm(*response_args, **kwargs) response = self.call_llm(*response_args, **kwargs)
# Log step metadata
step_meta = self.log_step_metadata(loop_count, task_prompt, response)
# Check if response is a dictionary and has 'choices' key # Check if response is a dictionary and has 'choices' key
if isinstance(response, dict) and 'choices' in response: if isinstance(response, dict) and 'choices' in response:
response = response['choices'][0]['message']['content'] response = response['choices'][0]['message']['content']
@ -825,9 +830,7 @@ class Agent:
# Check and execute tools # Check and execute tools
if self.tools is not None: if self.tools is not None:
print(f"self.tools is not None: {response}")
tool_result = self.parse_and_execute_tools(response) tool_result = self.parse_and_execute_tools(response)
self.parse_and_execute_tools(response)
if tool_result: if tool_result:
self.update_tool_usage( self.update_tool_usage(
step_meta["step_id"], step_meta["step_id"],
@ -836,7 +839,7 @@ class Agent:
tool_result["response"] tool_result["response"]
) )
# Update agent output history # Update agent output history
self.agent_output.full_history = self.short_memory.return_history_as_string() self.agent_output.full_history = self.short_memory.return_history_as_string()
@ -850,9 +853,6 @@ class Agent:
# Convert to a str if the response is not a str # Convert to a str if the response is not a str
response = self.llm_output_parser(response) response = self.llm_output_parser(response)
# Log step metadata
step_meta = self.log_step_metadata(loop_count, task_prompt, response)
# Print # Print
if self.streaming_on is True: if self.streaming_on is True:
@ -1003,6 +1003,7 @@ class Agent:
else: else:
return concat_strings(all_responses) return concat_strings(all_responses)
except Exception as error: except Exception as error:
logger.info( logger.info(
f"Error running agent: {error} optimize your input parameters" f"Error running agent: {error} optimize your input parameters"
@ -1948,6 +1949,7 @@ class Agent:
# Add step to agent output tracking # Add step to agent output tracking
return self.step_pool.append(step_log) return self.step_pool.append(step_log)
def update_tool_usage(self, step_id: str, tool_name: str, tool_args: dict, tool_response: Any): def update_tool_usage(self, step_id: str, tool_name: str, tool_args: dict, tool_response: Any):
"""Update tool usage information for a specific step.""" """Update tool usage information for a specific step."""
for step in self.agent_output.steps: for step in self.agent_output.steps:

Loading…
Cancel
Save