You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.4 KiB
81 lines
2.4 KiB
5 months ago
|
import os
|
||
|
from unittest.mock import patch
|
||
|
|
||
|
from dotenv import load_dotenv
|
||
|
|
||
|
from playground.memory.pg import PostgresDB
|
||
|
|
||
|
load_dotenv()
|
||
|
|
||
|
PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING")
|
||
|
|
||
|
|
||
|
def test_init():
|
||
|
with patch("sqlalchemy.create_engine") as MockEngine:
|
||
|
db = PostgresDB(
|
||
|
connection_string=PSG_CONNECTION_STRING,
|
||
|
table_name="test",
|
||
|
)
|
||
|
MockEngine.assert_called_once()
|
||
|
assert db.engine == MockEngine.return_value
|
||
|
|
||
|
|
||
|
def test_create_vector_model():
|
||
|
with patch("sqlalchemy.create_engine"):
|
||
|
db = PostgresDB(
|
||
|
connection_string=PSG_CONNECTION_STRING,
|
||
|
table_name="test",
|
||
|
)
|
||
|
model = db._create_vector_model()
|
||
|
assert model.__tablename__ == "test"
|
||
|
|
||
|
|
||
|
def test_add_or_update_vector():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
db = PostgresDB(
|
||
|
connection_string=PSG_CONNECTION_STRING,
|
||
|
table_name="test",
|
||
|
)
|
||
|
db.add_or_update_vector(
|
||
|
"test_vector",
|
||
|
"test_id",
|
||
|
"test_namespace",
|
||
|
{"meta": "data"},
|
||
|
)
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.merge.assert_called()
|
||
|
MockSession.return_value.commit.assert_called()
|
||
|
|
||
|
|
||
|
def test_query_vectors():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
db = PostgresDB(
|
||
|
connection_string=PSG_CONNECTION_STRING,
|
||
|
table_name="test",
|
||
|
)
|
||
|
db.query_vectors("test_query", "test_namespace")
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.query.assert_called()
|
||
|
MockSession.return_value.query.return_value.filter_by.assert_called()
|
||
|
MockSession.return_value.query.return_value.filter.assert_called()
|
||
|
MockSession.return_value.query.return_value.all.assert_called()
|
||
|
|
||
|
|
||
|
def test_delete_vector():
|
||
|
with patch("sqlalchemy.create_engine"), patch(
|
||
|
"sqlalchemy.orm.Session"
|
||
|
) as MockSession:
|
||
|
db = PostgresDB(
|
||
|
connection_string=PSG_CONNECTION_STRING,
|
||
|
table_name="test",
|
||
|
)
|
||
|
db.delete_vector("test_id")
|
||
|
MockSession.assert_called()
|
||
|
MockSession.return_value.get.assert_called()
|
||
|
MockSession.return_value.delete.assert_called()
|
||
|
MockSession.return_value.commit.assert_called()
|