|
|
|
@ -20,8 +20,16 @@ from langchain.schema import (
|
|
|
|
|
HumanMessage,
|
|
|
|
|
)
|
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.agents.agent import AgentOutputParser
|
|
|
|
|
from langchain.schema import AgentAction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.prompts.prompts import EVAL_TOOL_RESPONSE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
class ConversationalChatAgent(Agent):
|
|
|
|
@ -142,7 +150,12 @@ class ConversationalChatAgent(Agent):
|
|
|
|
|
logging.error(f"Error while creating agent from LLM and tools: {str(e)}")
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
class OutputParser(AgentOutputParser):
|
|
|
|
|
def parse(self, full_output: str) -> AgentAction:
|
|
|
|
|
return AgentAction(action="chat", details={'message': full_output})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatAgent(ConversationalChatAgent):
|
|
|
|
|
def _get_default_output_parser(self):
|
|
|
|
|
"""Get default output parser for this class."""
|
|
|
|
|
return OutputParser()
|
|
|
|
|