From eeac81e6af155ffde4aba8e31f65698e627290a6 Mon Sep 17 00:00:00 2001
From: Kye <kye@apacmediasolutions.com>
Date: Sat, 14 Oct 2023 19:29:14 -0400
Subject: [PATCH] tasks for memory storages

Former-commit-id: 07e3c9256108dccbead2210b30b48f4d2ee4751f
---
 .env.example                            |   7 +-
 swarms/embeddings/__init__.py           |   1 +
 swarms/embeddings/pegasus.py            |  33 ++++++-
 swarms/memory/__init__.py               |   1 +
 swarms/memory/ocean.py                  | 114 ++++++++++++++++++++++--
 swarms/memory/vector_stores/pg.py       |   2 +-
 swarms/memory/vector_stores/pinecone.py |   4 +-
 swarms/models/mistral.py                |   4 +-
 tests/memory/oceandb.py                 |  95 ++++++++++++++++++++
 tests/memory/pg.py                      |  91 +++++++++++++++++++
 tests/memory/pinecone.py                |  65 ++++++++++++++
 11 files changed, 400 insertions(+), 17 deletions(-)
 create mode 100644 tests/memory/oceandb.py
 create mode 100644 tests/memory/pg.py
 create mode 100644 tests/memory/pinecone.py

diff --git a/.env.example b/.env.example
index eee83ce0..f13ce77f 100644
--- a/.env.example
+++ b/.env.example
@@ -30,4 +30,9 @@ HF_API_KEY="your_huggingface_api_key_here"
 
 
 REDIS_HOST=
-REDIS_PORT=
\ No newline at end of file
+REDIS_PORT=
+
+
+
+#dbs
+PINECONE_API_KEY=""
\ No newline at end of file
diff --git a/swarms/embeddings/__init__.py b/swarms/embeddings/__init__.py
index e69de29b..3d968663 100644
--- a/swarms/embeddings/__init__.py
+++ b/swarms/embeddings/__init__.py
@@ -0,0 +1 @@
+from swarms.embeddings.pegasus import PegasusEmbedding
diff --git a/swarms/embeddings/pegasus.py b/swarms/embeddings/pegasus.py
index a517135e..e388d40c 100644
--- a/swarms/embeddings/pegasus.py
+++ b/swarms/embeddings/pegasus.py
@@ -1,12 +1,38 @@
 import logging
 from typing import Union
-from pegasus import Pegasus
 
-# import oceandb
-# from oceandb.utils.embedding_functions import MultiModalEmbeddingfunction
+from pegasus import Pegasus
 
 
 class PegasusEmbedding:
+    """
+    Pegasus
+
+    Args:
+        modality (str): Modality to use for embedding
+        multi_process (bool, optional): Whether to use multi-process. Defaults to False.
+        n_processes (int, optional): Number of processes to use. Defaults to 4.
+
+    Usage:
+    --------------
+    pegasus = PegasusEmbedding(modality="text")
+    pegasus.embed("Hello world")
+
+
+    vision
+    --------------
+    pegasus = PegasusEmbedding(modality="vision")
+    pegasus.embed("https://i.imgur.com/1qZ0K8r.jpeg")
+
+    audio
+    --------------
+    pegasus = PegasusEmbedding(modality="audio")
+    pegasus.embed("https://www2.cs.uic.edu/~i101/SoundFiles/StarWars60.wav")
+
+
+
+    """
+
     def __init__(
         self, modality: str, multi_process: bool = False, n_processes: int = 4
     ):
@@ -22,6 +48,7 @@ class PegasusEmbedding:
             raise
 
     def embed(self, data: Union[str, list[str]]):
+        """Embed the data"""
         try:
             return self.pegasus.embed(data)
         except Exception as e:
diff --git a/swarms/memory/__init__.py b/swarms/memory/__init__.py
index 99a738e8..dccc5965 100644
--- a/swarms/memory/__init__.py
+++ b/swarms/memory/__init__.py
@@ -1,3 +1,4 @@
 from swarms.memory.vector_stores.pinecone import PineconeVector
 from swarms.memory.vector_stores.base import BaseVectorStore
 from swarms.memory.vector_stores.pg import PgVectorVectorStore
