diff --git a/playground/memory/chroma_usage.py b/playground/memory/chroma_usage.py new file mode 100644 index 00000000..21ae475a --- /dev/null +++ b/playground/memory/chroma_usage.py @@ -0,0 +1,10 @@ +from swarms.memory import chroma + +chromadbcl = chroma.ChromaClient() + +chromadbcl.add_vectors(["This is a document", "BONSAIIIIIII", "the walking dead"]) + +results = chromadbcl.search_vectors("zombie", limit=1) + +print(results) + diff --git a/swarms/memory/base_vector_db.py b/swarms/memory/base_vector_db.py index fc58bf36..991bc8b5 100644 --- a/swarms/memory/base_vector_db.py +++ b/swarms/memory/base_vector_db.py @@ -18,13 +18,13 @@ class VectorDatabase(ABC): @abstractmethod def query( - self, vector: Dict[str, Any], num_results: int + self, text: str, num_results: int ) -> Dict[str, Any]: """ Query the database for vectors similar to the given vector. Args: - vector (Dict[str, Any]): The vector to compare against. + text (Dict[str, Any]): The vector to compare against. num_results (int): The number of similar vectors to return. Returns: diff --git a/swarms/memory/chroma.py b/swarms/memory/chroma.py index e69de29b..6fedc6f4 100644 --- a/swarms/memory/chroma.py +++ b/swarms/memory/chroma.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from chromadb.utils import embedding_functions +from httpx import RequestError +import chromadb + +from swarms.memory.base_vector_db import VectorDatabase + + +class ChromaClient(VectorDatabase): + def __init__( + self, + collection_name: str = "chromadb-collection", + model_name: str = "BAAI/bge-small-en-v1.5", + ): + try: + self.client = chromadb.Client() + self.collection_name = collection_name + self.model = None + self.collection = None + self._load_embedding_model(model_name) + self._setup_collection() + except RequestError as e: + print(f"Error setting up QdrantClient: {e}") + + def _load_embedding_model(self, model_name: str): + """ + Loads the sentence embedding model specified by the model name. + + Args: + model_name (str): The name of the model to load for generating embeddings. + """ + try: + self.model =embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) + except Exception as e: + print(f"Error loading embedding model: {e}") + + def _setup_collection(self): + try: + self.collection = self.client.get_collection(name=self.collection_name, embedding_function=self.model) + except Exception as e: + print(f"{e}. Creating new collection: {self.collection}") + + self.collection = self.client.create_collection(name=self.collection_name, embedding_function=self.model) + + + def add_vectors(self, docs: List[str]): + """ + Adds vector representations of documents to the Qdrant collection. + + Args: + docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key. + + Returns: + OperationResponse or None: Returns the operation information if successful, otherwise None. + """ + points = [] + ids = [] + for i, doc in enumerate(docs): + try: + points.append(doc) + ids.append("id"+str(i)) + except Exception as e: + print(f"Error processing document at index {i}: {e}") + + try: + self.collection.add( + documents=points, + ids=ids + ) + except Exception as e: + print(f"Error adding vectors: {e}") + return None + + def search_vectors(self, query: str, limit: int = 2): + """ + Searches the collection for vectors similar to the query vector. + + Args: + query (str): The query string to be converted into a vector and used for searching. + limit (int): The number of search results to return. Defaults to 3. + + Returns: + SearchResult or None: Returns the search results if successful, otherwise None. + """ + try: + search_result = self.collection.query( + query_texts=query, + n_results=limit, + ) + return search_result + except Exception as e: + print(f"Error searching vectors: {e}") + return None + + def add(self, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None: + pass + + def query(self, vector: Dict[str, Any], num_results: int) -> Dict[str, Any]: + pass + + def delete(self, vector_id: str) -> None: + pass + + def update(self, vector_id: str, vector: Dict[str, Any], metadata: Dict[str, Any]) -> None: + pass diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index fbfd4620..c2846a24 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -36,7 +36,7 @@ from swarms.utils.token_count_tiktoken import limit_tokens_from_string # Custom stopping condition def stop_when_repeats(response: str) -> bool: # Stop if the word stop appears in the response - return "Stop" in response.lower() + return "stop" in response.lower() # Parse done token @@ -489,6 +489,7 @@ class Agent: Interactive: {self.interactive} Dashboard: {self.dashboard} Dynamic Temperature: {self.dynamic_temperature_enabled} + Temperature: {self.llm.model_kwargs.get('temperature')} Autosave: {self.autosave} Saved State: {self.saved_state_path} Model Configuration: {model_config} @@ -638,6 +639,10 @@ class Agent: AGENT_SYSTEM_PROMPT_3, response ) + # Retreiving long term memory + if self.memory: + task = self.agent_memory_prompt(response, task) + attempt = 0 while attempt < self.retry_attempts: try: @@ -758,6 +763,33 @@ class Agent: """ return agent_history_prompt + def agent_memory_prompt( + self, + query, + prompt + ): + """ + Generate the agent long term memory prompt + + Args: + system_prompt (str): The system prompt + history (List[str]): The history of the conversation + + Returns: + str: The agent history prompt + """ + context_injected_prompt = prompt + if self.memory: + ltr = self.memory.query(query) + + context_injected_prompt = f"""{prompt} + ################ CONTEXT #################### + {ltr} + """ + + return context_injected_prompt + + async def run_concurrent(self, tasks: List[str], **kwargs): """ Run a batch of tasks concurrently and handle an infinite level of task inputs.