parent
26033cefd1
commit
a5154aa26f
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,813 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import datetime
|
||||
from typing import List, Optional, Union, Dict
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import uuid
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from loguru import logger
|
||||
|
||||
LOGURU_AVAILABLE = True
|
||||
except ImportError:
|
||||
LOGURU_AVAILABLE = False
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Enum for different types of messages in the conversation."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Data class representing a message in the conversation."""
|
||||
|
||||
role: str
|
||||
content: Union[str, dict, list]
|
||||
timestamp: Optional[str] = None
|
||||
message_type: Optional[MessageType] = None
|
||||
metadata: Optional[Dict] = None
|
||||
token_count: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class SQLiteConversation:
|
||||
"""
|
||||
A production-grade SQLite wrapper class for managing conversation history.
|
||||
This class provides persistent storage for conversations with various features
|
||||
like message tracking, timestamps, and metadata support.
|
||||
|
||||
Attributes:
|
||||
db_path (str): Path to the SQLite database file
|
||||
table_name (str): Name of the table to store conversations
|
||||
enable_timestamps (bool): Whether to track message timestamps
|
||||
enable_logging (bool): Whether to enable logging
|
||||
use_loguru (bool): Whether to use loguru for logging
|
||||
max_retries (int): Maximum number of retries for database operations
|
||||
connection_timeout (float): Timeout for database connections
|
||||
current_conversation_id (str): Current active conversation ID
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str = "conversations.db",
|
||||
table_name: str = "conversations",
|
||||
enable_timestamps: bool = True,
|
||||
enable_logging: bool = True,
|
||||
use_loguru: bool = True,
|
||||
max_retries: int = 3,
|
||||
connection_timeout: float = 5.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the SQLite conversation manager.
|
||||
|
||||
Args:
|
||||
db_path (str): Path to the SQLite database file
|
||||
table_name (str): Name of the table to store conversations
|
||||
enable_timestamps (bool): Whether to track message timestamps
|
||||
enable_logging (bool): Whether to enable logging
|
||||
use_loguru (bool): Whether to use loguru for logging
|
||||
max_retries (int): Maximum number of retries for database operations
|
||||
connection_timeout (float): Timeout for database connections
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.table_name = table_name
|
||||
self.enable_timestamps = enable_timestamps
|
||||
self.enable_logging = enable_logging
|
||||
self.use_loguru = use_loguru and LOGURU_AVAILABLE
|
||||
self.max_retries = max_retries
|
||||
self.connection_timeout = connection_timeout
|
||||
self._lock = threading.Lock()
|
||||
self.current_conversation_id = (
|
||||
self._generate_conversation_id()
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
if self.enable_logging:
|
||||
if self.use_loguru:
|
||||
self.logger = logger
|
||||
else:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Initialize database
|
||||
self._init_db()
|
||||
|
||||
def _generate_conversation_id(self) -> str:
|
||||
"""Generate a unique conversation ID using UUID and timestamp."""
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"conv_{timestamp}_{unique_id}"
|
||||
|
||||
def start_new_conversation(self) -> str:
|
||||
"""
|
||||
Start a new conversation and return its ID.
|
||||
|
||||
Returns:
|
||||
str: The new conversation ID
|
||||
"""
|
||||
self.current_conversation_id = (
|
||||
self._generate_conversation_id()
|
||||
)
|
||||
return self.current_conversation_id
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize the database and create necessary tables."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT,
|
||||
message_type TEXT,
|
||||
metadata TEXT,
|
||||
token_count INTEGER,
|
||||
conversation_id TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Context manager for database connections with retry logic."""
|
||||
conn = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path), timeout=self.connection_timeout
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
yield conn
|
||||
break
|
||||
except sqlite3.Error as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
if self.enable_logging:
|
||||
self.logger.warning(
|
||||
f"Database connection attempt {attempt + 1} failed: {e}"
|
||||
)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def add(
|
||||
self,
|
||||
role: str,
|
||||
content: Union[str, dict, list],
|
||||
message_type: Optional[MessageType] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
token_count: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Add a message to the current conversation.
|
||||
|
||||
Args:
|
||||
role (str): The role of the speaker
|
||||
content (Union[str, dict, list]): The content of the message
|
||||
message_type (Optional[MessageType]): Type of the message
|
||||
metadata (Optional[Dict]): Additional metadata for the message
|
||||
token_count (Optional[int]): Number of tokens in the message
|
||||
|
||||
Returns:
|
||||
int: The ID of the inserted message
|
||||
"""
|
||||
timestamp = (
|
||||
datetime.datetime.now().isoformat()
|
||||
if self.enable_timestamps
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(content, (dict, list)):
|
||||
content = json.dumps(content)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(role, content, timestamp, message_type, metadata, token_count, conversation_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
role,
|
||||
content,
|
||||
timestamp,
|
||||
message_type.value if message_type else None,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
token_count,
|
||||
self.current_conversation_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def batch_add(self, messages: List[Message]) -> List[int]:
|
||||
"""
|
||||
Add multiple messages to the current conversation.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): List of messages to add
|
||||
|
||||
Returns:
|
||||
List[int]: List of inserted message IDs
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
message_ids = []
|
||||
|
||||
for message in messages:
|
||||
content = message.content
|
||||
if isinstance(content, (dict, list)):
|
||||
content = json.dumps(content)
|
||||
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(role, content, timestamp, message_type, metadata, token_count, conversation_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
message.role,
|
||||
content,
|
||||
(
|
||||
message.timestamp.isoformat()
|
||||
if message.timestamp
|
||||
else None
|
||||
),
|
||||
(
|
||||
message.message_type.value
|
||||
if message.message_type
|
||||
else None
|
||||
),
|
||||
(
|
||||
json.dumps(message.metadata)
|
||||
if message.metadata
|
||||
else None
|
||||
),
|
||||
message.token_count,
|
||||
self.current_conversation_id,
|
||||
),
|
||||
)
|
||||
message_ids.append(cursor.lastrowid)
|
||||
|
||||
conn.commit()
|
||||
return message_ids
|
||||
|
||||
def get_str(self) -> str:
|
||||
"""
|
||||
Get the current conversation history as a formatted string.
|
||||
|
||||
Returns:
|
||||
str: Formatted conversation history
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY id ASC
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
content = row["content"]
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
timestamp = (
|
||||
f"[{row['timestamp']}] "
|
||||
if row["timestamp"]
|
||||
else ""
|
||||
)
|
||||
messages.append(
|
||||
f"{timestamp}{row['role']}: {content}"
|
||||
)
|
||||
|
||||
return "\n".join(messages)
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get messages from the current conversation with optional pagination.
|
||||
|
||||
Args:
|
||||
limit (Optional[int]): Maximum number of messages to return
|
||||
offset (Optional[int]): Number of messages to skip
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of message dictionaries
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
query = f"""
|
||||
SELECT * FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
params = [self.current_conversation_id]
|
||||
|
||||
if limit is not None:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
if offset is not None:
|
||||
query += " OFFSET ?"
|
||||
params.append(offset)
|
||||
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def delete_current_conversation(self) -> bool:
|
||||
"""
|
||||
Delete the current conversation.
|
||||
|
||||
Returns:
|
||||
bool: True if deletion was successful
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE conversation_id = ?",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
message_id: int,
|
||||
content: Union[str, dict, list],
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update an existing message in the current conversation.
|
||||
|
||||
Args:
|
||||
message_id (int): ID of the message to update
|
||||
content (Union[str, dict, list]): New content for the message
|
||||
metadata (Optional[Dict]): New metadata for the message
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
if isinstance(content, (dict, list)):
|
||||
content = json.dumps(content)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
UPDATE {self.table_name}
|
||||
SET content = ?, metadata = ?
|
||||
WHERE id = ? AND conversation_id = ?
|
||||
""",
|
||||
(
|
||||
content,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
message_id,
|
||||
self.current_conversation_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def search_messages(self, query: str) -> List[Dict]:
|
||||
"""
|
||||
Search for messages containing specific text in the current conversation.
|
||||
|
||||
Args:
|
||||
query (str): Text to search for
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of matching messages
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT * FROM {self.table_name}
|
||||
WHERE conversation_id = ? AND content LIKE ?
|
||||
""",
|
||||
(self.current_conversation_id, f"%{query}%"),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get statistics about the current conversation.
|
||||
|
||||
Returns:
|
||||
Dict: Statistics about the conversation
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
COUNT(*) as total_messages,
|
||||
COUNT(DISTINCT role) as unique_roles,
|
||||
SUM(token_count) as total_tokens,
|
||||
MIN(timestamp) as first_message,
|
||||
MAX(timestamp) as last_message
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
return dict(cursor.fetchone())
|
||||
|
||||
def clear_all(self) -> bool:
|
||||
"""
|
||||
Clear all messages from the database.
|
||||
|
||||
Returns:
|
||||
bool: True if clearing was successful
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"DELETE FROM {self.table_name}")
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
def get_conversation_id(self) -> str:
|
||||
"""
|
||||
Get the current conversation ID.
|
||||
|
||||
Returns:
|
||||
str: The current conversation ID
|
||||
"""
|
||||
return self.current_conversation_id
|
||||
|
||||
def to_dict(self) -> List[Dict]:
|
||||
"""
|
||||
Convert the current conversation to a list of dictionaries.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of message dictionaries
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT role, content, timestamp, message_type, metadata, token_count
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY id ASC
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
content = row["content"]
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
message = {"role": row["role"], "content": content}
|
||||
|
||||
if row["timestamp"]:
|
||||
message["timestamp"] = row["timestamp"]
|
||||
if row["message_type"]:
|
||||
message["message_type"] = row["message_type"]
|
||||
if row["metadata"]:
|
||||
message["metadata"] = json.loads(row["metadata"])
|
||||
if row["token_count"]:
|
||||
message["token_count"] = row["token_count"]
|
||||
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""
|
||||
Convert the current conversation to a JSON string.
|
||||
|
||||
Returns:
|
||||
str: JSON string representation of the conversation
|
||||
"""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Convert the current conversation to a YAML string.
|
||||
|
||||
Returns:
|
||||
str: YAML string representation of the conversation
|
||||
"""
|
||||
return yaml.dump(self.to_dict())
|
||||
|
||||
def save_as_json(self, filename: str) -> bool:
|
||||
"""
|
||||
Save the current conversation to a JSON file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to save the JSON file
|
||||
|
||||
Returns:
|
||||
bool: True if save was successful
|
||||
"""
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.enable_logging:
|
||||
self.logger.error(
|
||||
f"Failed to save conversation to JSON: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def save_as_yaml(self, filename: str) -> bool:
|
||||
"""
|
||||
Save the current conversation to a YAML file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to save the YAML file
|
||||
|
||||
Returns:
|
||||
bool: True if save was successful
|
||||
"""
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
yaml.dump(self.to_dict(), f)
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.enable_logging:
|
||||
self.logger.error(
|
||||
f"Failed to save conversation to YAML: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def load_from_json(self, filename: str) -> bool:
|
||||
"""
|
||||
Load a conversation from a JSON file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the JSON file
|
||||
|
||||
Returns:
|
||||
bool: True if load was successful
|
||||
"""
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
messages = json.load(f)
|
||||
|
||||
# Start a new conversation
|
||||
self.start_new_conversation()
|
||||
|
||||
# Add all messages
|
||||
for message in messages:
|
||||
self.add(
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
message_type=(
|
||||
MessageType(message["message_type"])
|
||||
if "message_type" in message
|
||||
else None
|
||||
),
|
||||
metadata=message.get("metadata"),
|
||||
token_count=message.get("token_count"),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.enable_logging:
|
||||
self.logger.error(
|
||||
f"Failed to load conversation from JSON: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def load_from_yaml(self, filename: str) -> bool:
|
||||
"""
|
||||
Load a conversation from a YAML file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the YAML file
|
||||
|
||||
Returns:
|
||||
bool: True if load was successful
|
||||
"""
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
messages = yaml.safe_load(f)
|
||||
|
||||
# Start a new conversation
|
||||
self.start_new_conversation()
|
||||
|
||||
# Add all messages
|
||||
for message in messages:
|
||||
self.add(
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
message_type=(
|
||||
MessageType(message["message_type"])
|
||||
if "message_type" in message
|
||||
else None
|
||||
),
|
||||
metadata=message.get("metadata"),
|
||||
token_count=message.get("token_count"),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.enable_logging:
|
||||
self.logger.error(
|
||||
f"Failed to load conversation from YAML: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def get_last_message(self) -> Optional[Dict]:
|
||||
"""
|
||||
Get the last message from the current conversation.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The last message or None if conversation is empty
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT role, content, timestamp, message_type, metadata, token_count
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
content = row["content"]
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
message = {"role": row["role"], "content": content}
|
||||
|
||||
if row["timestamp"]:
|
||||
message["timestamp"] = row["timestamp"]
|
||||
if row["message_type"]:
|
||||
message["message_type"] = row["message_type"]
|
||||
if row["metadata"]:
|
||||
message["metadata"] = json.loads(row["metadata"])
|
||||
if row["token_count"]:
|
||||
message["token_count"] = row["token_count"]
|
||||
|
||||
return message
|
||||
|
||||
def get_last_message_as_string(self) -> str:
|
||||
"""
|
||||
Get the last message as a formatted string.
|
||||
|
||||
Returns:
|
||||
str: Formatted string of the last message
|
||||
"""
|
||||
last_message = self.get_last_message()
|
||||
if not last_message:
|
||||
return ""
|
||||
|
||||
timestamp = (
|
||||
f"[{last_message['timestamp']}] "
|
||||
if "timestamp" in last_message
|
||||
else ""
|
||||
)
|
||||
return f"{timestamp}{last_message['role']}: {last_message['content']}"
|
||||
|
||||
def count_messages_by_role(self) -> Dict[str, int]:
|
||||
"""
|
||||
Count messages by role in the current conversation.
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: Dictionary with role counts
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
GROUP BY role
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
|
||||
return {
|
||||
row["role"]: row["count"] for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_messages_by_role(self, role: str) -> List[Dict]:
|
||||
"""
|
||||
Get all messages from a specific role in the current conversation.
|
||||
|
||||
Args:
|
||||
role (str): Role to filter messages by
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of messages from the specified role
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT role, content, timestamp, message_type, metadata, token_count
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ? AND role = ?
|
||||
ORDER BY id ASC
|
||||
""",
|
||||
(self.current_conversation_id, role),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
content = row["content"]
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
message = {"role": row["role"], "content": content}
|
||||
|
||||
if row["timestamp"]:
|
||||
message["timestamp"] = row["timestamp"]
|
||||
if row["message_type"]:
|
||||
message["message_type"] = row["message_type"]
|
||||
if row["metadata"]:
|
||||
message["metadata"] = json.loads(row["metadata"])
|
||||
if row["token_count"]:
|
||||
message["token_count"] = row["token_count"]
|
||||
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
def get_conversation_summary(self) -> Dict:
|
||||
"""
|
||||
Get a summary of the current conversation.
|
||||
|
||||
Returns:
|
||||
Dict: Summary of the conversation including message counts, roles, and time range
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
COUNT(*) as total_messages,
|
||||
COUNT(DISTINCT role) as unique_roles,
|
||||
MIN(timestamp) as first_message_time,
|
||||
MAX(timestamp) as last_message_time,
|
||||
SUM(token_count) as total_tokens
|
||||
FROM {self.table_name}
|
||||
WHERE conversation_id = ?
|
||||
""",
|
||||
(self.current_conversation_id,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"conversation_id": self.current_conversation_id,
|
||||
"total_messages": row["total_messages"],
|
||||
"unique_roles": row["unique_roles"],
|
||||
"first_message_time": row["first_message_time"],
|
||||
"last_message_time": row["last_message_time"],
|
||||
"total_tokens": row["total_tokens"],
|
||||
"roles": self.count_messages_by_role(),
|
||||
}
|
@ -0,0 +1,317 @@
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import threading
|
||||
from swarms.communication.duckdb_wrap import (
|
||||
DuckDBConversation,
|
||||
Message,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
def setup_test():
|
||||
"""Set up test environment."""
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
db_path = Path(temp_dir.name) / "test_conversations.duckdb"
|
||||
conversation = DuckDBConversation(
|
||||
db_path=str(db_path),
|
||||
enable_timestamps=True,
|
||||
enable_logging=True,
|
||||
)
|
||||
return temp_dir, db_path, conversation
|
||||
|
||||
def cleanup_test(temp_dir, db_path):
|
||||
"""Clean up test environment."""
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
temp_dir.cleanup()
|
||||
|
||||
def test_initialization():
|
||||
"""Test conversation initialization."""
|
||||
temp_dir, db_path, _ = setup_test()
|
||||
try:
|
||||
conv = DuckDBConversation(db_path=str(db_path))
|
||||
assert conv.db_path == db_path, "Database path mismatch"
|
||||
assert conv.table_name == "conversations", "Table name mismatch"
|
||||
assert conv.enable_timestamps is True, "Timestamps should be enabled"
|
||||
assert conv.current_conversation_id is not None, "Conversation ID should not be None"
|
||||
print("✓ Initialization test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_add_message():
|
||||
"""Test adding a single message."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
msg_id = conversation.add(
|
||||
role="user",
|
||||
content="Hello, world!",
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
assert msg_id is not None, "Message ID should not be None"
|
||||
assert isinstance(msg_id, int), "Message ID should be an integer"
|
||||
print("✓ Add message test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_add_complex_message():
|
||||
"""Test adding a message with complex content."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
complex_content = {
|
||||
"text": "Hello",
|
||||
"data": [1, 2, 3],
|
||||
"nested": {"key": "value"}
|
||||
}
|
||||
msg_id = conversation.add(
|
||||
role="assistant",
|
||||
content=complex_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
metadata={"source": "test"},
|
||||
token_count=10
|
||||
)
|
||||
assert msg_id is not None, "Message ID should not be None"
|
||||
print("✓ Add complex message test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_batch_add():
|
||||
"""Test batch adding messages."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
messages = [
|
||||
Message(
|
||||
role="user",
|
||||
content="First message",
|
||||
message_type=MessageType.USER
|
||||
),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Second message",
|
||||
message_type=MessageType.ASSISTANT
|
||||
)
|
||||
]
|
||||
msg_ids = conversation.batch_add(messages)
|
||||
assert len(msg_ids) == 2, "Should have 2 message IDs"
|
||||
assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers"
|
||||
print("✓ Batch add test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_get_str():
|
||||
"""Test getting conversation as string."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi there!")
|
||||
conv_str = conversation.get_str()
|
||||
assert "user: Hello" in conv_str, "User message not found"
|
||||
assert "assistant: Hi there!" in conv_str, "Assistant message not found"
|
||||
print("✓ Get string test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_get_messages():
|
||||
"""Test getting messages with pagination."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
for i in range(5):
|
||||
conversation.add("user", f"Message {i}")
|
||||
|
||||
all_messages = conversation.get_messages()
|
||||
assert len(all_messages) == 5, "Should have 5 messages"
|
||||
|
||||
limited_messages = conversation.get_messages(limit=2)
|
||||
assert len(limited_messages) == 2, "Should have 2 limited messages"
|
||||
|
||||
offset_messages = conversation.get_messages(offset=2)
|
||||
assert len(offset_messages) == 3, "Should have 3 offset messages"
|
||||
print("✓ Get messages test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_search_messages():
|
||||
"""Test searching messages."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello world")
|
||||
conversation.add("assistant", "Hello there")
|
||||
conversation.add("user", "Goodbye world")
|
||||
|
||||
results = conversation.search_messages("world")
|
||||
assert len(results) == 2, "Should find 2 messages with 'world'"
|
||||
assert all("world" in msg["content"] for msg in results), "All results should contain 'world'"
|
||||
print("✓ Search messages test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_get_statistics():
|
||||
"""Test getting conversation statistics."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello", token_count=2)
|
||||
conversation.add("assistant", "Hi", token_count=1)
|
||||
|
||||
stats = conversation.get_statistics()
|
||||
assert stats["total_messages"] == 2, "Should have 2 total messages"
|
||||
assert stats["unique_roles"] == 2, "Should have 2 unique roles"
|
||||
assert stats["total_tokens"] == 3, "Should have 3 total tokens"
|
||||
print("✓ Get statistics test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_json_operations():
|
||||
"""Test JSON save and load operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
json_path = Path(temp_dir.name) / "test_conversation.json"
|
||||
conversation.save_as_json(str(json_path))
|
||||
assert json_path.exists(), "JSON file should exist"
|
||||
|
||||
new_conversation = DuckDBConversation(
|
||||
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
||||
)
|
||||
assert new_conversation.load_from_json(str(json_path)), "Should load from JSON"
|
||||
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
|
||||
print("✓ JSON operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_yaml_operations():
|
||||
"""Test YAML save and load operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
|
||||
conversation.save_as_yaml(str(yaml_path))
|
||||
assert yaml_path.exists(), "YAML file should exist"
|
||||
|
||||
new_conversation = DuckDBConversation(
|
||||
db_path=str(Path(temp_dir.name) / "new.duckdb")
|
||||
)
|
||||
assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML"
|
||||
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
|
||||
print("✓ YAML operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_message_types():
|
||||
"""Test different message types."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("system", "System message", message_type=MessageType.SYSTEM)
|
||||
conversation.add("user", "User message", message_type=MessageType.USER)
|
||||
conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT)
|
||||
conversation.add("function", "Function message", message_type=MessageType.FUNCTION)
|
||||
conversation.add("tool", "Tool message", message_type=MessageType.TOOL)
|
||||
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) == 5, "Should have 5 messages"
|
||||
assert all("message_type" in msg for msg in messages), "All messages should have type"
|
||||
print("✓ Message types test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_delete_operations():
|
||||
"""Test deletion operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
conversation.add("user", "Hello")
|
||||
conversation.add("assistant", "Hi")
|
||||
|
||||
assert conversation.delete_current_conversation(), "Should delete conversation"
|
||||
assert len(conversation.get_messages()) == 0, "Should have no messages after delete"
|
||||
|
||||
conversation.add("user", "New message")
|
||||
assert conversation.clear_all(), "Should clear all messages"
|
||||
assert len(conversation.get_messages()) == 0, "Should have no messages after clear"
|
||||
print("✓ Delete operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_concurrent_operations():
|
||||
"""Test concurrent operations."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
def add_messages():
|
||||
for i in range(10):
|
||||
conversation.add("user", f"Message {i}")
|
||||
|
||||
threads = [threading.Thread(target=add_messages) for _ in range(5)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
messages = conversation.get_messages()
|
||||
assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)"
|
||||
print("✓ Concurrent operations test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling."""
|
||||
temp_dir, db_path, conversation = setup_test()
|
||||
try:
|
||||
# Test invalid message type
|
||||
try:
|
||||
conversation.add("user", "Message", message_type="invalid")
|
||||
assert False, "Should raise exception for invalid message type"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Test invalid JSON content
|
||||
try:
|
||||
conversation.add("user", {"invalid": object()})
|
||||
assert False, "Should raise exception for invalid JSON content"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Test invalid file operations
|
||||
try:
|
||||
conversation.load_from_json("/nonexistent/path.json")
|
||||
assert False, "Should raise exception for invalid file path"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("✓ Error handling test passed")
|
||||
finally:
|
||||
cleanup_test(temp_dir, db_path)
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests."""
|
||||
print("Running DuckDB Conversation tests...")
|
||||
tests = [
|
||||
test_initialization,
|
||||
test_add_message,
|
||||
test_add_complex_message,
|
||||
test_batch_add,
|
||||
test_get_str,
|
||||
test_get_messages,
|
||||
test_search_messages,
|
||||
test_get_statistics,
|
||||
test_json_operations,
|
||||
test_yaml_operations,
|
||||
test_message_types,
|
||||
test_delete_operations,
|
||||
test_concurrent_operations,
|
||||
test_error_handling
|
||||
]
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
test()
|
||||
except Exception as e:
|
||||
print(f"✗ {test.__name__} failed: {str(e)}")
|
||||
raise
|
||||
|
||||
print("\nAll tests completed successfully!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_all_tests()
|
@ -0,0 +1,386 @@
|
||||
import json
|
||||
import datetime
|
||||
import os
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from loguru import logger
|
||||
from swarms.communication.sqlite_wrap import (
|
||||
SQLiteConversation,
|
||||
Message,
|
||||
MessageType,
|
||||
)
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def print_test_header(test_name: str) -> None:
|
||||
"""Print a formatted test header."""
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold blue]Running Test: {test_name}[/bold blue]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_test_result(
|
||||
test_name: str, success: bool, message: str, execution_time: float
|
||||
) -> None:
|
||||
"""Print a formatted test result."""
|
||||
status = (
|
||||
"[bold green]PASSED[/bold green]"
|
||||
if success
|
||||
else "[bold red]FAILED[/bold red]"
|
||||
)
|
||||
console.print(f"\n{status} - {test_name}")
|
||||
console.print(f"Message: {message}")
|
||||
console.print(f"Execution time: {execution_time:.3f} seconds\n")
|
||||
|
||||
|
||||
def print_messages(
|
||||
messages: List[Dict], title: str = "Messages"
|
||||
) -> None:
|
||||
"""Print messages in a formatted table."""
|
||||
table = Table(title=title)
|
||||
table.add_column("Role", style="cyan")
|
||||
table.add_column("Content", style="green")
|
||||
table.add_column("Type", style="yellow")
|
||||
table.add_column("Timestamp", style="magenta")
|
||||
|
||||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
if isinstance(content, (dict, list)):
|
||||
content = json.dumps(content)
|
||||
table.add_row(
|
||||
msg.get("role", ""),
|
||||
content,
|
||||
str(msg.get("message_type", "")),
|
||||
str(msg.get("timestamp", "")),
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def run_test(
|
||||
test_func: callable, *args, **kwargs
|
||||
) -> Tuple[bool, str, float]:
|
||||
"""
|
||||
Run a test function and return its results.
|
||||
|
||||
Args:
|
||||
test_func: The test function to run
|
||||
*args: Arguments for the test function
|
||||
**kwargs: Keyword arguments for the test function
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, float]: (success, message, execution_time)
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
try:
|
||||
result = test_func(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
execution_time = (end_time - start_time).total_seconds()
|
||||
return True, str(result), execution_time
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
execution_time = (end_time - start_time).total_seconds()
|
||||
return False, str(e), execution_time
|
||||
|
||||
|
||||
def test_basic_conversation() -> bool:
|
||||
"""Test basic conversation operations."""
|
||||
print_test_header("Basic Conversation Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test adding messages
|
||||
console.print("\n[bold]Adding messages...[/bold]")
|
||||
msg_id1 = conversation.add("user", "Hello")
|
||||
msg_id2 = conversation.add("assistant", "Hi there!")
|
||||
|
||||
# Test getting messages
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
messages = conversation.get_messages()
|
||||
print_messages(messages)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_message_types() -> bool:
|
||||
"""Test different message types and content formats."""
|
||||
print_test_header("Message Types Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test different content types
|
||||
console.print("\n[bold]Adding different message types...[/bold]")
|
||||
conversation.add("user", "Simple text")
|
||||
conversation.add(
|
||||
"assistant", {"type": "json", "content": "Complex data"}
|
||||
)
|
||||
conversation.add("system", ["list", "of", "items"])
|
||||
conversation.add(
|
||||
"function",
|
||||
"Function result",
|
||||
message_type=MessageType.FUNCTION,
|
||||
)
|
||||
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
messages = conversation.get_messages()
|
||||
print_messages(messages)
|
||||
|
||||
assert len(messages) == 4
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_conversation_operations() -> bool:
|
||||
"""Test various conversation operations."""
|
||||
print_test_header("Conversation Operations Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test batch operations
|
||||
console.print("\n[bold]Adding batch messages...[/bold]")
|
||||
messages = [
|
||||
Message(role="user", content="Message 1"),
|
||||
Message(role="assistant", content="Message 2"),
|
||||
Message(role="user", content="Message 3"),
|
||||
]
|
||||
conversation.batch_add(messages)
|
||||
|
||||
console.print("\n[bold]Retrieved messages:[/bold]")
|
||||
all_messages = conversation.get_messages()
|
||||
print_messages(all_messages)
|
||||
|
||||
# Test statistics
|
||||
console.print("\n[bold]Conversation Statistics:[/bold]")
|
||||
stats = conversation.get_statistics()
|
||||
console.print(json.dumps(stats, indent=2))
|
||||
|
||||
# Test role counting
|
||||
console.print("\n[bold]Role Counts:[/bold]")
|
||||
role_counts = conversation.count_messages_by_role()
|
||||
console.print(json.dumps(role_counts, indent=2))
|
||||
|
||||
assert stats["total_messages"] == 3
|
||||
assert role_counts["user"] == 2
|
||||
assert role_counts["assistant"] == 1
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_file_operations() -> bool:
|
||||
"""Test file operations (JSON/YAML)."""
|
||||
print_test_header("File Operations Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
json_path = "test_conversation.json"
|
||||
yaml_path = "test_conversation.yaml"
|
||||
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
conversation.add("user", "Test message")
|
||||
|
||||
# Test JSON operations
|
||||
console.print("\n[bold]Testing JSON operations...[/bold]")
|
||||
assert conversation.save_as_json(json_path)
|
||||
console.print(f"Saved to JSON: {json_path}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
assert conversation.load_from_json(json_path)
|
||||
console.print("Loaded from JSON")
|
||||
|
||||
# Test YAML operations
|
||||
console.print("\n[bold]Testing YAML operations...[/bold]")
|
||||
assert conversation.save_as_yaml(yaml_path)
|
||||
console.print(f"Saved to YAML: {yaml_path}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
assert conversation.load_from_yaml(yaml_path)
|
||||
console.print("Loaded from YAML")
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
os.remove(json_path)
|
||||
os.remove(yaml_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_search_and_filter() -> bool:
|
||||
"""Test search and filter operations."""
|
||||
print_test_header("Search and Filter Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Add test messages
|
||||
console.print("\n[bold]Adding test messages...[/bold]")
|
||||
conversation.add("user", "Hello world")
|
||||
conversation.add("assistant", "Hello there")
|
||||
conversation.add("user", "Goodbye world")
|
||||
|
||||
# Test search
|
||||
console.print("\n[bold]Searching for 'world'...[/bold]")
|
||||
results = conversation.search_messages("world")
|
||||
print_messages(results, "Search Results")
|
||||
|
||||
# Test role filtering
|
||||
console.print("\n[bold]Filtering user messages...[/bold]")
|
||||
user_messages = conversation.get_messages_by_role("user")
|
||||
print_messages(user_messages, "User Messages")
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(user_messages) == 2
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def test_conversation_management() -> bool:
|
||||
"""Test conversation management features."""
|
||||
print_test_header("Conversation Management Test")
|
||||
|
||||
db_path = "test_conversations.db"
|
||||
conversation = SQLiteConversation(db_path=db_path)
|
||||
|
||||
# Test conversation ID generation
|
||||
console.print("\n[bold]Testing conversation IDs...[/bold]")
|
||||
conv_id1 = conversation.get_conversation_id()
|
||||
console.print(f"First conversation ID: {conv_id1}")
|
||||
|
||||
conversation.start_new_conversation()
|
||||
conv_id2 = conversation.get_conversation_id()
|
||||
console.print(f"Second conversation ID: {conv_id2}")
|
||||
|
||||
assert conv_id1 != conv_id2
|
||||
|
||||
# Test conversation deletion
|
||||
console.print("\n[bold]Testing conversation deletion...[/bold]")
|
||||
conversation.add("user", "Test message")
|
||||
assert conversation.delete_current_conversation()
|
||||
console.print("Conversation deleted successfully")
|
||||
|
||||
# Cleanup
|
||||
os.remove(db_path)
|
||||
return True
|
||||
|
||||
|
||||
def generate_test_report(
|
||||
test_results: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a test report in JSON format.
|
||||
|
||||
Args:
|
||||
test_results: List of test results
|
||||
|
||||
Returns:
|
||||
Dict containing the test report
|
||||
"""
|
||||
total_tests = len(test_results)
|
||||
passed_tests = sum(
|
||||
1 for result in test_results if result["success"]
|
||||
)
|
||||
failed_tests = total_tests - passed_tests
|
||||
total_time = sum(
|
||||
result["execution_time"] for result in test_results
|
||||
)
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"summary": {
|
||||
"total_tests": total_tests,
|
||||
"passed_tests": passed_tests,
|
||||
"failed_tests": failed_tests,
|
||||
"total_execution_time": total_time,
|
||||
"average_execution_time": (
|
||||
total_time / total_tests if total_tests > 0 else 0
|
||||
),
|
||||
},
|
||||
"test_results": test_results,
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def run_all_tests() -> None:
|
||||
"""Run all tests and generate a report."""
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold blue]Starting Test Suite[/bold blue]", expand=False
|
||||
)
|
||||
)
|
||||
|
||||
tests = [
|
||||
("Basic Conversation", test_basic_conversation),
|
||||
("Message Types", test_message_types),
|
||||
("Conversation Operations", test_conversation_operations),
|
||||
("File Operations", test_file_operations),
|
||||
("Search and Filter", test_search_and_filter),
|
||||
("Conversation Management", test_conversation_management),
|
||||
]
|
||||
|
||||
test_results = []
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"Running test: {test_name}")
|
||||
success, message, execution_time = run_test(test_func)
|
||||
|
||||
print_test_result(test_name, success, message, execution_time)
|
||||
|
||||
result = {
|
||||
"test_name": test_name,
|
||||
"success": success,
|
||||
"message": message,
|
||||
"execution_time": execution_time,
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
if success:
|
||||
logger.success(f"Test passed: {test_name}")
|
||||
else:
|
||||
logger.error(f"Test failed: {test_name} - {message}")
|
||||
|
||||
test_results.append(result)
|
||||
|
||||
# Generate and save report
|
||||
report = generate_test_report(test_results)
|
||||
report_path = "test_report.json"
|
||||
|
||||
with open(report_path, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
# Print final summary
|
||||
console.print("\n[bold blue]Test Suite Summary[/bold blue]")
|
||||
console.print(
|
||||
Panel(
|
||||
f"Total tests: {report['summary']['total_tests']}\n"
|
||||
f"Passed tests: {report['summary']['passed_tests']}\n"
|
||||
f"Failed tests: {report['summary']['failed_tests']}\n"
|
||||
f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds",
|
||||
title="Summary",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Test report saved to {report_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
Loading…
Reference in new issue