parent
							
								
									26033cefd1
								
							
						
					
					
						commit
						a5154aa26f
					
				
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								| @ -0,0 +1,813 @@ | ||||
| import sqlite3 | ||||
| import json | ||||
| import datetime | ||||
| from typing import List, Optional, Union, Dict | ||||
| from pathlib import Path | ||||
| import threading | ||||
| from contextlib import contextmanager | ||||
| import logging | ||||
| from dataclasses import dataclass | ||||
| from enum import Enum | ||||
| import uuid | ||||
| import yaml | ||||
| 
 | ||||
| try: | ||||
|     from loguru import logger | ||||
| 
 | ||||
|     LOGURU_AVAILABLE = True | ||||
| except ImportError: | ||||
|     LOGURU_AVAILABLE = False | ||||
| 
 | ||||
| 
 | ||||
| class MessageType(Enum): | ||||
|     """Enum for different types of messages in the conversation.""" | ||||
| 
 | ||||
|     SYSTEM = "system" | ||||
|     USER = "user" | ||||
|     ASSISTANT = "assistant" | ||||
|     FUNCTION = "function" | ||||
|     TOOL = "tool" | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class Message: | ||||
|     """Data class representing a message in the conversation.""" | ||||
| 
 | ||||
|     role: str | ||||
|     content: Union[str, dict, list] | ||||
|     timestamp: Optional[str] = None | ||||
|     message_type: Optional[MessageType] = None | ||||
|     metadata: Optional[Dict] = None | ||||
|     token_count: Optional[int] = None | ||||
| 
 | ||||
|     class Config: | ||||
|         arbitrary_types_allowed = True | ||||
| 
 | ||||
| 
 | ||||
| class SQLiteConversation: | ||||
|     """ | ||||
|     A production-grade SQLite wrapper class for managing conversation history. | ||||
|     This class provides persistent storage for conversations with various features | ||||
|     like message tracking, timestamps, and metadata support. | ||||
| 
 | ||||
|     Attributes: | ||||
|         db_path (str): Path to the SQLite database file | ||||
|         table_name (str): Name of the table to store conversations | ||||
|         enable_timestamps (bool): Whether to track message timestamps | ||||
|         enable_logging (bool): Whether to enable logging | ||||
|         use_loguru (bool): Whether to use loguru for logging | ||||
|         max_retries (int): Maximum number of retries for database operations | ||||
|         connection_timeout (float): Timeout for database connections | ||||
|         current_conversation_id (str): Current active conversation ID | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         db_path: str = "conversations.db", | ||||
|         table_name: str = "conversations", | ||||
|         enable_timestamps: bool = True, | ||||
|         enable_logging: bool = True, | ||||
|         use_loguru: bool = True, | ||||
|         max_retries: int = 3, | ||||
|         connection_timeout: float = 5.0, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
|         Initialize the SQLite conversation manager. | ||||
| 
 | ||||
|         Args: | ||||
|             db_path (str): Path to the SQLite database file | ||||
|             table_name (str): Name of the table to store conversations | ||||
|             enable_timestamps (bool): Whether to track message timestamps | ||||
|             enable_logging (bool): Whether to enable logging | ||||
|             use_loguru (bool): Whether to use loguru for logging | ||||
|             max_retries (int): Maximum number of retries for database operations | ||||
|             connection_timeout (float): Timeout for database connections | ||||
|         """ | ||||
|         self.db_path = Path(db_path) | ||||
|         self.table_name = table_name | ||||
|         self.enable_timestamps = enable_timestamps | ||||
|         self.enable_logging = enable_logging | ||||
|         self.use_loguru = use_loguru and LOGURU_AVAILABLE | ||||
|         self.max_retries = max_retries | ||||
|         self.connection_timeout = connection_timeout | ||||
|         self._lock = threading.Lock() | ||||
|         self.current_conversation_id = ( | ||||
|             self._generate_conversation_id() | ||||
|         ) | ||||
| 
 | ||||
|         # Setup logging | ||||
|         if self.enable_logging: | ||||
|             if self.use_loguru: | ||||
|                 self.logger = logger | ||||
|             else: | ||||
|                 self.logger = logging.getLogger(__name__) | ||||
|                 handler = logging.StreamHandler() | ||||
|                 formatter = logging.Formatter( | ||||
|                     "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||||
|                 ) | ||||
|                 handler.setFormatter(formatter) | ||||
|                 self.logger.addHandler(handler) | ||||
|                 self.logger.setLevel(logging.INFO) | ||||
| 
 | ||||
|         # Initialize database | ||||
|         self._init_db() | ||||
| 
 | ||||
|     def _generate_conversation_id(self) -> str: | ||||
|         """Generate a unique conversation ID using UUID and timestamp.""" | ||||
|         timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | ||||
|         unique_id = str(uuid.uuid4())[:8] | ||||
|         return f"conv_{timestamp}_{unique_id}" | ||||
| 
 | ||||
|     def start_new_conversation(self) -> str: | ||||
|         """ | ||||
|         Start a new conversation and return its ID. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: The new conversation ID | ||||
|         """ | ||||
|         self.current_conversation_id = ( | ||||
|             self._generate_conversation_id() | ||||
|         ) | ||||
|         return self.current_conversation_id | ||||
| 
 | ||||
|     def _init_db(self): | ||||
|         """Initialize the database and create necessary tables.""" | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 CREATE TABLE IF NOT EXISTS {self.table_name} ( | ||||
|                     id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||
|                     role TEXT NOT NULL, | ||||
|                     content TEXT NOT NULL, | ||||
|                     timestamp TEXT, | ||||
|                     message_type TEXT, | ||||
|                     metadata TEXT, | ||||
|                     token_count INTEGER, | ||||
|                     conversation_id TEXT, | ||||
|                     created_at TEXT DEFAULT CURRENT_TIMESTAMP | ||||
|                 ) | ||||
|             """ | ||||
|             ) | ||||
|             conn.commit() | ||||
| 
 | ||||
|     @contextmanager | ||||
|     def _get_connection(self): | ||||
|         """Context manager for database connections with retry logic.""" | ||||
|         conn = None | ||||
|         for attempt in range(self.max_retries): | ||||
|             try: | ||||
|                 conn = sqlite3.connect( | ||||
|                     str(self.db_path), timeout=self.connection_timeout | ||||
|                 ) | ||||
|                 conn.row_factory = sqlite3.Row | ||||
|                 yield conn | ||||
|                 break | ||||
|             except sqlite3.Error as e: | ||||
|                 if attempt == self.max_retries - 1: | ||||
|                     raise | ||||
|                 if self.enable_logging: | ||||
|                     self.logger.warning( | ||||
|                         f"Database connection attempt {attempt + 1} failed: {e}" | ||||
|                     ) | ||||
|             finally: | ||||
|                 if conn: | ||||
|                     conn.close() | ||||
| 
 | ||||
|     def add( | ||||
|         self, | ||||
|         role: str, | ||||
|         content: Union[str, dict, list], | ||||
|         message_type: Optional[MessageType] = None, | ||||
|         metadata: Optional[Dict] = None, | ||||
|         token_count: Optional[int] = None, | ||||
|     ) -> int: | ||||
|         """ | ||||
|         Add a message to the current conversation. | ||||
| 
 | ||||
