diff --git a/swarms/agents/base.py b/swarms/agents/base.py index ed30fa1e..f8244703 100644 --- a/swarms/agents/base.py +++ b/swarms/agents/base.py @@ -8,6 +8,8 @@ from swarms.agents.utils.Agent import AgentOutputParser from swarms.agents.utils.human_input import HumanInputRun from swarms.agents.prompts.prompt_generator import FINISH_NAME from swarms.agents.models.base import AbstractModel +from swarms.agents.prompts.agent_output_parser import AgentOutputParser + from langchain.chains.llm import LLMChain @@ -26,7 +28,7 @@ class Agent: ai_name: str, chain: LLMChain, memory: VectorStoreRetriever, - output_parser: BaseAgentOutputParser, + output_parser: AgentOutputParser, tools: List[BaseTool], feedback_tool: Optional[HumanInputRun] = None, chat_history_memory: Optional[BaseChatMessageHistory] = None, @@ -49,7 +51,7 @@ class Agent: tools: List[BaseTool], llm: AbstractModel, human_in_the_loop: bool = False, - output_parser: Optional[BaseAgentOutputParser] = None, + output_parser: Optional[AgentOutputParser] = None, chat_history_memory: Optional[BaseChatMessageHistory] = None, ) -> Agent: prompt = AgentPrompt( diff --git a/swarms/agents/prompts/agent_output_parser.py b/swarms/agents/prompts/agent_output_parser.py new file mode 100644 index 00000000..3e0934da --- /dev/null +++ b/swarms/agents/prompts/agent_output_parser.py @@ -0,0 +1,47 @@ +import json +import re +from abc import abstractmethod +from typing import Dict, NamedTuple + +class AgentAction(NamedTuple): + """Action returned by AgentOutputParser.""" + name: str + args: Dict + +class BaseAgentOutputParser: + """Base Output parser for Agent.""" + + @abstractmethod + def parse(self, text: str) -> AgentAction: + """Return AgentAction""" + +class AgentOutputParser(BaseAgentOutputParser): + """Output parser for Agent.""" + + @staticmethod + def _preprocess_json_input(input_str: str) -> str: + corrected_str = re.sub( + r'(? dict: + try: + parsed = json.loads(text, strict=False) + except json.JSONDecodeError: + preprocessed_text = self._preprocess_json_input(text) + parsed = json.loads(preprocessed_text, strict=False) + return parsed + + def parse(self, text: str) -> AgentAction: + try: + parsed = self._parse_json(text) + return AgentAction( + name=parsed["command"]["name"], + args=parsed["command"]["args"], + ) + except (KeyError, TypeError, json.JSONDecodeError) as e: + return AgentAction( + name="ERROR", + args={"error": f"Error in parsing: {e}"}, + )