diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index d2eed0d5..1d52d718 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -1,11 +1,11 @@ -from swarms.memory.base_vectordb import VectorDatabase +from swarms.memory.base_vectordb import AbstractDatabase from swarms.memory.short_term_memory import ShortTermMemory from swarms.memory.sqlite import SQLiteDB from swarms.memory.weaviate_db import WeaviateDB from swarms.memory.visual_memory import VisualShortTermMemory __all__ = [ - "VectorDatabase", + "AbstractDatabase", "ShortTermMemory", "SQLiteDB", "WeaviateDB", diff --git a/swarms/memory/base_db.py b/swarms/memory/base_db.py index 0501def7..bb0a2961 100644 --- a/swarms/memory/base_db.py +++ b/swarms/memory/base_db.py @@ -156,4 +156,4 @@ class AbstractDatabase(ABC): """ - pass + pass \ No newline at end of file diff --git a/swarms/memory/base_vectordatabase.py b/swarms/memory/base_vectordatabase.py new file mode 100644 index 00000000..734c872a --- /dev/null +++ b/swarms/memory/base_vectordatabase.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod + + +class AbstractVectorDatabase(ABC): + """ + Abstract base class for a database. + + This class defines the interface for interacting with a database. + Subclasses must implement the abstract methods to provide the + specific implementation details for connecting to a database, + executing queries, and performing CRUD operations. + + """ + + @abstractmethod + def connect(self): + """ + Connect to the database. + + This method establishes a connection to the database. + + """ + + pass + + @abstractmethod + def close(self): + """ + Close the database connection. + + This method closes the connection to the database. + + """ + + pass + + @abstractmethod + def query(self, query: str): + """ + Execute a database query. + + This method executes the given query on the database. + + Parameters: + query (str): The query to be executed. + + """ + + pass + + @abstractmethod + def fetch_all(self): + """ + Fetch all rows from the result set. + + This method retrieves all rows from the result set of a query. + + Returns: + list: A list of dictionaries representing the rows. + + """ + + pass + + @abstractmethod + def fetch_one(self): + """ + Fetch one row from the result set. + + This method retrieves one row from the result set of a query. + + Returns: + dict: A dictionary representing the row. + + """ + + pass + + @abstractmethod + def add(self, doc: str): + """ + Add a new record to the database. + + This method adds a new record to the specified table in the database. + + Parameters: + table (str): The name of the table. + data (dict): A dictionary representing the data to be added. + + """ + + pass + + + @abstractmethod + def get(self, query: str): + """ + Get a record from the database. + + This method retrieves a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be retrieved. + + Returns: + dict: A dictionary representing the retrieved record. + + """ + + pass + + @abstractmethod + def update(self, doc): + """ + Update a record in the database. + + This method updates a record in the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be updated. + data (dict): A dictionary representing the updated data. + + """ + + pass + + @abstractmethod + def delete(self, message): + """ + Delete a record from the database. + + This method deletes a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be deleted. + + """ + + pass \ No newline at end of file diff --git a/swarms/memory/base_vectordb.py b/swarms/memory/base_vectordb.py deleted file mode 100644 index 841c6147..00000000 --- a/swarms/memory/base_vectordb.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict - - -class VectorDatabase(ABC): - @abstractmethod - def add( - self, vector: Dict[str, Any], metadata: Dict[str, Any] - ) -> None: - """ - add a vector into the database. - - Args: - vector (Dict[str, Any]): The vector to add. - metadata (Dict[str, Any]): Metadata associated with the vector. - """ - pass - - @abstractmethod - def query(self, text: str, num_results: int) -> Dict[str, Any]: - """ - Query the database for vectors similar to the given vector. - - Args: - text (Dict[str, Any]): The vector to compare against. - num_results (int): The number of similar vectors to return. - - Returns: - Dict[str, Any]: The most similar vectors and their associated metadata. - """ - pass - - @abstractmethod - def delete(self, vector_id: str) -> None: - """ - Delete a vector from the database. - - Args: - vector_id (str): The ID of the vector to delete. - """ - pass - - @abstractmethod - def update( - self, - vector_id: str, - vector: Dict[str, Any], - metadata: Dict[str, Any], - ) -> None: - """ - Update a vector in the database. - - Args: - vector_id (str): The ID of the vector to update. - vector (Dict[str, Any]): The new vector. - metadata (Dict[str, Any]): The new metadata. - """ - pass diff --git a/swarms/memory/pinecone.py b/swarms/memory/pinecone.py index 164cb334..bf073d3e 100644 --- a/swarms/memory/pinecone.py +++ b/swarms/memory/pinecone.py @@ -1,12 +1,12 @@ from typing import Optional -from swarms.memory.base_vectordb import VectorDatabase +from swarms.memory.base_vectordb import AbstractDatabase import pinecone from attr import define, field from swarms.utils.hash import str_to_hash @define -class PineconeDB(VectorDatabase): +class PineconeDB(AbstractDatabase): """ PineconeDB is a vector storage driver that uses Pinecone as the underlying storage engine. diff --git a/swarms/memory/sqlite.py b/swarms/memory/sqlite.py index eed4ee2c..2c4f2740 100644 --- a/swarms/memory/sqlite.py +++ b/swarms/memory/sqlite.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Any, Optional -from swarms.memory.base_vectordb import VectorDatabase +from swarms.memory.base_vectordb import AbstractDatabase try: import sqlite3 @@ -9,7 +9,7 @@ except ImportError: ) -class SQLiteDB(VectorDatabase): +class SQLiteDB(AbstractDatabase): """ A reusable class for SQLite database operations with methods for adding, deleting, updating, and querying data. diff --git a/swarms/memory/weaviate_db.py b/swarms/memory/weaviate_db.py index 0c0b09a2..fec1199e 100644 --- a/swarms/memory/weaviate_db.py +++ b/swarms/memory/weaviate_db.py @@ -4,7 +4,7 @@ Weaviate API Client from typing import Any, Dict, List, Optional -from swarms.memory.base_vectordb import VectorDatabase +from swarms.memory.base_vectordb import AbstractDatabase try: import weaviate @@ -12,7 +12,7 @@ except ImportError: print("pip install weaviate-client") -class WeaviateDB(VectorDatabase): +class WeaviateDB(AbstractDatabase): """ Weaviate API Client diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 3b298e3a..bd245ba7 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from termcolor import colored -from swarms.memory.base_vectordb import VectorDatabase +from swarms.memory.base_vectordb import AbstractDatabase from swarms.prompts.agent_system_prompts import ( AGENT_SYSTEM_PROMPT_3, ) @@ -83,7 +83,7 @@ class Agent: pdf_path (str): The path to the pdf list_of_pdf (str): The list of pdf tokenizer (Any): The tokenizer - memory (VectorDatabase): The memory + memory (AbstractDatabase): The memory preset_stopping_token (bool): Enable preset stopping token traceback (Any): The traceback traceback_handlers (Any): The traceback handlers @@ -168,7 +168,7 @@ class Agent: pdf_path: Optional[str] = None, list_of_pdf: Optional[str] = None, tokenizer: Optional[Any] = None, - long_term_memory: Optional[VectorDatabase] = None, + long_term_memory: Optional[AbstractDatabase] = None, preset_stopping_token: Optional[bool] = False, traceback: Any = None, traceback_handlers: Any = None, @@ -657,7 +657,7 @@ class Agent: """ return agent_history_prompt - def long_term_memory_prompt(self, query: str, prompt: str): + def long_term_memory_prompt(self, query: str): """ Generate the agent long term memory prompt @@ -671,7 +671,7 @@ class Agent: ltr = self.long_term_memory.query(query) context = f""" - {prompt} + {query} ####### Long Term Memory ################ {ltr} """ diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 441ff3d9..9a2224a4 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -60,6 +60,7 @@ class Conversation(BaseStructure): def __init__( self, + system_prompt: str, time_enabled: bool = False, database: AbstractDatabase = None, autosave: bool = False, @@ -68,11 +69,16 @@ class Conversation(BaseStructure): **kwargs, ): super().__init__() + self.system_prompt = system_prompt self.time_enabled = time_enabled self.database = database self.autosave = autosave self.save_filepath = save_filepath self.conversation_history = [] + + # If system prompt is not None, add it to the conversation history + if self.system_prompt: + self.add("system", self.system_prompt) def add(self, role: str, content: str, *args, **kwargs): """Add a message to the conversation history diff --git a/swarms/structs/multi_agent_rag.py b/swarms/structs/multi_agent_rag.py new file mode 100644 index 00000000..7b51332e --- /dev/null +++ b/swarms/structs/multi_agent_rag.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass +from typing import List, Optional + +from swarms.memory.base_vectordatabase import AbstractVectorDatabase +from swarms.structs.agent import Agent + + +@dataclass +class MultiAgentRag: + """ + Represents a multi-agent RAG (Relational Agent Graph) structure. + + Attributes: + agents (List[Agent]): List of agents in the multi-agent RAG. + db (AbstractVectorDatabase): Database used for querying. + verbose (bool): Flag indicating whether to print verbose output. + """ + agents: List[Agent] + db: AbstractVectorDatabase + verbose: bool = False + + + def query_database(self, query: str): + """ + Queries the database using the given query string. + + Args: + query (str): The query string. + + Returns: + List: The list of results from the database. + """ + results = [] + for agent in self.agents: + agent_results = agent.long_term_memory_prompt(query) + results.extend(agent_results) + return results + + + def get_agent_by_id(self, agent_id) -> Optional[Agent]: + """ + Retrieves an agent from the multi-agent RAG by its ID. + + Args: + agent_id: The ID of the agent to retrieve. + + Returns: + Agent or None: The agent with the specified ID, or None if not found. + """ + for agent in self.agents: + if agent.agent_id == agent_id: + return agent + return None + + def add_message( + self, + sender: Agent, + message: str, + *args, + **kwargs + ): + """ + Adds a message to the database. + + Args: + sender (Agent): The agent sending the message. + message (str): The message to add. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + int: The ID of the added message. + """ + doc = f"{sender.ai_name}: {message}" + + return self.db.add(doc) + + def query( + self, + message: str, + *args, + **kwargs + ): + """ + Queries the database using the given message. + + Args: + message (str): The message to query. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List: The list of results from the database. + """ + return self.db.query(message) + +