You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/examples/memory/pg.py

141 lines
4.4 KiB

6 months ago
import uuid
from typing import Any, List, Optional
from sqlalchemy import JSON, Column, String, create_engine
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from swarms.memory.base_vectordb import BaseVectorDatabase
class PostgresDB(BaseVectorDatabase):
"""
A class representing a Postgres database.
Args:
connection_string (str): The connection string for the Postgres database.
table_name (str): The name of the table in the database.
Attributes:
engine: The SQLAlchemy engine for connecting to the database.
table_name (str): The name of the table in the database.
VectorModel: The SQLAlchemy model representing the vector table.
"""
def __init__(
self, connection_string: str, table_name: str, *args, **kwargs
):
"""
Initializes a new instance of the PostgresDB class.
Args:
connection_string (str): The connection string for the Postgres database.
table_name (str): The name of the table in the database.
"""
self.engine = create_engine(connection_string, *args, **kwargs)
self.table_name = table_name
self.VectorModel = self._create_vector_model()
def _create_vector_model(self):
"""
Creates the SQLAlchemy model for the vector table.
Returns:
The SQLAlchemy model representing the vector table.
"""
Base = declarative_base()
class VectorModel(Base):
__tablename__ = self.table_name
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
unique=True,
nullable=False,
)
vector = Column(
String
) # Assuming vector is stored as a string
namespace = Column(String)
meta = Column(JSON)
return VectorModel
def add(
self,
vector: str,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
) -> None:
"""
Adds or updates a vector in the database.
Args:
vector (str): The vector to be added or updated.
vector_id (str, optional): The ID of the vector. If not provided, a new ID will be generated.
namespace (str, optional): The namespace of the vector.
meta (dict, optional): Additional metadata associated with the vector.
"""
try:
with Session(self.engine) as session:
obj = self.VectorModel(
id=vector_id,
vector=vector,
namespace=namespace,
meta=meta,
)
session.merge(obj)
session.commit()
except Exception as e:
print(f"Error adding or updating vector: {e}")
def query(
self, query: Any, namespace: Optional[str] = None
) -> List[Any]:
"""
Queries vectors from the database based on the given query and namespace.
Args:
query (Any): The query or condition to filter the vectors.
namespace (str, optional): The namespace of the vectors to be queried.
Returns:
List[Any]: A list of vectors that match the query and namespace.
"""
try:
with Session(self.engine) as session:
q = session.query(self.VectorModel)
if namespace:
q = q.filter_by(namespace=namespace)
# Assuming 'query' is a condition or filter
q = q.filter(query)
return q.all()
except Exception as e:
print(f"Error querying vectors: {e}")
return []
def delete_vector(self, vector_id):
"""
Deletes a vector from the database based on the given vector ID.
Args:
vector_id: The ID of the vector to be deleted.
"""
try:
with Session(self.engine) as session:
obj = session.get(self.VectorModel, vector_id)
if obj:
session.delete(obj)
session.commit()
except Exception as e:
print(f"Error deleting vector: {e}")