parent
20ad31543a
commit
51a271172f
@ -1,4 +1,5 @@
|
|||||||
from swarms.memory.base_vectordb import VectorDatabase
|
from swarms.memory.base_vectordb import VectorDatabase
|
||||||
from swarms.memory.short_term_memory import ShortTermMemory
|
from swarms.memory.short_term_memory import ShortTermMemory
|
||||||
|
from swarms.memory.sqlite import SQLiteDB
|
||||||
|
|
||||||
__all__ = ["VectorDatabase", "ShortTermMemory"]
|
__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"]
|
||||||
|
@ -0,0 +1,120 @@
|
|||||||
|
from typing import List, Tuple, Any, Optional
|
||||||
|
from swarms.memory.base_vectordb import VectorDatabase
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sqlite3
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install sqlite3 to use the SQLiteDB class."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteDB(VectorDatabase):
|
||||||
|
"""
|
||||||
|
A reusable class for SQLite database operations with methods for adding,
|
||||||
|
deleting, updating, and querying data.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
db_path (str): The file path to the SQLite database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
"""
|
||||||
|
Initializes the SQLiteDB class with the given database path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path (str): The file path to the SQLite database.
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
|
||||||
|
def execute_query(
|
||||||
|
self, query: str, params: Optional[Tuple[Any, ...]] = None
|
||||||
|
) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
Executes a SQL query and returns fetched results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query to execute.
|
||||||
|
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple]: The results fetched from the database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params or ())
|
||||||
|
return cursor.fetchall()
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error executing query: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def add(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Adds a new entry to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query for insertion.
|
||||||
|
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error adding new entry: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def delete(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Deletes an entry from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query for deletion.
|
||||||
|
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error deleting entry: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def update(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Updates an entry in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query for updating.
|
||||||
|
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(query, params)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error updating entry: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, query: str, params: Optional[Tuple[Any, ...]] = None
|
||||||
|
) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
Fetches data from the database based on a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The SQL query to execute.
|
||||||
|
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple]: The results fetched from the database.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.execute_query(query, params)
|
||||||
|
except Exception as error:
|
||||||
|
print(f"Error querying database: {error}")
|
||||||
|
raise error
|
@ -0,0 +1,104 @@
|
|||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
from swarms.memory.sqlite import SQLiteDB
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db():
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return SQLiteDB(":memory:")
|
||||||
|
|
||||||
|
|
||||||
|
def test_add(db):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||||
|
result = db.query("SELECT * FROM test")
|
||||||
|
assert result == [(1, "test")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(db):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||||
|
db.delete("DELETE FROM test WHERE name = ?", ("test",))
|
||||||
|
result = db.query("SELECT * FROM test")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(db):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||||
|
db.update(
|
||||||
|
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
|
||||||
|
)
|
||||||
|
result = db.query("SELECT * FROM test")
|
||||||
|
assert result == [(1, "new")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_query(db):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||||
|
result = db.query("SELECT * FROM test WHERE name = ?", ("test",))
|
||||||
|
assert result == [(1, "test")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_query(db):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||||
|
result = db.execute_query(
|
||||||
|
"SELECT * FROM test WHERE name = ?", ("test",)
|
||||||
|
)
|
||||||
|
assert result == [(1, "test")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_without_params(db):
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
db.add("INSERT INTO test (name) VALUES (?)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_without_params(db):
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
db.delete("DELETE FROM test WHERE name = ?")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_without_params(db):
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
db.update("UPDATE test SET name = ? WHERE name = ?")
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_without_params(db):
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
db.query("SELECT * FROM test WHERE name = ?")
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_query_without_params(db):
|
||||||
|
with pytest.raises(sqlite3.ProgrammingError):
|
||||||
|
db.execute_query("SELECT * FROM test WHERE name = ?")
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_with_wrong_query(db):
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.add("INSERT INTO wrong (name) VALUES (?)", ("test",))
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_with_wrong_query(db):
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.delete("DELETE FROM wrong WHERE name = ?", ("test",))
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_with_wrong_query(db):
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.update(
|
||||||
|
"UPDATE wrong SET name = ? WHERE name = ?",
|
||||||
|
("new", "test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_with_wrong_query(db):
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.query("SELECT * FROM wrong WHERE name = ?", ("test",))
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_query_with_wrong_query(db):
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute_query(
|
||||||
|
"SELECT * FROM wrong WHERE name = ?", ("test",)
|
||||||
|
)
|
Loading…
Reference in new issue