|         Args: | ||||
|             role (str): The role of the speaker | ||||
|             content (Union[str, dict, list]): The content of the message | ||||
|             message_type (Optional[MessageType]): Type of the message | ||||
|             metadata (Optional[Dict]): Additional metadata for the message | ||||
|             token_count (Optional[int]): Number of tokens in the message | ||||
| 
 | ||||
|         Returns: | ||||
|             int: The ID of the inserted message | ||||
|         """ | ||||
|         timestamp = ( | ||||
|             datetime.datetime.now().isoformat() | ||||
|             if self.enable_timestamps | ||||
|             else None | ||||
|         ) | ||||
| 
 | ||||
|         if isinstance(content, (dict, list)): | ||||
|             content = json.dumps(content) | ||||
| 
 | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 INSERT INTO {self.table_name}  | ||||
|                 (role, content, timestamp, message_type, metadata, token_count, conversation_id) | ||||
|                 VALUES (?, ?, ?, ?, ?, ?, ?) | ||||
|             """, | ||||
|                 ( | ||||
|                     role, | ||||
|                     content, | ||||
|                     timestamp, | ||||
|                     message_type.value if message_type else None, | ||||
|                     json.dumps(metadata) if metadata else None, | ||||
|                     token_count, | ||||
|                     self.current_conversation_id, | ||||
|                 ), | ||||
|             ) | ||||
|             conn.commit() | ||||
|             return cursor.lastrowid | ||||
| 
 | ||||
|     def batch_add(self, messages: List[Message]) -> List[int]: | ||||
|         """ | ||||
|         Add multiple messages to the current conversation. | ||||
| 
 | ||||
|         Args: | ||||
|             messages (List[Message]): List of messages to add | ||||
| 
 | ||||
|         Returns: | ||||
|             List[int]: List of inserted message IDs | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             message_ids = [] | ||||
| 
 | ||||
|             for message in messages: | ||||
|                 content = message.content | ||||
|                 if isinstance(content, (dict, list)): | ||||
|                     content = json.dumps(content) | ||||
| 
 | ||||
|                 cursor.execute( | ||||
|                     f""" | ||||
|                     INSERT INTO {self.table_name}  | ||||
|                     (role, content, timestamp, message_type, metadata, token_count, conversation_id) | ||||
|                     VALUES (?, ?, ?, ?, ?, ?, ?) | ||||
|                 """, | ||||
|                     ( | ||||
|                         message.role, | ||||
|                         content, | ||||
|                         ( | ||||
|                             message.timestamp.isoformat() | ||||
|                             if message.timestamp | ||||
|                             else None | ||||
|                         ), | ||||
|                         ( | ||||
|                             message.message_type.value | ||||
|                             if message.message_type | ||||
|                             else None | ||||
|                         ), | ||||
|                         ( | ||||
|                             json.dumps(message.metadata) | ||||
|                             if message.metadata | ||||
|                             else None | ||||
|                         ), | ||||
|                         message.token_count, | ||||
|                         self.current_conversation_id, | ||||
|                     ), | ||||
|                 ) | ||||
|                 message_ids.append(cursor.lastrowid) | ||||
| 
 | ||||
|             conn.commit() | ||||
|             return message_ids | ||||
| 
 | ||||
|     def get_str(self) -> str: | ||||
|         """ | ||||
|         Get the current conversation history as a formatted string. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: Formatted conversation history | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT * FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|                 ORDER BY id ASC | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
| 
 | ||||
|             messages = [] | ||||
|             for row in cursor.fetchall(): | ||||
|                 content = row["content"] | ||||
|                 try: | ||||
|                     content = json.loads(content) | ||||
|                 except json.JSONDecodeError: | ||||
|                     pass | ||||
| 
 | ||||
|                 timestamp = ( | ||||
|                     f"[{row['timestamp']}] " | ||||
|                     if row["timestamp"] | ||||
|                     else "" | ||||
|                 ) | ||||
|                 messages.append( | ||||
|                     f"{timestamp}{row['role']}: {content}" | ||||
|                 ) | ||||
| 
 | ||||
|             return "\n".join(messages) | ||||
| 
 | ||||
|     def get_messages( | ||||
|         self, | ||||
|         limit: Optional[int] = None, | ||||
|         offset: Optional[int] = None, | ||||
|     ) -> List[Dict]: | ||||
|         """ | ||||
|         Get messages from the current conversation with optional pagination. | ||||
| 
 | ||||
|         Args: | ||||
|             limit (Optional[int]): Maximum number of messages to return | ||||
|             offset (Optional[int]): Number of messages to skip | ||||
| 
 | ||||
|         Returns: | ||||
|             List[Dict]: List of message dictionaries | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             query = f""" | ||||
|                 SELECT * FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|                 ORDER BY id ASC | ||||
|             """ | ||||
|             params = [self.current_conversation_id] | ||||
| 
 | ||||
|             if limit is not None: | ||||
|                 query += " LIMIT ?" | ||||
|                 params.append(limit) | ||||
| 
 | ||||
|             if offset is not None: | ||||
|                 query += " OFFSET ?" | ||||
|                 params.append(offset) | ||||
| 
 | ||||
|             cursor.execute(query, params) | ||||
|             return [dict(row) for row in cursor.fetchall()] | ||||
| 
 | ||||
|     def delete_current_conversation(self) -> bool: | ||||
|         """ | ||||
|         Delete the current conversation. | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if deletion was successful | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f"DELETE FROM {self.table_name} WHERE conversation_id = ?", | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
|             conn.commit() | ||||
|             return cursor.rowcount > 0 | ||||
| 
 | ||||
