From 0d8a71b61982ccf970cb103fd05d7c10a6a4e771 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 28 Jul 2023 13:02:18 -0400 Subject: [PATCH] base agent class --- swarms/agents/agent_prompt.py | 78 +++++++++++++++ swarms/agents/base.py | 139 ++++++++++++++++++++++++++ swarms/agents/memory.py | 30 ++++++ swarms/agents/prompts/agent_prompt.py | 78 +++++++++++++++ swarms/agents/utils/output_parser.py | 70 ++++++++++++- 5 files changed, 393 insertions(+), 2 deletions(-) create mode 100644 swarms/agents/agent_prompt.py create mode 100644 swarms/agents/base.py create mode 100644 swarms/agents/memory.py create mode 100644 swarms/agents/prompts/agent_prompt.py diff --git a/swarms/agents/agent_prompt.py b/swarms/agents/agent_prompt.py new file mode 100644 index 00000000..30a1cfb9 --- /dev/null +++ b/swarms/agents/agent_prompt.py @@ -0,0 +1,78 @@ +import json +from typing import List + +class PromptGenerator: + """A class for generating custom prompt strings.""" + + def __init__(self) -> None: + """Initialize the PromptGenerator object.""" + self.constraints: List[str] = [] + self.commands: List[str] = [] + self.resources: List[str] = [] + self.performance_evaluation: List[str] = [] + self.response_format = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, + } + + def add_constraint(self, constraint: str) -> None: + """ + Add a constraint to the constraints list. + + Args: + constraint (str): The constraint to be added. + """ + self.constraints.append(constraint) + + def add_command(self, command: str) -> None: + """ + Add a command to the commands list. + + Args: + command (str): The command to be added. + """ + self.commands.append(command) + + def add_resource(self, resource: str) -> None: + """ + Add a resource to the resources list. + + Args: + resource (str): The resource to be added. + """ + self.resources.append(resource) + + def add_performance_evaluation(self, evaluation: str) -> None: + """ + Add a performance evaluation item to the performance_evaluation list. + + Args: + evaluation (str): The evaluation item to be added. + """ + self.performance_evaluation.append(evaluation) + + def generate_prompt_string(self) -> str: + """Generate a prompt string. + + Returns: + str: The generated prompt string. + """ + formatted_response_format = json.dumps(self.response_format, indent=4) + prompt_string = ( + f"Constraints:\n{''.join(self.constraints)}\n\n" + f"Commands:\n{''.join(self.commands)}\n\n" + f"Resources:\n{''.join(self.resources)}\n\n" + f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" + f"You should only respond in JSON format as described below " + f"\nResponse Format: \n{formatted_response_format} " + f"\nEnsure the response can be parsed by Python json.loads" + ) + + return prompt_string + diff --git a/swarms/agents/base.py b/swarms/agents/base.py new file mode 100644 index 00000000..1de2aca5 --- /dev/null +++ b/swarms/agents/base.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import List, Optional + +from langchain.chains.llm import LLMChain +from langchain.chat_models.base import BaseChatModel +from langchain.memory import ChatMessageHistory +from langchain.schema import ( + BaseChatMessageHistory, + Document, +) +from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage +from langchain.tools.base import BaseTool +from langchain.tools.human.tool import HumanInputRun +from langchain.vectorstores.base import VectorStoreRetriever +from langchain_experimental.autonomous_agents.autogpt.prompt_generator import ( + FINISH_NAME, +) +from pydantic import ValidationError + +from swarms.agents.utils.Agent import AgentOutputParser + + +class Agent: + """Base Agent class""" + + def __init__( + self, + ai_name: str, + memory: VectorStoreRetriever, + chain: LLMChain, + output_parser: BaseAgentOutputParser, + tools: List[BaseTool], + feedback_tool: Optional[HumanInputRun] = None, + chat_history_memory: Optional[BaseChatMessageHistory] = None, + ): + self.ai_name = ai_name + self.memory = memory + self.next_action_count = 0 + self.chain = chain + self.output_parser = output_parser + self.tools = tools + self.feedback_tool = feedback_tool + self.chat_history_memory = chat_history_memory or ChatMessageHistory() + + @classmethod + def from_llm_and_tools( + cls, + ai_name: str, + ai_role: str, + memory: VectorStoreRetriever, + tools: List[BaseTool], + llm: BaseChatModel, + human_in_the_loop: bool = False, + output_parser: Optional[BaseAgentOutputParser] = None, + chat_history_memory: Optional[BaseChatMessageHistory] = None, + ) -> Agent: + prompt = AgentPrompt( + ai_name=ai_name, + ai_role=ai_role, + tools=tools, + input_variables=["memory", "messages", "goals", "user_input"], + token_counter=llm.get_num_tokens, + ) + human_feedback_tool = HumanInputRun() if human_in_the_loop else None + chain = LLMChain(llm=llm, prompt=prompt) + return cls( + ai_name, + memory, + chain, + output_parser or AgentOutputParser(), + tools, + feedback_tool=human_feedback_tool, + chat_history_memory=chat_history_memory, + ) + + def run(self, goals: List[str]) -> str: + user_input = ( + "Determine which next command to use, " + "and respond using the format specified above:" + ) + # Interaction Loop + loop_count = 0 + while True: + # Discontinue if continuous limit is reached + loop_count += 1 + + # Send message to AI, get response + assistant_reply = self.chain.run( + goals=goals, + messages=self.chat_history_memory.messages, + memory=self.memory, + user_input=user_input, + ) + + # Print Assistant thoughts + print(assistant_reply) + self.chat_history_memory.add_message(HumanMessage(content=user_input)) + self.chat_history_memory.add_message(AIMessage(content=assistant_reply)) + + # Get command name and arguments + action = self.output_parser.parse(assistant_reply) + tools = {t.name: t for t in self.tools} + if action.name == FINISH_NAME: + return action.args["response"] + if action.name in tools: + tool = tools[action.name] + try: + observation = tool.run(action.args) + except ValidationError as e: + observation = ( + f"Validation Error in args: {str(e)}, args: {action.args}" + ) + except Exception as e: + observation = ( + f"Error: {str(e)}, {type(e).__name__}, args: {action.args}" + ) + result = f"Command {tool.name} returned: {observation}" + elif action.name == "ERROR": + result = f"Error: {action.args}. " + else: + result = ( + f"Unknown command '{action.name}'. " + f"Please refer to the 'COMMANDS' list for available " + f"commands and only respond in the specified JSON format." + ) + + memory_to_add = ( + f"Assistant Reply: {assistant_reply} " f"\nResult: {result} " + ) + if self.feedback_tool is not None: + feedback = f"\n{self.feedback_tool.run('Input: ')}" + if feedback in {"q", "stop"}: + print("EXITING") + return "EXITING" + memory_to_add += feedback + + self.memory.add_documents([Document(page_content=memory_to_add)]) + self.chat_history_memory.add_message(SystemMessage(content=result)) \ No newline at end of file diff --git a/swarms/agents/memory.py b/swarms/agents/memory.py new file mode 100644 index 00000000..26994da1 --- /dev/null +++ b/swarms/agents/memory.py @@ -0,0 +1,30 @@ +from typing import Any, Dict, List + +from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key +from langchain.vectorstores.base import VectorStoreRetriever + +from pydantic import Field + + +class AutoGPTMemory(BaseChatMemory): + retriever: VectorStoreRetriever = Field(exclude=True) + """VectorStoreRetriever object to connect to.""" + + @property + def memory_variables(self) -> List[str]: + return ["chat_history", "relevant_context"] + + def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: + """Get the input key for the prompt.""" + if self.input_key is None: + return get_prompt_input_key(inputs, self.memory_variables) + return self.input_key + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + input_key = self._get_prompt_input_key(inputs) + query = inputs[input_key] + docs = self.retriever.get_relevant_documents(query) + return { + "chat_history": self.chat_memory.messages[-10:], + "relevant_context": docs, + } \ No newline at end of file diff --git a/swarms/agents/prompts/agent_prompt.py b/swarms/agents/prompts/agent_prompt.py new file mode 100644 index 00000000..482dc5c8 --- /dev/null +++ b/swarms/agents/prompts/agent_prompt.py @@ -0,0 +1,78 @@ +import json +from typing import List + +class PromptGenerator: + """A class for generating custom prompt strings.""" + + def __init__(self) -> None: + """Initialize the PromptGenerator object.""" + self.constraints: List[str] = [] + self.commands: List[str] = [] + self.resources: List[str] = [] + self.performance_evaluation: List[str] = [] + self.response_format = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, + } + + def add_constraint(self, constraint: str) -> None: + """ + Add a constraint to the constraints list. + + Args: + constraint (str): The constraint to be added. + """ + self.constraints.append(constraint) + + def add_command(self, command: str) -> None: + """ + Add a command to the commands list. + + Args: + command (str): The command to be added. + """ + self.commands.append(command) + + def add_resource(self, resource: str) -> None: + """ + Add a resource to the resources list. + + Args: + resource (str): The resource to be added. + """ + self.resources.append(resource) + + def add_performance_evaluation(self, evaluation: str) -> None: + """ + Add a performance evaluation item to the performance_evaluation list. + + Args: + evaluation (str): The evaluation item to be added. + """ + self.performance_evaluation.append(evaluation) + + def generate_prompt_string(self) -> str: + """Generate a prompt string. + + Returns: + str: The generated prompt string. + """ + formatted_response_format = json.dumps(self.response_format, indent=4) + prompt_string = ( + f"Constraints:\n{''.join(self.constraints)}\n\n" + f"Commands:\n{''.join(self.commands)}\n\n" + f"Resources:\n{''.join(self.resources)}\n\n" + f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" + f"You should only respond in JSON format as described below " + f"\nResponse Format: \n{formatted_response_format} " + f"\nEnsure the response can be parsed by Python json.loads" + ) + + return prompt_string + diff --git a/swarms/agents/utils/output_parser.py b/swarms/agents/utils/output_parser.py index 4b4b8e71..9156927d 100644 --- a/swarms/agents/utils/output_parser.py +++ b/swarms/agents/utils/output_parser.py @@ -1,10 +1,13 @@ +import json import re -from typing import Dict +from abc import abstractmethod +from typing import Dict, NamedTuple from langchain.schema import BaseOutputParser from swarms.agents.prompts.prompts import EVAL_FORMAT_INSTRUCTIONS + class EvalOutputParser(BaseOutputParser): @staticmethod def parse_all(text: str) -> Dict[str, str]: @@ -39,4 +42,67 @@ class EvalOutputParser(BaseOutputParser): return {"action": parsed["action"], "action_input": parsed["action_input"]} def __str__(self): - return "EvalOutputParser" \ No newline at end of file + return "EvalOutputParser" + + + +class AgentAction(NamedTuple): + """Action for Agent.""" + + name: str + """Name of the action.""" + args: Dict + """Arguments for the action.""" + + +class BaseAgentOutputParser(BaseOutputParser): + """Base class for Agent output parsers.""" + + @abstractmethod + def parse(self, text: str) -> AgentAction: + """Parse text and return AgentAction""" + + +def preprocess_json_input(input_str: str) -> str: + """Preprocesses a string to be parsed as json. + + Replace single backslashes with double backslashes, + while leaving already escaped ones intact. + + Args: + input_str: String to be preprocessed + + Returns: + Preprocessed string + """ + corrected_str = re.sub( + r'(? AgentAction: + try: + parsed = json.loads(text, strict=False) + except json.JSONDecodeError: + preprocessed_text = preprocess_json_input(text) + try: + parsed = json.loads(preprocessed_text, strict=False) + except Exception: + return AgentAction( + name="ERROR", + args={"error": f"Could not parse invalid json: {text}"}, + ) + try: + return AgentAction( + name=parsed["command"]["name"], + args=parsed["command"]["args"], + ) + except (KeyError, TypeError): + # If the command is null or incomplete, return an erroneous tool + return AgentAction( + name="ERROR", args={"error": f"Incomplete command args: {parsed}"} + ) \ No newline at end of file