parent
7a504623c4
commit
63ea4187c3
@ -0,0 +1 @@
|
|||||||
|
from swarms.embeddings.pegasus import PegasusEmbedding
|
@ -1,3 +1,4 @@
|
|||||||
from swarms.memory.vector_stores.pinecone import PineconeVector
|
from swarms.memory.vector_stores.pinecone import PineconeVector
|
||||||
from swarms.memory.vector_stores.base import BaseVectorStore
|
from swarms.memory.vector_stores.base import BaseVectorStore
|
||||||
from swarms.memory.vector_stores.pg import PgVectorVectorStore
|
from swarms.memory.vector_stores.pg import PgVectorVectorStore
|
||||||
|
from swarms.memory.ocean import OceanDB
|
||||||
|
@ -0,0 +1,95 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.memory.oceandb import OceanDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.return_value.heartbeat.return_value = "OK"
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
MockClient.assert_called_once()
|
||||||
|
assert db.client == MockClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.side_effect = Exception("Client error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
assert str(e.value) == "Client error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_collection():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
db.create_collection("test", "modality")
|
||||||
|
MockClient.create_collection.assert_called_once_with(
|
||||||
|
"test", embedding_function=Mock.ANY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_collection_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
MockClient.create_collection.side_effect = Exception("Create collection error")
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.create_collection("test", "modality")
|
||||||
|
assert str(e.value) == "Create collection error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_document():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.append_document(collection, "doc", "id")
|
||||||
|
collection.add.assert_called_once_with(documents=["doc"], ids=["id"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_document_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.add.side_effect = Exception("Append document error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.append_document(collection, "doc", "id")
|
||||||
|
assert str(e.value) == "Append document error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_documents():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||||
|
collection.add.assert_called_once_with(
|
||||||
|
documents=["doc1", "doc2"], ids=["id1", "id2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_documents_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.add.side_effect = Exception("Add documents error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.add_documents(collection, ["doc1", "doc2"], ["id1", "id2"])
|
||||||
|
assert str(e.value) == "Add documents error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
db.query(collection, ["query1", "query2"], 2)
|
||||||
|
collection.query.assert_called_once_with(
|
||||||
|
query_texts=["query1", "query2"], n_results=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_exception():
|
||||||
|
with patch("oceandb.Client") as MockClient:
|
||||||
|
db = OceanDB(MockClient)
|
||||||
|
collection = Mock()
|
||||||
|
collection.query.side_effect = Exception("Query error")
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
db.query(collection, ["query1", "query2"], 2)
|
||||||
|
assert str(e.value) == "Query error"
|
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from swarms.memory import PgVectorVectorStore
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
MockEngine.assert_called_once()
|
||||||
|
assert store.engine == MockEngine.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_exception():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PgVectorVectorStore(
|
||||||
|
connection_string="mysql://root:password@localhost:3306/test",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup():
|
||||||
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
store.setup()
|
||||||
|
MockEngine.execute.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_vector():
|
||||||
|
with patch("sqlalchemy.create_engine"), patch(
|
||||||
|
"sqlalchemy.orm.Session"
|
||||||
|
) as MockSession:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
store.upsert_vector(
|
||||||
|
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
|
||||||
|
)
|
||||||
|
MockSession.assert_called()
|
||||||
|
MockSession.return_value.merge.assert_called()
|
||||||
|
MockSession.return_value.commit.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_entry():
|
||||||
|
with patch("sqlalchemy.create_engine"), patch(
|
||||||
|
"sqlalchemy.orm.Session"
|
||||||
|
) as MockSession:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
store.load_entry("test_id", "test_namespace")
|
||||||
|
MockSession.assert_called()
|
||||||
|
MockSession.return_value.get.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_entries():
|
||||||
|
with patch("sqlalchemy.create_engine"), patch(
|
||||||
|
"sqlalchemy.orm.Session"
|
||||||
|
) as MockSession:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
store.load_entries("test_namespace")
|
||||||
|
MockSession.assert_called()
|
||||||
|
MockSession.return_value.query.assert_called()
|
||||||
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
||||||
|
MockSession.return_value.query.return_value.all.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_query():
|
||||||
|
with patch("sqlalchemy.create_engine"), patch(
|
||||||
|
"sqlalchemy.orm.Session"
|
||||||
|
) as MockSession:
|
||||||
|
store = PgVectorVectorStore(
|
||||||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||||||
|
table_name="test",
|
||||||
|
)
|
||||||
|
store.query("test_query", 10, "test_namespace")
|
||||||
|
MockSession.assert_called()
|
||||||
|
MockSession.return_value.query.assert_called()
|
||||||
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
||||||
|
MockSession.return_value.query.return_value.limit.assert_called()
|
||||||
|
MockSession.return_value.query.return_value.all.assert_called()
|
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from swarms.memory import PineconeVectorStore
|
||||||
|
|
||||||
|
api_key = os.getenv("PINECONE_API_KEY") or ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
MockInit.assert_called_once()
|
||||||
|
MockIndex.assert_called_once()
|
||||||
|
assert store.index == MockIndex.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_vector():
|
||||||
|
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
store.upsert_vector(
|
||||||
|
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
|
||||||
|
)
|
||||||
|
MockIndex.return_value.upsert.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_entry():
|
||||||
|
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
store.load_entry("test_id", "test_namespace")
|
||||||
|
MockIndex.return_value.fetch.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_entries():
|
||||||
|
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
store.load_entries("test_namespace")
|
||||||
|
MockIndex.return_value.query.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_query():
|
||||||
|
with patch("pinecone.init") as MockInit, patch("pinecone.Index") as MockIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
store.query("test_query", 10, "test_namespace")
|
||||||
|
MockIndex.return_value.query.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_index():
|
||||||
|
with patch("pinecone.init") as MockInit, patch(
|
||||||
|
"pinecone.Index"
|
||||||
|
) as MockIndex, patch("pinecone.create_index") as MockCreateIndex:
|
||||||
|
store = PineconeVectorStore(
|
||||||
|
api_key=api_key, index_name="test_index", environment="test_env"
|
||||||
|
)
|
||||||
|
store.create_index("test_index")
|
||||||
|
MockCreateIndex.assert_called()
|
Loading…
Reference in new issue