|     def update_message( | ||||
|         self, | ||||
|         message_id: int, | ||||
|         content: Union[str, dict, list], | ||||
|         metadata: Optional[Dict] = None, | ||||
|     ) -> bool: | ||||
|         """ | ||||
|         Update an existing message in the current conversation. | ||||
| 
 | ||||
|         Args: | ||||
|             message_id (int): ID of the message to update | ||||
|             content (Union[str, dict, list]): New content for the message | ||||
|             metadata (Optional[Dict]): New metadata for the message | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if update was successful | ||||
|         """ | ||||
|         if isinstance(content, (dict, list)): | ||||
|             content = json.dumps(content) | ||||
| 
 | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 UPDATE {self.table_name} | ||||
|                 SET content = ?, metadata = ? | ||||
|                 WHERE id = ? AND conversation_id = ? | ||||
|             """, | ||||
|                 ( | ||||
|                     content, | ||||
|                     json.dumps(metadata) if metadata else None, | ||||
|                     message_id, | ||||
|                     self.current_conversation_id, | ||||
|                 ), | ||||
|             ) | ||||
|             conn.commit() | ||||
|             return cursor.rowcount > 0 | ||||
| 
 | ||||
|     def search_messages(self, query: str) -> List[Dict]: | ||||
|         """ | ||||
|         Search for messages containing specific text in the current conversation. | ||||
| 
 | ||||
|         Args: | ||||
|             query (str): Text to search for | ||||
| 
 | ||||
|         Returns: | ||||
|             List[Dict]: List of matching messages | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT * FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? AND content LIKE ? | ||||
|             """, | ||||
|                 (self.current_conversation_id, f"%{query}%"), | ||||
|             ) | ||||
|             return [dict(row) for row in cursor.fetchall()] | ||||
| 
 | ||||
|     def get_statistics(self) -> Dict: | ||||
|         """ | ||||
|         Get statistics about the current conversation. | ||||
| 
 | ||||
|         Returns: | ||||
|             Dict: Statistics about the conversation | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT  | ||||
|                     COUNT(*) as total_messages, | ||||
|                     COUNT(DISTINCT role) as unique_roles, | ||||
|                     SUM(token_count) as total_tokens, | ||||
|                     MIN(timestamp) as first_message, | ||||
|                     MAX(timestamp) as last_message | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
|             return dict(cursor.fetchone()) | ||||
| 
 | ||||
|     def clear_all(self) -> bool: | ||||
|         """ | ||||
|         Clear all messages from the database. | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if clearing was successful | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute(f"DELETE FROM {self.table_name}") | ||||
|             conn.commit() | ||||
|             return True | ||||
| 
 | ||||
|     def get_conversation_id(self) -> str: | ||||
|         """ | ||||
|         Get the current conversation ID. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: The current conversation ID | ||||
|         """ | ||||
|         return self.current_conversation_id | ||||
| 
 | ||||
|     def to_dict(self) -> List[Dict]: | ||||
|         """ | ||||
|         Convert the current conversation to a list of dictionaries. | ||||
| 
 | ||||
|         Returns: | ||||
|             List[Dict]: List of message dictionaries | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT role, content, timestamp, message_type, metadata, token_count | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|                 ORDER BY id ASC | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
| 
 | ||||
|             messages = [] | ||||
|             for row in cursor.fetchall(): | ||||
|                 content = row["content"] | ||||
|                 try: | ||||
|                     content = json.loads(content) | ||||
|                 except json.JSONDecodeError: | ||||
|                     pass | ||||
| 
 | ||||
|                 message = {"role": row["role"], "content": content} | ||||
| 
 | ||||
|                 if row["timestamp"]: | ||||
|                     message["timestamp"] = row["timestamp"] | ||||
|                 if row["message_type"]: | ||||
|                     message["message_type"] = row["message_type"] | ||||
|                 if row["metadata"]: | ||||
|                     message["metadata"] = json.loads(row["metadata"]) | ||||
|                 if row["token_count"]: | ||||
|                     message["token_count"] = row["token_count"] | ||||
| 
 | ||||
|                 messages.append(message) | ||||
| 
 | ||||
|             return messages | ||||
| 
 | ||||
|     def to_json(self) -> str: | ||||
|         """ | ||||
|         Convert the current conversation to a JSON string. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: JSON string representation of the conversation | ||||
|         """ | ||||
|         return json.dumps(self.to_dict(), indent=2) | ||||
| 
 | ||||
|     def to_yaml(self) -> str: | ||||
|         """ | ||||
|         Convert the current conversation to a YAML string. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: YAML string representation of the conversation | ||||
|         """ | ||||
|         return yaml.dump(self.to_dict()) | ||||
| 
 | ||||
|     def save_as_json(self, filename: str) -> bool: | ||||
|         """ | ||||
|         Save the current conversation to a JSON file. | ||||
| 
 | ||||
|         Args: | ||||
|             filename (str): Path to save the JSON file | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if save was successful | ||||
|         """ | ||||
|         try: | ||||
|             with open(filename, "w") as f: | ||||
|                 json.dump(self.to_dict(), f, indent=2) | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             if self.enable_logging: | ||||
|                 self.logger.error( | ||||
|                     f"Failed to save conversation to JSON: {e}" | ||||
|                 ) | ||||
|             return False | ||||
| 
 | ||||
|     def save_as_yaml(self, filename: str) -> bool: | ||||
|         """ | ||||
|         Save the current conversation to a YAML file. | ||||
| 
 | ||||
|         Args: | ||||
|             filename (str): Path to save the YAML file | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if save was successful | ||||
|         """ | ||||
|         try: | ||||
|             with open(filename, "w") as f: | ||||
|                 yaml.dump(self.to_dict(), f) | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             if self.enable_logging: | ||||
|                 self.logger.error( | ||||
|                     f"Failed to save conversation to YAML: {e}" | ||||
|                 ) | ||||
|             return False | ||||
| 
 | ||||
|     def load_from_json(self, filename: str) -> bool: | ||||
|         """ | ||||
|         Load a conversation from a JSON file. | ||||
| 
 | ||||
|         Args: | ||||
|             filename (str): Path to the JSON file | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if load was successful | ||||
|         """ | ||||
|         try: | ||||
|             with open(filename, "r") as f: | ||||
|                 messages = json.load(f) | ||||
| 
 | ||||
|             # Start a new conversation | ||||
|             self.start_new_conversation() | ||||
| 
 | ||||