+from swarms.memory.ocean import OceanDB
diff --git a/swarms/memory/ocean.py b/swarms/memory/ocean.py
index a4534d45..0b1a9daf 100644
--- a/swarms/memory/ocean.py
+++ b/swarms/memory/ocean.py
@@ -1,36 +1,117 @@
-# init ocean
-# TODO upload ocean to pip and config it to the abstract class
 import logging
-from typing import Union, List
+from typing import List
 
 import oceandb
 from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
 
 
 class OceanDB:
-    def __init__(self):
+    """
+    A class to interact with OceanDB.
+
+    ...
+
+    Attributes
+    ----------
+    client : oceandb.Client
+        a client to interact with OceanDB
+
+    Methods
+    -------
+    create_collection(collection_name: str, modality: str):
+        Creates a new collection in OceanDB.
+    append_document(collection, document: str, id: str):
+        Appends a document to a collection in OceanDB.
+    add_documents(collection, documents: List[str], ids: List[str]):
+        Adds multiple documents to a collection in OceanDB.
+    query(collection, query_texts: list[str], n_results: int):
+        Queries a collection in OceanDB.
+    """
+
+    def __init__(self, client: oceandb.Client = None):
+        """
+        Constructs all the necessary attributes for the OceanDB object.
+
+        Parameters
+        ----------
+            client : oceandb.Client, optional
+                a client to interact with OceanDB (default is None, which creates a new client)
+        """
         try:
-            self.client = oceandb.Client()
+            self.client = client if client else oceandb.Client()
             print(self.client.heartbeat())
         except Exception as e:
             logging.error(f"Failed to initialize OceanDB client. Error: {e}")
+            raise
 
     def create_collection(self, collection_name: str, modality: str):
+        """
+        Creates a new collection in OceanDB.
+
+        Parameters
+        ----------
+            collection_name : str
+                the name of the new collection
+            modality : str
+                the modality of the new collection
+
+        Returns
+        -------
+            collection
+                the created collection
+        """
         try:
             embedding_function = MultiModalEmbeddingFunction(modality=modality)
-            collection = self.client.create_collection(collection_name, embedding_function=embedding_function)
+            collection = self.client.create_collection(
+                collection_name, embedding_function=embedding_function
+            )
             return collection
         except Exception as e:
             logging.error(f"Failed to create collection. Error {e}")
+            raise
 
     def append_document(self, collection, document: str, id: str):
+        """
+        Appends a document to a collection in OceanDB.
+
+        Parameters
+        ----------
+            collection
+                the collection to append the document to
+            document : str
+                the document to append
+            id : str
+                the id of the document
+
+        Returns
+        -------
+            result
+                the result of the append operation
+        """
         try:
-            return collection.add(documents=[document], ids[id])
+            return collection.add(documents=[document], ids=[id])
         except Exception as e:
-            logging.error(f"Faield to append document to the collection. Error {e}")
+            logging.error(f"Failed to append document to the collection. Error {e}")
             raise
 
     def add_documents(self, collection, documents: List[str], ids: List[str]):
+        """
+        Adds multiple documents to a collection in OceanDB.
+
+        Parameters
+        ----------
+            collection
+                the collection to add the documents to
+            documents : List[str]
+                the documents to add
+            ids : List[str]
+                the ids of the documents
+
+        Returns
+        -------
+            result
+                the result of the add operation
+        """
         try:
             return collection.add(documents=documents, ids=ids)
         except Exception as e:
@@ -38,6 +119,23 @@ class OceanDB:
             raise
 
     def query(self, collection, query_texts: list[str], n_results: int):
+        """
+        Queries a collection in OceanDB.
+
+        Parameters
+        ----------
+            collection
+                the collection to query
+            query_texts : list[str]
+                the texts to query
+            n_results : int
+                the number of results to return
+
+        Returns
+        -------
+            results
+                the results of the query
+        """
         try:
             results = collection.query(query_texts=query_texts, n_results=n_results)
             return results
diff --git a/swarms/memory/vector_stores/pg.py b/swarms/memory/vector_stores/pg.py
index 4065fc4e..bd768459 100644
--- a/swarms/memory/vector_stores/pg.py
+++ b/swarms/memory/vector_stores/pg.py
@@ -81,7 +81,7 @@ class PgVectorVectorStore(BaseVectorStore):
     >>>     namespace="your-namespace"
     >>> )
 
