From 6db796e885de36aa68e5b66cddc220695a8d9e54 Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 13 Jul 2023 13:25:37 -0400 Subject: [PATCH] clean up with agents abstract class Former-commit-id: dbe2e82584bfe6d2dc493e05d2fe6cbe0c269364 --- swarms/agents/utils/Agent.py | 617 ++++++++++++++++++ swarms/agents/utils/AgentBuilder.py | 4 +- .../agents/utils/ConversationalChatAgent.py | 2 +- 3 files changed, 620 insertions(+), 3 deletions(-) create mode 100644 swarms/agents/utils/Agent.py diff --git a/swarms/agents/utils/Agent.py b/swarms/agents/utils/Agent.py new file mode 100644 index 00000000..e6e6871a --- /dev/null +++ b/swarms/agents/utils/Agent.py @@ -0,0 +1,617 @@ +"""Chain that takes in an input and produces an action and action input.""" +from __future__ import annotations + +import asyncio +import json +import logging +import time +from abc import abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import yaml +from pydantic import BaseModel, root_validator + +from langchain.agents.agent_types import AgentType +from langchain.agents.tools import InvalidTool +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + AsyncCallbackManagerForToolRun, + CallbackManagerForChainRun, + CallbackManagerForToolRun, + Callbacks, +) +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.input import get_color_mapping +from langchain.prompts.few_shot import FewShotPromptTemplate +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseOutputParser, + BasePromptTemplate, + OutputParserException, +) +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.messages import BaseMessage +from langchain.tools.base import BaseTool +from langchain.utilities.asyncio import asyncio_timeout + +logger = logging.getLogger(__name__) + + +class BaseSingleActionAgent(BaseModel): + """Base Agent class.""" + + @property + def return_values(self) -> List[str]: + """Return values of the agent.""" + return ["output"] + + def get_allowed_tools(self) -> Optional[List[str]]: + return None + + @abstractmethod + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + + @abstractmethod + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + + @property + @abstractmethod + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + + def return_stopped_response( + self, + early_stopping_method: str, + intermediate_steps: List[Tuple[AgentAction, str]], + **kwargs: Any, + ) -> AgentFinish: + """Return response when agent has been stopped due to max iterations.""" + if early_stopping_method == "force": + # `force` just returns a constant string + return AgentFinish( + {"output": "Agent stopped due to iteration limit or time limit."}, "" + ) + else: + raise ValueError( + f"Got unsupported early_stopping_method `{early_stopping_method}`" + ) + + @classmethod + def from_llm_and_tools( + cls, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, + ) -> BaseSingleActionAgent: + raise NotImplementedError + + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + _type = self._agent_type + if isinstance(_type, AgentType): + _dict["_type"] = str(_type.value) + else: + _dict["_type"] = _type + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the agent. + + Args: + file_path: Path to file to save the agent to. + + Example: + .. code-block:: python + + # If working with agent executor + agent.agent.save(file_path="path/agent.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + agent_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(agent_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(agent_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + def tool_run_logging_kwargs(self) -> Dict: + return {} + + +class BaseMultiActionAgent(BaseModel): + """Base Agent class.""" + + @property + def return_values(self) -> List[str]: + """Return values of the agent.""" + return ["output"] + + def get_allowed_tools(self) -> Optional[List[str]]: + return None + + @abstractmethod + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[List[AgentAction], AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Actions specifying what tool to use. + """ + + @abstractmethod + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[List[AgentAction], AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Actions specifying what tool to use. + """ + + @property + @abstractmethod + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + + def return_stopped_response( + self, + early_stopping_method: str, + intermediate_steps: List[Tuple[AgentAction, str]], + **kwargs: Any, + ) -> AgentFinish: + """Return response when agent has been stopped due to max iterations.""" + if early_stopping_method == "force": + # `force` just returns a constant string + return AgentFinish({"output": "Agent stopped due to max iterations."}, "") + else: + raise ValueError( + f"Got unsupported early_stopping_method `{early_stopping_method}`" + ) + + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + raise NotImplementedError + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + _dict["_type"] = str(self._agent_type) + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the agent. + + Args: + file_path: Path to file to save the agent to. + + Example: + .. code-block:: python + + # If working with agent executor + agent.agent.save(file_path="path/agent.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + agent_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(agent_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(agent_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + + def tool_run_logging_kwargs(self) -> Dict: + return {} + + +class AgentOutputParser(BaseOutputParser): + @abstractmethod + def parse(self, text: str) -> Union[AgentAction, AgentFinish]: + """Parse text into agent action/finish.""" + + +class LLMSingleActionAgent(BaseSingleActionAgent): + llm_chain: LLMChain + output_parser: AgentOutputParser + stop: List[str] + + @property + def input_keys(self) -> List[str]: + return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + del _dict["output_parser"] + return _dict + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + output = self.llm_chain.run( + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, + ) + return self.output_parser.parse(output) + + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + output = await self.llm_chain.arun( + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, + ) + return self.output_parser.parse(output) + + def tool_run_logging_kwargs(self) -> Dict: + return { + "llm_prefix": "", + "observation_prefix": "" if len(self.stop) == 0 else self.stop[0], + } + + +class Agent(BaseSingleActionAgent): + """Class responsible for calling the language model and deciding the action. + + This is driven by an LLMChain. The prompt in the LLMChain MUST include + a variable called "agent_scratchpad" where the agent can put its + intermediary work. + """ + + llm_chain: LLMChain + output_parser: AgentOutputParser + allowed_tools: Optional[List[str]] = None + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + del _dict["output_parser"] + return _dict + + def get_allowed_tools(self) -> Optional[List[str]]: + return self.allowed_tools + + @property + def return_values(self) -> List[str]: + return ["output"] + + def _fix_text(self, text: str) -> str: + """Fix the text.""" + raise ValueError("fix_text not implemented for this agent.") + + @property + def _stop(self) -> List[str]: + return [ + f"\n{self.observation_prefix.rstrip()}", + f"\n\t{self.observation_prefix.rstrip()}", + ] + + def _construct_scratchpad( + self, intermediate_steps: List[Tuple[AgentAction, str]] + ) -> Union[str, List[BaseMessage]]: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + return thoughts + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) + full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) + return self.output_parser.parse(full_output) + + async def aplan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) + full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) + return self.output_parser.parse(full_output) + + def get_full_inputs( + self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + ) -> Dict[str, Any]: + """Create the full inputs for the LLMChain from intermediate steps.""" + thoughts = self._construct_scratchpad(intermediate_steps) + new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} + full_inputs = {**kwargs, **new_inputs} + return full_inputs + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"}) + + @root_validator() + def validate_prompt(cls, values: Dict) -> Dict: + """Validate that prompt matches format.""" + prompt = values["llm_chain"].prompt + if "agent_scratchpad" not in prompt.input_variables: + logger.warning( + "`agent_scratchpad` should be a variable in prompt.input_variables." + " Did not find it, so adding it at the end." + ) + prompt.input_variables.append("agent_scratchpad") + if isinstance(prompt, PromptTemplate): + prompt.template += "\n{agent_scratchpad}" + elif isinstance(prompt, FewShotPromptTemplate): + prompt.suffix += "\n{agent_scratchpad}" + else: + raise ValueError(f"Got unexpected prompt type {type(prompt)}") + return values + + @property + @abstractmethod + def observation_prefix(self) -> str: + """Prefix to append the observation with.""" + + @property + @abstractmethod + def llm_prefix(self) -> str: + """Prefix to append the LLM call with.""" + + @classmethod + @abstractmethod + def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: + """Create a prompt for this class.""" + + @classmethod + def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: + """Validate that appropriate tools are passed in.""" + pass + + @classmethod + @abstractmethod + def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: + """Get default output parser for this class.""" + + @classmethod + def from_llm_and_tools( + cls, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + callback_manager: Optional[BaseCallbackManager] = None, + output_parser: Optional[AgentOutputParser] = None, + **kwargs: Any, + ) -> Agent: + """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) + llm_chain = LLMChain( + llm=llm, + prompt=cls.create_prompt(tools), + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + _output_parser = output_parser or cls._get_default_output_parser() + return cls( + llm_chain=llm_chain, + allowed_tools=tool_names, + output_parser=_output_parser, + **kwargs, + ) + + def return_stopped_response( + self, + early_stopping_method: str, + intermediate_steps: List[Tuple[AgentAction, str]], + **kwargs: Any, + ) -> AgentFinish: + """Return response when agent has been stopped due to max iterations.""" + if early_stopping_method == "force": + # `force` just returns a constant string + return AgentFinish( + {"output": "Agent stopped due to iteration limit or time limit."}, "" + ) + elif early_stopping_method == "generate": + # Generate does one final forward pass + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += ( + f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + ) + # Adding to the previous steps, we now tell the LLM to make a final pred + thoughts += ( + "\n\nI now need to return a final answer based on the previous steps:" + ) + new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop} + full_inputs = {**kwargs, **new_inputs} + full_output = self.llm_chain.predict(**full_inputs) + # We try to extract a final answer + parsed_output = self.output_parser.parse(full_output) + if isinstance(parsed_output, AgentFinish): + # If we can extract, we send the correct stuff + return parsed_output + else: + # If we can extract, but the tool is not the final tool, + # we just return the full output + return AgentFinish({"output": full_output}, full_output) + else: + raise ValueError( + "early_stopping_method should be one of `force` or `generate`, " + f"got {early_stopping_method}" + ) + + def tool_run_logging_kwargs(self) -> Dict: + return { + "llm_prefix": self.llm_prefix, + "observation_prefix": self.observation_prefix, + } + + +class ExceptionTool(BaseTool): + name = "_Exception" + description = "Exception tool" + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + return query + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return query diff --git a/swarms/agents/utils/AgentBuilder.py b/swarms/agents/utils/AgentBuilder.py index eeb7eec8..fb8ef625 100644 --- a/swarms/agents/utils/AgentBuilder.py +++ b/swarms/agents/utils/AgentBuilder.py @@ -87,8 +87,8 @@ class AgentBuilder: self.toolsets ), # for names and descriptions ], - system_message=EVAL_PREFIX.format(bot_name=os.environ["BOT_NAME"]), - human_message=EVAL_SUFFIX.format(bot_name=os.environ["BOT_NAME"]), + system_message=EVAL_PREFIX.format(bot_name=os.environ["BOT_NAME"] or 'WorkerUltraNode'), + human_message=EVAL_SUFFIX.format(bot_name=os.environ["BOT_NAME"] or 'WorkerUltraNode'), output_parser=self.parser, max_iterations=30, ) \ No newline at end of file diff --git a/swarms/agents/utils/ConversationalChatAgent.py b/swarms/agents/utils/ConversationalChatAgent.py index b95c49ea..68d9a81c 100644 --- a/swarms/agents/utils/ConversationalChatAgent.py +++ b/swarms/agents/utils/ConversationalChatAgent.py @@ -1,7 +1,7 @@ from typing import Any, List, Optional, Sequence, Tuple import logging -from langchain.agents.agent import Agent +from swarms.agents.utils.Agent import Agent from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.schema import BaseOutputParser