diff --git a/pyproject.toml b/pyproject.toml index b1005c56..e9628e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "7.7.3" +version = "7.7.5" description = "Swarms - TGSC" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/communication/__init__.py b/swarms/communication/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/swarms/communication/duckdb_wrap.py b/swarms/communication/duckdb_wrap.py new file mode 100644 index 00000000..7ea3223d --- /dev/null +++ b/swarms/communication/duckdb_wrap.py @@ -0,0 +1,1037 @@ +import duckdb +import json +import datetime +from typing import List, Optional, Union, Dict +from pathlib import Path +import threading +from contextlib import contextmanager +import logging +from dataclasses import dataclass +from enum import Enum +import uuid +import yaml + +try: + from loguru import logger + + LOGURU_AVAILABLE = True +except ImportError: + LOGURU_AVAILABLE = False + + +class MessageType(Enum): + """Enum for different types of messages in the conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + +@dataclass +class Message: + """Data class representing a message in the conversation.""" + + role: str + content: Union[str, dict, list] + timestamp: Optional[str] = None + message_type: Optional[MessageType] = None + metadata: Optional[Dict] = None + token_count: Optional[int] = None + + class Config: + arbitrary_types_allowed = True + + +class DateTimeEncoder(json.JSONEncoder): + """Custom JSON encoder for handling datetime objects.""" + def default(self, obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + return super().default(obj) + + +class DuckDBConversation: + """ + A production-grade DuckDB wrapper class for managing conversation history. + This class provides persistent storage for conversations with various features + like message tracking, timestamps, and metadata support. + + Attributes: + db_path (str): Path to the DuckDB database file + table_name (str): Name of the table to store conversations + enable_timestamps (bool): Whether to track message timestamps + enable_logging (bool): Whether to enable logging + use_loguru (bool): Whether to use loguru for logging + max_retries (int): Maximum number of retries for database operations + connection_timeout (float): Timeout for database connections + current_conversation_id (str): Current active conversation ID + """ + + def __init__( + self, + db_path: Union[str, Path] = "conversations.duckdb", + table_name: str = "conversations", + enable_timestamps: bool = True, + enable_logging: bool = True, + use_loguru: bool = True, + max_retries: int = 3, + connection_timeout: float = 5.0, + ): + self.db_path = Path(db_path) + self.table_name = table_name + self.enable_timestamps = enable_timestamps + self.enable_logging = enable_logging + self.use_loguru = use_loguru and LOGURU_AVAILABLE + self.max_retries = max_retries + self.connection_timeout = connection_timeout + self.current_conversation_id = None + self._lock = threading.Lock() + + # Setup logging + if self.enable_logging: + if self.use_loguru: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + + self._init_db() + self.start_new_conversation() + + def _generate_conversation_id(self) -> str: + """Generate a unique conversation ID using UUID and timestamp.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = str(uuid.uuid4())[:8] + return f"conv_{timestamp}_{unique_id}" + + def start_new_conversation(self) -> str: + """ + Start a new conversation and return its ID. + + Returns: + str: The new conversation ID + """ + self.current_conversation_id = ( + self._generate_conversation_id() + ) + return self.current_conversation_id + + def _init_db(self): + """Initialize the database and create necessary tables.""" + with self._get_connection() as conn: + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id BIGINT PRIMARY KEY, + role VARCHAR NOT NULL, + content VARCHAR NOT NULL, + timestamp TIMESTAMP, + message_type VARCHAR, + metadata VARCHAR, + token_count INTEGER, + conversation_id VARCHAR, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + @contextmanager + def _get_connection(self): + """Context manager for database connections with retry logic.""" + conn = None + for attempt in range(self.max_retries): + try: + conn = duckdb.connect(str(self.db_path)) + yield conn + break + except Exception as e: + if attempt == self.max_retries - 1: + raise + if self.enable_logging: + self.logger.warning( + f"Database connection attempt {attempt + 1} failed: {e}" + ) + if conn: + conn.close() + conn = None + + def add( + self, + role: str, + content: Union[str, dict, list], + message_type: Optional[MessageType] = None, + metadata: Optional[Dict] = None, + token_count: Optional[int] = None, + ) -> int: + """ + Add a message to the current conversation. + + Args: + role (str): The role of the speaker + content (Union[str, dict, list]): The content of the message + message_type (Optional[MessageType]): Type of the message + metadata (Optional[Dict]): Additional metadata for the message + token_count (Optional[int]): Number of tokens in the message + + Returns: + int: The ID of the inserted message + """ + timestamp = ( + datetime.datetime.now().isoformat() + if self.enable_timestamps + else None + ) + + if isinstance(content, (dict, list)): + content = json.dumps(content) + + with self._get_connection() as conn: + # Get the next ID + result = conn.execute( + f"SELECT COALESCE(MAX(id), 0) + 1 as next_id FROM {self.table_name}" + ).fetchone() + next_id = result[0] + + # Insert the message + conn.execute( + f""" + INSERT INTO {self.table_name} + (id, role, content, timestamp, message_type, metadata, token_count, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + next_id, + role, + content, + timestamp, + message_type.value if message_type else None, + json.dumps(metadata) if metadata else None, + token_count, + self.current_conversation_id, + ), + ) + return next_id + + def batch_add(self, messages: List[Message]) -> List[int]: + """ + Add multiple messages to the current conversation. + + Args: + messages (List[Message]): List of messages to add + + Returns: + List[int]: List of inserted message IDs + """ + with self._get_connection() as conn: + message_ids = [] + + # Get the starting ID + result = conn.execute( + f"SELECT COALESCE(MAX(id), 0) + 1 as next_id FROM {self.table_name}" + ).fetchone() + next_id = result[0] + + for i, message in enumerate(messages): + content = message.content + if isinstance(content, (dict, list)): + content = json.dumps(content) + + conn.execute( + f""" + INSERT INTO {self.table_name} + (id, role, content, timestamp, message_type, metadata, token_count, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + next_id + i, + message.role, + content, + ( + message.timestamp.isoformat() + if message.timestamp + else None + ), + ( + message.message_type.value + if message.message_type + else None + ), + ( + json.dumps(message.metadata) + if message.metadata + else None + ), + message.token_count, + self.current_conversation_id, + ), + ) + message_ids.append(next_id + i) + + return message_ids + + def get_str(self) -> str: + """ + Get the current conversation history as a formatted string. + + Returns: + str: Formatted conversation history + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """, + (self.current_conversation_id,), + ).fetchall() + + messages = [] + for row in result: + content = row[2] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + timestamp = ( + f"[{row[3]}] " if row[3] else "" + ) # timestamp column + messages.append( + f"{timestamp}{row[1]}: {content}" + ) # role column + + return "\n".join(messages) + + def get_messages( + self, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[Dict]: + """ + Get messages from the current conversation with optional pagination. + + Args: + limit (Optional[int]): Maximum number of messages to return + offset (Optional[int]): Number of messages to skip + + Returns: + List[Dict]: List of message dictionaries + """ + with self._get_connection() as conn: + query = f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """ + params = [self.current_conversation_id] + + if limit is not None: + query += " LIMIT ?" + params.append(limit) + + if offset is not None: + query += " OFFSET ?" + params.append(offset) + + result = conn.execute(query, params).fetchall() + messages = [] + for row in result: + content = row[2] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "role": row[1], # role column + "content": content, + } + + if row[3]: # timestamp column + message["timestamp"] = row[3] + if row[4]: # message_type column + message["message_type"] = row[4] + if row[5]: # metadata column + message["metadata"] = json.loads(row[5]) + if row[6]: # token_count column + message["token_count"] = row[6] + + messages.append(message) + + return messages + + def delete_current_conversation(self) -> bool: + """ + Delete the current conversation. + + Returns: + bool: True if deletion was successful + """ + with self._get_connection() as conn: + result = conn.execute( + f"DELETE FROM {self.table_name} WHERE conversation_id = ?", + (self.current_conversation_id,), + ) + return result.rowcount > 0 + + def update_message( + self, + message_id: int, + content: Union[str, dict, list], + metadata: Optional[Dict] = None, + ) -> bool: + """ + Update an existing message in the current conversation. + + Args: + message_id (int): ID of the message to update + content (Union[str, dict, list]): New content for the message + metadata (Optional[Dict]): New metadata for the message + + Returns: + bool: True if update was successful + """ + if isinstance(content, (dict, list)): + content = json.dumps(content) + + with self._get_connection() as conn: + result = conn.execute( + f""" + UPDATE {self.table_name} + SET content = ?, metadata = ? + WHERE id = ? AND conversation_id = ? + """, + ( + content, + json.dumps(metadata) if metadata else None, + message_id, + self.current_conversation_id, + ), + ) + return result.rowcount > 0 + + def search_messages(self, query: str) -> List[Dict]: + """ + Search for messages containing specific text in the current conversation. + + Args: + query (str): Text to search for + + Returns: + List[Dict]: List of matching messages + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? AND content LIKE ? + """, + (self.current_conversation_id, f"%{query}%"), + ).fetchall() + + messages = [] + for row in result: + content = row[2] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "role": row[1], # role column + "content": content, + } + + if row[3]: # timestamp column + message["timestamp"] = row[3] + if row[4]: # message_type column + message["message_type"] = row[4] + if row[5]: # metadata column + message["metadata"] = json.loads(row[5]) + if row[6]: # token_count column + message["token_count"] = row[6] + + messages.append(message) + + return messages + + def get_statistics(self) -> Dict: + """ + Get statistics about the current conversation. + + Returns: + Dict: Statistics about the conversation + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT + COUNT(*) as total_messages, + COUNT(DISTINCT role) as unique_roles, + SUM(token_count) as total_tokens, + MIN(timestamp) as first_message, + MAX(timestamp) as last_message + FROM {self.table_name} + WHERE conversation_id = ? + """, + (self.current_conversation_id,), + ).fetchone() + + return { + "total_messages": result[0], + "unique_roles": result[1], + "total_tokens": result[2], + "first_message": result[3], + "last_message": result[4], + } + + def clear_all(self) -> bool: + """ + Clear all messages from the database. + + Returns: + bool: True if clearing was successful + """ + with self._get_connection() as conn: + conn.execute(f"DELETE FROM {self.table_name}") + return True + + def get_conversation_id(self) -> str: + """ + Get the current conversation ID. + + Returns: + str: The current conversation ID + """ + return self.current_conversation_id + + def to_dict(self) -> List[Dict]: + """ + Convert the current conversation to a list of dictionaries. + + Returns: + List[Dict]: List of message dictionaries + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT role, content, timestamp, message_type, metadata, token_count + FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """, + (self.current_conversation_id,), + ).fetchall() + + messages = [] + for row in result: + content = row[1] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = {"role": row[0], "content": content} # role column + + if row[2]: # timestamp column + message["timestamp"] = row[2] + if row[3]: # message_type column + message["message_type"] = row[3] + if row[4]: # metadata column + message["metadata"] = json.loads(row[4]) + if row[5]: # token_count column + message["token_count"] = row[5] + + messages.append(message) + + return messages + + def to_json(self) -> str: + """ + Convert the current conversation to a JSON string. + + Returns: + str: JSON string representation of the conversation + """ + return json.dumps(self.to_dict(), indent=2, cls=DateTimeEncoder) + + def to_yaml(self) -> str: + """ + Convert the current conversation to a YAML string. + + Returns: + str: YAML string representation of the conversation + """ + return yaml.dump(self.to_dict()) + + def save_as_json(self, filename: str) -> bool: + """ + Save the current conversation to a JSON file. + + Args: + filename (str): Path to save the JSON file + + Returns: + bool: True if save was successful + """ + try: + with open(filename, "w") as f: + json.dump(self.to_dict(), f, indent=2, cls=DateTimeEncoder) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to save conversation to JSON: {e}" + ) + return False + + def load_from_json(self, filename: str) -> bool: + """ + Load a conversation from a JSON file. + + Args: + filename (str): Path to the JSON file + + Returns: + bool: True if load was successful + """ + try: + with open(filename, "r") as f: + messages = json.load(f) + + # Start a new conversation + self.start_new_conversation() + + # Add all messages + for message in messages: + # Convert timestamp string back to datetime if it exists + timestamp = None + if "timestamp" in message: + try: + timestamp = datetime.datetime.fromisoformat(message["timestamp"]) + except (ValueError, TypeError): + timestamp = message["timestamp"] + + self.add( + role=message["role"], + content=message["content"], + message_type=( + MessageType(message["message_type"]) + if "message_type" in message + else None + ), + metadata=message.get("metadata"), + token_count=message.get("token_count"), + ) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to load conversation from JSON: {e}" + ) + return False + + def get_last_message(self) -> Optional[Dict]: + """ + Get the last message from the current conversation. + + Returns: + Optional[Dict]: The last message or None if conversation is empty + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id DESC + LIMIT 1 + """, + (self.current_conversation_id,), + ).fetchone() + + if not result: + return None + + content = result[2] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "role": result[1], # role column + "content": content, + } + + if result[3]: # timestamp column + message["timestamp"] = result[3] + if result[4]: # message_type column + message["message_type"] = result[4] + if result[5]: # metadata column + message["metadata"] = json.loads(result[5]) + if result[6]: # token_count column + message["token_count"] = result[6] + + return message + + def get_last_message_as_string(self) -> str: + """ + Get the last message as a formatted string. + + Returns: + str: Formatted string of the last message + """ + last_message = self.get_last_message() + if not last_message: + return "" + + timestamp = ( + f"[{last_message['timestamp']}] " + if "timestamp" in last_message + else "" + ) + return f"{timestamp}{last_message['role']}: {last_message['content']}" + + def count_messages_by_role(self) -> Dict[str, int]: + """ + Count messages by role in the current conversation. + + Returns: + Dict[str, int]: Dictionary with role counts + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT role, COUNT(*) as count + FROM {self.table_name} + WHERE conversation_id = ? + GROUP BY role + """, + (self.current_conversation_id,), + ).fetchall() + + return {row[0]: row[1] for row in result} + + def get_messages_by_role(self, role: str) -> List[Dict]: + """ + Get all messages from a specific role in the current conversation. + + Args: + role (str): Role to filter messages by + + Returns: + List[Dict]: List of messages from the specified role + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? AND role = ? + ORDER BY id ASC + """, + (self.current_conversation_id, role), + ).fetchall() + + messages = [] + for row in result: + content = row[2] # content column + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "role": row[1], # role column + "content": content, + } + + if row[3]: # timestamp column + message["timestamp"] = row[3] + if row[4]: # message_type column + message["message_type"] = row[4] + if row[5]: # metadata column + message["metadata"] = json.loads(row[5]) + if row[6]: # token_count column + message["token_count"] = row[6] + + messages.append(message) + + return messages + + def get_conversation_summary(self) -> Dict: + """ + Get a summary of the current conversation. + + Returns: + Dict: Summary of the conversation including message counts, roles, and time range + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT + COUNT(*) as total_messages, + COUNT(DISTINCT role) as unique_roles, + MIN(timestamp) as first_message_time, + MAX(timestamp) as last_message_time, + SUM(token_count) as total_tokens + FROM {self.table_name} + WHERE conversation_id = ? + """, + (self.current_conversation_id,), + ).fetchone() + + return { + "conversation_id": self.current_conversation_id, + "total_messages": result[0], + "unique_roles": result[1], + "first_message_time": result[2], + "last_message_time": result[3], + "total_tokens": result[4], + "roles": self.count_messages_by_role(), + } + + def get_conversation_as_dict(self) -> Dict: + """ + Get the entire conversation as a dictionary with messages and metadata. + + Returns: + Dict: Dictionary containing conversation ID, messages, and metadata + """ + messages = self.get_messages() + stats = self.get_statistics() + + return { + "conversation_id": self.current_conversation_id, + "messages": messages, + "metadata": { + "total_messages": stats["total_messages"], + "unique_roles": stats["unique_roles"], + "total_tokens": stats["total_tokens"], + "first_message": stats["first_message"], + "last_message": stats["last_message"], + "roles": self.count_messages_by_role(), + }, + } + + def get_conversation_by_role_dict(self) -> Dict[str, List[Dict]]: + """ + Get the conversation organized by roles. + + Returns: + Dict[str, List[Dict]]: Dictionary with roles as keys and lists of messages as values + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT role, content, timestamp, message_type, metadata, token_count + FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """, + (self.current_conversation_id,), + ).fetchall() + + role_dict = {} + for row in result: + role = row[0] + content = row[1] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "content": content, + "timestamp": row[2], + "message_type": row[3], + "metadata": ( + json.loads(row[4]) if row[4] else None + ), + "token_count": row[5], + } + + if role not in role_dict: + role_dict[role] = [] + role_dict[role].append(message) + + return role_dict + + def get_conversation_timeline_dict(self) -> Dict[str, List[Dict]]: + """ + Get the conversation organized by timestamps. + + Returns: + Dict[str, List[Dict]]: Dictionary with dates as keys and lists of messages as values + """ + with self._get_connection() as conn: + result = conn.execute( + f""" + SELECT + DATE(timestamp) as date, + role, + content, + timestamp, + message_type, + metadata, + token_count + FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY timestamp ASC + """, + (self.current_conversation_id,), + ).fetchall() + + timeline_dict = {} + for row in result: + date = row[0] + content = row[2] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = { + "role": row[1], + "content": content, + "timestamp": row[3], + "message_type": row[4], + "metadata": ( + json.loads(row[5]) if row[5] else None + ), + "token_count": row[6], + } + + if date not in timeline_dict: + timeline_dict[date] = [] + timeline_dict[date].append(message) + + return timeline_dict + + def get_conversation_metadata_dict(self) -> Dict: + """ + Get detailed metadata about the conversation. + + Returns: + Dict: Dictionary containing detailed conversation metadata + """ + with self._get_connection() as conn: + # Get basic statistics + stats = self.get_statistics() + + # Get message type distribution + type_dist = conn.execute( + f""" + SELECT message_type, COUNT(*) as count + FROM {self.table_name} + WHERE conversation_id = ? + GROUP BY message_type + """, + (self.current_conversation_id,), + ).fetchall() + + # Get average tokens per message + avg_tokens = conn.execute( + f""" + SELECT AVG(token_count) as avg_tokens + FROM {self.table_name} + WHERE conversation_id = ? AND token_count IS NOT NULL + """, + (self.current_conversation_id,), + ).fetchone() + + # Get message frequency by hour + hourly_freq = conn.execute( + f""" + SELECT + EXTRACT(HOUR FROM timestamp) as hour, + COUNT(*) as count + FROM {self.table_name} + WHERE conversation_id = ? + GROUP BY hour + ORDER BY hour + """, + (self.current_conversation_id,), + ).fetchall() + + return { + "conversation_id": self.current_conversation_id, + "basic_stats": stats, + "message_type_distribution": { + row[0]: row[1] for row in type_dist + }, + "average_tokens_per_message": ( + avg_tokens[0] if avg_tokens[0] is not None else 0 + ), + "hourly_message_frequency": { + row[0]: row[1] for row in hourly_freq + }, + "role_distribution": self.count_messages_by_role(), + } + + def save_as_yaml(self, filename: str) -> bool: + """ + Save the current conversation to a YAML file. + + Args: + filename (str): Path to save the YAML file + + Returns: + bool: True if save was successful + """ + try: + with open(filename, "w") as f: + yaml.dump(self.to_dict(), f) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to save conversation to YAML: {e}" + ) + return False + + def load_from_yaml(self, filename: str) -> bool: + """ + Load a conversation from a YAML file. + + Args: + filename (str): Path to the YAML file + + Returns: + bool: True if load was successful + """ + try: + with open(filename, "r") as f: + messages = yaml.safe_load(f) + + # Start a new conversation + self.start_new_conversation() + + # Add all messages + for message in messages: + self.add( + role=message["role"], + content=message["content"], + message_type=( + MessageType(message["message_type"]) + if "message_type" in message + else None + ), + metadata=message.get("metadata"), + token_count=message.get("token_count"), + ) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to load conversation from YAML: {e}" + ) + return False diff --git a/swarms/communication/sqlite_wrap.py b/swarms/communication/sqlite_wrap.py new file mode 100644 index 00000000..4e39a22a --- /dev/null +++ b/swarms/communication/sqlite_wrap.py @@ -0,0 +1,813 @@ +import sqlite3 +import json +import datetime +from typing import List, Optional, Union, Dict +from pathlib import Path +import threading +from contextlib import contextmanager +import logging +from dataclasses import dataclass +from enum import Enum +import uuid +import yaml + +try: + from loguru import logger + + LOGURU_AVAILABLE = True +except ImportError: + LOGURU_AVAILABLE = False + + +class MessageType(Enum): + """Enum for different types of messages in the conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + +@dataclass +class Message: + """Data class representing a message in the conversation.""" + + role: str + content: Union[str, dict, list] + timestamp: Optional[str] = None + message_type: Optional[MessageType] = None + metadata: Optional[Dict] = None + token_count: Optional[int] = None + + class Config: + arbitrary_types_allowed = True + + +class SQLiteConversation: + """ + A production-grade SQLite wrapper class for managing conversation history. + This class provides persistent storage for conversations with various features + like message tracking, timestamps, and metadata support. + + Attributes: + db_path (str): Path to the SQLite database file + table_name (str): Name of the table to store conversations + enable_timestamps (bool): Whether to track message timestamps + enable_logging (bool): Whether to enable logging + use_loguru (bool): Whether to use loguru for logging + max_retries (int): Maximum number of retries for database operations + connection_timeout (float): Timeout for database connections + current_conversation_id (str): Current active conversation ID + """ + + def __init__( + self, + db_path: str = "conversations.db", + table_name: str = "conversations", + enable_timestamps: bool = True, + enable_logging: bool = True, + use_loguru: bool = True, + max_retries: int = 3, + connection_timeout: float = 5.0, + **kwargs, + ): + """ + Initialize the SQLite conversation manager. + + Args: + db_path (str): Path to the SQLite database file + table_name (str): Name of the table to store conversations + enable_timestamps (bool): Whether to track message timestamps + enable_logging (bool): Whether to enable logging + use_loguru (bool): Whether to use loguru for logging + max_retries (int): Maximum number of retries for database operations + connection_timeout (float): Timeout for database connections + """ + self.db_path = Path(db_path) + self.table_name = table_name + self.enable_timestamps = enable_timestamps + self.enable_logging = enable_logging + self.use_loguru = use_loguru and LOGURU_AVAILABLE + self.max_retries = max_retries + self.connection_timeout = connection_timeout + self._lock = threading.Lock() + self.current_conversation_id = ( + self._generate_conversation_id() + ) + + # Setup logging + if self.enable_logging: + if self.use_loguru: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + + # Initialize database + self._init_db() + + def _generate_conversation_id(self) -> str: + """Generate a unique conversation ID using UUID and timestamp.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = str(uuid.uuid4())[:8] + return f"conv_{timestamp}_{unique_id}" + + def start_new_conversation(self) -> str: + """ + Start a new conversation and return its ID. + + Returns: + str: The new conversation ID + """ + self.current_conversation_id = ( + self._generate_conversation_id() + ) + return self.current_conversation_id + + def _init_db(self): + """Initialize the database and create necessary tables.""" + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT, + message_type TEXT, + metadata TEXT, + token_count INTEGER, + conversation_id TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.commit() + + @contextmanager + def _get_connection(self): + """Context manager for database connections with retry logic.""" + conn = None + for attempt in range(self.max_retries): + try: + conn = sqlite3.connect( + str(self.db_path), timeout=self.connection_timeout + ) + conn.row_factory = sqlite3.Row + yield conn + break + except sqlite3.Error as e: + if attempt == self.max_retries - 1: + raise + if self.enable_logging: + self.logger.warning( + f"Database connection attempt {attempt + 1} failed: {e}" + ) + finally: + if conn: + conn.close() + + def add( + self, + role: str, + content: Union[str, dict, list], + message_type: Optional[MessageType] = None, + metadata: Optional[Dict] = None, + token_count: Optional[int] = None, + ) -> int: + """ + Add a message to the current conversation. + + Args: + role (str): The role of the speaker + content (Union[str, dict, list]): The content of the message + message_type (Optional[MessageType]): Type of the message + metadata (Optional[Dict]): Additional metadata for the message + token_count (Optional[int]): Number of tokens in the message + + Returns: + int: The ID of the inserted message + """ + timestamp = ( + datetime.datetime.now().isoformat() + if self.enable_timestamps + else None + ) + + if isinstance(content, (dict, list)): + content = json.dumps(content) + + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + INSERT INTO {self.table_name} + (role, content, timestamp, message_type, metadata, token_count, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + role, + content, + timestamp, + message_type.value if message_type else None, + json.dumps(metadata) if metadata else None, + token_count, + self.current_conversation_id, + ), + ) + conn.commit() + return cursor.lastrowid + + def batch_add(self, messages: List[Message]) -> List[int]: + """ + Add multiple messages to the current conversation. + + Args: + messages (List[Message]): List of messages to add + + Returns: + List[int]: List of inserted message IDs + """ + with self._get_connection() as conn: + cursor = conn.cursor() + message_ids = [] + + for message in messages: + content = message.content + if isinstance(content, (dict, list)): + content = json.dumps(content) + + cursor.execute( + f""" + INSERT INTO {self.table_name} + (role, content, timestamp, message_type, metadata, token_count, conversation_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + message.role, + content, + ( + message.timestamp.isoformat() + if message.timestamp + else None + ), + ( + message.message_type.value + if message.message_type + else None + ), + ( + json.dumps(message.metadata) + if message.metadata + else None + ), + message.token_count, + self.current_conversation_id, + ), + ) + message_ids.append(cursor.lastrowid) + + conn.commit() + return message_ids + + def get_str(self) -> str: + """ + Get the current conversation history as a formatted string. + + Returns: + str: Formatted conversation history + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """, + (self.current_conversation_id,), + ) + + messages = [] + for row in cursor.fetchall(): + content = row["content"] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + timestamp = ( + f"[{row['timestamp']}] " + if row["timestamp"] + else "" + ) + messages.append( + f"{timestamp}{row['role']}: {content}" + ) + + return "\n".join(messages) + + def get_messages( + self, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[Dict]: + """ + Get messages from the current conversation with optional pagination. + + Args: + limit (Optional[int]): Maximum number of messages to return + offset (Optional[int]): Number of messages to skip + + Returns: + List[Dict]: List of message dictionaries + """ + with self._get_connection() as conn: + cursor = conn.cursor() + query = f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """ + params = [self.current_conversation_id] + + if limit is not None: + query += " LIMIT ?" + params.append(limit) + + if offset is not None: + query += " OFFSET ?" + params.append(offset) + + cursor.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + + def delete_current_conversation(self) -> bool: + """ + Delete the current conversation. + + Returns: + bool: True if deletion was successful + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f"DELETE FROM {self.table_name} WHERE conversation_id = ?", + (self.current_conversation_id,), + ) + conn.commit() + return cursor.rowcount > 0 + + def update_message( + self, + message_id: int, + content: Union[str, dict, list], + metadata: Optional[Dict] = None, + ) -> bool: + """ + Update an existing message in the current conversation. + + Args: + message_id (int): ID of the message to update + content (Union[str, dict, list]): New content for the message + metadata (Optional[Dict]): New metadata for the message + + Returns: + bool: True if update was successful + """ + if isinstance(content, (dict, list)): + content = json.dumps(content) + + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + UPDATE {self.table_name} + SET content = ?, metadata = ? + WHERE id = ? AND conversation_id = ? + """, + ( + content, + json.dumps(metadata) if metadata else None, + message_id, + self.current_conversation_id, + ), + ) + conn.commit() + return cursor.rowcount > 0 + + def search_messages(self, query: str) -> List[Dict]: + """ + Search for messages containing specific text in the current conversation. + + Args: + query (str): Text to search for + + Returns: + List[Dict]: List of matching messages + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT * FROM {self.table_name} + WHERE conversation_id = ? AND content LIKE ? + """, + (self.current_conversation_id, f"%{query}%"), + ) + return [dict(row) for row in cursor.fetchall()] + + def get_statistics(self) -> Dict: + """ + Get statistics about the current conversation. + + Returns: + Dict: Statistics about the conversation + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT + COUNT(*) as total_messages, + COUNT(DISTINCT role) as unique_roles, + SUM(token_count) as total_tokens, + MIN(timestamp) as first_message, + MAX(timestamp) as last_message + FROM {self.table_name} + WHERE conversation_id = ? + """, + (self.current_conversation_id,), + ) + return dict(cursor.fetchone()) + + def clear_all(self) -> bool: + """ + Clear all messages from the database. + + Returns: + bool: True if clearing was successful + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute(f"DELETE FROM {self.table_name}") + conn.commit() + return True + + def get_conversation_id(self) -> str: + """ + Get the current conversation ID. + + Returns: + str: The current conversation ID + """ + return self.current_conversation_id + + def to_dict(self) -> List[Dict]: + """ + Convert the current conversation to a list of dictionaries. + + Returns: + List[Dict]: List of message dictionaries + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT role, content, timestamp, message_type, metadata, token_count + FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id ASC + """, + (self.current_conversation_id,), + ) + + messages = [] + for row in cursor.fetchall(): + content = row["content"] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = {"role": row["role"], "content": content} + + if row["timestamp"]: + message["timestamp"] = row["timestamp"] + if row["message_type"]: + message["message_type"] = row["message_type"] + if row["metadata"]: + message["metadata"] = json.loads(row["metadata"]) + if row["token_count"]: + message["token_count"] = row["token_count"] + + messages.append(message) + + return messages + + def to_json(self) -> str: + """ + Convert the current conversation to a JSON string. + + Returns: + str: JSON string representation of the conversation + """ + return json.dumps(self.to_dict(), indent=2) + + def to_yaml(self) -> str: + """ + Convert the current conversation to a YAML string. + + Returns: + str: YAML string representation of the conversation + """ + return yaml.dump(self.to_dict()) + + def save_as_json(self, filename: str) -> bool: + """ + Save the current conversation to a JSON file. + + Args: + filename (str): Path to save the JSON file + + Returns: + bool: True if save was successful + """ + try: + with open(filename, "w") as f: + json.dump(self.to_dict(), f, indent=2) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to save conversation to JSON: {e}" + ) + return False + + def save_as_yaml(self, filename: str) -> bool: + """ + Save the current conversation to a YAML file. + + Args: + filename (str): Path to save the YAML file + + Returns: + bool: True if save was successful + """ + try: + with open(filename, "w") as f: + yaml.dump(self.to_dict(), f) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to save conversation to YAML: {e}" + ) + return False + + def load_from_json(self, filename: str) -> bool: + """ + Load a conversation from a JSON file. + + Args: + filename (str): Path to the JSON file + + Returns: + bool: True if load was successful + """ + try: + with open(filename, "r") as f: + messages = json.load(f) + + # Start a new conversation + self.start_new_conversation() + + # Add all messages + for message in messages: + self.add( + role=message["role"], + content=message["content"], + message_type=( + MessageType(message["message_type"]) + if "message_type" in message + else None + ), + metadata=message.get("metadata"), + token_count=message.get("token_count"), + ) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to load conversation from JSON: {e}" + ) + return False + + def load_from_yaml(self, filename: str) -> bool: + """ + Load a conversation from a YAML file. + + Args: + filename (str): Path to the YAML file + + Returns: + bool: True if load was successful + """ + try: + with open(filename, "r") as f: + messages = yaml.safe_load(f) + + # Start a new conversation + self.start_new_conversation() + + # Add all messages + for message in messages: + self.add( + role=message["role"], + content=message["content"], + message_type=( + MessageType(message["message_type"]) + if "message_type" in message + else None + ), + metadata=message.get("metadata"), + token_count=message.get("token_count"), + ) + return True + except Exception as e: + if self.enable_logging: + self.logger.error( + f"Failed to load conversation from YAML: {e}" + ) + return False + + def get_last_message(self) -> Optional[Dict]: + """ + Get the last message from the current conversation. + + Returns: + Optional[Dict]: The last message or None if conversation is empty + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT role, content, timestamp, message_type, metadata, token_count + FROM {self.table_name} + WHERE conversation_id = ? + ORDER BY id DESC + LIMIT 1 + """, + (self.current_conversation_id,), + ) + + row = cursor.fetchone() + if not row: + return None + + content = row["content"] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = {"role": row["role"], "content": content} + + if row["timestamp"]: + message["timestamp"] = row["timestamp"] + if row["message_type"]: + message["message_type"] = row["message_type"] + if row["metadata"]: + message["metadata"] = json.loads(row["metadata"]) + if row["token_count"]: + message["token_count"] = row["token_count"] + + return message + + def get_last_message_as_string(self) -> str: + """ + Get the last message as a formatted string. + + Returns: + str: Formatted string of the last message + """ + last_message = self.get_last_message() + if not last_message: + return "" + + timestamp = ( + f"[{last_message['timestamp']}] " + if "timestamp" in last_message + else "" + ) + return f"{timestamp}{last_message['role']}: {last_message['content']}" + + def count_messages_by_role(self) -> Dict[str, int]: + """ + Count messages by role in the current conversation. + + Returns: + Dict[str, int]: Dictionary with role counts + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT role, COUNT(*) as count + FROM {self.table_name} + WHERE conversation_id = ? + GROUP BY role + """, + (self.current_conversation_id,), + ) + + return { + row["role"]: row["count"] for row in cursor.fetchall() + } + + def get_messages_by_role(self, role: str) -> List[Dict]: + """ + Get all messages from a specific role in the current conversation. + + Args: + role (str): Role to filter messages by + + Returns: + List[Dict]: List of messages from the specified role + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT role, content, timestamp, message_type, metadata, token_count + FROM {self.table_name} + WHERE conversation_id = ? AND role = ? + ORDER BY id ASC + """, + (self.current_conversation_id, role), + ) + + messages = [] + for row in cursor.fetchall(): + content = row["content"] + try: + content = json.loads(content) + except json.JSONDecodeError: + pass + + message = {"role": row["role"], "content": content} + + if row["timestamp"]: + message["timestamp"] = row["timestamp"] + if row["message_type"]: + message["message_type"] = row["message_type"] + if row["metadata"]: + message["metadata"] = json.loads(row["metadata"]) + if row["token_count"]: + message["token_count"] = row["token_count"] + + messages.append(message) + + return messages + + def get_conversation_summary(self) -> Dict: + """ + Get a summary of the current conversation. + + Returns: + Dict: Summary of the conversation including message counts, roles, and time range + """ + with self._get_connection() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT + COUNT(*) as total_messages, + COUNT(DISTINCT role) as unique_roles, + MIN(timestamp) as first_message_time, + MAX(timestamp) as last_message_time, + SUM(token_count) as total_tokens + FROM {self.table_name} + WHERE conversation_id = ? + """, + (self.current_conversation_id,), + ) + + row = cursor.fetchone() + return { + "conversation_id": self.current_conversation_id, + "total_messages": row["total_messages"], + "unique_roles": row["unique_roles"], + "first_message_time": row["first_message_time"], + "last_message_time": row["last_message_time"], + "total_tokens": row["total_tokens"], + "roles": self.count_messages_by_role(), + } diff --git a/swarms/structs/hybrid_hiearchical_peer_swarm.py b/swarms/structs/hybrid_hiearchical_peer_swarm.py index 02562ef3..650b3024 100644 --- a/swarms/structs/hybrid_hiearchical_peer_swarm.py +++ b/swarms/structs/hybrid_hiearchical_peer_swarm.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import List, Literal from swarms.structs.agent import Agent from swarms.structs.conversation import Conversation from swarms.structs.multi_agent_exec import get_swarms_info @@ -10,6 +10,23 @@ from swarms.utils.history_output_formatter import ( from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Union, Callable + +HistoryOutputType = Literal[ + "list", + "dict", + "dictionary", + "string", + "str", + "final", + "last", + "json", + "all", + "yaml", + # "dict-final", + "dict-all-except-first", + "str-all-except-first", +] + tools = [ { "type": "function", @@ -105,7 +122,7 @@ class HybridHierarchicalClusterSwarm: description: str = "A swarm that uses a hybrid hierarchical-peer model to solve complex tasks.", swarms: List[Union[SwarmRouter, Callable]] = [], max_loops: int = 1, - output_type: str = "list", + output_type: HistoryOutputType = "list", router_agent_model_name: str = "gpt-4o-mini", *args, **kwargs, diff --git a/swarms/structs/swarm_arange.py b/swarms/structs/swarm_arange.py index 8446abe0..7968c172 100644 --- a/swarms/structs/swarm_arange.py +++ b/swarms/structs/swarm_arange.py @@ -6,7 +6,7 @@ from swarms.utils.any_to_str import any_to_str from swarms.utils.loguru_logger import initialize_logger from swarms.structs.conversation import Conversation from swarms.utils.history_output_formatter import ( - output_type, + HistoryOutputType, ) logger = initialize_logger(log_folder="swarm_arange") @@ -58,7 +58,7 @@ class SwarmRearrange: Callable[[str], str] ] = None, return_json: bool = False, - output_type: output_type = "dict-all-except-first", + output_type: HistoryOutputType = "dict-all-except-first", *args, **kwargs, ): diff --git a/tests/communication/__init__.py b/tests/communication/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/communication/test_duckdb_conversation.py b/tests/communication/test_duckdb_conversation.py new file mode 100644 index 00000000..9494de43 --- /dev/null +++ b/tests/communication/test_duckdb_conversation.py @@ -0,0 +1,317 @@ +import os +import json +from datetime import datetime +from pathlib import Path +import tempfile +import threading +from swarms.communication.duckdb_wrap import ( + DuckDBConversation, + Message, + MessageType, +) + +def setup_test(): + """Set up test environment.""" + temp_dir = tempfile.TemporaryDirectory() + db_path = Path(temp_dir.name) / "test_conversations.duckdb" + conversation = DuckDBConversation( + db_path=str(db_path), + enable_timestamps=True, + enable_logging=True, + ) + return temp_dir, db_path, conversation + +def cleanup_test(temp_dir, db_path): + """Clean up test environment.""" + if os.path.exists(db_path): + os.remove(db_path) + temp_dir.cleanup() + +def test_initialization(): + """Test conversation initialization.""" + temp_dir, db_path, _ = setup_test() + try: + conv = DuckDBConversation(db_path=str(db_path)) + assert conv.db_path == db_path, "Database path mismatch" + assert conv.table_name == "conversations", "Table name mismatch" + assert conv.enable_timestamps is True, "Timestamps should be enabled" + assert conv.current_conversation_id is not None, "Conversation ID should not be None" + print("✓ Initialization test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_add_message(): + """Test adding a single message.""" + temp_dir, db_path, conversation = setup_test() + try: + msg_id = conversation.add( + role="user", + content="Hello, world!", + message_type=MessageType.USER, + ) + assert msg_id is not None, "Message ID should not be None" + assert isinstance(msg_id, int), "Message ID should be an integer" + print("✓ Add message test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_add_complex_message(): + """Test adding a message with complex content.""" + temp_dir, db_path, conversation = setup_test() + try: + complex_content = { + "text": "Hello", + "data": [1, 2, 3], + "nested": {"key": "value"} + } + msg_id = conversation.add( + role="assistant", + content=complex_content, + message_type=MessageType.ASSISTANT, + metadata={"source": "test"}, + token_count=10 + ) + assert msg_id is not None, "Message ID should not be None" + print("✓ Add complex message test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_batch_add(): + """Test batch adding messages.""" + temp_dir, db_path, conversation = setup_test() + try: + messages = [ + Message( + role="user", + content="First message", + message_type=MessageType.USER + ), + Message( + role="assistant", + content="Second message", + message_type=MessageType.ASSISTANT + ) + ] + msg_ids = conversation.batch_add(messages) + assert len(msg_ids) == 2, "Should have 2 message IDs" + assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers" + print("✓ Batch add test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_get_str(): + """Test getting conversation as string.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello") + conversation.add("assistant", "Hi there!") + conv_str = conversation.get_str() + assert "user: Hello" in conv_str, "User message not found" + assert "assistant: Hi there!" in conv_str, "Assistant message not found" + print("✓ Get string test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_get_messages(): + """Test getting messages with pagination.""" + temp_dir, db_path, conversation = setup_test() + try: + for i in range(5): + conversation.add("user", f"Message {i}") + + all_messages = conversation.get_messages() + assert len(all_messages) == 5, "Should have 5 messages" + + limited_messages = conversation.get_messages(limit=2) + assert len(limited_messages) == 2, "Should have 2 limited messages" + + offset_messages = conversation.get_messages(offset=2) + assert len(offset_messages) == 3, "Should have 3 offset messages" + print("✓ Get messages test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_search_messages(): + """Test searching messages.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello world") + conversation.add("assistant", "Hello there") + conversation.add("user", "Goodbye world") + + results = conversation.search_messages("world") + assert len(results) == 2, "Should find 2 messages with 'world'" + assert all("world" in msg["content"] for msg in results), "All results should contain 'world'" + print("✓ Search messages test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_get_statistics(): + """Test getting conversation statistics.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello", token_count=2) + conversation.add("assistant", "Hi", token_count=1) + + stats = conversation.get_statistics() + assert stats["total_messages"] == 2, "Should have 2 total messages" + assert stats["unique_roles"] == 2, "Should have 2 unique roles" + assert stats["total_tokens"] == 3, "Should have 3 total tokens" + print("✓ Get statistics test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_json_operations(): + """Test JSON save and load operations.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello") + conversation.add("assistant", "Hi") + + json_path = Path(temp_dir.name) / "test_conversation.json" + conversation.save_as_json(str(json_path)) + assert json_path.exists(), "JSON file should exist" + + new_conversation = DuckDBConversation( + db_path=str(Path(temp_dir.name) / "new.duckdb") + ) + assert new_conversation.load_from_json(str(json_path)), "Should load from JSON" + assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" + print("✓ JSON operations test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_yaml_operations(): + """Test YAML save and load operations.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello") + conversation.add("assistant", "Hi") + + yaml_path = Path(temp_dir.name) / "test_conversation.yaml" + conversation.save_as_yaml(str(yaml_path)) + assert yaml_path.exists(), "YAML file should exist" + + new_conversation = DuckDBConversation( + db_path=str(Path(temp_dir.name) / "new.duckdb") + ) + assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML" + assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load" + print("✓ YAML operations test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_message_types(): + """Test different message types.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("system", "System message", message_type=MessageType.SYSTEM) + conversation.add("user", "User message", message_type=MessageType.USER) + conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT) + conversation.add("function", "Function message", message_type=MessageType.FUNCTION) + conversation.add("tool", "Tool message", message_type=MessageType.TOOL) + + messages = conversation.get_messages() + assert len(messages) == 5, "Should have 5 messages" + assert all("message_type" in msg for msg in messages), "All messages should have type" + print("✓ Message types test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_delete_operations(): + """Test deletion operations.""" + temp_dir, db_path, conversation = setup_test() + try: + conversation.add("user", "Hello") + conversation.add("assistant", "Hi") + + assert conversation.delete_current_conversation(), "Should delete conversation" + assert len(conversation.get_messages()) == 0, "Should have no messages after delete" + + conversation.add("user", "New message") + assert conversation.clear_all(), "Should clear all messages" + assert len(conversation.get_messages()) == 0, "Should have no messages after clear" + print("✓ Delete operations test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_concurrent_operations(): + """Test concurrent operations.""" + temp_dir, db_path, conversation = setup_test() + try: + def add_messages(): + for i in range(10): + conversation.add("user", f"Message {i}") + + threads = [threading.Thread(target=add_messages) for _ in range(5)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + messages = conversation.get_messages() + assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)" + print("✓ Concurrent operations test passed") + finally: + cleanup_test(temp_dir, db_path) + +def test_error_handling(): + """Test error handling.""" + temp_dir, db_path, conversation = setup_test() + try: + # Test invalid message type + try: + conversation.add("user", "Message", message_type="invalid") + assert False, "Should raise exception for invalid message type" + except Exception: + pass + + # Test invalid JSON content + try: + conversation.add("user", {"invalid": object()}) + assert False, "Should raise exception for invalid JSON content" + except Exception: + pass + + # Test invalid file operations + try: + conversation.load_from_json("/nonexistent/path.json") + assert False, "Should raise exception for invalid file path" + except Exception: + pass + + print("✓ Error handling test passed") + finally: + cleanup_test(temp_dir, db_path) + +def run_all_tests(): + """Run all tests.""" + print("Running DuckDB Conversation tests...") + tests = [ + test_initialization, + test_add_message, + test_add_complex_message, + test_batch_add, + test_get_str, + test_get_messages, + test_search_messages, + test_get_statistics, + test_json_operations, + test_yaml_operations, + test_message_types, + test_delete_operations, + test_concurrent_operations, + test_error_handling + ] + + for test in tests: + try: + test() + except Exception as e: + print(f"✗ {test.__name__} failed: {str(e)}") + raise + + print("\nAll tests completed successfully!") + +if __name__ == '__main__': + run_all_tests() \ No newline at end of file diff --git a/tests/communication/test_sqlite_wrapper.py b/tests/communication/test_sqlite_wrapper.py new file mode 100644 index 00000000..2a42ce76 --- /dev/null +++ b/tests/communication/test_sqlite_wrapper.py @@ -0,0 +1,386 @@ +import json +import datetime +import os +from typing import Dict, List, Any, Tuple +from loguru import logger +from swarms.communication.sqlite_wrap import ( + SQLiteConversation, + Message, + MessageType, +) +from rich.console import Console +from rich.table import Table +from rich.panel import Panel + +console = Console() + + +def print_test_header(test_name: str) -> None: + """Print a formatted test header.""" + console.print( + Panel( + f"[bold blue]Running Test: {test_name}[/bold blue]", + expand=False, + ) + ) + + +def print_test_result( + test_name: str, success: bool, message: str, execution_time: float +) -> None: + """Print a formatted test result.""" + status = ( + "[bold green]PASSED[/bold green]" + if success + else "[bold red]FAILED[/bold red]" + ) + console.print(f"\n{status} - {test_name}") + console.print(f"Message: {message}") + console.print(f"Execution time: {execution_time:.3f} seconds\n") + + +def print_messages( + messages: List[Dict], title: str = "Messages" +) -> None: + """Print messages in a formatted table.""" + table = Table(title=title) + table.add_column("Role", style="cyan") + table.add_column("Content", style="green") + table.add_column("Type", style="yellow") + table.add_column("Timestamp", style="magenta") + + for msg in messages: + content = str(msg.get("content", "")) + if isinstance(content, (dict, list)): + content = json.dumps(content) + table.add_row( + msg.get("role", ""), + content, + str(msg.get("message_type", "")), + str(msg.get("timestamp", "")), + ) + + console.print(table) + + +def run_test( + test_func: callable, *args, **kwargs +) -> Tuple[bool, str, float]: + """ + Run a test function and return its results. + + Args: + test_func: The test function to run + *args: Arguments for the test function + **kwargs: Keyword arguments for the test function + + Returns: + Tuple[bool, str, float]: (success, message, execution_time) + """ + start_time = datetime.datetime.now() + try: + result = test_func(*args, **kwargs) + end_time = datetime.datetime.now() + execution_time = (end_time - start_time).total_seconds() + return True, str(result), execution_time + except Exception as e: + end_time = datetime.datetime.now() + execution_time = (end_time - start_time).total_seconds() + return False, str(e), execution_time + + +def test_basic_conversation() -> bool: + """Test basic conversation operations.""" + print_test_header("Basic Conversation Test") + + db_path = "test_conversations.db" + conversation = SQLiteConversation(db_path=db_path) + + # Test adding messages + console.print("\n[bold]Adding messages...[/bold]") + msg_id1 = conversation.add("user", "Hello") + msg_id2 = conversation.add("assistant", "Hi there!") + + # Test getting messages + console.print("\n[bold]Retrieved messages:[/bold]") + messages = conversation.get_messages() + print_messages(messages) + + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" + + # Cleanup + os.remove(db_path) + return True + + +def test_message_types() -> bool: + """Test different message types and content formats.""" + print_test_header("Message Types Test") + + db_path = "test_conversations.db" + conversation = SQLiteConversation(db_path=db_path) + + # Test different content types + console.print("\n[bold]Adding different message types...[/bold]") + conversation.add("user", "Simple text") + conversation.add( + "assistant", {"type": "json", "content": "Complex data"} + ) + conversation.add("system", ["list", "of", "items"]) + conversation.add( + "function", + "Function result", + message_type=MessageType.FUNCTION, + ) + + console.print("\n[bold]Retrieved messages:[/bold]") + messages = conversation.get_messages() + print_messages(messages) + + assert len(messages) == 4 + + # Cleanup + os.remove(db_path) + return True + + +def test_conversation_operations() -> bool: + """Test various conversation operations.""" + print_test_header("Conversation Operations Test") + + db_path = "test_conversations.db" + conversation = SQLiteConversation(db_path=db_path) + + # Test batch operations + console.print("\n[bold]Adding batch messages...[/bold]") + messages = [ + Message(role="user", content="Message 1"), + Message(role="assistant", content="Message 2"), + Message(role="user", content="Message 3"), + ] + conversation.batch_add(messages) + + console.print("\n[bold]Retrieved messages:[/bold]") + all_messages = conversation.get_messages() + print_messages(all_messages) + + # Test statistics + console.print("\n[bold]Conversation Statistics:[/bold]") + stats = conversation.get_statistics() + console.print(json.dumps(stats, indent=2)) + + # Test role counting + console.print("\n[bold]Role Counts:[/bold]") + role_counts = conversation.count_messages_by_role() + console.print(json.dumps(role_counts, indent=2)) + + assert stats["total_messages"] == 3 + assert role_counts["user"] == 2 + assert role_counts["assistant"] == 1 + + # Cleanup + os.remove(db_path) + return True + + +def test_file_operations() -> bool: + """Test file operations (JSON/YAML).""" + print_test_header("File Operations Test") + + db_path = "test_conversations.db" + json_path = "test_conversation.json" + yaml_path = "test_conversation.yaml" + + conversation = SQLiteConversation(db_path=db_path) + conversation.add("user", "Test message") + + # Test JSON operations + console.print("\n[bold]Testing JSON operations...[/bold]") + assert conversation.save_as_json(json_path) + console.print(f"Saved to JSON: {json_path}") + + conversation.start_new_conversation() + assert conversation.load_from_json(json_path) + console.print("Loaded from JSON") + + # Test YAML operations + console.print("\n[bold]Testing YAML operations...[/bold]") + assert conversation.save_as_yaml(yaml_path) + console.print(f"Saved to YAML: {yaml_path}") + + conversation.start_new_conversation() + assert conversation.load_from_yaml(yaml_path) + console.print("Loaded from YAML") + + # Cleanup + os.remove(db_path) + os.remove(json_path) + os.remove(yaml_path) + return True + + +def test_search_and_filter() -> bool: + """Test search and filter operations.""" + print_test_header("Search and Filter Test") + + db_path = "test_conversations.db" + conversation = SQLiteConversation(db_path=db_path) + + # Add test messages + console.print("\n[bold]Adding test messages...[/bold]") + conversation.add("user", "Hello world") + conversation.add("assistant", "Hello there") + conversation.add("user", "Goodbye world") + + # Test search + console.print("\n[bold]Searching for 'world'...[/bold]") + results = conversation.search_messages("world") + print_messages(results, "Search Results") + + # Test role filtering + console.print("\n[bold]Filtering user messages...[/bold]") + user_messages = conversation.get_messages_by_role("user") + print_messages(user_messages, "User Messages") + + assert len(results) == 2 + assert len(user_messages) == 2 + + # Cleanup + os.remove(db_path) + return True + + +def test_conversation_management() -> bool: + """Test conversation management features.""" + print_test_header("Conversation Management Test") + + db_path = "test_conversations.db" + conversation = SQLiteConversation(db_path=db_path) + + # Test conversation ID generation + console.print("\n[bold]Testing conversation IDs...[/bold]") + conv_id1 = conversation.get_conversation_id() + console.print(f"First conversation ID: {conv_id1}") + + conversation.start_new_conversation() + conv_id2 = conversation.get_conversation_id() + console.print(f"Second conversation ID: {conv_id2}") + + assert conv_id1 != conv_id2 + + # Test conversation deletion + console.print("\n[bold]Testing conversation deletion...[/bold]") + conversation.add("user", "Test message") + assert conversation.delete_current_conversation() + console.print("Conversation deleted successfully") + + # Cleanup + os.remove(db_path) + return True + + +def generate_test_report( + test_results: List[Dict[str, Any]] +) -> Dict[str, Any]: + """ + Generate a test report in JSON format. + + Args: + test_results: List of test results + + Returns: + Dict containing the test report + """ + total_tests = len(test_results) + passed_tests = sum( + 1 for result in test_results if result["success"] + ) + failed_tests = total_tests - passed_tests + total_time = sum( + result["execution_time"] for result in test_results + ) + + report = { + "timestamp": datetime.datetime.now().isoformat(), + "summary": { + "total_tests": total_tests, + "passed_tests": passed_tests, + "failed_tests": failed_tests, + "total_execution_time": total_time, + "average_execution_time": ( + total_time / total_tests if total_tests > 0 else 0 + ), + }, + "test_results": test_results, + } + + return report + + +def run_all_tests() -> None: + """Run all tests and generate a report.""" + console.print( + Panel( + "[bold blue]Starting Test Suite[/bold blue]", expand=False + ) + ) + + tests = [ + ("Basic Conversation", test_basic_conversation), + ("Message Types", test_message_types), + ("Conversation Operations", test_conversation_operations), + ("File Operations", test_file_operations), + ("Search and Filter", test_search_and_filter), + ("Conversation Management", test_conversation_management), + ] + + test_results = [] + + for test_name, test_func in tests: + logger.info(f"Running test: {test_name}") + success, message, execution_time = run_test(test_func) + + print_test_result(test_name, success, message, execution_time) + + result = { + "test_name": test_name, + "success": success, + "message": message, + "execution_time": execution_time, + "timestamp": datetime.datetime.now().isoformat(), + } + + if success: + logger.success(f"Test passed: {test_name}") + else: + logger.error(f"Test failed: {test_name} - {message}") + + test_results.append(result) + + # Generate and save report + report = generate_test_report(test_results) + report_path = "test_report.json" + + with open(report_path, "w") as f: + json.dump(report, f, indent=2) + + # Print final summary + console.print("\n[bold blue]Test Suite Summary[/bold blue]") + console.print( + Panel( + f"Total tests: {report['summary']['total_tests']}\n" + f"Passed tests: {report['summary']['passed_tests']}\n" + f"Failed tests: {report['summary']['failed_tests']}\n" + f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds", + title="Summary", + expand=False, + ) + ) + + logger.info(f"Test report saved to {report_path}") + + +if __name__ == "__main__": + run_all_tests()