From 2599d6b37f721cd89dca0492121f6bed861e3a73 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 14 Oct 2023 19:29:14 -0400 Subject: [PATCH] tasks for memory storages --- .env.example | 7 +- swarms/embeddings/__init__.py | 1 + swarms/embeddings/pegasus.py | 33 ++++++- swarms/memory/__init__.py | 1 + swarms/memory/ocean.py | 114 ++++++++++++++++++++++-- swarms/memory/vector_stores/pg.py | 2 +- swarms/memory/vector_stores/pinecone.py | 4 +- swarms/models/mistral.py | 4 +- tests/memory/oceandb.py | 95 ++++++++++++++++++++ tests/memory/pg.py | 91 +++++++++++++++++++ tests/memory/pinecone.py | 65 ++++++++++++++ 11 files changed, 400 insertions(+), 17 deletions(-) create mode 100644 tests/memory/oceandb.py create mode 100644 tests/memory/pg.py create mode 100644 tests/memory/pinecone.py diff --git a/.env.example b/.env.example index eee83ce0..f13ce77f 100644 --- a/.env.example +++ b/.env.example @@ -30,4 +30,9 @@ HF_API_KEY="your_huggingface_api_key_here" REDIS_HOST= -REDIS_PORT= \ No newline at end of file +REDIS_PORT= + + + +#dbs +PINECONE_API_KEY="" \ No newline at end of file diff --git a/swarms/embeddings/__init__.py b/swarms/embeddings/__init__.py index e69de29b..3d968663 100644 --- a/swarms/embeddings/__init__.py +++ b/swarms/embeddings/__init__.py @@ -0,0 +1 @@ +from swarms.embeddings.pegasus import PegasusEmbedding diff --git a/swarms/embeddings/pegasus.py b/swarms/embeddings/pegasus.py index a517135e..e388d40c 100644 --- a/swarms/embeddings/pegasus.py +++ b/swarms/embeddings/pegasus.py @@ -1,12 +1,38 @@ import logging from typing import Union -from pegasus import Pegasus -# import oceandb -# from oceandb.utils.embedding_functions import MultiModalEmbeddingfunction +from pegasus import Pegasus class PegasusEmbedding: + """ + Pegasus + + Args: + modality (str): Modality to use for embedding + multi_process (bool, optional): Whether to use multi-process. Defaults to False. + n_processes (int, optional): Number of processes to use. Defaults to 4. + + Usage: + -------------- + pegasus = PegasusEmbedding(modality="text") + pegasus.embed("Hello world") + + + vision + -------------- + pegasus = PegasusEmbedding(modality="vision") + pegasus.embed("https://i.imgur.com/1qZ0K8r.jpeg") + + audio + -------------- + pegasus = PegasusEmbedding(modality="audio") + pegasus.embed("https://www2.cs.uic.edu/~i101/SoundFiles/StarWars60.wav") + + + + """ + def __init__( self, modality: str, multi_process: bool = False, n_processes: int = 4 ): @@ -22,6 +48,7 @@ class PegasusEmbedding: raise def embed(self, data: Union[str, list[str]]): + """Embed the data""" try: return self.pegasus.embed(data) except Exception as e: diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py index 99a738e8..dccc5965 100644 --- a/swarms/memory/__init__.py +++ b/swarms/memory/__init__.py @@ -1,3 +1,4 @@ from swarms.memory.vector_stores.pinecone import PineconeVector from swarms.memory.vector_stores.base import BaseVectorStore from swarms.memory.vector_stores.pg import PgVectorVectorStore +from swarms.memory.ocean import OceanDB diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py index a4534d45..0b1a9daf 100644 --- a/swarms/memory/ocean.py +++ b/swarms/memory/ocean.py @@ -1,36 +1,117 @@ -# init ocean -# TODO upload ocean to pip and config it to the abstract class import logging -from typing import Union, List +from typing import List import oceandb from oceandb.utils.embedding_function import MultiModalEmbeddingFunction class OceanDB: - def __init__(self): + """ + A class to interact with OceanDB. + + ... + + Attributes + ---------- + client : oceandb.Client + a client to interact with OceanDB + + Methods + ------- + create_collection(collection_name: str, modality: str): + Creates a new collection in OceanDB. + append_document(collection, document: str, id: str): + Appends a document to a collection in OceanDB. + add_documents(collection, documents: List[str], ids: List[str]): + Adds multiple documents to a collection in OceanDB. + query(collection, query_texts: list[str], n_results: int): + Queries a collection in OceanDB. + """ + + def __init__(self, client: oceandb.Client = None): + """ + Constructs all the necessary attributes for the OceanDB object. + + Parameters + ---------- + client : oceandb.Client, optional + a client to interact with OceanDB (default is None, which creates a new client) + """ try: - self.client = oceandb.Client() + self.client = client if client else oceandb.Client() print(self.client.heartbeat()) except Exception as e: logging.error(f"Failed to initialize OceanDB client. Error: {e}") + raise def create_collection(self, collection_name: str, modality: str): + """ + Creates a new collection in OceanDB. + + Parameters + ---------- + collection_name : str + the name of the new collection + modality : str + the modality of the new collection + + Returns + ------- + collection + the created collection + """ try: embedding_function = MultiModalEmbeddingFunction(modality=modality) - collection = self.client.create_collection(collection_name, embedding_function=embedding_function) + collection = self.client.create_collection( + collection_name, embedding_function=embedding_function + ) return collection except Exception as e: logging.error(f"Failed to create collection. Error {e}") + raise def append_document(self, collection, document: str, id: str): + """ + Appends a document to a collection in OceanDB. + + Parameters + ---------- + collection + the collection to append the document to + document : str + the document to append + id : str + the id of the document + + Returns + ------- + result + the result of the append operation + """ try: - return collection.add(documents=[document], ids[id]) + return collection.add(documents=[document], ids=[id]) except Exception as e: - logging.error(f"Faield to append document to the collection. Error {e}") + logging.error(f"Failed to append document to the collection. Error {e}") raise def add_documents(self, collection, documents: List[str], ids: List[str]): + """ + Adds multiple documents to a collection in OceanDB. + + Parameters + ---------- + collection + the collection to add the documents to + documents : List[str] + the documents to add + ids : List[str] + the ids of the documents + + Returns + ------- + result + the result of the add operation + """ try: return collection.add(documents=documents, ids=ids) except Exception as e: @@ -38,6 +119,23 @@ class OceanDB: raise def query(self, collection, query_texts: list[str], n_results: int): + """ + Queries a collection in OceanDB. + + Parameters + ---------- + collection + the collection to query + query_texts : list[str] + the texts to query + n_results : int + the number of results to return + + Returns + ------- + results + the results of the query + """ try: results = collection.query(query_texts=query_texts, n_results=n_results) return results diff --git a/swarms/memory/vector_stores/pg.py b/swarms/memory/vector_stores/pg.py index 4065fc4e..bd768459 100644 --- a/swarms/memory/vector_stores/pg.py +++ b/swarms/memory/vector_stores/pg.py @@ -81,7 +81,7 @@ class PgVectorVectorStore(BaseVectorStore): >>> namespace="your-namespace" >>> ) - + """ connection_string: Optional[str] = field(default=None, kw_only=True) diff --git a/swarms/memory/vector_stores/pinecone.py b/swarms/memory/vector_stores/pinecone.py index 21f71621..2374f12a 100644 --- a/swarms/memory/vector_stores/pinecone.py +++ b/swarms/memory/vector_stores/pinecone.py @@ -93,7 +93,7 @@ class PineconeVectorStoreStore(BaseVector): index: pinecone.Index = field(init=False) def __attrs_post_init__(self) -> None: - """ Post init""" + """Post init""" pinecone.init( api_key=self.api_key, environment=self.environment, @@ -122,7 +122,7 @@ class PineconeVectorStoreStore(BaseVector): def load_entry( self, vector_id: str, namespace: Optional[str] = None ) -> Optional[BaseVector.Entry]: - """Load entry """ + """Load entry""" result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() vectors = list(result["vectors"].values()) diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 34d7264f..7f48a0d6 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -17,12 +17,12 @@ class Mistral: temperature (float, optional): Temperature. Defaults to 1.0. max_length (int, optional): Max length. Defaults to 100. do_sample (bool, optional): Whether to sample. Defaults to True. - + Usage: from swarms.models import Mistral model = Mistral(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200) - + task = "My favourite condiment is" result = model.run(task) print(result) diff --git a/tests/memory/oceandb.py b/tests/memory/oceandb.py new file mode 100644 index 00000000..978a46db --- /dev/null +++ b/tests/memory/oceandb.py @@ -0,0 +1,95 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.memory.oceandb import OceanDB + + +def test_init(): + with patch("oceandb.Client") as MockClient: + MockClient.return_value.heartbeat.return_value = "OK" + db = OceanDB(MockClient) + MockClient.assert_called_once() + assert db.client == MockClient + + +def test_init_exception(): + with patch("oceandb.Client") as MockClient: + MockClient.side_effect = Exception("Client error") + with pytest.raises(Exception) as e: + db = OceanDB(MockClient) + assert str(e.value) == "Client error" + + +def test_create_collection(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + db.create_collection("test", "modality") + MockClient.create_collection.assert_called_once_with( + "test", embedding_function=Mock.ANY + ) + + +def test_create_collection_exception(): + with patch("oceandb.Client") as MockClient: + MockClient.create_collection.side_effect = Exception("Create collection error") + db = OceanDB(MockClient) + with pytest.raises(Exception) as e: + db.create_collection("test", "modality") + assert str(e.value) == "Create collection error" + + +def test_append_document(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + db.append_document(collection, "doc", "id") + collection.add.assert_called_once_with(documents=["doc"], ids=["id"]) + + +def test_append_document_exception(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + collection.add.side_effect = Exception("Append document error") + with pytest.raises(Exception) as e: + db.append_document(collection, "doc", "id") + assert str(e.value) == "Append document error" + + +def test_add_documents(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"]) + collection.add.assert_called_once_with( + documents=["doc1", "doc2"], ids=["id1", "id2"] + ) + + +def test_add_documents_exception(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + collection.add.side_effect = Exception("Add documents error") + with pytest.raises(Exception) as e: + db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"]) + assert str(e.value) == "Add documents error" + + +def test_query(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + db.query(collection, ["query1", "query2"], 2) + collection.query.assert_called_once_with( + query_texts=["query1", "query2"], n_results=2 + ) + + +def test_query_exception(): + with patch("oceandb.Client") as MockClient: + db = OceanDB(MockClient) + collection = Mock() + collection.query.side_effect = Exception("Query error") + with pytest.raises(Exception) as e: + db.query(collection, ["query1", "query2"], 2) + assert str(e.value) == "Query error" diff --git a/tests/memory/pg.py b/tests/memory/pg.py new file mode 100644 index 00000000..6ba33077 --- /dev/null +++ b/tests/memory/pg.py @@ -0,0 +1,91 @@ +import pytest +from unittest.mock import patch +from swarms.memory import PgVectorVectorStore + + +def test_init(): + with patch("sqlalchemy.create_engine") as MockEngine: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + MockEngine.assert_called_once() + assert store.engine == MockEngine.return_value + + +def test_init_exception(): + with pytest.raises(ValueError): + PgVectorVectorStore( + connection_string="mysql://root:password@localhost:3306/test", + table_name="test", + ) + + +def test_setup(): + with patch("sqlalchemy.create_engine") as MockEngine: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + store.setup() + MockEngine.execute.assert_called() + + +def test_upsert_vector(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + store.upsert_vector( + [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"} + ) + MockSession.assert_called() + MockSession.return_value.merge.assert_called() + MockSession.return_value.commit.assert_called() + + +def test_load_entry(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + store.load_entry("test_id", "test_namespace") + MockSession.assert_called() + MockSession.return_value.get.assert_called() + + +def test_load_entries(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + store.load_entries("test_namespace") + MockSession.assert_called() + MockSession.return_value.query.assert_called() + MockSession.return_value.query.return_value.filter_by.assert_called() + MockSession.return_value.query.return_value.all.assert_called() + + +def test_query(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + store = PgVectorVectorStore( + connection_string="postgresql://postgres:password@localhost:5432/postgres", + table_name="test", + ) + store.query("test_query", 10, "test_namespace") + MockSession.assert_called() + MockSession.return_value.query.assert_called() + MockSession.return_value.query.return_value.filter_by.assert_called() + MockSession.return_value.query.return_value.limit.assert_called() + MockSession.return_value.query.return_value.all.assert_called() diff --git a/tests/memory/pinecone.py b/tests/memory/pinecone.py new file mode 100644 index 00000000..54b9026b --- /dev/null +++ b/tests/memory/pinecone.py @@ -0,0 +1,65 @@ +import os +import pytest +from unittest.mock import Mock, patch +from swarms.memory import PineconeVectorStore + +api_key = os.getenv("PINECONE_API_KEY") or "" + + +def test_init(): + with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + MockInit.assert_called_once() + MockIndex.assert_called_once() + assert store.index == MockIndex.return_value + + +def test_upsert_vector(): + with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + store.upsert_vector( + [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"} + ) + MockIndex.return_value.upsert.assert_called() + + +def test_load_entry(): + with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + store.load_entry("test_id", "test_namespace") + MockIndex.return_value.fetch.assert_called() + + +def test_load_entries(): + with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + store.load_entries("test_namespace") + MockIndex.return_value.query.assert_called() + + +def test_query(): + with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + store.query("test_query", 10, "test_namespace") + MockIndex.return_value.query.assert_called() + + +def test_create_index(): + with patch("pinecone.init") as MockInit, patch( + "pinecone.Index" + ) as MockIndex, patch("pinecone.create_index") as MockCreateIndex: + store = PineconeVectorStore( + api_key=api_key, index_name="test_index", environment="test_env" + ) + store.create_index("test_index") + MockCreateIndex.assert_called()