parent
							
								
									3a20b51f4c
								
							
						
					
					
						commit
						c3ccc69725
					
				@ -0,0 +1,119 @@
 | 
				
			|||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					from concurrent import futures
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from typing import Optional, Any
 | 
				
			||||||
 | 
					from attr import define, field, Factory
 | 
				
			||||||
 | 
					from swarms.utils.futures import execute_futures_dict
 | 
				
			||||||
 | 
					from griptape.artifacts import TextArtifact
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@define
 | 
				
			||||||
 | 
					class BaseVectorStore(ABC):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    DEFAULT_QUERY_COUNT = 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class QueryResult:
 | 
				
			||||||
 | 
					        id: str
 | 
				
			||||||
 | 
					        vector: list[float]
 | 
				
			||||||
 | 
					        score: float
 | 
				
			||||||
 | 
					        meta: Optional[dict] = None
 | 
				
			||||||
 | 
					        namespace: Optional[str] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class Entry:
 | 
				
			||||||
 | 
					        id: str
 | 
				
			||||||
 | 
					        vector: list[float]
 | 
				
			||||||
 | 
					        meta: Optional[dict] = None
 | 
				
			||||||
 | 
					        namespace: Optional[str] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    embedding_driver: Any 
 | 
				
			||||||
 | 
					    futures_executor: futures.Executor = field(
 | 
				
			||||||
 | 
					        default=Factory(lambda: futures.ThreadPoolExecutor()),
 | 
				
			||||||
 | 
					        kw_only=True
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def upsert_text_artifacts(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            artifacts: dict[str, list[TextArtifact]],
 | 
				
			||||||
 | 
					            meta: Optional[dict] = None,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        execute_futures_dict({
 | 
				
			||||||
 | 
					            namespace:
 | 
				
			||||||
 | 
					                self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
 | 
				
			||||||
 | 
					            for namespace, artifact_list in artifacts.items() for a in artifact_list
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def upsert_text_artifact(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            artifact: TextArtifact,
 | 
				
			||||||
 | 
					            namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					            meta: Optional[dict] = None,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        if not meta:
 | 
				
			||||||
 | 
					            meta = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        meta["artifact"] = artifact.to_json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if artifact.embedding:
 | 
				
			||||||
 | 
					            vector = artifact.embedding
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            vector = artifact.generate_embedding(self.embedding_driver)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.upsert_vector(
 | 
				
			||||||
 | 
					            vector,
 | 
				
			||||||
 | 
					            vector_id=artifact.id,
 | 
				
			||||||
 | 
					            namespace=namespace,
 | 
				
			||||||
 | 
					            meta=meta,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def upsert_text(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            string: str,
 | 
				
			||||||
 | 
					            vector_id: Optional[str] = None,
 | 
				
			||||||
 | 
					            namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					            meta: Optional[dict] = None,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        return self.upsert_vector(
 | 
				
			||||||
 | 
					            self.embedding_driver.embed_string(string),
 | 
				
			||||||
 | 
					            vector_id=vector_id,
 | 
				
			||||||
 | 
					            namespace=namespace,
 | 
				
			||||||
 | 
					            meta=meta if meta else {},
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def upsert_vector(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            vector: list[float],
 | 
				
			||||||
 | 
					            vector_id: Optional[str] = None,
 | 
				
			||||||
 | 
					            namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					            meta: Optional[dict] = None,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def query(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            query: str,
 | 
				
			||||||
 | 
					            count: Optional[int] = None,
 | 
				
			||||||
 | 
					            namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					            include_vectors: bool = False,
 | 
				
			||||||
 | 
					            **kwargs
 | 
				
			||||||
 | 
					    ) -> list[QueryResult]:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
@ -0,0 +1,192 @@
 | 
				
			|||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from attr import define, field, Factory
 | 
				
			||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from swarms.memory.vector_stores.base import BaseVectorStoreDriver
 | 
				
			||||||
 | 
					from sqlalchemy.engine import Engine
 | 
				
			||||||
 | 
					from sqlalchemy import create_engine, Column, String, JSON
 | 
				
			||||||
 | 
					from sqlalchemy.ext.declarative import declarative_base
 | 
				
			||||||
 | 
					from sqlalchemy.dialects.postgresql import UUID
 | 
				
			||||||
 | 
					from sqlalchemy.orm import Session
 | 
				
			||||||
 | 
					from pgvector.sqlalchemy import Vector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@define
 | 
				
			||||||
 | 
					class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
 | 
				
			||||||
 | 
					    """A vector store driver to Postgres using the PGVector extension.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Attributes:
 | 
				
			||||||
 | 
					        connection_string: An optional string describing the target Postgres database instance.
 | 
				
			||||||
 | 
					        create_engine_params: Additional configuration params passed when creating the database connection.
 | 
				
			||||||
 | 
					        engine: An optional sqlalchemy Postgres engine to use.
 | 
				
			||||||
 | 
					        table_name: Optionally specify the name of the table to used to store vectors.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    connection_string: Optional[str] = field(default=None, kw_only=True)
 | 
				
			||||||
 | 
					    create_engine_params: dict = field(factory=dict, kw_only=True)
 | 
				
			||||||
 | 
					    engine: Optional[Engine] = field(default=None, kw_only=True)
 | 
				
			||||||
 | 
					    table_name: str = field(kw_only=True)
 | 
				
			||||||
 | 
					    _model: any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @connection_string.validator
 | 
				
			||||||
 | 
					    def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
 | 
				
			||||||
 | 
					        # If an engine is provided, the connection string is not used.
 | 
				
			||||||
 | 
					        if self.engine is not None:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # If an engine is not provided, a connection string is required.
 | 
				
			||||||
 | 
					        if connection_string is None:
 | 
				
			||||||
 | 
					            raise ValueError("An engine or connection string is required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not connection_string.startswith("postgresql://"):
 | 
				
			||||||
 | 
					            raise ValueError("The connection string must describe a Postgres database connection")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @engine.validator
 | 
				
			||||||
 | 
					    def validate_engine(self, _, engine: Optional[Engine]) -> None:
 | 
				
			||||||
 | 
					        # If a connection string is provided, an engine does not need to be provided.
 | 
				
			||||||
 | 
					        if self.connection_string is not None:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # If a connection string is not provided, an engine is required.
 | 
				
			||||||
 | 
					        if engine is None:
 | 
				
			||||||
 | 
					            raise ValueError("An engine or connection string is required")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __attrs_post_init__(self) -> None:
 | 
				
			||||||
 | 
					        """If a an engine is provided, it will be used to connect to the database.
 | 
				
			||||||
 | 
					        If not, a connection string is used to create a new database connection here.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.engine is None:
 | 
				
			||||||
 | 
					            self.engine = create_engine(self.connection_string, **self.create_engine_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setup(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        create_schema: bool = True,
 | 
				
			||||||
 | 
					        install_uuid_extension: bool = True,
 | 
				
			||||||
 | 
					        install_vector_extension: bool = True,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Provides a mechanism to initialize the database schema and extensions."""
 | 
				
			||||||
 | 
					        if install_uuid_extension:
 | 
				
			||||||
 | 
					            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if install_vector_extension:
 | 
				
			||||||
 | 
					            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if create_schema:
 | 
				
			||||||
 | 
					            self._model.metadata.create_all(self.engine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def upsert_vector(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vector: list[float],
 | 
				
			||||||
 | 
					        vector_id: Optional[str] = None,
 | 
				
			||||||
 | 
					        namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					        meta: Optional[dict] = None,
 | 
				
			||||||
 | 
					        **kwargs
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        """Inserts or updates a vector in the collection."""
 | 
				
			||||||
 | 
					        with Session(self.engine) as session:
 | 
				
			||||||
 | 
					            obj = self._model(
 | 
				
			||||||
 | 
					                id=vector_id,
 | 
				
			||||||
 | 
					                vector=vector,
 | 
				
			||||||
 | 
					                namespace=namespace,
 | 
				
			||||||
 | 
					                meta=meta,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            obj = session.merge(obj)
 | 
				
			||||||
 | 
					            session.commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return str(obj.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
 | 
				
			||||||
 | 
					        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
 | 
				
			||||||
 | 
					        with Session(self.engine) as session:
 | 
				
			||||||
 | 
					            result = session.get(self._model, vector_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return BaseVectorStoreDriver.Entry(
 | 
				
			||||||
 | 
					                id=result.id,
 | 
				
			||||||
 | 
					                vector=result.vector,
 | 
				
			||||||
 | 
					                namespace=result.namespace,
 | 
				
			||||||
 | 
					                meta=result.meta,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
 | 
				
			||||||
 | 
					        """Retrieves all vector entries from the collection, optionally filtering to only
 | 
				
			||||||
 | 
					        those that match the provided namespace.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with Session(self.engine) as session:
 | 
				
			||||||
 | 
					            query = session.query(self._model)
 | 
				
			||||||
 | 
					            if namespace:
 | 
				
			||||||
 | 
					                query = query.filter_by(namespace=namespace)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            results = query.all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return [
 | 
				
			||||||
 | 
					                BaseVectorStoreDriver.Entry(
 | 
				
			||||||
 | 
					                    id=str(result.id),
 | 
				
			||||||
 | 
					                    vector=result.vector,
 | 
				
			||||||
 | 
					                    namespace=result.namespace,
 | 
				
			||||||
 | 
					                    meta=result.meta,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for result in results
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def query(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        query: str,
 | 
				
			||||||
 | 
					        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
 | 
				
			||||||
 | 
					        namespace: Optional[str] = None,
 | 
				
			||||||
 | 
					        include_vectors: bool = False,
 | 
				
			||||||
 | 
					        distance_metric: str = "cosine_distance",
 | 
				
			||||||
 | 
					        **kwargs
 | 
				
			||||||
 | 
					    ) -> list[BaseVectorStoreDriver.QueryResult]:
 | 
				
			||||||
 | 
					        """Performs a search on the collection to find vectors similar to the provided input vector,
 | 
				
			||||||
 | 
					        optionally filtering to only those that match the provided namespace.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        distance_metrics = {
 | 
				
			||||||
 | 
					            "cosine_distance": self._model.vector.cosine_distance,
 | 
				
			||||||
 | 
					            "l2_distance": self._model.vector.l2_distance,
 | 
				
			||||||
 | 
					            "inner_product": self._model.vector.max_inner_product,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if distance_metric not in distance_metrics:
 | 
				
			||||||
 | 
					            raise ValueError("Invalid distance metric provided")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        op = distance_metrics[distance_metric]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with Session(self.engine) as session:
 | 
				
			||||||
 | 
					            vector = self.embedding_driver.embed_string(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # The query should return both the vector and the distance metric score.
 | 
				
			||||||
 | 
					            query = session.query(
 | 
				
			||||||
 | 
					                self._model,
 | 
				
			||||||
 | 
					                op(vector).label("score"),
 | 
				
			||||||
 | 
					            ).order_by(op(vector))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if namespace:
 | 
				
			||||||
 | 
					                query = query.filter_by(namespace=namespace)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            results = query.limit(count).all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return [
 | 
				
			||||||
 | 
					                BaseVectorStoreDriver.QueryResult(
 | 
				
			||||||
 | 
					                    id=str(result[0].id),
 | 
				
			||||||
 | 
					                    vector=result[0].vector if include_vectors else None,
 | 
				
			||||||
 | 
					                    score=result[1],
 | 
				
			||||||
 | 
					                    meta=result[0].meta,
 | 
				
			||||||
 | 
					                    namespace=result[0].namespace,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for result in results
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def default_vector_model(self) -> any:
 | 
				
			||||||
 | 
					        Base = declarative_base()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @dataclass
 | 
				
			||||||
 | 
					        class VectorModel(Base):
 | 
				
			||||||
 | 
					            __tablename__ = self.table_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
 | 
				
			||||||
 | 
					            vector = Column(Vector())
 | 
				
			||||||
 | 
					            namespace = Column(String)
 | 
				
			||||||
 | 
					            meta = Column(JSON)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return VectorModel
 | 
				
			||||||
@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					from concurrent import futures
 | 
				
			||||||
 | 
					from typing import TypeVar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					T = TypeVar("T")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
 | 
				
			||||||
 | 
					    futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {key: future.result() for key, future in fs_dict.items()}
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue