import pytest
from unittest.mock import patch
from swarms.memory import PgVectorVectorStore


def test_init():
    with patch("sqlalchemy.create_engine") as MockEngine:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        MockEngine.assert_called_once()
        assert store.engine == MockEngine.return_value


def test_init_exception():
    with pytest.raises(ValueError):
        PgVectorVectorStore(
            connection_string="mysql://root:password@localhost:3306/test",
            table_name="test",
        )


def test_setup():
    with patch("sqlalchemy.create_engine") as MockEngine:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        store.setup()
        MockEngine.execute.assert_called()


def test_upsert_vector():
    with patch("sqlalchemy.create_engine"), patch(
        "sqlalchemy.orm.Session"
    ) as MockSession:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        store.upsert_vector(
            [1.0, 2.0, 3.0], "test_id", "test_namespace", {"meta": "data"}
        )
        MockSession.assert_called()
        MockSession.return_value.merge.assert_called()
        MockSession.return_value.commit.assert_called()


def test_load_entry():
    with patch("sqlalchemy.create_engine"), patch(
        "sqlalchemy.orm.Session"
    ) as MockSession:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        store.load_entry("test_id", "test_namespace")
        MockSession.assert_called()
        MockSession.return_value.get.assert_called()


def test_load_entries():
    with patch("sqlalchemy.create_engine"), patch(
        "sqlalchemy.orm.Session"
    ) as MockSession:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        store.load_entries("test_namespace")
        MockSession.assert_called()
        MockSession.return_value.query.assert_called()
        MockSession.return_value.query.return_value.filter_by.assert_called()
        MockSession.return_value.query.return_value.all.assert_called()


def test_query():
    with patch("sqlalchemy.create_engine"), patch(
        "sqlalchemy.orm.Session"
    ) as MockSession:
        store = PgVectorVectorStore(
            connection_string="postgresql://postgres:password@localhost:5432/postgres",
            table_name="test",
        )
        store.query("test_query", 10, "test_namespace")
        MockSession.assert_called()
        MockSession.return_value.query.assert_called()
        MockSession.return_value.query.return_value.filter_by.assert_called()
        MockSession.return_value.query.return_value.limit.assert_called()
        MockSession.return_value.query.return_value.all.assert_called()