|             # Add all messages | ||||
|             for message in messages: | ||||
|                 self.add( | ||||
|                     role=message["role"], | ||||
|                     content=message["content"], | ||||
|                     message_type=( | ||||
|                         MessageType(message["message_type"]) | ||||
|                         if "message_type" in message | ||||
|                         else None | ||||
|                     ), | ||||
|                     metadata=message.get("metadata"), | ||||
|                     token_count=message.get("token_count"), | ||||
|                 ) | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             if self.enable_logging: | ||||
|                 self.logger.error( | ||||
|                     f"Failed to load conversation from JSON: {e}" | ||||
|                 ) | ||||
|             return False | ||||
| 
 | ||||
|     def load_from_yaml(self, filename: str) -> bool: | ||||
|         """ | ||||
|         Load a conversation from a YAML file. | ||||
| 
 | ||||
|         Args: | ||||
|             filename (str): Path to the YAML file | ||||
| 
 | ||||
|         Returns: | ||||
|             bool: True if load was successful | ||||
|         """ | ||||
|         try: | ||||
|             with open(filename, "r") as f: | ||||
|                 messages = yaml.safe_load(f) | ||||
| 
 | ||||
|             # Start a new conversation | ||||
|             self.start_new_conversation() | ||||
| 
 | ||||
|             # Add all messages | ||||
|             for message in messages: | ||||
|                 self.add( | ||||
|                     role=message["role"], | ||||
|                     content=message["content"], | ||||
|                     message_type=( | ||||
|                         MessageType(message["message_type"]) | ||||
|                         if "message_type" in message | ||||
|                         else None | ||||
|                     ), | ||||
|                     metadata=message.get("metadata"), | ||||
|                     token_count=message.get("token_count"), | ||||
|                 ) | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             if self.enable_logging: | ||||
|                 self.logger.error( | ||||
|                     f"Failed to load conversation from YAML: {e}" | ||||
|                 ) | ||||
|             return False | ||||
| 
 | ||||
|     def get_last_message(self) -> Optional[Dict]: | ||||
|         """ | ||||
|         Get the last message from the current conversation. | ||||
| 
 | ||||
|         Returns: | ||||
|             Optional[Dict]: The last message or None if conversation is empty | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT role, content, timestamp, message_type, metadata, token_count | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|                 ORDER BY id DESC | ||||
|                 LIMIT 1 | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
| 
 | ||||
|             row = cursor.fetchone() | ||||
|             if not row: | ||||
|                 return None | ||||
| 
 | ||||
|             content = row["content"] | ||||
|             try: | ||||
|                 content = json.loads(content) | ||||
|             except json.JSONDecodeError: | ||||
|                 pass | ||||
| 
 | ||||
|             message = {"role": row["role"], "content": content} | ||||
| 
 | ||||
|             if row["timestamp"]: | ||||
|                 message["timestamp"] = row["timestamp"] | ||||
|             if row["message_type"]: | ||||
|                 message["message_type"] = row["message_type"] | ||||
|             if row["metadata"]: | ||||
|                 message["metadata"] = json.loads(row["metadata"]) | ||||
|             if row["token_count"]: | ||||
|                 message["token_count"] = row["token_count"] | ||||
| 
 | ||||
|             return message | ||||
| 
 | ||||
|     def get_last_message_as_string(self) -> str: | ||||
|         """ | ||||
|         Get the last message as a formatted string. | ||||
| 
 | ||||
|         Returns: | ||||
|             str: Formatted string of the last message | ||||
|         """ | ||||
|         last_message = self.get_last_message() | ||||
|         if not last_message: | ||||
|             return "" | ||||
| 
 | ||||
|         timestamp = ( | ||||
|             f"[{last_message['timestamp']}] " | ||||
|             if "timestamp" in last_message | ||||
|             else "" | ||||
|         ) | ||||
|         return f"{timestamp}{last_message['role']}: {last_message['content']}" | ||||
| 
 | ||||
|     def count_messages_by_role(self) -> Dict[str, int]: | ||||
|         """ | ||||
|         Count messages by role in the current conversation. | ||||
| 
 | ||||
|         Returns: | ||||
|             Dict[str, int]: Dictionary with role counts | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT role, COUNT(*) as count | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|                 GROUP BY role | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
| 
 | ||||
|             return { | ||||
|                 row["role"]: row["count"] for row in cursor.fetchall() | ||||
|             } | ||||
| 
 | ||||
|     def get_messages_by_role(self, role: str) -> List[Dict]: | ||||
|         """ | ||||
|         Get all messages from a specific role in the current conversation. | ||||
| 
 | ||||
|         Args: | ||||
|             role (str): Role to filter messages by | ||||
| 
 | ||||
|         Returns: | ||||
|             List[Dict]: List of messages from the specified role | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT role, content, timestamp, message_type, metadata, token_count | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? AND role = ? | ||||
|                 ORDER BY id ASC | ||||
|             """, | ||||
|                 (self.current_conversation_id, role), | ||||
|             ) | ||||
| 
 | ||||
|             messages = [] | ||||
|             for row in cursor.fetchall(): | ||||
|                 content = row["content"] | ||||
|                 try: | ||||
|                     content = json.loads(content) | ||||
|                 except json.JSONDecodeError: | ||||
|                     pass | ||||
| 
 | ||||
|                 message = {"role": row["role"], "content": content} | ||||
| 
 | ||||
|                 if row["timestamp"]: | ||||
|                     message["timestamp"] = row["timestamp"] | ||||
|                 if row["message_type"]: | ||||
|                     message["message_type"] = row["message_type"] | ||||
|                 if row["metadata"]: | ||||
|                     message["metadata"] = json.loads(row["metadata"]) | ||||
|                 if row["token_count"]: | ||||
|                     message["token_count"] = row["token_count"] | ||||
| 
 | ||||
|                 messages.append(message) | ||||
| 
 | ||||
|             return messages | ||||
| 
 | ||||
