You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/memory/test_pq_db.py

81 lines
2.4 KiB

6 months ago
import os
from unittest.mock import patch
from dotenv import load_dotenv
from playground.memory.pg import PostgresDB
load_dotenv()
PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING")
def test_init():
with patch("sqlalchemy.create_engine") as MockEngine:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
MockEngine.assert_called_once()
assert db.engine == MockEngine.return_value
def test_create_vector_model():
with patch("sqlalchemy.create_engine"):
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
model = db._create_vector_model()
assert model.__tablename__ == "test"
def test_add_or_update_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
db.add_or_update_vector(
"test_vector",
"test_id",
"test_namespace",
{"meta": "data"},
)
MockSession.assert_called()
MockSession.return_value.merge.assert_called()
MockSession.return_value.commit.assert_called()
def test_query_vectors():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
db.query_vectors("test_query", "test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.filter.assert_called()
MockSession.return_value.query.return_value.all.assert_called()
def test_delete_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
db = PostgresDB(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
db.delete_vector("test_id")
MockSession.assert_called()
MockSession.return_value.get.assert_called()
MockSession.return_value.delete.assert_called()
MockSession.return_value.commit.assert_called()