parent
4393f70504
commit
6238abe6b4
@ -1,612 +0,0 @@
|
||||
"""Chain that takes in an input and produces an action and action input."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseOutputParser,
|
||||
BasePromptTemplate,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSingleActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
_type = self._agent_type
|
||||
if isinstance(_type, AgentType):
|
||||
_dict["_type"] = str(_type.value)
|
||||
else:
|
||||
_dict["_type"] = _type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the agent.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the agent to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# If working with agent executor
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
class BaseMultiActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Actions specifying what tool to use.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Actions specifying what tool to use.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish({"output": "Agent stopped due to max iterations."}, "")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
)
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = str(self._agent_type)
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the agent.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the agent to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# If working with agent executor
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
class AgentOutputParser(BaseOutputParser):
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""Parse text into agent action/finish."""
|
||||
|
||||
|
||||
class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
llm_chain: LLMChain
|
||||
output_parser: AgentOutputParser
|
||||
stop: List[str]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
return _dict
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = self.llm_chain.run(
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
output = await self.llm_chain.arun(
|
||||
intermediate_steps=intermediate_steps,
|
||||
stop=self.stop,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {
|
||||
"llm_prefix": "",
|
||||
"observation_prefix": "" if len(self.stop) == 0 else self.stop[0],
|
||||
}
|
||||
|
||||
|
||||
class Agent(BaseSingleActionAgent):
|
||||
"""Class responsible for calling the language model and deciding the action.
|
||||
|
||||
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
||||
a variable called "agent_scratchpad" where the agent can put its
|
||||
intermediary work.
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
output_parser: AgentOutputParser
|
||||
allowed_tools: Optional[List[str]] = None
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
return _dict
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return self.allowed_tools
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
return ["output"]
|
||||
|
||||
def _fix_text(self, text: str) -> str:
|
||||
"""Fix the text."""
|
||||
raise ValueError("fix_text not implemented for this agent.")
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return [
|
||||
f"\n{self.observation_prefix.rstrip()}",
|
||||
f"\n\t{self.observation_prefix.rstrip()}",
|
||||
]
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> Union[str, List[BaseMessage]]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Create the full inputs for the LLMChain from intermediate steps."""
|
||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
return full_inputs
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt matches format."""
|
||||
prompt = values["llm_chain"].prompt
|
||||
if "agent_scratchpad" not in prompt.input_variables:
|
||||
logger.warning(
|
||||
"`agent_scratchpad` should be a variable in prompt.input_variables."
|
||||
" Did not find it, so adding it at the end."
|
||||
)
|
||||
prompt.input_variables.append("agent_scratchpad")
|
||||
if isinstance(prompt, PromptTemplate):
|
||||
prompt.template += "\n{agent_scratchpad}"
|
||||
elif isinstance(prompt, FewShotPromptTemplate):
|
||||
prompt.suffix += "\n{agent_scratchpad}"
|
||||
else:
|
||||
raise ValueError(f"Got unexpected prompt type {type(prompt)}")
|
||||
return values
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the observation with."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the LLM call with."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
"""Create a prompt for this class."""
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
"""Get default output parser for this class."""
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=cls.create_prompt(tools),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser()
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
elif early_stopping_method == "generate":
|
||||
# Generate does one final forward pass
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += (
|
||||
f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
)
|
||||
# Adding to the previous steps, we now tell the LLM to make a final pred
|
||||
thoughts += (
|
||||
"\n\nI now need to return a final answer based on the previous steps:"
|
||||
)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
# We try to extract a final answer
|
||||
parsed_output = self.output_parser.parse(full_output)
|
||||
if isinstance(parsed_output, AgentFinish):
|
||||
# If we can extract, we send the correct stuff
|
||||
return parsed_output
|
||||
else:
|
||||
# If we can extract, but the tool is not the final tool,
|
||||
# we just return the full output
|
||||
return AgentFinish({"output": full_output}, full_output)
|
||||
else:
|
||||
raise ValueError(
|
||||
"early_stopping_method should be one of `force` or `generate`, "
|
||||
f"got {early_stopping_method}"
|
||||
)
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {
|
||||
"llm_prefix": self.llm_prefix,
|
||||
"observation_prefix": self.observation_prefix,
|
||||
}
|
||||
|
||||
|
||||
class ExceptionTool(BaseTool):
|
||||
name = "_Exception"
|
||||
description = "Exception tool"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return query
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return query
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,197 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from celery import Task
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
from swarms.utils.logger import logger
|
||||
from swarms.utils.main import ANSI, Color, Style, dim_multiline
|
||||
|
||||
|
||||
class EVALCallbackHandler(BaseCallbackHandler):
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_parser(self, parser) -> None:
|
||||
self.parser = parser
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
text = response.generations[0][0].text
|
||||
|
||||
parsed = self.parser.parse_all(text)
|
||||
|
||||
logger.info(ANSI("Plan").to(Color.blue().bright()) + ": " + parsed["plan"])
|
||||
logger.info(ANSI("What I Did").to(Color.blue()) + ": " + parsed["what_i_did"])
|
||||
logger.info(
|
||||
ANSI("Action").to(Color.cyan())
|
||||
+ ": "
|
||||
+ ANSI(parsed["action"]).to(Style.bold())
|
||||
)
|
||||
logger.info(
|
||||
ANSI("Input").to(Color.cyan())
|
||||
+ ": "
|
||||
+ dim_multiline(parsed["action_input"])
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
logger.info(ANSI(f"on_llm_new_token {token}").to(Color.green(), Style.italic()))
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
logger.info(ANSI("Entering new chain.").to(Color.green(), Style.italic()))
|
||||
logger.info(ANSI("Prompted Text").to(Color.yellow()) + f': {inputs["input"]}\n')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
logger.info(ANSI("Finished chain.").to(Color.green(), Style.italic()))
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logger.error(
|
||||
ANSI("Chain Error").to(Color.red()) + ": " + dim_multiline(str(error))
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
logger.info(
|
||||
ANSI("Observation").to(Color.magenta()) + ": " + dim_multiline(output)
|
||||
)
|
||||
logger.info(ANSI("Thinking...").to(Color.green(), Style.italic()))
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
logger.error(ANSI("Tool Error").to(Color.red()) + f": {error}")
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
logger.info(
|
||||
ANSI("Final Answer").to(Color.yellow())
|
||||
+ ": "
|
||||
+ dim_multiline(finish.return_values.get("output", ""))
|
||||
)
|
||||
|
||||
|
||||
class ExecutionTracingCallbackHandler(BaseCallbackHandler):
|
||||
def __init__(self, execution: Task):
|
||||
self.execution = execution
|
||||
self.index = 0
|
||||
|
||||
def set_parser(self, parser) -> None:
|
||||
self.parser = parser
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
text = response.generations[0][0].text
|
||||
parsed = self.parser.parse_all(text)
|
||||
self.index += 1
|
||||
parsed["index"] = self.index
|
||||
self.execution.update_state(state="LLM_END", meta=parsed)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
self.execution.update_state(state="CHAIN_ERROR", meta={"error": str(error)})
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
previous = self.execution.AsyncResult(self.execution.request.id)
|
||||
self.execution.update_state(
|
||||
state="TOOL_END", meta={**previous.info, "observation": output}
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
previous = self.execution.AsyncResult(self.execution.request.id)
|
||||
self.execution.update_state(
|
||||
state="TOOL_ERROR", meta={**previous.info, "error": str(error)}
|
||||
)
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
@ -1 +0,0 @@
|
||||
"""Agents"""
|
@ -1,94 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from celery import Task
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains.conversation.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from swarms.tools.main import BaseToolSet, ToolsFactory
|
||||
|
||||
from swarms.agents.utils.agent_setup import AgentSetup
|
||||
from swarms.agents.utils.Calback import EVALCallbackHandler, ExecutionTracingCallbackHandler
|
||||
|
||||
callback_manager_instance = CallbackManager(EVALCallbackHandler())
|
||||
|
||||
class AgentCreator:
|
||||
def __init__(self, toolsets: list[BaseToolSet] = []):
|
||||
if not isinstance(toolsets, list):
|
||||
raise TypeError("Toolsets must be a list")
|
||||
self.toolsets: list[BaseToolSet] = toolsets
|
||||
self.memories: Dict[str, BaseChatMemory] = {}
|
||||
self.executors: Dict[str, AgentExecutor] = {}
|
||||
|
||||
def create_memory(self) -> BaseChatMemory:
|
||||
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||
|
||||
def get_or_create_memory(self, session: str) -> BaseChatMemory:
|
||||
if not isinstance(session, str):
|
||||
raise TypeError("Session must be a string")
|
||||
if not session:
|
||||
raise ValueError("Session is empty")
|
||||
if session not in self.memories:
|
||||
self.memories[session] = self.create_memory()
|
||||
return self.memories[session]
|
||||
|
||||
def create_executor(self, session: str, execution: Optional[Task] = None, openai_api_key: str = None) -> AgentExecutor:
|
||||
try:
|
||||
builder = AgentSetup(self.toolsets)
|
||||
builder.setup_parser()
|
||||
|
||||
callbacks = []
|
||||
eval_callback = EVALCallbackHandler()
|
||||
eval_callback.set_parser(builder.get_parser())
|
||||
callbacks.append(eval_callback)
|
||||
|
||||
if execution:
|
||||
execution_callback = ExecutionTracingCallbackHandler(execution)
|
||||
execution_callback.set_parser(builder.get_parser())
|
||||
callbacks.append(execution_callback)
|
||||
|
||||
callback_manager = CallbackManager(callbacks)
|
||||
builder.setup_llm(callback_manager, openai_api_key)
|
||||
if builder.llm is None:
|
||||
raise ValueError('LLM not created')
|
||||
|
||||
builder.setup_global_tools()
|
||||
|
||||
agent = builder.get_agent()
|
||||
if not agent:
|
||||
raise ValueError("Agent not created")
|
||||
|
||||
memory: BaseChatMemory = self.get_or_create_memory(session)
|
||||
tools = [
|
||||
*builder.get_global_tools(),
|
||||
*ToolsFactory.create_per_session_tools(
|
||||
self.toolsets,
|
||||
get_session=lambda: (session, self.executors[session]),
|
||||
),
|
||||
]
|
||||
|
||||
for tool in tools:
|
||||
tool.callback_manager = callback_manager
|
||||
|
||||
executor = AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if 'agent' not in executor.__dict__:
|
||||
executor.__dict__['agent'] = agent
|
||||
self.executors[session] = executor
|
||||
|
||||
return executor
|
||||
except Exception as e:
|
||||
logging.error(f"Error while creating executor: {str(e)}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def create(toolsets: list[BaseToolSet]) -> "AgentCreator":
|
||||
if not isinstance(toolsets, list):
|
||||
raise TypeError("Toolsets must be a list")
|
||||
return AgentCreator(toolsets=toolsets)
|
@ -1,92 +0,0 @@
|
||||
import os
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
||||
# from .ChatOpenAI import ChatOpenAI
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
from swarms.models.prompts.prebuild.multi_modal_prompts import EVAL_PREFIX, EVAL_SUFFIX
|
||||
from swarms.tools.main import BaseToolSet, ToolsFactory
|
||||
|
||||
from .ConversationalChatAgent import ConversationalChatAgent
|
||||
from .output_parser import EvalOutputParser
|
||||
|
||||
|
||||
class AgentSetup:
|
||||
def __init__(self, toolsets: list[BaseToolSet] = [], openai_api_key: str = None, serpapi_api_key: str = None, bing_search_url: str = None, bing_subscription_key: str = None):
|
||||
self.llm: BaseChatModel = None
|
||||
self.parser: BaseOutputParser = None
|
||||
self.global_tools: list = None
|
||||
self.toolsets = toolsets
|
||||
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.serpapi_api_key = serpapi_api_key or os.getenv('SERPAPI_API_KEY')
|
||||
self.bing_search_url = bing_search_url or os.getenv('BING_SEARCH_URL')
|
||||
self.bing_subscription_key = bing_subscription_key or os.getenv('BING_SUBSCRIPTION_KEY')
|
||||
if not self.openai_api_key:
|
||||
raise ValueError("OpenAI key is missing, it should either be set as an environment variable or passed as a parameter")
|
||||
|
||||
def setup_llm(self, callback_manager: BaseCallbackManager = None, openai_api_key: str = None):
|
||||
if openai_api_key is None:
|
||||
openai_api_key = os.getenv('OPENAI_API_KEY')
|
||||
if openai_api_key is None:
|
||||
raise ValueError("OpenAI API key is missing. It should either be set as an environment variable or passed as a parameter.")
|
||||
|
||||
self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0.5, callback_manager=callback_manager, verbose=True)
|
||||
|
||||
def setup_parser(self):
|
||||
self.parser = EvalOutputParser()
|
||||
|
||||
def setup_global_tools(self):
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized before tools")
|
||||
|
||||
toolnames = ["wikipedia"]
|
||||
|
||||
if self.serpapi_api_key:
|
||||
toolnames.append("serpapi")
|
||||
|
||||
if self.bing_search_url and self.bing_subscription_key:
|
||||
toolnames.append("bing-search")
|
||||
|
||||
self.global_tools = [
|
||||
*ToolsFactory.create_global_tools_from_names(toolnames, llm=self.llm),
|
||||
*ToolsFactory.create_global_tools(self.toolsets),
|
||||
]
|
||||
|
||||
def get_parser(self):
|
||||
if self.parser is None:
|
||||
raise ValueError("Parser is not initialized yet")
|
||||
|
||||
return self.parser
|
||||
|
||||
def get_global_tools(self):
|
||||
if self.global_tools is None:
|
||||
raise ValueError("Global tools are not initialized yet")
|
||||
|
||||
return self.global_tools
|
||||
|
||||
def get_agent(self):
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized before agent")
|
||||
|
||||
if self.parser is None:
|
||||
raise ValueError("Parser must be initialized before agent")
|
||||
|
||||
if self.global_tools is None:
|
||||
raise ValueError("Global tools must be initialized before agent")
|
||||
|
||||
return ConversationalChatAgent.from_llm_and_tools(
|
||||
llm=self.llm,
|
||||
tools=[
|
||||
*self.global_tools,
|
||||
*ToolsFactory.create_per_session_tools(
|
||||
self.toolsets
|
||||
), # for names and descriptions
|
||||
],
|
||||
system_message=EVAL_PREFIX.format(bot_name=os.environ["BOT_NAME"] or 'WorkerUltraNode'),
|
||||
human_message=EVAL_SUFFIX.format(bot_name=os.environ["BOT_NAME"] or 'WorkerUltraNode'),
|
||||
output_parser=self.parser,
|
||||
max_iterations=30,
|
||||
)
|
@ -1,108 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
from swarms.models.prompts.prebuild.multi_modal_prompts import EVAL_FORMAT_INSTRUCTIONS
|
||||
|
||||
|
||||
class EvalOutputParser(BaseOutputParser):
|
||||
@staticmethod
|
||||
def parse_all(text: str) -> Dict[str, str]:
|
||||
regex = r"Action: (.*?)[\n]Plan:(.*)[\n]What I Did:(.*)[\n]Action Input: (.*)"
|
||||
match = re.search(regex, text, re.DOTALL)
|
||||
if not match:
|
||||
raise Exception("parse error")
|
||||
|
||||
action = match.group(1).strip()
|
||||
plan = match.group(2)
|
||||
what_i_did = match.group(3)
|
||||
action_input = match.group(4).strip(" ")
|
||||
|
||||
return {
|
||||
"action": action,
|
||||
"plan": plan,
|
||||
"what_i_did": what_i_did,
|
||||
"action_input": action_input,
|
||||
}
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return EVAL_FORMAT_INSTRUCTIONS
|
||||
|
||||
def parse(self, text: str) -> Dict[str, str]:
|
||||
regex = r"Action: (.*?)[\n]Plan:(.*)[\n]What I Did:(.*)[\n]Action Input: (.*)"
|
||||
match = re.search(regex, text, re.DOTALL)
|
||||
if not match:
|
||||
raise Exception("parse error")
|
||||
|
||||
parsed = EvalOutputParser.parse_all(text)
|
||||
|
||||
return {"action": parsed["action"], "action_input": parsed["action_input"]}
|
||||
|
||||
def __str__(self):
|
||||
return "EvalOutputParser"
|
||||
|
||||
|
||||
|
||||
class AgentAction(NamedTuple):
|
||||
"""Action for Agent."""
|
||||
|
||||
name: str
|
||||
"""Name of the action."""
|
||||
args: Dict
|
||||
"""Arguments for the action."""
|
||||
|
||||
|
||||
class BaseAgentOutputParser(BaseOutputParser):
|
||||
"""Base class for Agent output parsers."""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> AgentAction:
|
||||
"""Parse text and return AgentAction"""
|
||||
|
||||
|
||||
def preprocess_json_input(input_str: str) -> str:
|
||||
"""Preprocesses a string to be parsed as json.
|
||||
|
||||
Replace single backslashes with double backslashes,
|
||||
while leaving already escaped ones intact.
|
||||
|
||||
Args:
|
||||
input_str: String to be preprocessed
|
||||
|
||||
Returns:
|
||||
Preprocessed string
|
||||
"""
|
||||
corrected_str = re.sub(
|
||||
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
|
||||
)
|
||||
return corrected_str
|
||||
|
||||
|
||||
class AgentOutputParser(BaseAgentOutputParser):
|
||||
"""Output parser for Agent."""
|
||||
|
||||
def parse(self, text: str) -> AgentAction:
|
||||
try:
|
||||
parsed = json.loads(text, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
preprocessed_text = preprocess_json_input(text)
|
||||
try:
|
||||
parsed = json.loads(preprocessed_text, strict=False)
|
||||
except Exception:
|
||||
return AgentAction(
|
||||
name="ERROR",
|
||||
args={"error": f"Could not parse invalid json: {text}"},
|
||||
)
|
||||
try:
|
||||
return AgentAction(
|
||||
name=parsed["command"]["name"],
|
||||
args=parsed["command"]["args"],
|
||||
)
|
||||
except (KeyError, TypeError):
|
||||
# If the command is null or incomplete, return an erroneous tool
|
||||
return AgentAction(
|
||||
name="ERROR", args={"error": f"Incomplete command args: {parsed}"}
|
||||
)
|
Loading…
Reference in new issue