From 2eccf04a308dfaec76f64b8135f87320cd17871a Mon Sep 17 00:00:00 2001
From: Kye <kye@apacmediasolutions.com>
Date: Sun, 30 Jul 2023 15:15:52 -0400
Subject: [PATCH] clean up

Former-commit-id: c2019717e26d6a58d92b409bff8de51dd2e4ad9b
---
 swarms/agents/base.py                      | 17 ++--
 swarms/agents/prompts/agent_prompt_auto.py | 95 ++++++++++++++++++++++
 2 files changed, 102 insertions(+), 10 deletions(-)
 create mode 100644 swarms/agents/prompts/agent_prompt_auto.py

diff --git a/swarms/agents/base.py b/swarms/agents/base.py
index f8244703..b8a692ee 100644
--- a/swarms/agents/base.py
+++ b/swarms/agents/base.py
@@ -9,6 +9,7 @@ from swarms.agents.utils.human_input import HumanInputRun
 from swarms.agents.prompts.prompt_generator import FINISH_NAME
 from swarms.agents.models.base import AbstractModel
 from swarms.agents.prompts.agent_output_parser import AgentOutputParser
+from swarms.agents.prompts.agent_prompt_auto import PromptConstructor, MessageFormatter
 
 
 
@@ -19,7 +20,6 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
 from langchain.tools.base import BaseTool
 from langchain.vectorstores.base import VectorStoreRetriever
 
-
 class Agent:
     """Base Agent class"""
 
@@ -54,15 +54,12 @@ class Agent:
         output_parser: Optional[AgentOutputParser] = None,
         chat_history_memory: Optional[BaseChatMessageHistory] = None,
     ) -> Agent:
-        prompt = AgentPrompt(
-            ai_name=ai_name,
-            ai_role=ai_role,
-            tools=tools,
-            input_variables=["memory", "messages", "goals", "user_input"],
-            token_counter=llm.get_num_tokens,
-        )
+        prompt_constructor = PromptConstructor(ai_name=ai_name,
+                                               ai_role=ai_role,
+                                               tools=tools)
+        message_formatter = MessageFormatter()
         human_feedback_tool = HumanInputRun() if human_in_the_loop else None
-        chain = LLMChain(llm=llm, prompt=prompt)
+        chain = LLMChain(llm=llm, prompt_constructor=prompt_constructor, message_formatter=message_formatter)
         return cls(
             ai_name,
             memory,
@@ -135,4 +132,4 @@ class Agent:
                 memory_to_add += feedback
 
             self.memory.add_documents([Document(page_content=memory_to_add)])
-            self.chat_history_memory.add_message(SystemMessage(content=result))
\ No newline at end of file
+            self.chat_history_memory.add_message(SystemMessage(content=result))
diff --git a/swarms/agents/prompts/agent_prompt_auto.py b/swarms/agents/prompts/agent_prompt_auto.py
new file mode 100644
index 00000000..cdf8fd33
--- /dev/null
+++ b/swarms/agents/prompts/agent_prompt_auto.py
@@ -0,0 +1,95 @@
+import time
+from typing import Any, Callable, List
+from pydantic import BaseModel
+
+class TokenUtils:
+    @staticmethod
+    def count_tokens(text: str) -> int:
+        return len(text.split())
+
+
+class PromptConstructor:
+    def __init__(self, ai_name: str, ai_role: str, tools: List[BaseTool]):
+        self.ai_name = ai_name
+        self.ai_role = ai_role
+        self.tools = tools
+
+    def construct_full_prompt(self, goals: List[str]) -> str:
+        prompt_start = (
+            "Your decisions must always be made independently "
+            "without seeking user assistance.\n"
+            "Play to your strengths as an LLM and pursue simple "
+            "strategies with no legal complications.\n"
+            "If you have completed all your tasks, make sure to "
+            'use the "finish" command.'
+        )
+        # Construct full prompt
+        full_prompt = (
+            f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
+        )
+        for i, goal in enumerate(goals):
+            full_prompt += f"{i+1}. {goal}\n"
+        full_prompt += f"\n\n{get_prompt(self.tools)}"
+        return full_prompt
+
+
+class Message(BaseModel):
+    content: str
+
+    def count_tokens(self) -> int:
+        return TokenUtils.count_tokens(self.content)
+
+    def format_content(self) -> str:
+        return self.content
+
+
+class SystemMessage(Message):
+    pass
+
+
+class HumanMessage(Message):
+    pass
+
+
+class MessageFormatter:
+    send_token_limit: int = 4196
+
+    def format_messages(self, **kwargs: Any) -> List[Message]:
+        prompt_constructor = PromptConstructor(ai_name=kwargs["ai_name"],
+                                               ai_role=kwargs["ai_role"],
+                                               tools=kwargs["tools"])
+        base_prompt = SystemMessage(content=prompt_constructor.construct_full_prompt(kwargs["goals"]))
+        time_prompt = SystemMessage(
+            content=f"The current time and date is {time.strftime('%c')}"
+        )
+        used_tokens = base_prompt.count_tokens() + time_prompt.count_tokens()
+        memory: VectorStoreRetriever = kwargs["memory"]
+        previous_messages = kwargs["messages"]
+        relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
+        relevant_memory = [d.page_content for d in relevant_docs]
+        relevant_memory_tokens = sum(
+            [TokenUtils.count_tokens(doc) for doc in relevant_memory]
+        )
+        while used_tokens + relevant_memory_tokens > 2500:
+            relevant_memory = relevant_memory[:-1]
+            relevant_memory_tokens = sum(
+                [TokenUtils.count_tokens(doc) for doc in relevant_memory]
+            )
+        content_format = (
+            f"This reminds you of these events "
+            f"from your past:\n{relevant_memory}\n\n"
+        )
+        memory_message = SystemMessage(content=content_format)
+        used_tokens += memory_message.count_tokens()
+        historical_messages: List[Message] = []
+        for message in previous_messages[-10:][::-1]:
+            message_tokens = message.count_tokens()
+            if used_tokens + message_tokens > self.send_token_limit - 1000:
+                break
+            historical_messages = [message] + historical_messages
+            used_tokens += message_tokens
+        input_message = HumanMessage(content=kwargs["user_input"])
+        messages: List[Message] = [base_prompt, time_prompt, memory_message]
+        messages += historical_messages
+        messages.append(input_message)
+        return messages