|     def get_conversation_summary(self) -> Dict: | ||||
|         """ | ||||
|         Get a summary of the current conversation. | ||||
| 
 | ||||
|         Returns: | ||||
|             Dict: Summary of the conversation including message counts, roles, and time range | ||||
|         """ | ||||
|         with self._get_connection() as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 SELECT  | ||||
|                     COUNT(*) as total_messages, | ||||
|                     COUNT(DISTINCT role) as unique_roles, | ||||
|                     MIN(timestamp) as first_message_time, | ||||
|                     MAX(timestamp) as last_message_time, | ||||
|                     SUM(token_count) as total_tokens | ||||
|                 FROM {self.table_name} | ||||
|                 WHERE conversation_id = ? | ||||
|             """, | ||||
|                 (self.current_conversation_id,), | ||||
|             ) | ||||
| 
 | ||||
|             row = cursor.fetchone() | ||||
|             return { | ||||
|                 "conversation_id": self.current_conversation_id, | ||||
|                 "total_messages": row["total_messages"], | ||||
|                 "unique_roles": row["unique_roles"], | ||||
|                 "first_message_time": row["first_message_time"], | ||||
|                 "last_message_time": row["last_message_time"], | ||||
|                 "total_tokens": row["total_tokens"], | ||||
|                 "roles": self.count_messages_by_role(), | ||||
|             } | ||||
| @ -0,0 +1,317 @@ | ||||
| import os | ||||
| import json | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
| import tempfile | ||||
| import threading | ||||
| from swarms.communication.duckdb_wrap import ( | ||||
|     DuckDBConversation, | ||||
|     Message, | ||||
|     MessageType, | ||||
| ) | ||||
| 
 | ||||
| def setup_test(): | ||||
|     """Set up test environment.""" | ||||
|     temp_dir = tempfile.TemporaryDirectory() | ||||
|     db_path = Path(temp_dir.name) / "test_conversations.duckdb" | ||||
|     conversation = DuckDBConversation( | ||||
|         db_path=str(db_path), | ||||
|         enable_timestamps=True, | ||||
|         enable_logging=True, | ||||
|     ) | ||||
|     return temp_dir, db_path, conversation | ||||
| 
 | ||||
| def cleanup_test(temp_dir, db_path): | ||||
|     """Clean up test environment.""" | ||||
|     if os.path.exists(db_path): | ||||
|         os.remove(db_path) | ||||
|     temp_dir.cleanup() | ||||
| 
 | ||||
| def test_initialization(): | ||||
|     """Test conversation initialization.""" | ||||
|     temp_dir, db_path, _ = setup_test() | ||||
|     try: | ||||
|         conv = DuckDBConversation(db_path=str(db_path)) | ||||
|         assert conv.db_path == db_path, "Database path mismatch" | ||||
|         assert conv.table_name == "conversations", "Table name mismatch" | ||||
|         assert conv.enable_timestamps is True, "Timestamps should be enabled" | ||||
|         assert conv.current_conversation_id is not None, "Conversation ID should not be None" | ||||
|         print("✓ Initialization test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_add_message(): | ||||
|     """Test adding a single message.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         msg_id = conversation.add( | ||||
|             role="user", | ||||
|             content="Hello, world!", | ||||
|             message_type=MessageType.USER, | ||||
|         ) | ||||
|         assert msg_id is not None, "Message ID should not be None" | ||||
|         assert isinstance(msg_id, int), "Message ID should be an integer" | ||||
|         print("✓ Add message test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_add_complex_message(): | ||||
|     """Test adding a message with complex content.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         complex_content = { | ||||
|             "text": "Hello", | ||||
|             "data": [1, 2, 3], | ||||
|             "nested": {"key": "value"} | ||||
|         } | ||||
|         msg_id = conversation.add( | ||||
|             role="assistant", | ||||
|             content=complex_content, | ||||
|             message_type=MessageType.ASSISTANT, | ||||
|             metadata={"source": "test"}, | ||||
|             token_count=10 | ||||
|         ) | ||||
|         assert msg_id is not None, "Message ID should not be None" | ||||
|         print("✓ Add complex message test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_batch_add(): | ||||
|     """Test batch adding messages.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         messages = [ | ||||
|             Message( | ||||
|                 role="user", | ||||
|                 content="First message", | ||||
|                 message_type=MessageType.USER | ||||
|             ), | ||||
|             Message( | ||||
|                 role="assistant", | ||||
|                 content="Second message", | ||||
|                 message_type=MessageType.ASSISTANT | ||||
|             ) | ||||
|         ] | ||||
|         msg_ids = conversation.batch_add(messages) | ||||
|         assert len(msg_ids) == 2, "Should have 2 message IDs" | ||||
|         assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers" | ||||
|         print("✓ Batch add test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_get_str(): | ||||
|     """Test getting conversation as string.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello") | ||||
|         conversation.add("assistant", "Hi there!") | ||||
|         conv_str = conversation.get_str() | ||||
|         assert "user: Hello" in conv_str, "User message not found" | ||||
|         assert "assistant: Hi there!" in conv_str, "Assistant message not found" | ||||
|         print("✓ Get string test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_get_messages(): | ||||
|     """Test getting messages with pagination.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         for i in range(5): | ||||
|             conversation.add("user", f"Message {i}") | ||||
|          | ||||
|         all_messages = conversation.get_messages() | ||||
|         assert len(all_messages) == 5, "Should have 5 messages" | ||||
|          | ||||
|         limited_messages = conversation.get_messages(limit=2) | ||||
|         assert len(limited_messages) == 2, "Should have 2 limited messages" | ||||
|          | ||||
|         offset_messages = conversation.get_messages(offset=2) | ||||
|         assert len(offset_messages) == 3, "Should have 3 offset messages" | ||||
|         print("✓ Get messages test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_search_messages(): | ||||
|     """Test searching messages.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello world") | ||||
|         conversation.add("assistant", "Hello there") | ||||
|         conversation.add("user", "Goodbye world") | ||||
|          | ||||
|         results = conversation.search_messages("world") | ||||
|         assert len(results) == 2, "Should find 2 messages with 'world'" | ||||
|         assert all("world" in msg["content"] for msg in results), "All results should contain 'world'" | ||||
|         print("✓ Search messages test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_get_statistics(): | ||||
|     """Test getting conversation statistics.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello", token_count=2) | ||||
|         conversation.add("assistant", "Hi", token_count=1) | ||||
|          | ||||
|         stats = conversation.get_statistics() | ||||
|         assert stats["total_messages"] == 2, "Should have 2 total messages" | ||||
|         assert stats["unique_roles"] == 2, "Should have 2 unique roles" | ||||
|         assert stats["total_tokens"] == 3, "Should have 3 total tokens" | ||||
|         print("✓ Get statistics test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_json_operations(): | ||||
|     """Test JSON save and load operations.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello") | ||||
|         conversation.add("assistant", "Hi") | ||||
|          | ||||
|         json_path = Path(temp_dir.name) / "test_conversation.json" | ||||
|         conversation.save_as_json(str(json_path)) | ||||
|         assert json_path.exists(), "JSON file should exist" | ||||
|          | ||||
|         new_conversation = DuckDBConversation( | ||||
|             db_path=str(Path(temp_dir.name) / "new.duckdb") | ||||
|         ) | ||||
|         assert new_conversation.load_from_json(str(json_path)), "Should load from JSON" | ||||
|         assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" | ||||
|         print("✓ JSON operations test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_yaml_operations(): | ||||
|     """Test YAML save and load operations.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello") | ||||
|         conversation.add("assistant", "Hi") | ||||
|          | ||||
|         yaml_path = Path(temp_dir.name) / "test_conversation.yaml" | ||||
|         conversation.save_as_yaml(str(yaml_path)) | ||||
|         assert yaml_path.exists(), "YAML file should exist" | ||||
|          | ||||
|         new_conversation = DuckDBConversation( | ||||
|             db_path=str(Path(temp_dir.name) / "new.duckdb") | ||||
|         ) | ||||
|         assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML" | ||||
|         assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" | ||||
|         print("✓ YAML operations test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_message_types(): | ||||
|     """Test different message types.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("system", "System message", message_type=MessageType.SYSTEM) | ||||
|         conversation.add("user", "User message", message_type=MessageType.USER) | ||||
|         conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT) | ||||
|         conversation.add("function", "Function message", message_type=MessageType.FUNCTION) | ||||
|         conversation.add("tool", "Tool message", message_type=MessageType.TOOL) | ||||
|          | ||||
|         messages = conversation.get_messages() | ||||
|         assert len(messages) == 5, "Should have 5 messages" | ||||
|         assert all("message_type" in msg for msg in messages), "All messages should have type" | ||||
|         print("✓ Message types test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_delete_operations(): | ||||
|     """Test deletion operations.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         conversation.add("user", "Hello") | ||||
|         conversation.add("assistant", "Hi") | ||||
|          | ||||
|         assert conversation.delete_current_conversation(), "Should delete conversation" | ||||
|         assert len(conversation.get_messages()) == 0, "Should have no messages after delete" | ||||
|          | ||||
|         conversation.add("user", "New message") | ||||
|         assert conversation.clear_all(), "Should clear all messages" | ||||
|         assert len(conversation.get_messages()) == 0, "Should have no messages after clear" | ||||
|         print("✓ Delete operations test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_concurrent_operations(): | ||||
|     """Test concurrent operations.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         def add_messages(): | ||||
|             for i in range(10): | ||||
|                 conversation.add("user", f"Message {i}") | ||||
|          | ||||
|         threads = [threading.Thread(target=add_messages) for _ in range(5)] | ||||
|         for thread in threads: | ||||
|             thread.start() | ||||
|         for thread in threads: | ||||
|             thread.join() | ||||
|          | ||||
|         messages = conversation.get_messages() | ||||
|         assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)" | ||||
|         print("✓ Concurrent operations test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def test_error_handling(): | ||||
|     """Test error handling.""" | ||||
|     temp_dir, db_path, conversation = setup_test() | ||||
|     try: | ||||
|         # Test invalid message type | ||||
|         try: | ||||
|             conversation.add("user", "Message", message_type="invalid") | ||||
|             assert False, "Should raise exception for invalid message type" | ||||
|         except Exception: | ||||
|             pass | ||||
|          | ||||
|         # Test invalid JSON content | ||||
|         try: | ||||
|             conversation.add("user", {"invalid": object()}) | ||||
|             assert False, "Should raise exception for invalid JSON content" | ||||
|         except Exception: | ||||
|             pass | ||||
|          | ||||
|         # Test invalid file operations | ||||
|         try: | ||||
|             conversation.load_from_json("/nonexistent/path.json") | ||||
|             assert False, "Should raise exception for invalid file path" | ||||
|         except Exception: | ||||
|             pass | ||||
|          | ||||
|         print("✓ Error handling test passed") | ||||
|     finally: | ||||
|         cleanup_test(temp_dir, db_path) | ||||
| 
 | ||||
| def run_all_tests(): | ||||
|     """Run all tests.""" | ||||
|     print("Running DuckDB Conversation tests...") | ||||
|     tests = [ | ||||
|         test_initialization, | ||||
|         test_add_message, | ||||
|         test_add_complex_message, | ||||
|         test_batch_add, | ||||
|         test_get_str, | ||||
|         test_get_messages, | ||||
|         test_search_messages, | ||||
|         test_get_statistics, | ||||
|         test_json_operations, | ||||
|         test_yaml_operations, | ||||
|         test_message_types, | ||||
|         test_delete_operations, | ||||
|         test_concurrent_operations, | ||||
|         test_error_handling | ||||
|     ] | ||||
|      | ||||
|     for test in tests: | ||||
|         try: | ||||
|             test() | ||||
|         except Exception as e: | ||||
|             print(f"✗ {test.__name__} failed: {str(e)}") | ||||
|             raise | ||||
|      | ||||
|     print("\nAll tests completed successfully!") | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     run_all_tests()  | ||||
| @ -0,0 +1,386 @@ | ||||
| import json | ||||
| import datetime | ||||
| import os | ||||
| from typing import Dict, List, Any, Tuple | ||||
| from loguru import logger | ||||
| from swarms.communication.sqlite_wrap import ( | ||||
|     SQLiteConversation, | ||||
|     Message, | ||||
|     MessageType, | ||||
| ) | ||||
| from rich.console import Console | ||||
| from rich.table import Table | ||||
| from rich.panel import Panel | ||||
| 
 | ||||
| console = Console() | ||||
| 
 | ||||
| 
 | ||||
| def print_test_header(test_name: str) -> None: | ||||
|     """Print a formatted test header.""" | ||||
|     console.print( | ||||
|         Panel( | ||||
|             f"[bold blue]Running Test: {test_name}[/bold blue]", | ||||
|             expand=False, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def print_test_result( | ||||
|     test_name: str, success: bool, message: str, execution_time: float | ||||
| ) -> None: | ||||
|     """Print a formatted test result.""" | ||||
|     status = ( | ||||
|         "[bold green]PASSED[/bold green]" | ||||
|         if success | ||||
|         else "[bold red]FAILED[/bold red]" | ||||
|     ) | ||||
|     console.print(f"\n{status} - {test_name}") | ||||
|     console.print(f"Message: {message}") | ||||
|     console.print(f"Execution time: {execution_time:.3f} seconds\n") | ||||
| 
 | ||||
| 
 | ||||
| def print_messages( | ||||
|     messages: List[Dict], title: str = "Messages" | ||||
| ) -> None: | ||||
|     """Print messages in a formatted table.""" | ||||
|     table = Table(title=title) | ||||
|     table.add_column("Role", style="cyan") | ||||
|     table.add_column("Content", style="green") | ||||
|     table.add_column("Type", style="yellow") | ||||
|     table.add_column("Timestamp", style="magenta") | ||||
| 
 | ||||
|     for msg in messages: | ||||
|         content = str(msg.get("content", "")) | ||||
|         if isinstance(content, (dict, list)): | ||||
|             content = json.dumps(content) | ||||
|         table.add_row( | ||||
|             msg.get("role", ""), | ||||
|             content, | ||||
|             str(msg.get("message_type", "")), | ||||
|             str(msg.get("timestamp", "")), | ||||
|         ) | ||||
| 
 | ||||
|     console.print(table) | ||||
| 
 | ||||
| 
 | ||||
| def run_test( | ||||
|     test_func: callable, *args, **kwargs | ||||
| ) -> Tuple[bool, str, float]: | ||||
|     """ | ||||
|     Run a test function and return its results. | ||||
| 
 | ||||
|     Args: | ||||
|         test_func: The test function to run | ||||
|         *args: Arguments for the test function | ||||
|         **kwargs: Keyword arguments for the test function | ||||
| 
 | ||||
|     Returns: | ||||
|         Tuple[bool, str, float]: (success, message, execution_time) | ||||
|     """ | ||||
|     start_time = datetime.datetime.now() | ||||
|     try: | ||||
|         result = test_func(*args, **kwargs) | ||||
|         end_time = datetime.datetime.now() | ||||
|         execution_time = (end_time - start_time).total_seconds() | ||||
|         return True, str(result), execution_time | ||||
|     except Exception as e: | ||||
|         end_time = datetime.datetime.now() | ||||
|         execution_time = (end_time - start_time).total_seconds() | ||||
|         return False, str(e), execution_time | ||||
| 
 | ||||
| 
 | ||||
| def test_basic_conversation() -> bool: | ||||
|     """Test basic conversation operations.""" | ||||
|     print_test_header("Basic Conversation Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
| 
 | ||||
|     # Test adding messages | ||||
|     console.print("\n[bold]Adding messages...[/bold]") | ||||
|     msg_id1 = conversation.add("user", "Hello") | ||||
|     msg_id2 = conversation.add("assistant", "Hi there!") | ||||
| 
 | ||||
|     # Test getting messages | ||||
|     console.print("\n[bold]Retrieved messages:[/bold]") | ||||
|     messages = conversation.get_messages() | ||||
|     print_messages(messages) | ||||
| 
 | ||||
|     assert len(messages) == 2 | ||||
|     assert messages[0]["role"] == "user" | ||||
|     assert messages[1]["role"] == "assistant" | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def test_message_types() -> bool: | ||||
|     """Test different message types and content formats.""" | ||||
|     print_test_header("Message Types Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
| 
 | ||||
|     # Test different content types | ||||
|     console.print("\n[bold]Adding different message types...[/bold]") | ||||
|     conversation.add("user", "Simple text") | ||||
|     conversation.add( | ||||
|         "assistant", {"type": "json", "content": "Complex data"} | ||||
|     ) | ||||
|     conversation.add("system", ["list", "of", "items"]) | ||||
|     conversation.add( | ||||
|         "function", | ||||
|         "Function result", | ||||
|         message_type=MessageType.FUNCTION, | ||||
|     ) | ||||
| 
 | ||||
|     console.print("\n[bold]Retrieved messages:[/bold]") | ||||
|     messages = conversation.get_messages() | ||||
|     print_messages(messages) | ||||
| 
 | ||||
|     assert len(messages) == 4 | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def test_conversation_operations() -> bool: | ||||
|     """Test various conversation operations.""" | ||||
|     print_test_header("Conversation Operations Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
| 
 | ||||
|     # Test batch operations | ||||
|     console.print("\n[bold]Adding batch messages...[/bold]") | ||||
|     messages = [ | ||||
|         Message(role="user", content="Message 1"), | ||||
|         Message(role="assistant", content="Message 2"), | ||||
|         Message(role="user", content="Message 3"), | ||||
|     ] | ||||
|     conversation.batch_add(messages) | ||||
| 
 | ||||
|     console.print("\n[bold]Retrieved messages:[/bold]") | ||||
|     all_messages = conversation.get_messages() | ||||
|     print_messages(all_messages) | ||||
| 
 | ||||
|     # Test statistics | ||||
|     console.print("\n[bold]Conversation Statistics:[/bold]") | ||||
|     stats = conversation.get_statistics() | ||||
|     console.print(json.dumps(stats, indent=2)) | ||||
| 
 | ||||
|     # Test role counting | ||||
|     console.print("\n[bold]Role Counts:[/bold]") | ||||
|     role_counts = conversation.count_messages_by_role() | ||||
|     console.print(json.dumps(role_counts, indent=2)) | ||||
| 
 | ||||
|     assert stats["total_messages"] == 3 | ||||
|     assert role_counts["user"] == 2 | ||||
|     assert role_counts["assistant"] == 1 | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def test_file_operations() -> bool: | ||||
|     """Test file operations (JSON/YAML).""" | ||||
|     print_test_header("File Operations Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     json_path = "test_conversation.json" | ||||
|     yaml_path = "test_conversation.yaml" | ||||
| 
 | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
|     conversation.add("user", "Test message") | ||||
| 
 | ||||
|     # Test JSON operations | ||||
|     console.print("\n[bold]Testing JSON operations...[/bold]") | ||||
|     assert conversation.save_as_json(json_path) | ||||
|     console.print(f"Saved to JSON: {json_path}") | ||||
| 
 | ||||
|     conversation.start_new_conversation() | ||||
|     assert conversation.load_from_json(json_path) | ||||
|     console.print("Loaded from JSON") | ||||
| 
 | ||||
|     # Test YAML operations | ||||
|     console.print("\n[bold]Testing YAML operations...[/bold]") | ||||
|     assert conversation.save_as_yaml(yaml_path) | ||||
|     console.print(f"Saved to YAML: {yaml_path}") | ||||
| 
 | ||||
|     conversation.start_new_conversation() | ||||
|     assert conversation.load_from_yaml(yaml_path) | ||||
|     console.print("Loaded from YAML") | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     os.remove(json_path) | ||||
|     os.remove(yaml_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def test_search_and_filter() -> bool: | ||||
|     """Test search and filter operations.""" | ||||
|     print_test_header("Search and Filter Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
| 
 | ||||
|     # Add test messages | ||||
|     console.print("\n[bold]Adding test messages...[/bold]") | ||||
|     conversation.add("user", "Hello world") | ||||
|     conversation.add("assistant", "Hello there") | ||||
|     conversation.add("user", "Goodbye world") | ||||
| 
 | ||||
|     # Test search | ||||
|     console.print("\n[bold]Searching for 'world'...[/bold]") | ||||
|     results = conversation.search_messages("world") | ||||
|     print_messages(results, "Search Results") | ||||
| 
 | ||||
|     # Test role filtering | ||||
|     console.print("\n[bold]Filtering user messages...[/bold]") | ||||
|     user_messages = conversation.get_messages_by_role("user") | ||||
|     print_messages(user_messages, "User Messages") | ||||
| 
 | ||||
|     assert len(results) == 2 | ||||
|     assert len(user_messages) == 2 | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def test_conversation_management() -> bool: | ||||
|     """Test conversation management features.""" | ||||
|     print_test_header("Conversation Management Test") | ||||
| 
 | ||||
|     db_path = "test_conversations.db" | ||||
|     conversation = SQLiteConversation(db_path=db_path) | ||||
| 
 | ||||
|     # Test conversation ID generation | ||||
|     console.print("\n[bold]Testing conversation IDs...[/bold]") | ||||
|     conv_id1 = conversation.get_conversation_id() | ||||
|     console.print(f"First conversation ID: {conv_id1}") | ||||
| 
 | ||||
|     conversation.start_new_conversation() | ||||
|     conv_id2 = conversation.get_conversation_id() | ||||
|     console.print(f"Second conversation ID: {conv_id2}") | ||||
| 
 | ||||
|     assert conv_id1 != conv_id2 | ||||
| 
 | ||||
|     # Test conversation deletion | ||||
|     console.print("\n[bold]Testing conversation deletion...[/bold]") | ||||
|     conversation.add("user", "Test message") | ||||
|     assert conversation.delete_current_conversation() | ||||
|     console.print("Conversation deleted successfully") | ||||
| 
 | ||||
|     # Cleanup | ||||
|     os.remove(db_path) | ||||
|     return True | ||||
| 
 | ||||
| 
 | ||||
| def generate_test_report( | ||||
|     test_results: List[Dict[str, Any]] | ||||
| ) -> Dict[str, Any]: | ||||
|     """ | ||||
|     Generate a test report in JSON format. | ||||
| 
 | ||||
|     Args: | ||||
|         test_results: List of test results | ||||
| 
 | ||||
|     Returns: | ||||
|         Dict containing the test report | ||||
|     """ | ||||
|     total_tests = len(test_results) | ||||
|     passed_tests = sum( | ||||
|         1 for result in test_results if result["success"] | ||||
|     ) | ||||
|     failed_tests = total_tests - passed_tests | ||||
|     total_time = sum( | ||||
|         result["execution_time"] for result in test_results | ||||
|     ) | ||||
| 
 | ||||
|     report = { | ||||
|         "timestamp": datetime.datetime.now().isoformat(), | ||||
|         "summary": { | ||||
|             "total_tests": total_tests, | ||||
|             "passed_tests": passed_tests, | ||||
|             "failed_tests": failed_tests, | ||||
|             "total_execution_time": total_time, | ||||
|             "average_execution_time": ( | ||||
|                 total_time / total_tests if total_tests > 0 else 0 | ||||
|             ), | ||||
|         }, | ||||
|         "test_results": test_results, | ||||
|     } | ||||
| 
 | ||||
|     return report | ||||
| 
 | ||||
| 
 | ||||
| def run_all_tests() -> None: | ||||
|     """Run all tests and generate a report.""" | ||||
|     console.print( | ||||
|         Panel( | ||||
|             "[bold blue]Starting Test Suite[/bold blue]", expand=False | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     tests = [ | ||||
|         ("Basic Conversation", test_basic_conversation), | ||||
|         ("Message Types", test_message_types), | ||||
|         ("Conversation Operations", test_conversation_operations), | ||||
|         ("File Operations", test_file_operations), | ||||
|         ("Search and Filter", test_search_and_filter), | ||||
|         ("Conversation Management", test_conversation_management), | ||||
|     ] | ||||
| 
 | ||||
|     test_results = [] | ||||
| 
 | ||||
|     for test_name, test_func in tests: | ||||
|         logger.info(f"Running test: {test_name}") | ||||
|         success, message, execution_time = run_test(test_func) | ||||
| 
 | ||||
|         print_test_result(test_name, success, message, execution_time) | ||||
| 
 | ||||
|         result = { | ||||
|             "test_name": test_name, | ||||
|             "success": success, | ||||
|             "message": message, | ||||
|             "execution_time": execution_time, | ||||
|             "timestamp": datetime.datetime.now().isoformat(), | ||||
|         } | ||||
| 
 | ||||
|         if success: | ||||
|             logger.success(f"Test passed: {test_name}") | ||||
|         else: | ||||
|             logger.error(f"Test failed: {test_name} - {message}") | ||||
| 
 | ||||
|         test_results.append(result) | ||||
| 
 | ||||
|     # Generate and save report | ||||
|     report = generate_test_report(test_results) | ||||
|     report_path = "test_report.json" | ||||
| 
 | ||||
|     with open(report_path, "w") as f: | ||||
|         json.dump(report, f, indent=2) | ||||
| 
 | ||||
|     # Print final summary | ||||
|     console.print("\n[bold blue]Test Suite Summary[/bold blue]") | ||||
|     console.print( | ||||
|         Panel( | ||||
|             f"Total tests: {report['summary']['total_tests']}\n" | ||||
|             f"Passed tests: {report['summary']['passed_tests']}\n" | ||||
|             f"Failed tests: {report['summary']['failed_tests']}\n" | ||||
|             f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds", | ||||
|             title="Summary", | ||||
|             expand=False, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     logger.info(f"Test report saved to {report_path}") | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     run_all_tests() | ||||
					Loading…
					
					
				
		Reference in new issue