parent
cb77ac4e2b
commit
a4d953ec08
@ -1,28 +0,0 @@
|
|||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from swarms.memory.base_memory import BaseChatMemory, get_prompt_input_key
|
|
||||||
from swarms.memory.base import VectorStoreRetriever
|
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory(BaseChatMemory):
|
|
||||||
retriever: VectorStoreRetriever
|
|
||||||
"""VectorStoreRetriever object to connect to."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
return ["chat_history", "relevant_context"]
|
|
||||||
|
|
||||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
|
||||||
"""Get the input key for the prompt."""
|
|
||||||
if self.input_key is None:
|
|
||||||
return get_prompt_input_key(inputs, self.memory_variables)
|
|
||||||
return self.input_key
|
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
input_key = self._get_prompt_input_key(inputs)
|
|
||||||
query = inputs[input_key]
|
|
||||||
docs = self.retriever.get_relevant_documents(query)
|
|
||||||
return {
|
|
||||||
"chat_history": self.chat_memory.messages[-10:],
|
|
||||||
"relevant_context": docs,
|
|
||||||
}
|
|
Loading…
Reference in new issue