tasks for memory storages

pull/64/head
Kye 1 year ago
parent 6a799acd87
commit 2599d6b37f

@ -31,3 +31,8 @@ HF_API_KEY="your_huggingface_api_key_here"
REDIS_HOST=
REDIS_PORT=
#dbs
PINECONE_API_KEY=""

@ -0,0 +1 @@
from swarms.embeddings.pegasus import PegasusEmbedding

@ -1,12 +1,38 @@
import logging
from typing import Union
from pegasus import Pegasus
# import oceandb
# from oceandb.utils.embedding_functions import MultiModalEmbeddingfunction
from pegasus import Pegasus
class PegasusEmbedding:
"""
Pegasus
Args:
modality (str): Modality to use for embedding
multi_process (bool, optional): Whether to use multi-process. Defaults to False.
n_processes (int, optional): Number of processes to use. Defaults to 4.
Usage:
--------------
pegasus = PegasusEmbedding(modality="text")
pegasus.embed("Hello world")
vision
--------------
pegasus = PegasusEmbedding(modality="vision")
pegasus.embed("https://i.imgur.com/1qZ0K8r.jpeg")
audio
--------------
pegasus = PegasusEmbedding(modality="audio")
pegasus.embed("https://www2.cs.uic.edu/~i101/SoundFiles/StarWars60.wav")
"""
def __init__(
self, modality: str, multi_process: bool = False, n_processes: int = 4
):
@ -22,6 +48,7 @@ class PegasusEmbedding:
raise
def embed(self, data: Union[str, list[str]]):
"""Embed the data"""
try:
return self.pegasus.embed(data)
except Exception as e:

@ -1,3 +1,4 @@
from swarms.memory.vector_stores.pinecone import PineconeVector
from swarms.memory.vector_stores.base import BaseVectorStore
from swarms.memory.vector_stores.pg import PgVectorVectorStore
from swarms.memory.ocean import OceanDB

@ -1,36 +1,117 @@
# init ocean
# TODO upload ocean to pip and config it to the abstract class
import logging
from typing import Union, List
from typing import List
import oceandb
from oceandb.utils.embedding_function import MultiModalEmbeddingFunction
class OceanDB:
def __init__(self):
"""
A class to interact with OceanDB.
...
Attributes
----------
client : oceandb.Client
a client to interact with OceanDB
Methods
-------
create_collection(collection_name: str, modality: str):
Creates a new collection in OceanDB.
append_document(collection, document: str, id: str):
Appends a document to a collection in OceanDB.
add_documents(collection, documents: List[str], ids: List[str]):
Adds multiple documents to a collection in OceanDB.
query(collection, query_texts: list[str], n_results: int):
Queries a collection in OceanDB.
"""
def __init__(self, client: oceandb.Client = None):
"""
Constructs all the necessary attributes for the OceanDB object.
Parameters
----------
client : oceandb.Client, optional
a client to interact with OceanDB (default is None, which creates a new client)
"""
try:
self.client = oceandb.Client()
self.client = client if client else oceandb.Client()
print(self.client.heartbeat())
except Exception as e:
logging.error(f"Failed to initialize OceanDB client. Error: {e}")
raise
def create_collection(self, collection_name: str, modality: str):
"""
Creates a new collection in OceanDB.
Parameters
----------
collection_name : str
the name of the new collection
modality : str
the modality of the new collection
Returns
-------
collection
the created collection
"""
try:
embedding_function = MultiModalEmbeddingFunction(modality=modality)
collection = self.client.create_collection(collection_name, embedding_function=embedding_function)
collection = self.client.create_collection(
collection_name, embedding_function=embedding_function
)
return collection
except Exception as e:
logging.error(f"Failed to create collection. Error {e}")
raise
def append_document(self, collection, document: str, id: str):
"""
Appends a document to a collection in OceanDB.
Parameters
----------
collection
the collection to append the document to
document : str
the document to append
id : str
the id of the document
Returns
-------
result
the result of the append operation
"""
try:
return collection.add(documents=[document], ids[id])
return collection.add(documents=[document], ids=[id])
except Exception as e:
logging.error(f"Faield to append document to the collection. Error {e}")
logging.error(f"Failed to append document to the collection. Error {e}")
raise
def add_documents(self, collection, documents: List[str], ids: List[str]):
"""
Adds multiple documents to a collection in OceanDB.
Parameters
----------
collection
the collection to add the documents to
documents : List[str]
the documents to add
ids : List[str]
the ids of the documents
Returns
-------
result
the result of the add operation
"""
try:
return collection.add(documents=documents, ids=ids)
except Exception as e:
@ -38,6 +119,23 @@ class OceanDB:
raise
def query(self, collection, query_texts: list[str], n_results: int):
"""
Queries a collection in OceanDB.
Parameters
----------
collection
the collection to query
query_texts : list[str]
the texts to query
n_results : int
the number of results to return
Returns
-------
results
the results of the query
"""
try:
results = collection.query(query_texts=query_texts, n_results=n_results)
return results

