tasks for memory storages

Former-commit-id: 2599d6b37f
discord-bot-framework
Kye 1 year ago
parent 7a504623c4
commit 63ea4187c3

@ -31,3 +31,8 @@ HF_API_KEY="your_huggingface_api_key_here"
REDIS_HOST= REDIS_HOST=
REDIS_PORT= REDIS_PORT=
#dbs
PINECONE_API_KEY=""

@ -0,0 +1 @@
from swarms.embeddings.pegasus import PegasusEmbedding

@ -1,12 +1,38 @@
import logging import logging
from typing import Union from typing import Union
from pegasus import Pegasus
# import oceandb from pegasus import Pegasus
# from oceandb.utils.embedding_functions import MultiModalEmbeddingfunction
class PegasusEmbedding: class PegasusEmbedding:
"""
Pegasus
Args:
modality (str): Modality to use for embedding
multi_process (bool, optional): Whether to use multi-process. Defaults to False.
n_processes (int, optional): Number of processes to use. Defaults to 4.
Usage:
--------------
pegasus = PegasusEmbedding(modality="text")
pegasus.embed("Hello world")
vision
--------------
pegasus = PegasusEmbedding(modality="vision")
pegasus.embed("https://i.imgur.com/1qZ0K8r.jpeg")
audio
--------------
pegasus = PegasusEmbedding(modality="audio")
pegasus.embed("https://www2.cs.uic.edu/~i101/SoundFiles/StarWars60.wav")
"""
def __init__( def __init__(
self, modality: str, multi_process: bool = False, n_processes: int = 4 self, modality: str, multi_process: bool = False, n_processes: int = 4
): ):
@ -22,6 +48,7 @@ class PegasusEmbedding:
raise raise
def embed(self, data: Union[str, list[str]]): def embed(self, data: Union[str, list[str]]):
"""Embed the data"""
try: try:
return self.pegasus.embed(data) return self.pegasus.embed(data)
except Exception as e: except Exception as e:

@ -1,3 +1,4 @@
from swarms.memory.vector_stores.pinecone import PineconeVector from swarms.memory.vector_stores.pinecone import PineconeVector
from swarms.memory.vector_stores.base import BaseVectorStore from swarms.memory.vector_stores.base import BaseVectorStore
from swarms.memory.vector_stores.pg import PgVectorVectorStore from swarms.memory.vector_stores.pg import PgVectorVectorStore
from swarms.memory.ocean import OceanDB

@ -1,36 +1,117 @@
# init ocean
# TODO upload ocean to pip and config it to the abstract class
import logging import logging
from typing import Union, List from typing import List
import oceandb import oceandb
from oceandb.utils.embedding_function import MultiModalEmbeddingFunction from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
class OceanDB: class OceanDB:
def __init__(self): """
A class to interact with OceanDB.
...
Attributes
----------
client : oceandb.Client
a client to interact with OceanDB
Methods
-------
create_collection(collection_name: str, modality: str):
Creates a new collection in OceanDB.
append_document(collection, document: str, id: str):
Appends a document to a collection in OceanDB.
add_documents(collection, documents: List[str], ids: List[str]):
Adds multiple documents to a collection in OceanDB.
query(collection, query_texts: list[str], n_results: int):
Queries a collection in OceanDB.
"""
def __init__(self, client: oceandb.Client = None):
"""
Constructs all the necessary attributes for the OceanDB object.
Parameters
----------
client : oceandb.Client, optional
a client to interact with OceanDB (default is None, which creates a new client)
"""
try: try:
self.client = oceandb.Client() self.client = client if client else oceandb.Client()
print(self.client.heartbeat()) print(self.client.heartbeat())
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize OceanDB client. Error: {e}") logging.error(f"Failed to initialize OceanDB client. Error: {e}")
raise
def create_collection(self, collection_name: str, modality: str): def create_collection(self, collection_name: str, modality: str):
"""
Creates a new collection in OceanDB.
Parameters
----------
collection_name : str
the name of the new collection
modality : str
the modality of the new collection
Returns
-------
collection
the created collection
"""
try: try:
embedding_function = MultiModalEmbeddingFunction(modality=modality) embedding_function = MultiModalEmbeddingFunction(modality=modality)
collection = self.client.create_collection(collection_name, embedding_function=embedding_function) collection = self.client.create_collection(
collection_name, embedding_function=embedding_function
)
return collection return collection
except Exception as e: except Exception as e:
logging.error(f"Failed to create collection. Error {e}") logging.error(f"Failed to create collection. Error {e}")
raise
def append_document(self, collection, document: str, id: str): def append_document(self, collection, document: str, id: str):
"""
Appends a document to a collection in OceanDB.
Parameters
----------
collection
the collection to append the document to
document : str
the document to append
id : str
the id of the document
Returns
-------
result
the result of the append operation
"""
try: try:
return collection.add(documents=[document], ids[id]) return collection.add(documents=[document], ids=[id])
except Exception as e: except Exception as e:
logging.error(f"Faield to append document to the collection. Error {e}") logging.error(f"Failed to append document to the collection. Error {e}")
raise raise
def add_documents(self, collection, documents: List[str], ids: List[str]): def add_documents(self, collection, documents: List[str], ids: List[str]):
"""
Adds multiple documents to a collection in OceanDB.
Parameters
----------
collection
the collection to add the documents to
documents : List[str]
the documents to add
ids : List[str]
the ids of the documents
Returns
-------
result
the result of the add operation
"""
try: try:
return collection.add(documents=documents, ids=ids) return collection.add(documents=documents, ids=ids)
except Exception as e: except Exception as e:
@ -38,6 +119,23 @@ class OceanDB:
raise raise
def query(self, collection, query_texts: list[str], n_results: int): def query(self, collection, query_texts: list[str], n_results: int):
"""
Queries a collection in OceanDB.
Parameters
----------
collection
the collection to query
query_texts : list[str]
the texts to query
n_results : int
the number of results to return
Returns
-------
results
the results of the query
"""
try: try:
results = collection.query(query_texts=query_texts, n_results=n_results) results = collection.query(query_texts=query_texts, n_results=n_results)
return results return results

