You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
3.0 KiB
99 lines
3.0 KiB
import pytest
|
|
from unittest.mock import patch
|
|
from swarms.memory import PgVectorVectorStore
|
|
from dotenv import load_dotenv
|
|
import os
|
|
|
|
load_dotenv()
|
|
|
|
|
|
PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING")
|
|
|
|
|
|
def test_init():
|
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
MockEngine.assert_called_once()
|
|
assert store.engine == MockEngine.return_value
|
|
|
|
|
|
def test_init_exception():
|
|
with pytest.raises(ValueError):
|
|
PgVectorVectorStore(
|
|
connection_string="mysql://root:password@localhost:3306/test",
|
|
table_name="test",
|
|
)
|
|
|
|
|
|
def test_setup():
|
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
store.setup()
|
|
MockEngine.execute.assert_called()
|
|
|
|
|
|
def test_upsert_vector():
|
|
with patch("sqlalchemy.create_engine"), patch(
|
|
"sqlalchemy.orm.Session"
|
|
) as MockSession:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
store.upsert_vector(
|
|
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
|
|
)
|
|
MockSession.assert_called()
|
|
MockSession.return_value.merge.assert_called()
|
|
MockSession.return_value.commit.assert_called()
|
|
|
|
|
|
def test_load_entry():
|
|
with patch("sqlalchemy.create_engine"), patch(
|
|
"sqlalchemy.orm.Session"
|
|
) as MockSession:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
store.load_entry("test_id", "test_namespace")
|
|
MockSession.assert_called()
|
|
MockSession.return_value.get.assert_called()
|
|
|
|
|
|
def test_load_entries():
|
|
with patch("sqlalchemy.create_engine"), patch(
|
|
"sqlalchemy.orm.Session"
|
|
) as MockSession:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
store.load_entries("test_namespace")
|
|
MockSession.assert_called()
|
|
MockSession.return_value.query.assert_called()
|
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
|
MockSession.return_value.query.return_value.all.assert_called()
|
|
|
|
|
|
def test_query():
|
|
with patch("sqlalchemy.create_engine"), patch(
|
|
"sqlalchemy.orm.Session"
|
|
) as MockSession:
|
|
store = PgVectorVectorStore(
|
|
connection_string=PSG_CONNECTION_STRING,
|
|
table_name="test",
|
|
)
|
|
store.query("test_query", 10, "test_namespace")
|
|
MockSession.assert_called()
|
|
MockSession.return_value.query.assert_called()
|
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
|
MockSession.return_value.query.return_value.limit.assert_called()
|
|
MockSession.return_value.query.return_value.all.assert_called()
|