diff --git a/docs/examples/omni_agent.md b/docs/examples/omni_agent.md index f3a9b3cf..56a6c996 100644 --- a/docs/examples/omni_agent.md +++ b/docs/examples/omni_agent.md @@ -50,11 +50,12 @@ Let’s embark on an exciting journey with OmniModalAgent: **i. Basic Interaction**: ```python -from swarms import OmniModalAgent, OpenAIChat +from swarms.agents import OmniModalAgent +from swarms.models import OpenAIChat -llm = OpenAIChat() +llm = OpenAIChat(openai_api_key="sk-") agent = OmniModalAgent(llm) -response = agent.run("Hello, how are you? Create an image of how you are doing!") +response = agent.run("Create an video of a swarm of fish concept art, game art") print(response) ``` diff --git a/swarms/memory/vector_stores/__init__.py b/swarms/memory/vector_stores/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/swarms/memory/vector_stores/base.py b/swarms/memory/vector_stores/base.py new file mode 100644 index 00000000..f23ac87a --- /dev/null +++ b/swarms/memory/vector_stores/base.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from concurrent import futures +from dataclasses import dataclass +from typing import Optional, Any +from attr import define, field, Factory +from swarms.utils.futures import execute_futures_dict +from griptape.artifacts import TextArtifact + + + +@define +class BaseVectorStore(ABC): + """ + """ + DEFAULT_QUERY_COUNT = 5 + + @dataclass + class QueryResult: + id: str + vector: list[float] + score: float + meta: Optional[dict] = None + namespace: Optional[str] = None + + @dataclass + class Entry: + id: str + vector: list[float] + meta: Optional[dict] = None + namespace: Optional[str] = None + + embedding_driver: Any + futures_executor: futures.Executor = field( + default=Factory(lambda: futures.ThreadPoolExecutor()), + kw_only=True + ) + + def upsert_text_artifacts( + self, + artifacts: dict[str, list[TextArtifact]], + meta: Optional[dict] = None, + **kwargs + ) -> None: + execute_futures_dict({ + namespace: + self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs) + for namespace, artifact_list in artifacts.items() for a in artifact_list + }) + + def upsert_text_artifact( + self, + artifact: TextArtifact, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: + if not meta: + meta = {} + + meta["artifact"] = artifact.to_json() + + if artifact.embedding: + vector = artifact.embedding + else: + vector = artifact.generate_embedding(self.embedding_driver) + + return self.upsert_vector( + vector, + vector_id=artifact.id, + namespace=namespace, + meta=meta, + **kwargs + ) + + def upsert_text( + self, + string: str, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: + return self.upsert_vector( + self.embedding_driver.embed_string(string), + vector_id=vector_id, + namespace=namespace, + meta=meta if meta else {}, + **kwargs + ) + + @abstractmethod + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: + ... + + @abstractmethod + def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: + ... + + @abstractmethod + def load_entries(self, namespace: Optional[str] = None) -> list[Entry]: + ... + + @abstractmethod + def query( + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + **kwargs + ) -> list[QueryResult]: + ... \ No newline at end of file diff --git a/swarms/memory/vector_stores/pg.py b/swarms/memory/vector_stores/pg.py new file mode 100644 index 00000000..21ec919d --- /dev/null +++ b/swarms/memory/vector_stores/pg.py @@ -0,0 +1,192 @@ +import uuid +from typing import Optional +from attr import define, field, Factory +from dataclasses import dataclass +from swarms.memory.vector_stores.base import BaseVectorStoreDriver +from sqlalchemy.engine import Engine +from sqlalchemy import create_engine, Column, String, JSON +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session +from pgvector.sqlalchemy import Vector + + +@define +class PgVectorVectorStoreDriver(BaseVectorStoreDriver): + """A vector store driver to Postgres using the PGVector extension. + + Attributes: + connection_string: An optional string describing the target Postgres database instance. + create_engine_params: Additional configuration params passed when creating the database connection. + engine: An optional sqlalchemy Postgres engine to use. + table_name: Optionally specify the name of the table to used to store vectors. + """ + + connection_string: Optional[str] = field(default=None, kw_only=True) + create_engine_params: dict = field(factory=dict, kw_only=True) + engine: Optional[Engine] = field(default=None, kw_only=True) + table_name: str = field(kw_only=True) + _model: any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) + + @connection_string.validator + def validate_connection_string(self, _, connection_string: Optional[str]) -> None: + # If an engine is provided, the connection string is not used. + if self.engine is not None: + return + + # If an engine is not provided, a connection string is required. + if connection_string is None: + raise ValueError("An engine or connection string is required") + + if not connection_string.startswith("postgresql://"): + raise ValueError("The connection string must describe a Postgres database connection") + + @engine.validator + def validate_engine(self, _, engine: Optional[Engine]) -> None: + # If a connection string is provided, an engine does not need to be provided. + if self.connection_string is not None: + return + + # If a connection string is not provided, an engine is required. + if engine is None: + raise ValueError("An engine or connection string is required") + + def __attrs_post_init__(self) -> None: + """If a an engine is provided, it will be used to connect to the database. + If not, a connection string is used to create a new database connection here. + """ + if self.engine is None: + self.engine = create_engine(self.connection_string, **self.create_engine_params) + + def setup( + self, + create_schema: bool = True, + install_uuid_extension: bool = True, + install_vector_extension: bool = True, + ) -> None: + """Provides a mechanism to initialize the database schema and extensions.""" + if install_uuid_extension: + self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + + if install_vector_extension: + self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";') + + if create_schema: + self._model.metadata.create_all(self.engine) + + def upsert_vector( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs + ) -> str: + """Inserts or updates a vector in the collection.""" + with Session(self.engine) as session: + obj = self._model( + id=vector_id, + vector=vector, + namespace=namespace, + meta=meta, + ) + + obj = session.merge(obj) + session.commit() + + return str(obj.id) + + def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: + """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" + with Session(self.engine) as session: + result = session.get(self._model, vector_id) + + return BaseVectorStoreDriver.Entry( + id=result.id, + vector=result.vector, + namespace=result.namespace, + meta=result.meta, + ) + + def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: + """Retrieves all vector entries from the collection, optionally filtering to only + those that match the provided namespace. + """ + with Session(self.engine) as session: + query = session.query(self._model) + if namespace: + query = query.filter_by(namespace=namespace) + + results = query.all() + + return [ + BaseVectorStoreDriver.Entry( + id=str(result.id), + vector=result.vector, + namespace=result.namespace, + meta=result.meta, + ) + for result in results + ] + + def query( + self, + query: str, + count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + namespace: Optional[str] = None, + include_vectors: bool = False, + distance_metric: str = "cosine_distance", + **kwargs + ) -> list[BaseVectorStoreDriver.QueryResult]: + """Performs a search on the collection to find vectors similar to the provided input vector, + optionally filtering to only those that match the provided namespace. + """ + distance_metrics = { + "cosine_distance": self._model.vector.cosine_distance, + "l2_distance": self._model.vector.l2_distance, + "inner_product": self._model.vector.max_inner_product, + } + + if distance_metric not in distance_metrics: + raise ValueError("Invalid distance metric provided") + + op = distance_metrics[distance_metric] + + with Session(self.engine) as session: + vector = self.embedding_driver.embed_string(query) + + # The query should return both the vector and the distance metric score. + query = session.query( + self._model, + op(vector).label("score"), + ).order_by(op(vector)) + + if namespace: + query = query.filter_by(namespace=namespace) + + results = query.limit(count).all() + + return [ + BaseVectorStoreDriver.QueryResult( + id=str(result[0].id), + vector=result[0].vector if include_vectors else None, + score=result[1], + meta=result[0].meta, + namespace=result[0].namespace, + ) + for result in results + ] + + def default_vector_model(self) -> any: + Base = declarative_base() + + @dataclass + class VectorModel(Base): + __tablename__ = self.table_name + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False) + vector = Column(Vector()) + namespace = Column(String) + meta = Column(JSON) + + return VectorModel \ No newline at end of file diff --git a/swarms/utils/futures.py b/swarms/utils/futures.py new file mode 100644 index 00000000..d8719672 --- /dev/null +++ b/swarms/utils/futures.py @@ -0,0 +1,10 @@ +from concurrent import futures +from typing import TypeVar + +T = TypeVar("T") + + +def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: + futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) + + return {key: future.result() for key, future in fs_dict.items()} \ No newline at end of file