parent
20ad31543a
commit
51a271172f
@ -1,4 +1,5 @@
|
||||
from swarms.memory.base_vectordb import VectorDatabase
|
||||
from swarms.memory.short_term_memory import ShortTermMemory
|
||||
from swarms.memory.sqlite import SQLiteDB
|
||||
|
||||
__all__ = ["VectorDatabase", "ShortTermMemory"]
|
||||
__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"]
|
||||
|
@ -0,0 +1,120 @@
|
||||
from typing import List, Tuple, Any, Optional
|
||||
from swarms.memory.base_vectordb import VectorDatabase
|
||||
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install sqlite3 to use the SQLiteDB class."
|
||||
)
|
||||
|
||||
|
||||
class SQLiteDB(VectorDatabase):
|
||||
"""
|
||||
A reusable class for SQLite database operations with methods for adding,
|
||||
deleting, updating, and querying data.
|
||||
|
||||
Attributes:
|
||||
db_path (str): The file path to the SQLite database.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
"""
|
||||
Initializes the SQLiteDB class with the given database path.
|
||||
|
||||
Args:
|
||||
db_path (str): The file path to the SQLite database.
|
||||
"""
|
||||
self.db_path = db_path
|
||||
|
||||
def execute_query(
|
||||
self, query: str, params: Optional[Tuple[Any, ...]] = None
|
||||
) -> List[Tuple]:
|
||||
"""
|
||||
Executes a SQL query and returns fetched results.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: The results fetched from the database.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params or ())
|
||||
return cursor.fetchall()
|
||||
except Exception as error:
|
||||
print(f"Error executing query: {error}")
|
||||
raise error
|
||||
|
||||
def add(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||
"""
|
||||
Adds a new entry to the database.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query for insertion.
|
||||
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
except Exception as error:
|
||||
print(f"Error adding new entry: {error}")
|
||||
raise error
|
||||
|
||||
def delete(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||
"""
|
||||
Deletes an entry from the database.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query for deletion.
|
||||
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
except Exception as error:
|
||||
print(f"Error deleting entry: {error}")
|
||||
raise error
|
||||
|
||||
def update(self, query: str, params: Tuple[Any, ...]) -> None:
|
||||
"""
|
||||
Updates an entry in the database.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query for updating.
|
||||
params (Tuple[Any, ...]): The parameters to substitute into the query.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
except Exception as error:
|
||||
print(f"Error updating entry: {error}")
|
||||
raise error
|
||||
|
||||
def query(
|
||||
self, query: str, params: Optional[Tuple[Any, ...]] = None
|
||||
) -> List[Tuple]:
|
||||
"""
|
||||
Fetches data from the database based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: The results fetched from the database.
|
||||
"""
|
||||
try:
|
||||
return self.execute_query(query, params)
|
||||
except Exception as error:
|
||||
print(f"Error querying database: {error}")
|
||||
raise error
|
@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
import sqlite3
|
||||
from swarms.memory.sqlite import SQLiteDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute(
|
||||
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
|
||||
)
|
||||
conn.commit()
|
||||
return SQLiteDB(":memory:")
|
||||
|
||||
|
||||
def test_add(db):
|
||||
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||
result = db.query("SELECT * FROM test")
|
||||
assert result == [(1, "test")]
|
||||
|
||||
|
||||
def test_delete(db):
|
||||
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||
db.delete("DELETE FROM test WHERE name = ?", ("test",))
|
||||
result = db.query("SELECT * FROM test")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_update(db):
|
||||
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||
db.update(
|
||||
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
|
||||
)
|
||||
result = db.query("SELECT * FROM test")
|
||||
assert result == [(1, "new")]
|
||||
|
||||
|
||||
def test_query(db):
|
||||
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||
result = db.query("SELECT * FROM test WHERE name = ?", ("test",))
|
||||
assert result == [(1, "test")]
|
||||
|
||||
|
||||
def test_execute_query(db):
|
||||
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
|
||||
result = db.execute_query(
|
||||
"SELECT * FROM test WHERE name = ?", ("test",)
|
||||
)
|
||||
assert result == [(1, "test")]
|
||||
|
||||
|
||||
def test_add_without_params(db):
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
db.add("INSERT INTO test (name) VALUES (?)")
|
||||
|
||||
|
||||
def test_delete_without_params(db):
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
db.delete("DELETE FROM test WHERE name = ?")
|
||||
|
||||
|
||||
def test_update_without_params(db):
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
db.update("UPDATE test SET name = ? WHERE name = ?")
|
||||
|
||||
|
||||
def test_query_without_params(db):
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
db.query("SELECT * FROM test WHERE name = ?")
|
||||
|
||||
|
||||
def test_execute_query_without_params(db):
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
db.execute_query("SELECT * FROM test WHERE name = ?")
|
||||
|
||||
|
||||
def test_add_with_wrong_query(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.add("INSERT INTO wrong (name) VALUES (?)", ("test",))
|
||||
|
||||
|
||||
def test_delete_with_wrong_query(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.delete("DELETE FROM wrong WHERE name = ?", ("test",))
|
||||
|
||||
|
||||
def test_update_with_wrong_query(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.update(
|
||||
"UPDATE wrong SET name = ? WHERE name = ?",
|
||||
("new", "test"),
|
||||
)
|
||||
|
||||
|
||||
def test_query_with_wrong_query(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.query("SELECT * FROM wrong WHERE name = ?", ("test",))
|
||||
|
||||
|
||||
def test_execute_query_with_wrong_query(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute_query(
|
||||
"SELECT * FROM wrong WHERE name = ?", ("test",)
|
||||
)
|
Loading…
Reference in new issue