[FEAT][Conversation]

pull/362/head^2
Kye 1 year ago
parent 8113d6ddbc
commit d5c0ca0128

@ -1,11 +1,11 @@
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.base_vectordb import AbstractDatabase
from swarms.memory.short_term_memory import ShortTermMemory
from swarms.memory.sqlite import SQLiteDB
from swarms.memory.weaviate_db import WeaviateDB
from swarms.memory.visual_memory import VisualShortTermMemory
__all__ = [
"VectorDatabase",
"AbstractDatabase",
"ShortTermMemory",
"SQLiteDB",
"WeaviateDB",

@ -156,4 +156,4 @@ class AbstractDatabase(ABC):
"""
pass
pass

@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
class AbstractVectorDatabase(ABC):
"""
Abstract base class for a database.
This class defines the interface for interacting with a database.
Subclasses must implement the abstract methods to provide the
specific implementation details for connecting to a database,
executing queries, and performing CRUD operations.
"""
@abstractmethod
def connect(self):
"""
Connect to the database.
This method establishes a connection to the database.
"""
pass
@abstractmethod
def close(self):
"""
Close the database connection.
This method closes the connection to the database.
"""
pass
@abstractmethod
def query(self, query: str):
"""
Execute a database query.
This method executes the given query on the database.
Parameters:
query (str): The query to be executed.
"""
pass
@abstractmethod
def fetch_all(self):
"""
Fetch all rows from the result set.
This method retrieves all rows from the result set of a query.
Returns:
list: A list of dictionaries representing the rows.
"""
pass
@abstractmethod
def fetch_one(self):
"""
Fetch one row from the result set.
This method retrieves one row from the result set of a query.
Returns:
dict: A dictionary representing the row.
"""
pass
@abstractmethod
def add(self, doc: str):
"""
Add a new record to the database.
This method adds a new record to the specified table in the database.
Parameters:
table (str): The name of the table.
data (dict): A dictionary representing the data to be added.
"""
pass
@abstractmethod
def get(self, query: str):
"""
Get a record from the database.
This method retrieves a record from the specified table in the database based on the given ID.
Parameters:
table (str): The name of the table.
id (int): The ID of the record to be retrieved.
Returns:
dict: A dictionary representing the retrieved record.
"""
pass
@abstractmethod
def update(self, doc):
"""
Update a record in the database.
This method updates a record in the specified table in the database based on the given ID.
Parameters:
table (str): The name of the table.
id (int): The ID of the record to be updated.
data (dict): A dictionary representing the updated data.
"""
pass
@abstractmethod
def delete(self, message):
"""
Delete a record from the database.
This method deletes a record from the specified table in the database based on the given ID.
Parameters:
table (str): The name of the table.
id (int): The ID of the record to be deleted.
"""
pass

@ -1,58 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
class VectorDatabase(ABC):
@abstractmethod
def add(
self, vector: Dict[str, Any], metadata: Dict[str, Any]
) -> None:
"""
add a vector into the database.
Args:
vector (Dict[str, Any]): The vector to add.
metadata (Dict[str, Any]): Metadata associated with the vector.
"""
pass
@abstractmethod
def query(self, text: str, num_results: int) -> Dict[str, Any]:
"""
Query the database for vectors similar to the given vector.
Args:
text (Dict[str, Any]): The vector to compare against.
num_results (int): The number of similar vectors to return.
Returns:
Dict[str, Any]: The most similar vectors and their associated metadata.
"""
pass
@abstractmethod
def delete(self, vector_id: str) -> None:
"""
Delete a vector from the database.
Args:
vector_id (str): The ID of the vector to delete.
"""
pass
@abstractmethod
def update(
self,
vector_id: str,
vector: Dict[str, Any],
metadata: Dict[str, Any],
) -> None:
"""
Update a vector in the database.
Args:
vector_id (str): The ID of the vector to update.
vector (Dict[str, Any]): The new vector.
metadata (Dict[str, Any]): The new metadata.
"""
pass

@ -1,12 +1,12 @@
from typing import Optional
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.base_vectordb import AbstractDatabase
import pinecone
from attr import define, field
from swarms.utils.hash import str_to_hash
@define
class PineconeDB(VectorDatabase):
class PineconeDB(AbstractDatabase):
"""
PineconeDB is a vector storage driver that uses Pinecone as the underlying storage engine.

@ -1,5 +1,5 @@
from typing import List, Tuple, Any, Optional
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.base_vectordb import AbstractDatabase
try:
import sqlite3
@ -9,7 +9,7 @@ except ImportError:
)
class SQLiteDB(VectorDatabase):
class SQLiteDB(AbstractDatabase):
"""
A reusable class for SQLite database operations with methods for adding,
deleting, updating, and querying data.

@ -4,7 +4,7 @@ Weaviate API Client
from typing import Any, Dict, List, Optional
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.base_vectordb import AbstractDatabase
try:
import weaviate
@ -12,7 +12,7 @@ except ImportError:
print("pip install weaviate-client")
class WeaviateDB(VectorDatabase):
class WeaviateDB(AbstractDatabase):
"""
Weaviate API Client