@ -93,7 +93,7 @@ class PineconeVectorStoreStore(BaseVector):
index: pinecone.Index = field(init=False)
def __attrs_post_init__(self) -> None:
""" Post init"""
"""Post init"""
pinecone.init(
api_key=self.api_key,
environment=self.environment,
@ -122,7 +122,7 @@ class PineconeVectorStoreStore(BaseVector):
def load_entry(
self, vector_id: str, namespace: Optional[str] = None
) -> Optional[BaseVector.Entry]:
"""Load entry """
"""Load entry"""
result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
vectors = list(result["vectors"].values())

@ -0,0 +1,95 @@
import pytest
from unittest.mock import Mock, patch
from swarms.memory.oceandb import OceanDB
def test_init():
with patch("oceandb.Client") as MockClient:
MockClient.return_value.heartbeat.return_value = "OK"
db = OceanDB(MockClient)
MockClient.assert_called_once()
assert db.client == MockClient
def test_init_exception():
with patch("oceandb.Client") as MockClient:
MockClient.side_effect = Exception("Client error")
with pytest.raises(Exception) as e:
db = OceanDB(MockClient)
assert str(e.value) == "Client error"
def test_create_collection():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
db.create_collection("test", "modality")
MockClient.create_collection.assert_called_once_with(
"test", embedding_function=Mock.ANY
)
def test_create_collection_exception():
with patch("oceandb.Client") as MockClient:
MockClient.create_collection.side_effect = Exception("Create collection error")
db = OceanDB(MockClient)
with pytest.raises(Exception) as e:
db.create_collection("test", "modality")
assert str(e.value) == "Create collection error"
def test_append_document():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.append_document(collection, "doc", "id")
collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
def test_append_document_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.add.side_effect = Exception("Append document error")
with pytest.raises(Exception) as e:
db.append_document(collection, "doc", "id")
assert str(e.value) == "Append document error"
def test_add_documents():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
collection.add.assert_called_once_with(
documents=["doc1", "doc2"], ids=["id1", "id2"]
)
def test_add_documents_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.add.side_effect = Exception("Add documents error")
with pytest.raises(Exception) as e:
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
assert str(e.value) == "Add documents error"
def test_query():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
db.query(collection, ["query1", "query2"], 2)
collection.query.assert_called_once_with(
query_texts=["query1", "query2"], n_results=2
)
def test_query_exception():
with patch("oceandb.Client") as MockClient:
db = OceanDB(MockClient)
collection = Mock()
collection.query.side_effect = Exception("Query error")
with pytest.raises(Exception) as e:
db.query(collection, ["query1", "query2"], 2)
assert str(e.value) == "Query error"

@ -0,0 +1,91 @@
import pytest
from unittest.mock import patch
from swarms.memory import PgVectorVectorStore
def test_init():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
MockEngine.assert_called_once()
assert store.engine == MockEngine.return_value
def test_init_exception():
with pytest.raises(ValueError):
PgVectorVectorStore(
connection_string="mysql://root:password@localhost:3306/test",
table_name="test",
)
def test_setup():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.setup()
MockEngine.execute.assert_called()
def test_upsert_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.upsert_vector(
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
)
MockSession.assert_called()
MockSession.return_value.merge.assert_called()
MockSession.return_value.commit.assert_called()
def test_load_entry():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.load_entry("test_id", "test_namespace")
MockSession.assert_called()
MockSession.return_value.get.assert_called()
def test_load_entries():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.load_entries("test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.all.assert_called()
def test_query():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string="postgresql://postgres:password@localhost:5432/postgres",
table_name="test",
)
store.query("test_query", 10, "test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.limit.assert_called()
MockSession.return_value.query.return_value.all.assert_called()

@ -0,0 +1,65 @@
import os
import pytest
from unittest.mock import Mock, patch
from swarms.memory import PineconeVectorStore
api_key = os.getenv("PINECONE_API_KEY") or ""
def test_init():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
MockInit.assert_called_once()
MockIndex.assert_called_once()
assert store.index == MockIndex.return_value
def test_upsert_vector():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.upsert_vector(
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
)
MockIndex.return_value.upsert.assert_called()
def test_load_entry():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.load_entry("test_id", "test_namespace")
MockIndex.return_value.fetch.assert_called()
def test_load_entries():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.load_entries("test_namespace")
MockIndex.return_value.query.assert_called()
def test_query():
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.query("test_query", 10, "test_namespace")
MockIndex.return_value.query.assert_called()
def test_create_index():
with patch("pinecone.init") as MockInit, patch(
"pinecone.Index"
) as MockIndex, patch("pinecone.create_index") as MockCreateIndex:
store = PineconeVectorStore(
api_key=api_key, index_name="test_index", environment="test_env"
)
store.create_index("test_index")
MockCreateIndex.assert_called()
Loading…
Cancel
Save