parent
							
								
									3a20b51f4c
								
							
						
					
					
						commit
						c3ccc69725
					
				@ -0,0 +1,119 @@
 | 
				
			||||
from abc import ABC, abstractmethod
 | 
				
			||||
from concurrent import futures
 | 
				
			||||
from dataclasses import dataclass
 | 
				
			||||
from typing import Optional, Any
 | 
				
			||||
from attr import define, field, Factory
 | 
				
			||||
from swarms.utils.futures import execute_futures_dict
 | 
				
			||||
from griptape.artifacts import TextArtifact
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
@define
 | 
				
			||||
class BaseVectorStore(ABC):
 | 
				
			||||
    """
 | 
				
			||||
    """
 | 
				
			||||
    DEFAULT_QUERY_COUNT = 5
 | 
				
			||||
 | 
				
			||||
    @dataclass
 | 
				
			||||
    class QueryResult:
 | 
				
			||||
        id: str
 | 
				
			||||
        vector: list[float]
 | 
				
			||||
        score: float
 | 
				
			||||
        meta: Optional[dict] = None
 | 
				
			||||
        namespace: Optional[str] = None
 | 
				
			||||
 | 
				
			||||
    @dataclass
 | 
				
			||||
    class Entry:
 | 
				
			||||
        id: str
 | 
				
			||||
        vector: list[float]
 | 
				
			||||
        meta: Optional[dict] = None
 | 
				
			||||
        namespace: Optional[str] = None
 | 
				
			||||
 | 
				
			||||
    embedding_driver: Any 
 | 
				
			||||
    futures_executor: futures.Executor = field(
 | 
				
			||||
        default=Factory(lambda: futures.ThreadPoolExecutor()),
 | 
				
			||||
        kw_only=True
 | 
				
			||||
    )
 | 
				
			||||
 | 
				
			||||
    def upsert_text_artifacts(
 | 
				
			||||
            self,
 | 
				
			||||
            artifacts: dict[str, list[TextArtifact]],
 | 
				
			||||
            meta: Optional[dict] = None,
 | 
				
			||||
            **kwargs
 | 
				
			||||
    ) -> None:
 | 
				
			||||
        execute_futures_dict({
 | 
				
			||||
            namespace:
 | 
				
			||||
                self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
 | 
				
			||||
            for namespace, artifact_list in artifacts.items() for a in artifact_list
 | 
				
			||||
        })
 | 
				
			||||
 | 
				
			||||
    def upsert_text_artifact(
 | 
				
			||||
            self,
 | 
				
			||||
            artifact: TextArtifact,
 | 
				
			||||
            namespace: Optional[str] = None,
 | 
				
			||||
            meta: Optional[dict] = None,
 | 
				
			||||
            **kwargs
 | 
				
			||||
    ) -> str:
 | 
				
			||||
        if not meta:
 | 
				
			||||
            meta = {}
 | 
				
			||||
 | 
				
			||||
        meta["artifact"] = artifact.to_json()
 | 
				
			||||
 | 
				
			||||
        if artifact.embedding:
 | 
				
			||||
            vector = artifact.embedding
 | 
				
			||||
        else:
 | 
				
			||||
            vector = artifact.generate_embedding(self.embedding_driver)
 | 
				
			||||
 | 
				
			||||
        return self.upsert_vector(
 | 
				
			||||
            vector,
 | 
				
			||||
            vector_id=artifact.id,
 | 
				
			||||
            namespace=namespace,
 | 
				
			||||
            meta=meta,
 | 
				
			||||
            **kwargs
 | 
				
			||||
        )
 | 
				
			||||
 | 
				
			||||
    def upsert_text(
 | 
				
			||||
            self,
 | 
				
			||||
            string: str,
 | 
				
			||||
            vector_id: Optional[str] = None,
 | 
				
			||||
            namespace: Optional[str] = None,
 | 
				
			||||
            meta: Optional[dict] = None,
 | 
				
			||||
            **kwargs
 | 
				
			||||
    ) -> str:
 | 
				
			||||
        return self.upsert_vector(
 | 
				
			||||
            self.embedding_driver.embed_string(string),
 | 
				
			||||
            vector_id=vector_id,
 | 
				
			||||
            namespace=namespace,
 | 
				
			||||
            meta=meta if meta else {},
 | 
				
			||||
            **kwargs
 | 
				
			||||
        )
 | 
				
			||||
 | 
				
			||||
    @abstractmethod
 | 
				
			||||
    def upsert_vector(
 | 
				
			||||
            self,
 | 
				
			||||
            vector: list[float],
 | 
				
			||||
            vector_id: Optional[str] = None,
 | 
				
			||||
            namespace: Optional[str] = None,
 | 
				
			||||
            meta: Optional[dict] = None,
 | 
				
			||||
            **kwargs
 | 
				
			||||
    ) -> str:
 | 
				
			||||
        ...
 | 
				
			||||
 | 
				
			||||
    @abstractmethod
 | 
				
			||||
    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry:
 | 
				
			||||
        ...
 | 
				
			||||
 | 
				
			||||
    @abstractmethod
 | 
				
			||||
    def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
 | 
				
			||||
        ...
 | 
				
			||||
 | 
				
			||||
    @abstractmethod
 | 
				
			||||
    def query(
 | 
				
			||||
            self,
 | 
				
			||||
            query: str,
 | 
				
			||||
            count: Optional[int] = None,
 | 
				
			||||
            namespace: Optional[str] = None,
 | 
				
			||||
            include_vectors: bool = False,
 | 
				
			||||
            **kwargs
 | 
				
			||||
    ) -> list[QueryResult]:
 | 
				
			||||
        ...
 | 
				
			||||
