You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.1 KiB
92 lines
3.1 KiB
1 year ago
|
import pytest
|
||
|
from unittest.mock import patch
|
||
|
from swarms.memory import PgVectorVectorStore
|
||
|
|
||
|
|
||
|
def test_init():
|
||
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
MockEngine.assert_called_once()
|
||
|
assert store.engine == MockEngine.return_value
|
||
|
|
||
|
|
||
|
def test_init_exception():
|
||
|
with pytest.raises(ValueError):
|
||
|
PgVectorVectorStore(
|
||
|
connection_string="mysql://root:password@localhost:3306/test",
|
||
|
table_name="test",
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_setup():
|
||
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
store.setup()
|
||
|
MockEngine.execute.assert_called()
|
||
|
|
||
|
|
||
|
def test_upsert_vector():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
store.upsert_vector(
|
||
|
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
|
||
|
)
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.merge.assert_called()
|
||
|
MockSession.return_value.commit.assert_called()
|
||
|
|
||
|
|
||
|
def test_load_entry():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
store.load_entry("test_id", "test_namespace")
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.get.assert_called()
|
||
|
|
||
|
|
||
|
def test_load_entries():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
store.load_entries("test_namespace")
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.query.assert_called()
|
||
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
||
|
MockSession.return_value.query.return_value.all.assert_called()
|
||
|
|
||
|
|
||
|
def test_query():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
store = PgVectorVectorStore(
|
||
|
connection_string="postgresql://postgres:password@localhost:5432/postgres",
|
||
|
table_name="test",
|
||
|
)
|
||
|
store.query("test_query", 10, "test_namespace")
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.query.assert_called()
|
||
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
||
|
MockSession.return_value.query.return_value.limit.assert_called()
|
||
|
MockSession.return_value.query.return_value.all.assert_called()
|