You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
4.4 KiB
141 lines
4.4 KiB
6 months ago
|
import uuid
|
||
|
from typing import Any, List, Optional
|
||
|
|
||
|
from sqlalchemy import JSON, Column, String, create_engine
|
||
|
from sqlalchemy.dialects.postgresql import UUID
|
||
|
from sqlalchemy.ext.declarative import declarative_base
|
||
|
from sqlalchemy.orm import Session
|
||
|
from swarms.memory.base_vectordb import BaseVectorDatabase
|
||
|
|
||
|
|
||
|
class PostgresDB(BaseVectorDatabase):
|
||
|
"""
|
||
|
A class representing a Postgres database.
|
||
|
|
||
|
Args:
|
||
|
connection_string (str): The connection string for the Postgres database.
|
||
|
table_name (str): The name of the table in the database.
|
||
|
|
||
|
Attributes:
|
||
|
engine: The SQLAlchemy engine for connecting to the database.
|
||
|
table_name (str): The name of the table in the database.
|
||
|
VectorModel: The SQLAlchemy model representing the vector table.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, connection_string: str, table_name: str, *args, **kwargs
|
||
|
):
|
||
|
"""
|
||
|
Initializes a new instance of the PostgresDB class.
|
||
|
|
||
|
Args:
|
||
|
connection_string (str): The connection string for the Postgres database.
|
||
|
table_name (str): The name of the table in the database.
|
||
|
|
||
|
"""
|
||
|
self.engine = create_engine(connection_string, *args, **kwargs)
|
||
|
self.table_name = table_name
|
||
|
self.VectorModel = self._create_vector_model()
|
||
|
|
||
|
def _create_vector_model(self):
|
||
|
"""
|
||
|
Creates the SQLAlchemy model for the vector table.
|
||
|
|
||
|
Returns:
|
||
|
The SQLAlchemy model representing the vector table.
|
||
|
|
||
|
"""
|
||
|
Base = declarative_base()
|
||
|
|
||
|
class VectorModel(Base):
|
||
|
__tablename__ = self.table_name
|
||
|
|
||
|
id = Column(
|
||
|
UUID(as_uuid=True),
|
||
|
primary_key=True,
|
||
|
default=uuid.uuid4,
|
||
|
unique=True,
|
||
|
nullable=False,
|
||
|
)
|
||
|
vector = Column(
|
||
|
String
|
||
|
) # Assuming vector is stored as a string
|
||
|
namespace = Column(String)
|
||
|
meta = Column(JSON)
|
||
|
|
||
|
return VectorModel
|
||
|
|
||
|
def add(
|
||
|
self,
|
||
|
vector: str,
|
||
|
vector_id: Optional[str] = None,
|
||
|
namespace: Optional[str] = None,
|
||
|
meta: Optional[dict] = None,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Adds or updates a vector in the database.
|
||
|
|
||
|
Args:
|
||
|
vector (str): The vector to be added or updated.
|
||
|
vector_id (str, optional): The ID of the vector. If not provided, a new ID will be generated.
|
||
|
namespace (str, optional): The namespace of the vector.
|
||
|
meta (dict, optional): Additional metadata associated with the vector.
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
with Session(self.engine) as session:
|
||
|
obj = self.VectorModel(
|
||
|
id=vector_id,
|
||
|
vector=vector,
|
||
|
namespace=namespace,
|
||
|
meta=meta,
|
||
|
)
|
||
|
session.merge(obj)
|
||
|
session.commit()
|
||
|
except Exception as e:
|
||
|
print(f"Error adding or updating vector: {e}")
|
||
|
|
||
|
def query(
|
||
|
self, query: Any, namespace: Optional[str] = None
|
||
|
) -> List[Any]:
|
||
|
"""
|
||
|
Queries vectors from the database based on the given query and namespace.
|
||
|
|
||
|
Args:
|
||
|
query (Any): The query or condition to filter the vectors.
|
||
|
namespace (str, optional): The namespace of the vectors to be queried.
|
||
|
|
||
|
Returns:
|
||
|
List[Any]: A list of vectors that match the query and namespace.
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
with Session(self.engine) as session:
|
||
|
q = session.query(self.VectorModel)
|
||
|
if namespace:
|
||
|
q = q.filter_by(namespace=namespace)
|
||
|
# Assuming 'query' is a condition or filter
|
||
|
q = q.filter(query)
|
||
|
return q.all()
|
||
|
except Exception as e:
|
||
|
print(f"Error querying vectors: {e}")
|
||
|
return []
|
||
|
|
||
|
def delete_vector(self, vector_id):
|
||
|
"""
|
||
|
Deletes a vector from the database based on the given vector ID.
|
||
|
|
||
|
Args:
|
||
|
vector_id: The ID of the vector to be deleted.
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
with Session(self.engine) as session:
|
||
|
obj = session.get(self.VectorModel, vector_id)
|
||
|
if obj:
|
||
|
session.delete(obj)
|
||
|
session.commit()
|
||
|
except Exception as e:
|
||
|
print(f"Error deleting vector: {e}")
|