@ -0,0 +1,192 @@
 | 
				
			||||
import uuid
 | 
				
			||||
from typing import Optional
 | 
				
			||||
from attr import define, field, Factory
 | 
				
			||||
from dataclasses import dataclass
 | 
				
			||||
from swarms.memory.vector_stores.base import BaseVectorStoreDriver
 | 
				
			||||
from sqlalchemy.engine import Engine
 | 
				
			||||
from sqlalchemy import create_engine, Column, String, JSON
 | 
				
			||||
from sqlalchemy.ext.declarative import declarative_base
 | 
				
			||||
from sqlalchemy.dialects.postgresql import UUID
 | 
				
			||||
from sqlalchemy.orm import Session
 | 
				
			||||
from pgvector.sqlalchemy import Vector
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
@define
 | 
				
			||||
class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
 | 
				
			||||
    """A vector store driver to Postgres using the PGVector extension.
 | 
				
			||||
 | 
				
			||||
    Attributes:
 | 
				
			||||
        connection_string: An optional string describing the target Postgres database instance.
 | 
				
			||||
        create_engine_params: Additional configuration params passed when creating the database connection.
 | 
				
			||||
        engine: An optional sqlalchemy Postgres engine to use.
 | 
				
			||||
        table_name: Optionally specify the name of the table to used to store vectors.
 | 
				
			||||
    """
 | 
				
			||||
 | 
				
			||||
    connection_string: Optional[str] = field(default=None, kw_only=True)
 | 
				
			||||
    create_engine_params: dict = field(factory=dict, kw_only=True)
 | 
				
			||||
    engine: Optional[Engine] = field(default=None, kw_only=True)
 | 
				
			||||
    table_name: str = field(kw_only=True)
 | 
				
			||||
    _model: any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))
 | 
				
			||||
 | 
				
			||||
    @connection_string.validator
 | 
				
			||||
    def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
 | 
				
			||||
        # If an engine is provided, the connection string is not used.
 | 
				
			||||
        if self.engine is not None:
 | 
				
			||||
            return
 | 
				
			||||
 | 
				
			||||
        # If an engine is not provided, a connection string is required.
 | 
				
			||||
        if connection_string is None:
 | 
				
			||||
            raise ValueError("An engine or connection string is required")
 | 
				
			||||
 | 
				
			||||
        if not connection_string.startswith("postgresql://"):
 | 
				
			||||
            raise ValueError("The connection string must describe a Postgres database connection")
 | 
				
			||||
 | 
				
			||||
    @engine.validator
 | 
				
			||||
    def validate_engine(self, _, engine: Optional[Engine]) -> None:
 | 
				
			||||
        # If a connection string is provided, an engine does not need to be provided.
 | 
				
			||||
        if self.connection_string is not None:
 | 
				
			||||
            return
 | 
				
			||||
 | 
				
			||||
        # If a connection string is not provided, an engine is required.
 | 
				
			||||
        if engine is None:
 | 
				
			||||
            raise ValueError("An engine or connection string is required")
 | 
				
			||||
 | 
				
			||||
    def __attrs_post_init__(self) -> None:
 | 
				
			||||
        """If a an engine is provided, it will be used to connect to the database.
 | 
				
			||||
        If not, a connection string is used to create a new database connection here.
 | 
				
			||||
        """
 | 
				
			||||
        if self.engine is None:
 | 
				
			||||
            self.engine = create_engine(self.connection_string, **self.create_engine_params)
 | 
				
			||||
 | 
				
			||||
    def setup(
 | 
				
			||||
        self,
 | 
				
			||||
        create_schema: bool = True,
 | 
				
			||||
        install_uuid_extension: bool = True,
 | 
				
			||||
        install_vector_extension: bool = True,
 | 
				
			||||
    ) -> None:
 | 
				
			||||
        """Provides a mechanism to initialize the database schema and extensions."""
 | 
				
			||||
        if install_uuid_extension:
 | 
				
			||||
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
 | 
				
			||||
 | 
				
			||||
        if install_vector_extension:
 | 
				
			||||
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')
 | 
				
			||||
 | 
				
			||||
        if create_schema:
 | 
				
			||||
            self._model.metadata.create_all(self.engine)
 | 
				
			||||
 | 
				
			||||
    def upsert_vector(
 | 
				
			||||
        self,
 | 
				
			||||
        vector: list[float],
 | 
				
			||||
        vector_id: Optional[str] = None,
 | 
				
			||||
        namespace: Optional[str] = None,
 | 
				
			||||
        meta: Optional[dict] = None,
 | 
				
			||||
        **kwargs
 | 
				
			||||
    ) -> str:
 | 
				
			||||
        """Inserts or updates a vector in the collection."""
 | 
				
			||||
        with Session(self.engine) as session:
 | 
				
			||||
            obj = self._model(
 | 
				
			||||
                id=vector_id,
 | 
				
			||||
                vector=vector,
 | 
				
			||||
                namespace=namespace,
 | 
				
			||||
                meta=meta,
 | 
				
			||||
            )
 | 
				
			||||
 | 
				
			||||
            obj = session.merge(obj)
 | 
				
			||||
            session.commit()
 | 
				
			||||
 | 
				
			||||
            return str(obj.id)
 | 
				
			||||
 | 
				
			||||
    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
 | 
				
			||||
        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
 | 
				
			||||
        with Session(self.engine) as session:
 | 
				
			||||
            result = session.get(self._model, vector_id)
 | 
				
			||||
 | 
				
			||||
            return BaseVectorStoreDriver.Entry(
 | 
				
			||||
                id=result.id,
 | 
				
			||||
                vector=result.vector,
 | 
				
			||||
                namespace=result.namespace,
 | 
				
			||||
                meta=result.meta,
 | 
				
			||||
            )
 | 
				
			||||
 | 
				
			||||
    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
 | 
				
			||||
        """Retrieves all vector entries from the collection, optionally filtering to only
 | 
				
			||||
        those that match the provided namespace.
 | 
				
			||||
        """
 | 
				
			||||
        with Session(self.engine) as session:
 | 
				
			||||
            query = session.query(self._model)
 | 
				
			||||
            if namespace:
 | 
				
			||||
                query = query.filter_by(namespace=namespace)
 | 
				
			||||
 | 
				
			||||
            results = query.all()
 | 
				
			||||
 | 
				
			||||
            return [
 | 
				
			||||
                BaseVectorStoreDriver.Entry(
 | 
				
			||||
                    id=str(result.id),
 | 
				
			||||
                    vector=result.vector,
 | 
				
			||||
                    namespace=result.namespace,
 | 
				
			||||
                    meta=result.meta,
 | 
				
			||||
                )
 | 
				
			||||
                for result in results
 | 
				
			||||
            ]
 | 
				
			||||
 | 
				
			||||
    def query(
 | 
				
			||||
        self,
 | 
				
			||||
        query: str,
 | 
				
			||||
        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
 | 
				
			||||
        namespace: Optional[str] = None,
 | 
				
			||||
        include_vectors: bool = False,
 | 
				
			||||
        distance_metric: str = "cosine_distance",
 | 
				
			||||
        **kwargs
 | 
				
			||||
    ) -> list[BaseVectorStoreDriver.QueryResult]:
 | 
				
			||||
        """Performs a search on the collection to find vectors similar to the provided input vector,
 | 
				
			||||
        optionally filtering to only those that match the provided namespace.
 | 
				
			||||
        """
 | 
				
			||||
        distance_metrics = {
 | 
				
			||||
            "cosine_distance": self._model.vector.cosine_distance,
 | 
				
			||||
            "l2_distance": self._model.vector.l2_distance,
 | 
				
			||||
            "inner_product": self._model.vector.max_inner_product,
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
        if distance_metric not in distance_metrics:
 | 
				
			||||
            raise ValueError("Invalid distance metric provided")
 | 
				
			||||
 | 
				
			||||
        op = distance_metrics[distance_metric]
 | 
				
			||||
 | 
				
			||||
        with Session(self.engine) as session:
 | 
				
			||||
            vector = self.embedding_driver.embed_string(query)
 | 
				
			||||
 | 
				
			||||
            # The query should return both the vector and the distance metric score.
 | 
				
			||||
            query = session.query(
 | 
				
			||||
                self._model,
 | 
				
			||||
                op(vector).label("score"),
 | 
				
			||||
            ).order_by(op(vector))
 | 
				
			||||
 | 
				
			||||
            if namespace:
 | 
				
			||||
                query = query.filter_by(namespace=namespace)
 | 
				
			||||
 | 
				
			||||
            results = query.limit(count).all()
 | 
				
			||||
 | 
				
			||||
            return [
 | 
				
			||||
                BaseVectorStoreDriver.QueryResult(
 | 
				
			||||
                    id=str(result[0].id),
 | 
				
			||||
                    vector=result[0].vector if include_vectors else None,
 | 
				
			||||
                    score=result[1],
 | 
				
			||||
                    meta=result[0].meta,
 | 
				
			||||
                    namespace=result[0].namespace,
 | 
				
			||||
                )
 | 
				
			||||
                for result in results
 | 
				
			||||
            ]
 | 
				
			||||
 | 
				
			||||
    def default_vector_model(self) -> any:
 | 
				
			||||
        Base = declarative_base()
 | 
				
			||||
 | 
				
			||||
        @dataclass
 | 
				
			||||
        class VectorModel(Base):
 | 
				
			||||
            __tablename__ = self.table_name
 | 
				
			||||
 | 
				
			||||
            id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
 | 
				
			||||
            vector = Column(Vector())
 | 
				
			||||
            namespace = Column(String)
 | 
				
			||||
            meta = Column(JSON)
 | 
				
			||||
 | 
				
			||||
        return VectorModel
 | 
				
			||||
@ -0,0 +1,10 @@
 | 
				
			||||
from concurrent import futures
 | 
				
			||||
from typing import TypeVar
 | 
				
			||||
 | 
				
			||||
T = TypeVar("T")
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
 | 
				
			||||
    futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED)
 | 
				
			||||
 | 
				
			||||
    return {key: future.result() for key, future in fs_dict.items()}
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue