You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/memory/pg.py

98 lines
3.0 KiB

import pytest
from unittest.mock import patch
from swarms.memory import PgVectorVectorStore
from dotenv import load_dotenv
import os
load_dotenv()
PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING")
def test_init():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
MockEngine.assert_called_once()
assert store.engine == MockEngine.return_value
def test_init_exception():
with pytest.raises(ValueError):
PgVectorVectorStore(
connection_string="mysql://root:password@localhost:3306/test",
table_name="test",
)
def test_setup():
with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
store.setup()
MockEngine.execute.assert_called()
def test_upsert_vector():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
store.upsert_vector(
[1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
)
MockSession.assert_called()
MockSession.return_value.merge.assert_called()
MockSession.return_value.commit.assert_called()
def test_load_entry():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
store.load_entry("test_id", "test_namespace")
MockSession.assert_called()
MockSession.return_value.get.assert_called()
def test_load_entries():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
store.load_entries("test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.all.assert_called()
def test_query():
with patch("sqlalchemy.create_engine"), patch(
"sqlalchemy.orm.Session"
) as MockSession:
store = PgVectorVectorStore(
connection_string=PSG_CONNECTION_STRING,
table_name="test",
)
store.query("test_query", 10, "test_namespace")
MockSession.assert_called()
MockSession.return_value.query.assert_called()
MockSession.return_value.query.return_value.filter_by.assert_called()
MockSession.return_value.query.return_value.limit.assert_called()
MockSession.return_value.query.return_value.all.assert_called()