@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from termcolor import colored
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.base_vectordb import AbstractDatabase
from swarms.prompts.agent_system_prompts import (
AGENT_SYSTEM_PROMPT_3,
)
@ -83,7 +83,7 @@ class Agent:
pdf_path (str): The path to the pdf
list_of_pdf (str): The list of pdf
tokenizer (Any): The tokenizer
memory (VectorDatabase): The memory
memory (AbstractDatabase): The memory
preset_stopping_token (bool): Enable preset stopping token
traceback (Any): The traceback
traceback_handlers (Any): The traceback handlers
@ -168,7 +168,7 @@ class Agent:
pdf_path: Optional[str] = None,
list_of_pdf: Optional[str] = None,
tokenizer: Optional[Any] = None,
long_term_memory: Optional[VectorDatabase] = None,
long_term_memory: Optional[AbstractDatabase] = None,
preset_stopping_token: Optional[bool] = False,
traceback: Any = None,
traceback_handlers: Any = None,
@ -657,7 +657,7 @@ class Agent:
"""
return agent_history_prompt
def long_term_memory_prompt(self, query: str, prompt: str):
def long_term_memory_prompt(self, query: str):
"""
Generate the agent long term memory prompt
@ -671,7 +671,7 @@ class Agent:
ltr = self.long_term_memory.query(query)
context = f"""
{prompt}
{query}
####### Long Term Memory ################
{ltr}
"""

@ -60,6 +60,7 @@ class Conversation(BaseStructure):
def __init__(
self,
system_prompt: str,
time_enabled: bool = False,
database: AbstractDatabase = None,
autosave: bool = False,
@ -68,11 +69,16 @@ class Conversation(BaseStructure):
**kwargs,
):
super().__init__()
self.system_prompt = system_prompt
self.time_enabled = time_enabled
self.database = database
self.autosave = autosave
self.save_filepath = save_filepath
self.conversation_history = []
# If system prompt is not None, add it to the conversation history
if self.system_prompt:
self.add("system", self.system_prompt)
def add(self, role: str, content: str, *args, **kwargs):
"""Add a message to the conversation history

@ -0,0 +1,97 @@
from dataclasses import dataclass
from typing import List, Optional
from swarms.memory.base_vectordatabase import AbstractVectorDatabase
from swarms.structs.agent import Agent
@dataclass
class MultiAgentRag:
"""
Represents a multi-agent RAG (Relational Agent Graph) structure.
Attributes:
agents (List[Agent]): List of agents in the multi-agent RAG.
db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output.
"""
agents: List[Agent]
db: AbstractVectorDatabase
verbose: bool = False
def query_database(self, query: str):
"""
Queries the database using the given query string.
Args:
query (str): The query string.
Returns:
List: The list of results from the database.
"""
results = []
for agent in self.agents:
agent_results = agent.long_term_memory_prompt(query)
results.extend(agent_results)
return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
"""
Retrieves an agent from the multi-agent RAG by its ID.
Args:
agent_id: The ID of the agent to retrieve.
Returns:
Agent or None: The agent with the specified ID, or None if not found.
"""
for agent in self.agents:
if agent.agent_id == agent_id:
return agent
return None
def add_message(
self,
sender: Agent,
message: str,
*args,
**kwargs
):
"""
Adds a message to the database.
Args:
sender (Agent): The agent sending the message.
message (str): The message to add.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
int: The ID of the added message.
"""
doc = f"{sender.ai_name}: {message}"
return self.db.add(doc)
def query(
self,
message: str,
*args,
**kwargs
):
"""
Queries the database using the given message.
Args:
message (str): The message to query.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
List: The list of results from the database.
"""
return self.db.query(message)
Loading…
Cancel
Save