Merge pull request #267 from kyegomez/memory

Memory with Chroma
pull/286/head
Eternal Reclaimer 1 year ago committed by GitHub
commit a4aab51655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,10 @@
from swarms.memory import chroma
chromadbcl = chroma.ChromaClient()
chromadbcl.add_vectors(["This is a document", "BONSAIIIIIII", "the walking dead"])
results = chromadbcl.search_vectors("zombie", limit=1)
print(results)

@ -18,13 +18,13 @@ class VectorDatabase(ABC):
@abstractmethod @abstractmethod
def query( def query(
self, vector: Dict[str, Any], num_results: int self, text: str, num_results: int
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Query the database for vectors similar to the given vector. Query the database for vectors similar to the given vector.
Args: Args:
vector (Dict[str, Any]): The vector to compare against. text (Dict[str, Any]): The vector to compare against.
num_results (int): The number of similar vectors to return. num_results (int): The number of similar vectors to return.
Returns: Returns:

@ -0,0 +1,106 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from chromadb.utils import embedding_functions
from httpx import RequestError
import chromadb
from swarms.memory.base_vector_db import VectorDatabase
class ChromaClient(VectorDatabase):
def __init__(
self,
collection_name: str = "chromadb-collection",
model_name: str = "BAAI/bge-small-en-v1.5",
):
try:
self.client = chromadb.Client()
self.collection_name = collection_name
self.model = None
self.collection = None
self._load_embedding_model(model_name)
self._setup_collection()
except RequestError as e:
print(f"Error setting up QdrantClient: {e}")
def _load_embedding_model(self, model_name: str):
"""
Loads the sentence embedding model specified by the model name.
Args:
model_name (str): The name of the model to load for generating embeddings.
"""
try:
self.model =embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
except Exception as e:
print(f"Error loading embedding model: {e}")
def _setup_collection(self):
try:
self.collection = self.client.get_collection(name=self.collection_name, embedding_function=self.model)
except Exception as e:
print(f"{e}. Creating new collection: {self.collection}")
self.collection = self.client.create_collection(name=self.collection_name, embedding_function=self.model)
def add_vectors(self, docs: List[str]):
"""
Adds vector representations of documents to the Qdrant collection.
Args:
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
Returns:
OperationResponse or None: Returns the operation information if successful, otherwise None.
"""
points = []
ids = []
for i, doc in enumerate(docs):
try:
points.append(doc)
ids.append("id"+str(i))
except Exception as e:
print(f"Error processing document at index {i}: {e}")
try:
self.collection.add(
documents=points,
ids=ids
)
except Exception as e:
print(f"Error adding vectors: {e}")
return None
def search_vectors(self, query: str, limit: int = 2):
"""
Searches the collection for vectors similar to the query vector.
Args:
query (str): The query string to be converted into a vector and used for searching.
limit (int): The number of search results to return. Defaults to 3.
Returns:
SearchResult or None: Returns the search results if successful, otherwise None.
"""
try:
search_result = self.collection.query(
query_texts=query,
n_results=limit,
)
return search_result
except Exception as e:
print(f"Error searching vectors: {e}")
return None
def add(self, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
pass
def query(self, vector: Dict[str, Any], num_results: int) -> Dict[str, Any]:
pass
def delete(self, vector_id: str) -> None:
pass
def update(self, vector_id: str, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None:
pass

@ -36,7 +36,7 @@ from swarms.utils.token_count_tiktoken import limit_tokens_from_string
# Custom stopping condition # Custom stopping condition
def stop_when_repeats(response: str) -> bool: def stop_when_repeats(response: str) -> bool:
# Stop if the word stop appears in the response # Stop if the word stop appears in the response
return "Stop" in response.lower() return "stop" in response.lower()
# Parse done token # Parse done token
@ -489,6 +489,7 @@ class Agent:
Interactive: {self.interactive} Interactive: {self.interactive}
Dashboard: {self.dashboard} Dashboard: {self.dashboard}
Dynamic Temperature: {self.dynamic_temperature_enabled} Dynamic Temperature: {self.dynamic_temperature_enabled}
Temperature: {self.llm.model_kwargs.get('temperature')}
Autosave: {self.autosave} Autosave: {self.autosave}
Saved State: {self.saved_state_path} Saved State: {self.saved_state_path}
Model Configuration: {model_config} Model Configuration: {model_config}
@ -638,6 +639,10 @@ class Agent:
AGENT_SYSTEM_PROMPT_3, response AGENT_SYSTEM_PROMPT_3, response
) )
# Retreiving long term memory
if self.memory:
task = self.agent_memory_prompt(response, task)
attempt = 0 attempt = 0
while attempt < self.retry_attempts: while attempt < self.retry_attempts:
try: try:
@ -758,6 +763,33 @@ class Agent:
""" """
return agent_history_prompt return agent_history_prompt
def agent_memory_prompt(
self,
query,
prompt
):
"""
Generate the agent long term memory prompt
Args:
system_prompt (str): The system prompt
history (List[str]): The history of the conversation
Returns:
str: The agent history prompt
"""
context_injected_prompt = prompt
if self.memory:
ltr = self.memory.query(query)
context_injected_prompt = f"""{prompt}
################ CONTEXT ####################
{ltr}
"""
return context_injected_prompt
async def run_concurrent(self, tasks: List[str], **kwargs): async def run_concurrent(self, tasks: List[str], **kwargs):
""" """
Run a batch of tasks concurrently and handle an infinite level of task inputs. Run a batch of tasks concurrently and handle an infinite level of task inputs.

Loading…
Cancel
Save