From c2019717e26d6a58d92b409bff8de51dd2e4ad9b Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 30 Jul 2023 15:15:52 -0400 Subject: [PATCH] clean up --- swarms/agents/base.py | 17 ++-- swarms/agents/prompts/agent_prompt_auto.py | 95 ++++++++++++++++++++++ 2 files changed, 102 insertions(+), 10 deletions(-) create mode 100644 swarms/agents/prompts/agent_prompt_auto.py diff --git a/swarms/agents/base.py b/swarms/agents/base.py index f8244703..b8a692ee 100644 --- a/swarms/agents/base.py +++ b/swarms/agents/base.py @@ -9,6 +9,7 @@ from swarms.agents.utils.human_input import HumanInputRun from swarms.agents.prompts.prompt_generator import FINISH_NAME from swarms.agents.models.base import AbstractModel from swarms.agents.prompts.agent_output_parser import AgentOutputParser +from swarms.agents.prompts.agent_prompt_auto import PromptConstructor, MessageFormatter @@ -19,7 +20,6 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.tools.base import BaseTool from langchain.vectorstores.base import VectorStoreRetriever - class Agent: """Base Agent class""" @@ -54,15 +54,12 @@ class Agent: output_parser: Optional[AgentOutputParser] = None, chat_history_memory: Optional[BaseChatMessageHistory] = None, ) -> Agent: - prompt = AgentPrompt( - ai_name=ai_name, - ai_role=ai_role, - tools=tools, - input_variables=["memory", "messages", "goals", "user_input"], - token_counter=llm.get_num_tokens, - ) + prompt_constructor = PromptConstructor(ai_name=ai_name, + ai_role=ai_role, + tools=tools) + message_formatter = MessageFormatter() human_feedback_tool = HumanInputRun() if human_in_the_loop else None - chain = LLMChain(llm=llm, prompt=prompt) + chain = LLMChain(llm=llm, prompt_constructor=prompt_constructor, message_formatter=message_formatter) return cls( ai_name, memory, @@ -135,4 +132,4 @@ class Agent: memory_to_add += feedback self.memory.add_documents([Document(page_content=memory_to_add)]) - self.chat_history_memory.add_message(SystemMessage(content=result)) \ No newline at end of file + self.chat_history_memory.add_message(SystemMessage(content=result)) diff --git a/swarms/agents/prompts/agent_prompt_auto.py b/swarms/agents/prompts/agent_prompt_auto.py new file mode 100644 index 00000000..cdf8fd33 --- /dev/null +++ b/swarms/agents/prompts/agent_prompt_auto.py @@ -0,0 +1,95 @@ +import time +from typing import Any, Callable, List +from pydantic import BaseModel + +class TokenUtils: + @staticmethod + def count_tokens(text: str) -> int: + return len(text.split()) + + +class PromptConstructor: + def __init__(self, ai_name: str, ai_role: str, tools: List[BaseTool]): + self.ai_name = ai_name + self.ai_role = ai_role + self.tools = tools + + def construct_full_prompt(self, goals: List[str]) -> str: + prompt_start = ( + "Your decisions must always be made independently " + "without seeking user assistance.\n" + "Play to your strengths as an LLM and pursue simple " + "strategies with no legal complications.\n" + "If you have completed all your tasks, make sure to " + 'use the "finish" command.' + ) + # Construct full prompt + full_prompt = ( + f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" + ) + for i, goal in enumerate(goals): + full_prompt += f"{i+1}. {goal}\n" + full_prompt += f"\n\n{get_prompt(self.tools)}" + return full_prompt + + +class Message(BaseModel): + content: str + + def count_tokens(self) -> int: + return TokenUtils.count_tokens(self.content) + + def format_content(self) -> str: + return self.content + + +class SystemMessage(Message): + pass + + +class HumanMessage(Message): + pass + + +class MessageFormatter: + send_token_limit: int = 4196 + + def format_messages(self, **kwargs: Any) -> List[Message]: + prompt_constructor = PromptConstructor(ai_name=kwargs["ai_name"], + ai_role=kwargs["ai_role"], + tools=kwargs["tools"]) + base_prompt = SystemMessage(content=prompt_constructor.construct_full_prompt(kwargs["goals"])) + time_prompt = SystemMessage( + content=f"The current time and date is {time.strftime('%c')}" + ) + used_tokens = base_prompt.count_tokens() + time_prompt.count_tokens() + memory: VectorStoreRetriever = kwargs["memory"] + previous_messages = kwargs["messages"] + relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) + relevant_memory = [d.page_content for d in relevant_docs] + relevant_memory_tokens = sum( + [TokenUtils.count_tokens(doc) for doc in relevant_memory] + ) + while used_tokens + relevant_memory_tokens > 2500: + relevant_memory = relevant_memory[:-1] + relevant_memory_tokens = sum( + [TokenUtils.count_tokens(doc) for doc in relevant_memory] + ) + content_format = ( + f"This reminds you of these events " + f"from your past:\n{relevant_memory}\n\n" + ) + memory_message = SystemMessage(content=content_format) + used_tokens += memory_message.count_tokens() + historical_messages: List[Message] = [] + for message in previous_messages[-10:][::-1]: + message_tokens = message.count_tokens() + if used_tokens + message_tokens > self.send_token_limit - 1000: + break + historical_messages = [message] + historical_messages + used_tokens += message_tokens + input_message = HumanMessage(content=kwargs["user_input"]) + messages: List[Message] = [base_prompt, time_prompt, memory_message] + messages += historical_messages + messages.append(input_message) + return messages