@ -93,7 +93,7 @@ class PineconeVectorStoreStore(BaseVector):
index: pinecone.Index = field(init=False) index: pinecone.Index = field(init=False)
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
""" Post init""" """Post init"""
pinecone.init( pinecone.init(
api_key=self.api_key, api_key=self.api_key,
environment=self.environment, environment=self.environment,
@ -122,7 +122,7 @@ class PineconeVectorStoreStore(BaseVector):
def load_entry( def load_entry(
self, vector_id: str, namespace: Optional[str] = None self, vector_id: str, namespace: Optional[str] = None
) -> Optional[BaseVector.Entry]: ) -> Optional[BaseVector.Entry]:
"""Load entry """ """Load entry"""
result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
vectors = list(result["vectors"].values()) vectors = list(result["vectors"].values())

@ -0,0 +1,95 @@
import pytest
from unittest.mock import Mock, patch
from swarms.memory.oceandb import OceanDB
def test_init():
with patch("oceandb.Client") as MockClient:
MockClient.return_value.heartbeat.return_value = "OK"
db = OceanDB(MockClient)
MockClient.assert_called_once()
assert db.client == MockClient
def test_init_exception():
with patch("oceandb.Client") as MockClient:
MockClient.side_effect = Exception("Client error")
with pytest.raises(Exception) as e:
db = OceanDB(MockClient)
assert str(e.value) == "Client error"
def test_create_collection():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
db.create_collection("test", "modality")
MockClient.create_collection.assert_called_once_with(
"test", embedding_function=Mock.ANY
)
def test_create_collection_exception():
with patch("oceandb.Client") as MockClient:
MockClient.create_collection.side_effect = Exception("Create collection error")
db = OceanDB(MockClient)
with pytest.raises(Exception) as e:
db.create_collection("test", "modality")
assert str(e.value) == "Create collection error"
def test_append_document():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.append_document(collection, "doc", "id")
collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
def test_append_document_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.add.side_effect = Exception("Append document error")
with pytest.raises(Exception) as e:
db.append_document(collection, "doc", "id")
assert str(e.value) == "Append document error"
def test_add_documents():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
collection.add.assert_called_once_with(
documents=["doc1", "doc2"], ids=["id1", "id2"]
)
def test_add_documents_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.add.side_effect = Exception("Add documents error")
with pytest.raises(Exception) as e:
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
assert str(e.value) == "Add documents error"
def test_query():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.query(collection, ["query1", "query2"], 2)
collection.query.assert_called_once_with(
query_texts=["query1", "query2"], n_results=2
)
def test_query_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.query.side_effect = Exception("Query error")
with pytest.raises(Exception) as e:
db.query(collection, ["query1", "query2"], 2)
assert str(e.value) == "Query error"

@ -0,0 +1,91 @@
import pytest
from unittest.mock import patch
from swarms.memory import PgVectorVectorStore
def test_init():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
MockEngine.assert_called_once()
assert store.engine == MockEngine.return_value
def test_init_exception():
with pytest.raises(ValueError):
PgVectorVectorStore(
connection_string="mysql://root:password@localhost:3306/test",
table_name="test",
)
def test_setup():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.setup()
MockEngine.execute.assert_called()
def test_upsert_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.upsert_vector(
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
)
MockSession.assert_called()
MockSession.return_value.merge.assert_called()
MockSession.return_value.commit.assert_called()
def test_load_entry():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.load_entry("test_id", "test_namespace")
MockSession.assert_called()
MockSession.return_value.get.assert_called()
def test_load_entries():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.load_entries("test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.all.assert_called()
def test_query():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.query("test_query", 10, "test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.limit.assert_called()
MockSession.return_value.query.return_value.all.assert_called()

@ -0,0 +1,65 @@
import os
import pytest
from unittest.mock import Mock, patch
from swarms.memory import PineconeVectorStore
api_key = os.getenv("PINECONE_API_KEY") or ""
def test_init():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
MockInit.assert_called_once()
MockIndex.assert_called_once()
assert store.index == MockIndex.return_value
def test_upsert_vector():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.upsert_vector(
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
)
MockIndex.return_value.upsert.assert_called()
def test_load_entry():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.load_entry("test_id", "test_namespace")
MockIndex.return_value.fetch.assert_called()
def test_load_entries():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.load_entries("test_namespace")
MockIndex.return_value.query.assert_called()
def test_query():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.query("test_query", 10, "test_namespace")
MockIndex.return_value.query.assert_called()
def test_create_index():
with patch("pinecone.init") as MockInit, patch(
"pinecone.Index"
) as MockIndex, patch("pinecone.create_index") as MockCreateIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.create_index("test_index")
MockCreateIndex.assert_called()
Loading…
Cancel
Save