-    
+
     """
 
     connection_string: Optional[str] = field(default=None, kw_only=True)
diff --git a/swarms/memory/vector_stores/pinecone.py b/swarms/memory/vector_stores/pinecone.py
index 21f71621..2374f12a 100644
--- a/swarms/memory/vector_stores/pinecone.py
+++ b/swarms/memory/vector_stores/pinecone.py
@@ -93,7 +93,7 @@ class PineconeVectorStoreStore(BaseVector):
     index: pinecone.Index = field(init=False)
 
     def __attrs_post_init__(self) -> None:
-        """ Post init"""
+        """Post init"""
         pinecone.init(
             api_key=self.api_key,
             environment=self.environment,
@@ -122,7 +122,7 @@ class PineconeVectorStoreStore(BaseVector):
     def load_entry(
         self, vector_id: str, namespace: Optional[str] = None
     ) -> Optional[BaseVector.Entry]:
-        """Load entry """
+        """Load entry"""
         result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
         vectors = list(result["vectors"].values())
 
diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py
index 34d7264f..7f48a0d6 100644
--- a/swarms/models/mistral.py
+++ b/swarms/models/mistral.py
@@ -17,12 +17,12 @@ class Mistral:
         temperature (float, optional): Temperature. Defaults to 1.0.
         max_length (int, optional): Max length. Defaults to 100.
         do_sample (bool, optional): Whether to sample. Defaults to True.
-    
+
     Usage:
     from swarms.models import Mistral
 
     model = Mistral(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200)
-    
+
     task = "My favourite condiment is"
     result = model.run(task)
     print(result)
diff --git a/tests/memory/oceandb.py b/tests/memory/oceandb.py
new file mode 100644
index 00000000..978a46db
--- /dev/null
+++ b/tests/memory/oceandb.py
@@ -0,0 +1,95 @@
+import pytest
+from unittest.mock import Mock, patch
+from swarms.memory.oceandb import OceanDB
+
+
+def test_init():
+    with patch("oceandb.Client") as MockClient:
+        MockClient.return_value.heartbeat.return_value = "OK"
+        db = OceanDB(MockClient)
+        MockClient.assert_called_once()
+        assert db.client == MockClient
+
+
+def test_init_exception():
+    with patch("oceandb.Client") as MockClient:
+        MockClient.side_effect = Exception("Client error")
+        with pytest.raises(Exception) as e:
+            db = OceanDB(MockClient)
+        assert str(e.value) == "Client error"
+
+
+def test_create_collection():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        db.create_collection("test", "modality")
+        MockClient.create_collection.assert_called_once_with(
+            "test", embedding_function=Mock.ANY
+        )
+
+
+def test_create_collection_exception():
+    with patch("oceandb.Client") as MockClient:
+        MockClient.create_collection.side_effect = Exception("Create collection error")
+        db = OceanDB(MockClient)
+        with pytest.raises(Exception) as e:
+            db.create_collection("test", "modality")
+        assert str(e.value) == "Create collection error"
+
+
+def test_append_document():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        db.append_document(collection, "doc", "id")
+        collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
+
+
+def test_append_document_exception():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        collection.add.side_effect = Exception("Append document error")
+        with pytest.raises(Exception) as e:
+            db.append_document(collection, "doc", "id")
+        assert str(e.value) == "Append document error"
+
+
+def test_add_documents():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
+        collection.add.assert_called_once_with(
+            documents=["doc1", "doc2"], ids=["id1", "id2"]
+        )
+
+
+def test_add_documents_exception():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        collection.add.side_effect = Exception("Add documents error")
+        with pytest.raises(Exception) as e:
+            db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
+        assert str(e.value) == "Add documents error"
+
+
+def test_query():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        db.query(collection, ["query1", "query2"], 2)
+        collection.query.assert_called_once_with(
+            query_texts=["query1", "query2"], n_results=2
+        )
+
+
+def test_query_exception():
+    with patch("oceandb.Client") as MockClient:
+        db = OceanDB(MockClient)
+        collection = Mock()
+        collection.query.side_effect = Exception("Query error")
+        with pytest.raises(Exception) as e:
+            db.query(collection, ["query1", "query2"], 2)
+        assert str(e.value) == "Query error"
diff --git a/tests/memory/pg.py b/tests/memory/pg.py
new file mode 100644
index 00000000..6ba33077
--- /dev/null
+++ b/tests/memory/pg.py
@@ -0,0 +1,91 @@
+import pytest
+from unittest.mock import patch
+from swarms.memory import PgVectorVectorStore
+
+
+def test_init():
+    with patch("sqlalchemy.create_engine") as MockEngine:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        MockEngine.assert_called_once()
+        assert store.engine == MockEngine.return_value
+
+
+def test_init_exception():
+    with pytest.raises(ValueError):
+        PgVectorVectorStore(
+            connection_string="mysql://root:password@localhost:3306/test",
+            table_name="test",
+        )
+
+
+def test_setup():
+    with patch("sqlalchemy.create_engine") as MockEngine:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        store.setup()
+        MockEngine.execute.assert_called()
+
+
+def test_upsert_vector():
+    with patch("sqlalchemy.create_engine"), patch(
+        "sqlalchemy.orm.Session"
+    ) as MockSession:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        store.upsert_vector(
+            [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
+        )
+        MockSession.assert_called()
+        MockSession.return_value.merge.assert_called()
+        MockSession.return_value.commit.assert_called()
+
+
+def test_load_entry():
+    with patch("sqlalchemy.create_engine"), patch(
+        "sqlalchemy.orm.Session"
+    ) as MockSession:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        store.load_entry("test_id", "test_namespace")
+        MockSession.assert_called()
+        MockSession.return_value.get.assert_called()
+
+
+def test_load_entries():
+    with patch("sqlalchemy.create_engine"), patch(
+        "sqlalchemy.orm.Session"
+    ) as MockSession:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        store.load_entries("test_namespace")
+        MockSession.assert_called()
+        MockSession.return_value.query.assert_called()
+        MockSession.return_value.query.return_value.filter_by.assert_called()
+        MockSession.return_value.query.return_value.all.assert_called()
+
+
+def test_query():
+    with patch("sqlalchemy.create_engine"), patch(
+        "sqlalchemy.orm.Session"
+    ) as MockSession:
+        store = PgVectorVectorStore(
+            connection_string="postgresql://postgres:password@localhost:5432/postgres",
+            table_name="test",
+        )
+        store.query("test_query", 10, "test_namespace")
+        MockSession.assert_called()
+        MockSession.return_value.query.assert_called()
+        MockSession.return_value.query.return_value.filter_by.assert_called()
+        MockSession.return_value.query.return_value.limit.assert_called()
+        MockSession.return_value.query.return_value.all.assert_called()
diff --git a/tests/memory/pinecone.py b/tests/memory/pinecone.py
new file mode 100644
index 00000000..54b9026b
--- /dev/null
+++ b/tests/memory/pinecone.py
@@ -0,0 +1,65 @@
+import os
+import pytest
+from unittest.mock import Mock, patch
+from swarms.memory import PineconeVectorStore
+
+api_key = os.getenv("PINECONE_API_KEY") or ""
+
+
+def test_init():
+    with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        MockInit.assert_called_once()
+        MockIndex.assert_called_once()
+        assert store.index == MockIndex.return_value
+
+
+def test_upsert_vector():
+    with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        store.upsert_vector(
+            [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
+        )
+        MockIndex.return_value.upsert.assert_called()
+
+
+def test_load_entry():
+    with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        store.load_entry("test_id", "test_namespace")
+        MockIndex.return_value.fetch.assert_called()
+
+
+def test_load_entries():
+    with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        store.load_entries("test_namespace")
+        MockIndex.return_value.query.assert_called()
+
+
+def test_query():
+    with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        store.query("test_query", 10, "test_namespace")
+        MockIndex.return_value.query.assert_called()
+
+
+def test_create_index():
+    with patch("pinecone.init") as MockInit, patch(
+        "pinecone.Index"
+    ) as MockIndex, patch("pinecone.create_index") as MockCreateIndex:
+        store = PineconeVectorStore(
+            api_key=api_key, index_name="test_index", environment="test_env"
+        )
+        store.create_index("test_index")
+        MockCreateIndex.assert_called()