diff --git a/example_supabase_usage.py b/example_supabase_usage.py new file mode 100644 index 00000000..0f8989af --- /dev/null +++ b/example_supabase_usage.py @@ -0,0 +1,212 @@ +""" +Example usage of the SupabaseConversation class for the Swarms Framework. + +This example demonstrates how to: +1. Initialize a SupabaseConversation with automatic table creation +2. Add messages of different types +3. Query and search messages +4. Export/import conversations +5. Get conversation statistics + +Prerequisites: +1. Install supabase-py: pip install supabase +2. Set up a Supabase project with valid URL and API key +3. Set environment variables (table will be created automatically) + +Automatic Table Creation: +The SupabaseConversation will automatically create the required table if it doesn't exist. +For optimal results, you can optionally create this RPC function in your Supabase SQL Editor: + +CREATE OR REPLACE FUNCTION exec_sql(sql TEXT) +RETURNS TEXT AS $$ +BEGIN + EXECUTE sql; + RETURN 'SUCCESS'; +END; +$$ LANGUAGE plpgsql SECURITY DEFINER; + +Environment Variables: + - SUPABASE_URL: Your Supabase project URL + - SUPABASE_KEY: Your Supabase anon/service key +""" + +import os +import json +from swarms.communication.supabase_wrap import ( + SupabaseConversation, + MessageType, + SupabaseOperationError, + SupabaseConnectionError +) +from swarms.communication.base_communication import Message + + +def main(): + # Load environment variables + supabase_url = os.getenv("SUPABASE_URL") + supabase_key = os.getenv("SUPABASE_KEY") + + if not supabase_url or not supabase_key: + print("Error: SUPABASE_URL and SUPABASE_KEY environment variables must be set.") + print("Please create a .env file with these values or set them in your environment.") + return + + try: + # Initialize SupabaseConversation + print("šŸš€ Initializing SupabaseConversation with automatic table creation...") + conversation = SupabaseConversation( + supabase_url=supabase_url, + supabase_key=supabase_key, + system_prompt="You are a helpful AI assistant.", + time_enabled=True, + enable_logging=True, + table_name="conversations", + ) + + print(f"āœ… Successfully initialized! Conversation ID: {conversation.get_conversation_id()}") + print("šŸ“‹ Table created automatically if it didn't exist!") + + # Add various types of messages + print("\nšŸ“ Adding messages...") + + # Add user message + user_msg_id = conversation.add( + role="user", + content="Hello! Can you help me understand Supabase?", + message_type=MessageType.USER, + metadata={"source": "example_script", "priority": "high"} + ) + print(f"Added user message (ID: {user_msg_id})") + + # Add assistant message with complex content + assistant_content = { + "response": "Of course! Supabase is an open-source Firebase alternative with a PostgreSQL database.", + "confidence": 0.95, + "topics": ["database", "backend", "realtime"] + } + assistant_msg_id = conversation.add( + role="assistant", + content=assistant_content, + message_type=MessageType.ASSISTANT, + metadata={"model": "gpt-4", "tokens_used": 150} + ) + print(f"Added assistant message (ID: {assistant_msg_id})") + + # Add system message + system_msg_id = conversation.add( + role="system", + content="User is asking about Supabase features.", + message_type=MessageType.SYSTEM + ) + print(f"Added system message (ID: {system_msg_id})") + + # Batch add multiple messages + print("\nšŸ“¦ Batch adding messages...") + batch_messages = [ + Message( + role="user", + content="What are the main features of Supabase?", + message_type=MessageType.USER, + metadata={"follow_up": True} + ), + Message( + role="assistant", + content="Supabase provides: database, auth, realtime subscriptions, edge functions, and storage.", + message_type=MessageType.ASSISTANT, + metadata={"comprehensive": True} + ) + ] + batch_ids = conversation.batch_add(batch_messages) + print(f"Batch added {len(batch_ids)} messages: {batch_ids}") + + # Get conversation as string + print("\nšŸ’¬ Current conversation:") + print(conversation.get_str()) + + # Search for messages + print("\nšŸ” Searching for messages containing 'Supabase':") + search_results = conversation.search("Supabase") + for result in search_results: + print(f" - ID {result['id']}: {result['role']} - {result['content'][:50]}...") + + # Get conversation statistics + print("\nšŸ“Š Conversation statistics:") + stats = conversation.get_conversation_summary() + print(json.dumps(stats, indent=2, default=str)) + + # Get messages by role + print("\nšŸ‘¤ User messages:") + user_messages = conversation.get_messages_by_role("user") + for msg in user_messages: + print(f" - {msg['content']}") + + # Update a message + print(f"\nāœļø Updating message {user_msg_id}...") + conversation.update( + index=str(user_msg_id), + role="user", + content="Hello! Can you help me understand Supabase and its key features?" + ) + print("Message updated successfully!") + + # Query a specific message + print(f"\nšŸ”Ž Querying message {assistant_msg_id}:") + queried_msg = conversation.query(str(assistant_msg_id)) + if queried_msg: + print(f" Role: {queried_msg['role']}") + print(f" Content: {queried_msg['content']}") + print(f" Timestamp: {queried_msg['timestamp']}") + + # Export conversation + print("\nšŸ’¾ Exporting conversation...") + conversation.export_conversation("supabase_conversation_export.yaml") + print("Conversation exported to supabase_conversation_export.yaml") + + # Get conversation organized by role + print("\nšŸ“‹ Messages organized by role:") + by_role = conversation.get_conversation_by_role_dict() + for role, messages in by_role.items(): + print(f" {role}: {len(messages)} messages") + + # Get timeline + print("\nšŸ“… Conversation timeline:") + timeline = conversation.get_conversation_timeline_dict() + for date, messages in timeline.items(): + print(f" {date}: {len(messages)} messages") + + # Test delete (be careful with this in production!) + print(f"\nšŸ—‘ļø Deleting system message {system_msg_id}...") + conversation.delete(str(system_msg_id)) + print("System message deleted successfully!") + + # Final message count + final_stats = conversation.get_conversation_summary() + print(f"\nšŸ“ˆ Final conversation has {final_stats['total_messages']} messages") + + # Start a new conversation + print("\nšŸ†• Starting a new conversation...") + new_conv_id = conversation.start_new_conversation() + print(f"New conversation started with ID: {new_conv_id}") + + # Add a message to the new conversation + conversation.add( + role="user", + content="This is a new conversation!", + message_type=MessageType.USER + ) + print("Added message to new conversation") + + print("\nāœ… Example completed successfully!") + + except SupabaseConnectionError as e: + print(f"āŒ Connection error: {e}") + print("Please check your Supabase URL and key.") + except SupabaseOperationError as e: + print(f"āŒ Operation error: {e}") + print("Please check your database schema and permissions.") + except Exception as e: + print(f"āŒ Unexpected error: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/swarms/communication/__init__.py b/swarms/communication/__init__.py index e69de29b..fd9845b6 100644 --- a/swarms/communication/__init__.py +++ b/swarms/communication/__init__.py @@ -0,0 +1,42 @@ +from swarms.communication.base_communication import ( + BaseCommunication, + Message, + MessageType, +) +from swarms.communication.sqlite_wrap import SQLiteConversation +from swarms.communication.duckdb_wrap import DuckDBConversation + +try: + from swarms.communication.supabase_wrap import ( + SupabaseConversation, + SupabaseConnectionError, + SupabaseOperationError, + ) +except ImportError: + # Supabase dependencies might not be installed + SupabaseConversation = None + SupabaseConnectionError = None + SupabaseOperationError = None + +try: + from swarms.communication.redis_wrap import RedisConversation +except ImportError: + RedisConversation = None + +try: + from swarms.communication.pulsar_struct import PulsarConversation +except ImportError: + PulsarConversation = None + +__all__ = [ + "BaseCommunication", + "Message", + "MessageType", + "SQLiteConversation", + "DuckDBConversation", + "SupabaseConversation", + "SupabaseConnectionError", + "SupabaseOperationError", + "RedisConversation", + "PulsarConversation", +] diff --git a/swarms/communication/supabase_wrap.py b/swarms/communication/supabase_wrap.py new file mode 100644 index 00000000..b9b43d61 --- /dev/null +++ b/swarms/communication/supabase_wrap.py @@ -0,0 +1,1183 @@ +import datetime +import json +import logging +import threading +import uuid +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +import yaml + +try: + from supabase import Client, create_client + from postgrest import APIResponse, APIError as PostgrestAPIError + SUPABASE_AVAILABLE = True +except ImportError: + SUPABASE_AVAILABLE = False + Client = None + APIResponse = None + PostgrestAPIError = None + + +from swarms.communication.base_communication import ( + BaseCommunication, + Message, + MessageType, +) + +# Try to import loguru logger, fallback to standard logging +try: + from loguru import logger + LOGURU_AVAILABLE = True +except ImportError: + LOGURU_AVAILABLE = False + logger = None + +# Custom Exceptions for Supabase Communication +class SupabaseConnectionError(Exception): + """Custom exception for Supabase connection errors.""" + pass + +class SupabaseOperationError(Exception): + """Custom exception for Supabase operation errors.""" + pass + +class DateTimeEncoder(json.JSONEncoder): + """Custom JSON encoder for handling datetime objects.""" + def default(self, obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + return super().default(obj) + +class SupabaseConversation(BaseCommunication): + """ + A Supabase-backed implementation of the BaseCommunication class for managing + conversation history using a Supabase (PostgreSQL) database. + + Prerequisites: + - supabase-py library: pip install supabase + - Valid Supabase project URL and API key + - Network access to your Supabase instance + + Attributes: + supabase_url (str): URL of the Supabase project. + supabase_key (str): Anon or service key for the Supabase project. + client (supabase.Client): The Supabase client instance. + table_name (str): Name of the table in Supabase to store conversations. + current_conversation_id (Optional[str]): ID of the currently active conversation. + tokenizer (Any): Tokenizer for counting tokens in messages. + context_length (int): Maximum number of tokens for context window. + time_enabled (bool): Flag to prepend timestamps to messages. + enable_logging (bool): Flag to enable logging. + logger (logging.Logger | loguru.Logger): Logger instance. + """ + + def __init__( + self, + supabase_url: str, + supabase_key: str, + system_prompt: Optional[str] = None, + time_enabled: bool = False, + autosave: bool = False, # Less relevant for DB-backed, but kept for interface + save_filepath: str = None, # Used for export/import + tokenizer: Any = None, + context_length: int = 8192, + rules: str = None, + custom_rules_prompt: str = None, + user: str = "User:", + auto_save: bool = True, # Less relevant + save_as_yaml: bool = True, # Default export format + save_as_json_bool: bool = False, # Alternative export format + token_count: bool = True, + cache_enabled: bool = True, # Currently for token counting + table_name: str = "conversations", + enable_timestamps: bool = True, # DB schema handles this with DEFAULT NOW() + enable_logging: bool = True, + use_loguru: bool = True, + max_retries: int = 3, # For Supabase API calls (not implemented yet, supabase-py might handle) + *args, + **kwargs, + ): + if not SUPABASE_AVAILABLE: + raise ImportError( + "Supabase client library is not installed. Please install it using: pip install supabase" + ) + + # Store initialization parameters - BaseCommunication.__init__ is just pass + self.system_prompt = system_prompt + self.time_enabled = time_enabled + self.autosave = autosave + self.save_filepath = save_filepath + self.tokenizer = tokenizer + self.context_length = context_length + self.rules = rules + self.custom_rules_prompt = custom_rules_prompt + self.user = user + self.auto_save = auto_save # Actual auto-saving to file is less relevant + self.save_as_yaml_on_export = save_as_yaml + self.save_as_json_on_export = save_as_json_bool + self.calculate_token_count = token_count + self.cache_enabled = cache_enabled + + self.supabase_url = supabase_url + self.supabase_key = supabase_key + self.table_name = table_name + self.enable_timestamps = enable_timestamps # DB handles actual timestamping + self.enable_logging = enable_logging + self.use_loguru = use_loguru and LOGURU_AVAILABLE + self.max_retries = max_retries + + # Setup logging + if self.enable_logging: + if self.use_loguru and logger: + self.logger = logger + else: + self.logger = logging.getLogger(__name__) + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.INFO) + else: + # Create a null logger that does nothing + self.logger = logging.getLogger(__name__) + self.logger.addHandler(logging.NullHandler()) + + self.current_conversation_id: Optional[str] = None + self._lock = threading.Lock() # For thread-safe operations if any (e.g. token calculation) + + try: + self.client: Client = create_client(supabase_url, supabase_key) + if self.enable_logging: + self.logger.info(f"Successfully initialized Supabase client for URL: {supabase_url}") + except Exception as e: + if self.enable_logging: + self.logger.error(f"Failed to initialize Supabase client: {e}") + raise SupabaseConnectionError(f"Failed to connect to Supabase: {e}") + + self._init_db() # Verifies table existence + self.start_new_conversation() # Initializes a conversation ID + + # Add initial prompts if provided + if self.system_prompt: + self.add(role="system", content=self.system_prompt, message_type=MessageType.SYSTEM) + if self.rules: + # Assuming rules are spoken by the system or user based on context + self.add(role="system", content=self.rules, message_type=MessageType.SYSTEM) + if self.custom_rules_prompt: + self.add(role=self.user, content=self.custom_rules_prompt, message_type=MessageType.USER) + + def _init_db(self): + """ + Initialize the database and create necessary tables. + Creates the table if it doesn't exist, similar to SQLite implementation. + """ + # First, try to create the table if it doesn't exist + try: + # Use Supabase RPC to execute raw SQL for table creation + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id BIGSERIAL PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TIMESTAMPTZ DEFAULT NOW(), + message_type TEXT, + metadata JSONB, + token_count INTEGER, + created_at TIMESTAMPTZ DEFAULT NOW() + ); + """ + + # Try to create index as well + create_index_sql = f""" + CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id + ON {self.table_name} (conversation_id); + """ + + # Attempt to create table using RPC function + # Note: This requires a stored procedure to be created in Supabase + # If RPC is not available, we'll fall back to checking if table exists + try: + # Try using a custom RPC function if available + self.client.rpc('exec_sql', {'sql': create_table_sql}).execute() + if self.enable_logging: + self.logger.info(f"Successfully created or verified table '{self.table_name}' using RPC.") + except Exception as rpc_error: + if self.enable_logging: + self.logger.debug(f"RPC table creation failed (expected if no custom function): {rpc_error}") + + # Fallback: Try to verify table exists, if not provide helpful error + try: + response = self.client.table(self.table_name).select("id").limit(1).execute() + if response.error and "does not exist" in str(response.error).lower(): + # Table doesn't exist, try alternative creation method + self._create_table_fallback() + elif response.error: + raise SupabaseOperationError(f"Error accessing table: {response.error.message}") + else: + if self.enable_logging: + self.logger.info(f"Successfully verified existing table '{self.table_name}'.") + except Exception as table_check_error: + if "does not exist" in str(table_check_error).lower() or "relation" in str(table_check_error).lower(): + # Table definitely doesn't exist, provide creation instructions + self._handle_missing_table() + else: + raise SupabaseOperationError(f"Failed to access or create table: {table_check_error}") + + except Exception as e: + if self.enable_logging: + self.logger.error(f"Database initialization failed: {e}") + raise SupabaseOperationError(f"Failed to initialize database: {e}") + + def _create_table_fallback(self): + """ + Fallback method to create table when RPC is not available. + Attempts to use Supabase's admin API or provides clear instructions. + """ + try: + # Try using the admin API if available (requires service role key) + # This might work if the user is using a service role key + admin_sql = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + id BIGSERIAL PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TIMESTAMPTZ DEFAULT NOW(), + message_type TEXT, + metadata JSONB, + token_count INTEGER, + created_at TIMESTAMPTZ DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id + ON {self.table_name} (conversation_id); + """ + + # Note: This might not work with all Supabase configurations + # but we attempt it anyway + if hasattr(self.client, 'postgrest') and hasattr(self.client.postgrest, 'rpc'): + result = self.client.postgrest.rpc('exec_sql', {'query': admin_sql}).execute() + if self.enable_logging: + self.logger.info(f"Successfully created table '{self.table_name}' using admin API.") + return + except Exception as e: + if self.enable_logging: + self.logger.debug(f"Admin API table creation failed: {e}") + + # If all else fails, call the missing table handler + self._handle_missing_table() + + def _handle_missing_table(self): + """ + Handle the case where the table doesn't exist and can't be created automatically. + Provides clear instructions for manual table creation. + """ + table_creation_sql = f""" +-- Run this SQL in your Supabase SQL Editor to create the required table: + +CREATE TABLE IF NOT EXISTS {self.table_name} ( + id BIGSERIAL PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TIMESTAMPTZ DEFAULT NOW(), + message_type TEXT, + metadata JSONB, + token_count INTEGER, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Create index for better query performance: +CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id +ON {self.table_name} (conversation_id); + +-- Optional: Enable Row Level Security (RLS) for production: +ALTER TABLE {self.table_name} ENABLE ROW LEVEL SECURITY; + +-- Optional: Create RLS policy (customize according to your needs): +CREATE POLICY "Users can manage their own conversations" ON {self.table_name} + FOR ALL USING (true); -- Adjust this policy based on your security requirements +""" + + error_msg = ( + f"Table '{self.table_name}' does not exist in your Supabase database and cannot be created automatically. " + f"Please create it manually by running the following SQL in your Supabase SQL Editor:\n\n{table_creation_sql}\n\n" + f"Alternatively, you can create a custom RPC function in Supabase to enable automatic table creation. " + f"Visit your Supabase dashboard > SQL Editor and create this function:\n\n" + f"CREATE OR REPLACE FUNCTION exec_sql(sql TEXT)\n" + f"RETURNS TEXT AS $$\n" + f"BEGIN\n" + f" EXECUTE sql;\n" + f" RETURN 'SUCCESS';\n" + f"END;\n" + f"$$ LANGUAGE plpgsql SECURITY DEFINER;\n\n" + f"After creating either the table or the RPC function, retry initializing the SupabaseConversation." + ) + + if self.enable_logging: + self.logger.error(error_msg) + raise SupabaseOperationError(error_msg) + + def _handle_api_response(self, response, operation_name: str = "Supabase operation"): + """Handles Supabase API response, checking for errors and returning data.""" + # The new supabase-py client structure: response has .data and .count attributes + # Errors are raised as exceptions rather than being in response.error + try: + if hasattr(response, 'data'): + # Return the data, which could be None, a list, or a dict + return response.data + else: + # Fallback for older response structures or direct data + return response + except Exception as e: + if self.enable_logging: + self.logger.error(f"{operation_name} failed: {e}") + raise SupabaseOperationError(f"{operation_name} failed: {e}") + + def _serialize_content(self, content: Union[str, dict, list]) -> str: + """Serializes content to JSON string if it's a dict or list.""" + if isinstance(content, (dict, list)): + return json.dumps(content, cls=DateTimeEncoder) + return str(content) + + def _deserialize_content(self, content_str: str) -> Union[str, dict, list]: + """Deserializes content from JSON string if it looks like JSON.""" + try: + # Try to parse if it looks like a JSON object or array + if content_str.strip().startswith(("{", "[")): + return json.loads(content_str) + except json.JSONDecodeError: + pass # Not a valid JSON, return as string + return content_str + + def _serialize_metadata(self, metadata: Optional[Dict]) -> Optional[str]: + """Serializes metadata dict to JSON string.""" + if metadata is None: + return None + return json.dumps(metadata, cls=DateTimeEncoder) + + def _deserialize_metadata(self, metadata_str: Optional[str]) -> Optional[Dict]: + """Deserializes metadata from JSON string.""" + if metadata_str is None: + return None + try: + return json.loads(metadata_str) + except json.JSONDecodeError: + self.logger.warning(f"Failed to deserialize metadata: {metadata_str}") + return None # Or return the string itself if preferred + + def _generate_conversation_id(self) -> str: + """Generate a unique conversation ID using UUID and timestamp.""" + timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + unique_id = str(uuid.uuid4())[:8] + return f"conv_{timestamp}_{unique_id}" + + def start_new_conversation(self) -> str: + """Starts a new conversation and returns its ID.""" + self.current_conversation_id = self._generate_conversation_id() + self.logger.info(f"Started new conversation with ID: {self.current_conversation_id}") + return self.current_conversation_id + + def add( + self, + role: str, + content: Union[str, dict, list], + message_type: Optional[MessageType] = None, + metadata: Optional[Dict] = None, + token_count: Optional[int] = None, + ) -> int: + """Add a message to the current conversation history in Supabase.""" + if self.current_conversation_id is None: + self.start_new_conversation() + + serialized_content = self._serialize_content(content) + current_timestamp_iso = datetime.datetime.now(datetime.timezone.utc).isoformat() + + message_data = { + "conversation_id": self.current_conversation_id, + "role": role, + "content": serialized_content, + "timestamp": current_timestamp_iso, # Supabase will use its default if not provided / column allows NULL + "message_type": message_type.value if message_type else None, + "metadata": self._serialize_metadata(metadata), + # token_count handled below + } + + # Calculate token_count if enabled and not provided + if self.calculate_token_count and token_count is None and self.tokenizer: + try: + # For now, do this synchronously. For long content, consider async/threading. + message_data["token_count"] = self.tokenizer.count_tokens(str(content)) + except Exception as e: + if self.enable_logging: + self.logger.warning(f"Failed to count tokens for content: {e}") + elif token_count is not None: + message_data["token_count"] = token_count + + # Filter out None values to let Supabase handle defaults or NULLs appropriately + message_to_insert = {k: v for k, v in message_data.items() if v is not None} + + try: + response = self.client.table(self.table_name).insert(message_to_insert).execute() + data = self._handle_api_response(response, "add_message") + if data and len(data) > 0 and "id" in data[0]: + inserted_id = data[0]["id"] + if self.enable_logging: + self.logger.debug(f"Added message with ID {inserted_id} to conversation {self.current_conversation_id}") + return inserted_id + if self.enable_logging: + self.logger.error(f"Failed to retrieve ID for inserted message in conversation {self.current_conversation_id}") + raise SupabaseOperationError("Failed to retrieve ID for inserted message.") + except Exception as e: + if self.enable_logging: + self.logger.error(f"Error adding message to Supabase: {e}") + raise SupabaseOperationError(f"Error adding message: {e}") + + def batch_add(self, messages: List[Message]) -> List[int]: + """Add multiple messages to the current conversation history in Supabase.""" + if self.current_conversation_id is None: + self.start_new_conversation() + + messages_to_insert = [] + for msg_obj in messages: + serialized_content = self._serialize_content(msg_obj.content) + current_timestamp_iso = (msg_obj.timestamp or datetime.datetime.now(datetime.timezone.utc).isoformat()) + + msg_data = { + "conversation_id": self.current_conversation_id, + "role": msg_obj.role, + "content": serialized_content, + "timestamp": current_timestamp_iso, + "message_type": msg_obj.message_type.value if msg_obj.message_type else None, + "metadata": self._serialize_metadata(msg_obj.metadata), + } + + # Token count + current_token_count = msg_obj.token_count + if self.calculate_token_count and current_token_count is None and self.tokenizer: + try: + current_token_count = self.tokenizer.count_tokens(str(msg_obj.content)) + except Exception as e: + self.logger.warning(f"Failed to count tokens for batch message: {e}") + if current_token_count is not None: + msg_data["token_count"] = current_token_count + + messages_to_insert.append({k: v for k, v in msg_data.items() if v is not None}) + + if not messages_to_insert: + return [] + + try: + response = self.client.table(self.table_name).insert(messages_to_insert).execute() + data = self._handle_api_response(response, "batch_add_messages") + inserted_ids = [item['id'] for item in data if 'id' in item] + if len(inserted_ids) != len(messages_to_insert): + self.logger.warning("Mismatch in expected and inserted message counts during batch_add.") + self.logger.debug(f"Batch added {len(inserted_ids)} messages to conversation {self.current_conversation_id}") + return inserted_ids + except Exception as e: + self.logger.error(f"Error batch adding messages to Supabase: {e}") + raise SupabaseOperationError(f"Error batch adding messages: {e}") + + def _format_row_to_dict(self, row: Dict) -> Dict: + """Helper to format a raw row from Supabase to our standard message dict.""" + formatted_message = { + "id": row.get("id"), + "role": row.get("role"), + "content": self._deserialize_content(row.get("content", "")), + "timestamp": row.get("timestamp"), + "message_type": row.get("message_type"), + "metadata": self._deserialize_metadata(row.get("metadata")), + "token_count": row.get("token_count"), + "conversation_id": row.get("conversation_id"), + "created_at": row.get("created_at"), + } + # Clean None values from the root, but keep them within deserialized content/metadata + return {k: v for k, v in formatted_message.items() if v is not None or k in ["metadata", "token_count", "message_type"]} + + + def get_messages( + self, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[Dict]: + """Get messages from the current conversation with optional pagination.""" + if self.current_conversation_id is None: + return [] + try: + query = self.client.table(self.table_name).select("*") \ + .eq("conversation_id", self.current_conversation_id) \ + .order("timestamp", desc=False) # Assuming 'timestamp' or 'id' for ordering + + if limit is not None: + query = query.limit(limit) + if offset is not None: + query = query.offset(offset) + + response = query.execute() + data = self._handle_api_response(response, "get_messages") + return [self._format_row_to_dict(row) for row in data] + except Exception as e: + self.logger.error(f"Error getting messages from Supabase: {e}") + raise SupabaseOperationError(f"Error getting messages: {e}") + + def get_str(self) -> str: + """Get the current conversation history as a formatted string.""" + messages_dict = self.get_messages() + conv_str = [] + for msg in messages_dict: + ts_prefix = f"[{msg['timestamp']}] " if msg.get('timestamp') and self.time_enabled else "" + # Content might be dict/list if deserialized + content_display = msg['content'] + if isinstance(content_display, (dict, list)): + content_display = json.dumps(content_display, indent=2, cls=DateTimeEncoder) + conv_str.append(f"{ts_prefix}{msg['role']}: {content_display}") + return "\n".join(conv_str) + + def display_conversation(self, detailed: bool = False): + """Display the conversation history.""" + # `detailed` flag might be used for more verbose printing if needed + print(self.get_str()) + + def delete(self, index: str): + """Delete a message from the conversation history by its primary key 'id'.""" + if self.current_conversation_id is None: + if self.enable_logging: + self.logger.warning("Cannot delete message: No current conversation.") + return + + try: + # Handle both string and int message IDs + try: + message_id = int(index) + except ValueError: + if self.enable_logging: + self.logger.error(f"Invalid message ID for delete: {index}. Must be an integer.") + raise ValueError(f"Invalid message ID for delete: {index}. Must be an integer.") + + response = self.client.table(self.table_name).delete() \ + .eq("id", message_id) \ + .eq("conversation_id", self.current_conversation_id) \ + .execute() + self._handle_api_response(response, f"delete_message (id: {message_id})") + if self.enable_logging: + self.logger.info(f"Deleted message with ID {message_id} from conversation {self.current_conversation_id}") + except Exception as e: + if self.enable_logging: + self.logger.error(f"Error deleting message ID {index} from Supabase: {e}") + raise SupabaseOperationError(f"Error deleting message ID {index}: {e}") + + def update( + self, + index: Union[str, int], + role: Optional[str] = None, + content: Optional[Union[str, dict]] = None, + metadata: Optional[Dict] = None + ) -> bool: + """Update a message in the conversation history. Returns True if successful, False otherwise.""" + if self.current_conversation_id is None: + if self.enable_logging: + self.logger.warning("Cannot update message: No current conversation.") + return False + + # Handle both string and int message IDs + try: + if isinstance(index, str): + message_id = int(index) + else: + message_id = index + except ValueError: + if self.enable_logging: + self.logger.error(f"Invalid message ID for update: {index}. Must be an integer.") + return False + + update_data = {} + if role is not None: + update_data["role"] = role + if content is not None: + update_data["content"] = self._serialize_content(content) + if self.calculate_token_count and self.tokenizer: + try: + update_data["token_count"] = self.tokenizer.count_tokens(str(content)) + except Exception as e: + if self.enable_logging: + self.logger.warning(f"Failed to count tokens for updated content: {e}") + if metadata is not None: # Allows setting metadata to null by passing {} then serializing + update_data["metadata"] = self._serialize_metadata(metadata) + + if not update_data: + if self.enable_logging: + self.logger.info("No fields provided to update for message.") + return False + + try: + response = self.client.table(self.table_name).update(update_data) \ + .eq("id", message_id) \ + .eq("conversation_id", self.current_conversation_id) \ + .execute() + + data = self._handle_api_response(response, f"update_message (id: {message_id})") + + # Check if any rows were actually updated + if data and len(data) > 0: + if self.enable_logging: + self.logger.info(f"Updated message with ID {message_id} in conversation {self.current_conversation_id}") + return True + else: + if self.enable_logging: + self.logger.warning(f"No message found with ID {message_id} in conversation {self.current_conversation_id}") + return False + + except Exception as e: + if self.enable_logging: + self.logger.error(f"Error updating message ID {message_id} in Supabase: {e}") + return False + + def query(self, index: str) -> Optional[Dict]: + """Query a message in the conversation history by its primary key 'id'.""" + if self.current_conversation_id is None: + return None + try: + # Handle both string and int message IDs + try: + message_id = int(index) + except ValueError: + if self.enable_logging: + self.logger.warning(f"Invalid message ID for query: {index}. Must be an integer.") + return None + + response = self.client.table(self.table_name).select("*") \ + .eq("id", message_id) \ + .eq("conversation_id", self.current_conversation_id) \ + .maybe_single() \ + .execute() # maybe_single returns one record or None + + data = self._handle_api_response(response, f"query_message (id: {message_id})") + if data: + return self._format_row_to_dict(data) + return None + except Exception as e: + if self.enable_logging: + self.logger.error(f"Error querying message ID {index} from Supabase: {e}") + return None + + def search(self, keyword: str) -> List[Dict]: + """Search for messages containing a keyword in their content.""" + if self.current_conversation_id is None: + return [] + try: + # PostgREST ilike is case-insensitive + response = self.client.table(self.table_name).select("*") \ + .eq("conversation_id", self.current_conversation_id) \ + .ilike("content", f"%{keyword}%") \ + .order("timestamp", desc=False) \ + .execute() + data = self._handle_api_response(response, f"search_messages (keyword: {keyword})") + return [self._format_row_to_dict(row) for row in data] + except Exception as e: + self.logger.error(f"Error searching messages in Supabase: {e}") + raise SupabaseOperationError(f"Error searching messages: {e}") + + def _export_to_file(self, filename: str, format_type: str): + """Helper to export conversation to JSON or YAML file.""" + if self.current_conversation_id is None: + self.logger.warning("No current conversation to export.") + return + + data_to_export = self.to_dict() # Gets messages for current_conversation_id + try: + with open(filename, "w") as f: + if format_type == "json": + json.dump(data_to_export, f, indent=2, cls=DateTimeEncoder) + elif format_type == "yaml": + yaml.dump(data_to_export, f, sort_keys=False) + else: + raise ValueError(f"Unsupported export format: {format_type}") + self.logger.info(f"Conversation {self.current_conversation_id} exported to {filename} as {format_type}.") + except Exception as e: + self.logger.error(f"Failed to export conversation to {format_type}: {e}") + raise + + def export_conversation(self, filename: str): + """Export the current conversation history to a file (JSON or YAML based on init flags).""" + if self.save_as_json_on_export: + self._export_to_file(filename, "json") + elif self.save_as_yaml_on_export: # Default if json is false + self._export_to_file(filename, "yaml") + else: # Fallback if somehow both are false + self._export_to_file(filename, "yaml") + + + def _import_from_file(self, filename: str, format_type: str): + """Helper to import conversation from JSON or YAML file.""" + try: + with open(filename, "r") as f: + if format_type == "json": + imported_data = json.load(f) + elif format_type == "yaml": + imported_data = yaml.safe_load(f) + else: + raise ValueError(f"Unsupported import format: {format_type}") + + if not isinstance(imported_data, list): + raise ValueError("Imported data must be a list of messages.") + + # Start a new conversation for the imported data + self.start_new_conversation() + + messages_to_batch = [] + for msg_data in imported_data: + # Adapt to Message dataclass structure if possible + role = msg_data.get("role") + content = msg_data.get("content") + if role is None or content is None: + self.logger.warning(f"Skipping message due to missing role/content: {msg_data}") + continue + + messages_to_batch.append(Message( + role=role, + content=content, + timestamp=msg_data.get("timestamp"), # Will be handled by batch_add + message_type=MessageType(msg_data["message_type"]) if msg_data.get("message_type") else None, + metadata=msg_data.get("metadata"), + token_count=msg_data.get("token_count") + )) + + if messages_to_batch: + self.batch_add(messages_to_batch) + self.logger.info(f"Conversation imported from {filename} ({format_type}) into new ID {self.current_conversation_id}.") + + except Exception as e: + self.logger.error(f"Failed to import conversation from {format_type}: {e}") + raise + + def import_conversation(self, filename: str): + """Import a conversation history from a file (tries JSON then YAML).""" + try: + if filename.lower().endswith(".json"): + self._import_from_file(filename, "json") + elif filename.lower().endswith((".yaml", ".yml")): + self._import_from_file(filename, "yaml") + else: + # Try JSON first, then YAML as a fallback + try: + self._import_from_file(filename, "json") + except (json.JSONDecodeError, ValueError): # ValueError if not list + self.logger.info(f"Failed to import {filename} as JSON, trying YAML.") + self._import_from_file(filename, "yaml") + except Exception as e: # Catch errors from _import_from_file + raise SupabaseOperationError(f"Could not import {filename}: {e}") + + + def count_messages_by_role(self) -> Dict[str, int]: + """Count messages by role for the current conversation.""" + if self.current_conversation_id is None: + return {} + try: + # Supabase rpc might be better for direct count, but select + python count is also fine + # For direct DB count: self.client.rpc('count_roles', {'conv_id': self.current_conversation_id}).execute() + messages = self.get_messages() # Fetches for current_conversation_id + counts = {} + for msg in messages: + role = msg.get("role", "unknown") + counts[role] = counts.get(role, 0) + 1 + return counts + except Exception as e: + self.logger.error(f"Error counting messages by role: {e}") + raise SupabaseOperationError(f"Error counting messages by role: {e}") + + def return_history_as_string(self) -> str: + """Return the conversation history as a string.""" + return self.get_str() + + def clear(self): + """Clear the current conversation history from Supabase.""" + if self.current_conversation_id is None: + self.logger.info("No current conversation to clear.") + return + try: + response = self.client.table(self.table_name).delete() \ + .eq("conversation_id", self.current_conversation_id) \ + .execute() + # response.data will be a list of deleted items. + # response.count might be available for delete operations in some supabase-py versions or configurations. + # For now, we assume success if no error. + self._handle_api_response(response, f"clear_conversation (id: {self.current_conversation_id})") + self.logger.info(f"Cleared conversation with ID: {self.current_conversation_id}") + except Exception as e: + self.logger.error(f"Error clearing conversation {self.current_conversation_id} from Supabase: {e}") + raise SupabaseOperationError(f"Error clearing conversation: {e}") + + def to_dict(self) -> List[Dict]: + """Convert the current conversation history to a list of dictionaries.""" + return self.get_messages() # Already fetches for current_conversation_id + + def to_json(self) -> str: + """Convert the current conversation history to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, cls=DateTimeEncoder) + + def to_yaml(self) -> str: + """Convert the current conversation history to a YAML string.""" + return yaml.dump(self.to_dict(), sort_keys=False) + + def save_as_json(self, filename: str): + """Save the current conversation history as a JSON file.""" + self._export_to_file(filename, "json") + + def load_from_json(self, filename: str): + """Load a conversation history from a JSON file into a new conversation.""" + self._import_from_file(filename, "json") + + def save_as_yaml(self, filename: str): + """Save the current conversation history as a YAML file.""" + self._export_to_file(filename, "yaml") + + def load_from_yaml(self, filename: str): + """Load a conversation history from a YAML file into a new conversation.""" + self._import_from_file(filename, "yaml") + + def get_last_message(self) -> Optional[Dict]: + """Get the last message from the current conversation history.""" + if self.current_conversation_id is None: + return None + try: + response = self.client.table(self.table_name).select("*") \ + .eq("conversation_id", self.current_conversation_id) \ + .order("timestamp", desc=True) \ + .limit(1) \ + .maybe_single() \ + .execute() + data = self._handle_api_response(response, "get_last_message") + return self._format_row_to_dict(data) if data else None + except Exception as e: + self.logger.error(f"Error getting last message from Supabase: {e}") + raise SupabaseOperationError(f"Error getting last message: {e}") + + def get_last_message_as_string(self) -> str: + """Get the last message as a formatted string.""" + last_msg = self.get_last_message() + if not last_msg: + return "" + ts_prefix = f"[{last_msg['timestamp']}] " if last_msg.get('timestamp') and self.time_enabled else "" + content_display = last_msg['content'] + if isinstance(content_display, (dict, list)): + content_display = json.dumps(content_display, cls=DateTimeEncoder) + return f"{ts_prefix}{last_msg['role']}: {content_display}" + + def get_messages_by_role(self, role: str) -> List[Dict]: + """Get all messages from a specific role in the current conversation.""" + if self.current_conversation_id is None: + return [] + try: + response = self.client.table(self.table_name).select("*") \ + .eq("conversation_id", self.current_conversation_id) \ + .eq("role", role) \ + .order("timestamp", desc=False) \ + .execute() + data = self._handle_api_response(response, f"get_messages_by_role (role: {role})") + return [self._format_row_to_dict(row) for row in data] + except Exception as e: + self.logger.error(f"Error getting messages by role '{role}' from Supabase: {e}") + raise SupabaseOperationError(f"Error getting messages by role '{role}': {e}") + + def get_conversation_summary(self) -> Dict: + """Get a summary of the current conversation.""" + if self.current_conversation_id is None: + return {"error": "No current conversation."} + + # This could be optimized with an RPC call in Supabase for better performance + # Example RPC: CREATE OR REPLACE FUNCTION get_conversation_summary(conv_id TEXT) ... + messages = self.get_messages() + if not messages: + return { + "conversation_id": self.current_conversation_id, + "total_messages": 0, + "unique_roles": 0, + "first_message_time": None, + "last_message_time": None, + "total_tokens": 0, + "roles": {}, + } + + roles_counts = {} + total_tokens_sum = 0 + for msg in messages: + roles_counts[msg["role"]] = roles_counts.get(msg["role"], 0) + 1 + if msg.get("token_count") is not None: + total_tokens_sum += int(msg["token_count"]) + + return { + "conversation_id": self.current_conversation_id, + "total_messages": len(messages), + "unique_roles": len(roles_counts), + "first_message_time": messages[0].get("timestamp"), + "last_message_time": messages[-1].get("timestamp"), + "total_tokens": total_tokens_sum, + "roles": roles_counts, + } + + def get_statistics(self) -> Dict: + """Get statistics about the current conversation (alias for get_conversation_summary).""" + return self.get_conversation_summary() + + def get_conversation_id(self) -> str: + """Get the current conversation ID.""" + return self.current_conversation_id or "" + + def delete_current_conversation(self) -> bool: + """Delete the current conversation. Returns True if successful.""" + if self.current_conversation_id: + self.clear() # clear messages for current_conversation_id + self.logger.info(f"Deleted current conversation: {self.current_conversation_id}") + self.current_conversation_id = None # No active conversation after deletion + return True + self.logger.info("No current conversation to delete.") + return False + + + def search_messages(self, query: str) -> List[Dict]: + """Search for messages containing specific text (alias for search).""" + return self.search(keyword=query) + + def get_conversation_metadata_dict(self) -> Dict: + """Get detailed metadata about the conversation.""" + # Similar to get_conversation_summary, could be expanded with more DB-side aggregations if needed via RPC. + # For now, returning the summary. + if self.current_conversation_id is None: + return {"error": "No current conversation."} + summary = self.get_conversation_summary() + + # Example of additional metadata one might compute client-side or via RPC + # message_type_distribution, average_tokens_per_message, hourly_message_frequency + return { + "conversation_id": self.current_conversation_id, + "basic_stats": summary, + # Placeholder for more detailed stats if implemented + } + + + def get_conversation_timeline_dict(self) -> Dict[str, List[Dict]]: + """Get the conversation organized by timestamps (dates as keys).""" + if self.current_conversation_id is None: + return {} + + messages = self.get_messages() # Assumes messages are ordered by timestamp + timeline_dict = {} + for msg in messages: + try: + # Ensure timestamp is a string and valid ISO format + ts_str = msg.get("timestamp") + if isinstance(ts_str, str): + date_key = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00")).strftime('%Y-%m-%d') + if date_key not in timeline_dict: + timeline_dict[date_key] = [] + timeline_dict[date_key].append(msg) + else: + self.logger.warning(f"Message ID {msg.get('id')} has invalid timestamp format: {ts_str}") + except ValueError as e: + self.logger.warning(f"Could not parse timestamp for message ID {msg.get('id')}: {ts_str}, Error: {e}") + + return timeline_dict + + + def get_conversation_by_role_dict(self) -> Dict[str, List[Dict]]: + """Get the conversation organized by roles.""" + if self.current_conversation_id is None: + return {} + + messages = self.get_messages() + role_dict = {} + for msg in messages: + role = msg.get("role", "unknown") + if role not in role_dict: + role_dict[role] = [] + role_dict[role].append(msg) + return role_dict + + def get_conversation_as_dict(self) -> Dict: + """Get the entire current conversation as a dictionary with messages and metadata.""" + if self.current_conversation_id is None: + return {"error": "No current conversation."} + + return { + "conversation_id": self.current_conversation_id, + "messages": self.get_messages(), + "metadata": self.get_conversation_summary(), # Using summary as metadata + } + + def truncate_memory_with_tokenizer(self): + """Truncate the conversation history based on token count if a tokenizer is provided.""" + if not self.tokenizer or self.current_conversation_id is None: + self.logger.info("Tokenizer not available or no current conversation, skipping truncation.") + return + + try: + messages = self.get_messages() # Fetches ordered by timestamp ASC + + # Calculate cumulative tokens from newest to oldest to decide cutoff + # Or, from oldest to newest to keep newest messages within context_length + + # Let's keep newest messages: iterate backwards, then delete earlier ones. + # This is complex with current `delete` by ID. + # A simpler approach: calculate total tokens, if > context_length, delete oldest ones. + + current_total_tokens = sum( + m.get("token_count", 0) if m.get("token_count") is not None + else (self.tokenizer.count_tokens(self._serialize_content(m["content"])) if self.calculate_token_count else 0) + for m in messages + ) + + tokens_to_remove = current_total_tokens - self.context_length + + if tokens_to_remove <= 0: + return # No truncation needed + + deleted_count = 0 + for msg in messages: # Oldest messages first + if tokens_to_remove <= 0: + break + + msg_id = msg.get("id") + if not msg_id: + continue + + msg_tokens = msg.get("token_count", 0) + if msg_tokens == 0 and self.calculate_token_count: # Recalculate if zero and enabled + msg_tokens = self.tokenizer.count_tokens(self._serialize_content(msg["content"])) + + self.delete(msg_id) # Delete by primary key + tokens_to_remove -= msg_tokens + deleted_count +=1 + + self.logger.info(f"Truncated conversation {self.current_conversation_id}, removed {deleted_count} oldest messages.") + + except Exception as e: + self.logger.error(f"Error during memory truncation for conversation {self.current_conversation_id}: {e}") + # Don't re-raise, truncation is best-effort + + # Methods from duckdb_wrap.py that seem generally useful and can be adapted + def get_visible_messages( + self, agent: Optional[Callable] = None, turn: Optional[int] = None + ) -> List[Dict]: + """ + Get visible messages, optionally filtered by agent visibility and turn. + Assumes 'metadata' field can contain 'visible_to' (list of agent names or 'all') + and 'turn' (integer). + """ + if self.current_conversation_id is None: + return [] + + # Base query + query = self.client.table(self.table_name).select("*") \ + .eq("conversation_id", self.current_conversation_id) \ + .order("timestamp", desc=False) + + # Execute and then filter in Python, as JSONB querying for array containment or + # numeric comparison within JSON can be complex with supabase-py's fluent API. + # For complex filtering, an RPC function in Supabase would be more efficient. + + try: + response = query.execute() + all_messages = self._handle_api_response(response, "get_visible_messages_fetch_all") + except Exception as e: + self.logger.error(f"Error fetching messages for visibility check: {e}") + return [] + + visible_messages = [] + for row_data in all_messages: + msg = self._format_row_to_dict(row_data) + metadata = msg.get("metadata") if isinstance(msg.get("metadata"), dict) else {} + + # Turn filtering + if turn is not None: + msg_turn = metadata.get("turn") + if not (isinstance(msg_turn, int) and msg_turn < turn): + continue # Skip if turn condition not met + + # Agent visibility filtering + if agent is not None: + visible_to = metadata.get("visible_to") + agent_name_attr = getattr(agent, 'agent_name', None) # Safely get agent_name + if agent_name_attr is None: # If agent has no name, assume it can't see restricted msgs + if visible_to is not None and visible_to != "all": + continue + elif isinstance(visible_to, list) and agent_name_attr not in visible_to: + continue # Skip if agent not in visible_to list + elif isinstance(visible_to, str) and visible_to != "all": + # If visible_to is a string but not "all", and doesn't match agent_name + if visible_to != agent_name_attr: + continue + + visible_messages.append(msg) + return visible_messages + + def return_messages_as_list(self) -> List[str]: + """Return the conversation messages as a list of formatted strings 'role: content'.""" + messages_dict = self.get_messages() + return [ + f"{msg.get('role', 'unknown')}: {self._serialize_content(msg.get('content', ''))}" + for msg in messages_dict + ] + + def return_messages_as_dictionary(self) -> List[Dict]: + """Return the conversation messages as a list of dictionaries [{role: R, content: C}].""" + messages_dict = self.get_messages() + return [ + {"role": msg.get("role"), "content": msg.get("content")} # Content already deserialized by _format_row_to_dict + for msg in messages_dict + ] + + def add_tool_output_to_agent(self, role: str, tool_output: dict): # role is usually "tool" + """Add a tool output to the conversation history.""" + # Assuming tool_output is a dict that should be stored as content + self.add(role=role, content=tool_output, message_type=MessageType.TOOL) + + def get_final_message(self) -> Optional[str]: + """Return the final message from the conversation history as 'role: content' string.""" + last_msg = self.get_last_message() + if not last_msg: + return None + content_display = last_msg['content'] + if isinstance(content_display, (dict, list)): + content_display = json.dumps(content_display, cls=DateTimeEncoder) + return f"{last_msg.get('role', 'unknown')}: {content_display}" + + def get_final_message_content(self) -> Union[str, dict, list, None]: + """Return the content of the final message from the conversation history.""" + last_msg = self.get_last_message() + return last_msg.get("content") if last_msg else None + + def return_all_except_first(self) -> List[Dict]: + """Return all messages except the first one.""" + # The limit=-1, offset=2 from duckdb_wrap is specific to its ID generation. + # For Supabase, we fetch all and skip the first one in Python. + all_messages = self.get_messages() + return all_messages[1:] if len(all_messages) > 1 else [] + + + def return_all_except_first_string(self) -> str: + """Return all messages except the first one as a concatenated string.""" + messages_to_format = self.return_all_except_first() + conv_str = [] + for msg in messages_to_format: + ts_prefix = f"[{msg['timestamp']}] " if msg.get('timestamp') and self.time_enabled else "" + content_display = msg['content'] + if isinstance(content_display, (dict, list)): + content_display = json.dumps(content_display, indent=2, cls=DateTimeEncoder) + conv_str.append(f"{ts_prefix}{msg['role']}: {content_display}") + return "\n".join(conv_str) + + def update_message( + self, + message_id: int, + content: Union[str, dict, list], + metadata: Optional[Dict] = None, + ) -> bool: + """Update an existing message.""" + # Use the unified update method which now returns a boolean + return self.update(index=message_id, content=content, metadata=metadata) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..f489b205 --- /dev/null +++ b/test.py @@ -0,0 +1,151 @@ +import os +from swarms.communication.supabase_wrap import SupabaseConversation, MessageType, SupabaseOperationError +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +# --- Configuration --- +SUPABASE_URL = os.getenv("SUPABASE_URL") +SUPABASE_KEY = os.getenv("SUPABASE_KEY") +TABLE_NAME = "conversations" # Make sure this table exists in your Supabase DB + +def main(): + if not SUPABASE_URL or not SUPABASE_KEY: + print("Error: SUPABASE_URL and SUPABASE_KEY environment variables must be set.") + print("Please create a .env file with these values or set them in your environment.") + return + + print(f"Attempting to connect to Supabase URL: {SUPABASE_URL[:20]}...") # Print partial URL for security + + try: + # Initialize SupabaseConversation + print(f"\n--- Initializing SupabaseConversation for table '{TABLE_NAME}' ---") + convo = SupabaseConversation( + supabase_url=SUPABASE_URL, + supabase_key=SUPABASE_KEY, + table_name=TABLE_NAME, + time_enabled=True, # DB schema handles timestamps by default + enable_logging=True, + ) + print(f"Initialized. Current Conversation ID: {convo.get_conversation_id()}") + + # --- Add messages --- + print("\n--- Adding messages ---") + user_msg_id = convo.add("user", "Hello, Supabase!", message_type=MessageType.USER, metadata={"source": "test_script"}) + print(f"Added user message. ID: {user_msg_id}") + + assistant_msg_content = {"response": "Hi there! How can I help you today?", "confidence": 0.95} + assistant_msg_id = convo.add("assistant", assistant_msg_content, message_type=MessageType.ASSISTANT) + print(f"Added assistant message. ID: {assistant_msg_id}") + + system_msg_id = convo.add("system", "Conversation started.", message_type=MessageType.SYSTEM) + print(f"Added system message. ID: {system_msg_id}") + + + # --- Display conversation --- + print("\n--- Displaying conversation ---") + convo.display_conversation() + + # --- Get all messages for current conversation --- + print("\n--- Retrieving all messages for current conversation ---") + all_messages = convo.get_messages() + if all_messages: + print(f"Retrieved {len(all_messages)} messages:") + for msg in all_messages: + print(f" ID: {msg.get('id')}, Role: {msg.get('role')}, Content: {str(msg.get('content'))[:50]}...") + else: + print("No messages found.") + + # --- Query a specific message --- + if user_msg_id: + print(f"\n--- Querying message with ID: {user_msg_id} ---") + queried_msg = convo.query(str(user_msg_id)) # Query expects string ID + if queried_msg: + print(f"Queried message: {queried_msg}") + else: + print(f"Message with ID {user_msg_id} not found.") + + # --- Search messages --- + print("\n--- Searching for messages containing 'Supabase' ---") + search_results = convo.search("Supabase") + if search_results: + print(f"Found {len(search_results)} matching messages:") + for msg in search_results: + print(f" ID: {msg.get('id')}, Content: {str(msg.get('content'))[:50]}...") + else: + print("No messages found matching 'Supabase'.") + + # --- Update a message --- + if assistant_msg_id: + print(f"\n--- Updating message with ID: {assistant_msg_id} ---") + new_content = {"response": "I am an updated assistant!", "confidence": 0.99} + convo.update(index_or_id=str(assistant_msg_id), content=new_content, metadata={"updated_by": "test_script"}) + updated_msg = convo.query(str(assistant_msg_id)) + print(f"Updated message: {updated_msg}") + + + # --- Get last message --- + print("\n--- Getting last message ---") + last_msg = convo.get_last_message_as_string() + print(f"Last message: {last_msg}") + + + # --- Export and Import (example) --- + # Create a dummy export file name based on conversation ID + export_filename_json = f"convo_{convo.get_conversation_id()}.json" + export_filename_yaml = f"convo_{convo.get_conversation_id()}.yaml" + + print(f"\n--- Exporting conversation to {export_filename_json} and {export_filename_yaml} ---") + convo.save_as_json_on_export = True # Test JSON export + convo.export_conversation(export_filename_json) + convo.save_as_json_on_export = False # Switch to YAML for next export + convo.save_as_yaml_on_export = True + convo.export_conversation(export_filename_yaml) + + + print("\n--- Starting a new conversation and importing from JSON ---") + new_convo_id_before_import = convo.start_new_conversation() + print(f"New conversation started with ID: {new_convo_id_before_import}") + convo.import_conversation(export_filename_json) # This will start another new convo internally + print(f"Conversation imported from {export_filename_json}. Current ID: {convo.get_conversation_id()}") + convo.display_conversation() + + # --- Delete a message --- + if system_msg_id: # Using system_msg_id from the *original* conversation for this demo + print(f"\n--- Attempting to delete message with ID: {system_msg_id} from a *previous* conversation (might not exist in current) ---") + # Note: After import, system_msg_id refers to an ID from a *previous* conversation. + # To robustly test delete, you'd query a message from the *current* imported conversation. + # For this example, we'll just show the call. + # Let's add a message to the *current* conversation and delete that one. + temp_msg_id_to_delete = convo.add("system", "This message will be deleted.") + print(f"Added temporary message with ID: {temp_msg_id_to_delete}") + convo.delete(str(temp_msg_id_to_delete)) + print(f"Message with ID {temp_msg_id_to_delete} deleted (if it existed in current convo).") + if convo.query(str(temp_msg_id_to_delete)) is None: + print("Verified: Message no longer exists.") + else: + print("Warning: Message still exists or query failed.") + + + # --- Clear current conversation --- + print("\n--- Clearing current conversation ---") + convo.clear() + print(f"Conversation {convo.get_conversation_id()} cleared.") + if not convo.get_messages(): + print("Verified: No messages in current conversation after clearing.") + + + print("\n--- Example Finished ---") + + except SupabaseOperationError as e: + print(f"Supabase Connection Error: {e}") + except SupabaseOperationError as e: + print(f"Supabase Operation Error: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/communication/__init__.py b/tests/communication/__init__.py index e69de29b..597a6db2 100644 --- a/tests/communication/__init__.py +++ b/tests/communication/__init__.py @@ -0,0 +1 @@ +i \ No newline at end of file diff --git a/tests/communication/test_duckdb_conversation.py b/tests/communication/test_duckdb_conversation.py index be837ad5..095a6e00 100644 --- a/tests/communication/test_duckdb_conversation.py +++ b/tests/communication/test_duckdb_conversation.py @@ -1,7 +1,13 @@ import os +import sys from pathlib import Path import tempfile import threading + +# Add the project root to Python path to allow imports +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + from swarms.communication.duckdb_wrap import ( DuckDBConversation, Message, diff --git a/tests/communication/test_sqlite_wrapper.py b/tests/communication/test_sqlite_wrapper.py index 2c092ce2..fe905c02 100644 --- a/tests/communication/test_sqlite_wrapper.py +++ b/tests/communication/test_sqlite_wrapper.py @@ -1,7 +1,14 @@ import json import datetime import os +import sys +from pathlib import Path from typing import Dict, List, Any, Tuple + +# Add the project root to Python path to allow imports +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + from loguru import logger from swarms.communication.sqlite_wrap import ( SQLiteConversation, @@ -12,9 +19,9 @@ from rich.console import Console from rich.table import Table from rich.panel import Panel +# Initialize logger console = Console() - def print_test_header(test_name: str) -> None: """Print a formatted test header.""" console.print( diff --git a/tests/communication/test_supabase_conversation.py b/tests/communication/test_supabase_conversation.py new file mode 100644 index 00000000..cac0988e --- /dev/null +++ b/tests/communication/test_supabase_conversation.py @@ -0,0 +1,1077 @@ +import os +import sys +import json +import datetime +import tempfile +import threading +from typing import Dict, List, Any, Tuple +from pathlib import Path + +# Add the project root to Python path to allow imports +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +try: + from loguru import logger + LOGURU_AVAILABLE = True +except ImportError: + import logging + logger = logging.getLogger(__name__) + LOGURU_AVAILABLE = False + +try: + from rich.console import Console + from rich.table import Table + from rich.panel import Panel + RICH_AVAILABLE = True + console = Console() +except ImportError: + RICH_AVAILABLE = False + console = None + +# Test if supabase is available +try: + from swarms.communication.supabase_wrap import ( + SupabaseConversation, + SupabaseConnectionError, + SupabaseOperationError, + ) + from swarms.communication.base_communication import Message, MessageType + SUPABASE_AVAILABLE = True +except ImportError as e: + SUPABASE_AVAILABLE = False + print(f"āŒ Supabase dependencies not available: {e}") + print("Please install supabase-py: pip install supabase") + +# Try to load environment variables +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass # dotenv is optional + +# Test configuration +TEST_SUPABASE_URL = os.getenv("SUPABASE_URL") +TEST_SUPABASE_KEY = os.getenv("SUPABASE_KEY") +TEST_TABLE_NAME = "conversations_test" + + +def print_test_header(test_name: str) -> None: + """Print a formatted test header.""" + if RICH_AVAILABLE and console: + console.print( + Panel( + f"[bold blue]Running Test: {test_name}[/bold blue]", + expand=False, + ) + ) + else: + print(f"\n=== Running Test: {test_name} ===") + + +def print_test_result( + test_name: str, success: bool, message: str, execution_time: float +) -> None: + """Print a formatted test result.""" + if RICH_AVAILABLE and console: + status = ( + "[bold green]PASSED[/bold green]" + if success + else "[bold red]FAILED[/bold red]" + ) + console.print(f"\n{status} - {test_name}") + console.print(f"Message: {message}") + console.print(f"Execution time: {execution_time:.3f} seconds\n") + else: + status = "PASSED" if success else "FAILED" + print(f"\n{status} - {test_name}") + print(f"Message: {message}") + print(f"Execution time: {execution_time:.3f} seconds\n") + + +def print_messages(messages: List[Dict], title: str = "Messages") -> None: + """Print messages in a formatted table.""" + if RICH_AVAILABLE and console: + table = Table(title=title) + table.add_column("ID", style="cyan") + table.add_column("Role", style="yellow") + table.add_column("Content", style="green") + table.add_column("Type", style="magenta") + table.add_column("Timestamp", style="blue") + + for msg in messages[:10]: # Limit to first 10 messages for display + content = str(msg.get("content", "")) + if isinstance(content, (dict, list)): + content = json.dumps(content)[:50] + "..." if len(json.dumps(content)) > 50 else json.dumps(content) + elif len(content) > 50: + content = content[:50] + "..." + + table.add_row( + str(msg.get("id", "")), + msg.get("role", ""), + content, + str(msg.get("message_type", "")), + str(msg.get("timestamp", ""))[:19] if msg.get("timestamp") else "", + ) + + console.print(table) + else: + print(f"\n{title}:") + for i, msg in enumerate(messages[:10]): + content = str(msg.get("content", "")) + if isinstance(content, (dict, list)): + content = json.dumps(content) + if len(content) > 50: + content = content[:50] + "..." + print(f"{i+1}. {msg.get('role', '')}: {content}") + + +def run_test(test_func: callable, *args, **kwargs) -> Tuple[bool, str, float]: + """ + Run a test function and return its results. + + Args: + test_func: The test function to run + *args: Arguments for the test function + **kwargs: Keyword arguments for the test function + + Returns: + Tuple[bool, str, float]: (success, message, execution_time) + """ + start_time = datetime.datetime.now() + try: + result = test_func(*args, **kwargs) + end_time = datetime.datetime.now() + execution_time = (end_time - start_time).total_seconds() + return True, str(result), execution_time + except Exception as e: + end_time = datetime.datetime.now() + execution_time = (end_time - start_time).total_seconds() + return False, str(e), execution_time + + +def setup_test_conversation(): + """Set up a test conversation instance.""" + if not SUPABASE_AVAILABLE: + raise ImportError("Supabase dependencies not available") + + if not TEST_SUPABASE_URL or not TEST_SUPABASE_KEY: + raise ValueError( + "SUPABASE_URL and SUPABASE_KEY environment variables must be set for testing" + ) + + conversation = SupabaseConversation( + supabase_url=TEST_SUPABASE_URL, + supabase_key=TEST_SUPABASE_KEY, + table_name=TEST_TABLE_NAME, + enable_logging=False, # Reduce noise during testing + time_enabled=True, + ) + return conversation + + +def cleanup_test_conversation(conversation): + """Clean up test conversation data.""" + try: + conversation.clear() + except Exception as e: + if LOGURU_AVAILABLE: + logger.warning(f"Failed to clean up test conversation: {e}") + else: + print(f"Warning: Failed to clean up test conversation: {e}") + + +def test_import_availability() -> bool: + """Test that Supabase imports are properly handled.""" + print_test_header("Import Availability Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Import availability test passed - detected missing dependencies correctly") + return True + + # Test that all required classes are available + assert SupabaseConversation is not None, "SupabaseConversation should be available" + assert SupabaseConnectionError is not None, "SupabaseConnectionError should be available" + assert SupabaseOperationError is not None, "SupabaseOperationError should be available" + assert Message is not None, "Message should be available" + assert MessageType is not None, "MessageType should be available" + + print("āœ“ Import availability test passed - all imports successful") + return True + + +def test_initialization() -> bool: + """Test SupabaseConversation initialization.""" + print_test_header("Initialization Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Initialization test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + assert conversation.supabase_url == TEST_SUPABASE_URL, "Supabase URL mismatch" + assert conversation.table_name == TEST_TABLE_NAME, "Table name mismatch" + assert conversation.current_conversation_id is not None, "Conversation ID should not be None" + assert conversation.client is not None, "Supabase client should not be None" + assert isinstance(conversation.get_conversation_id(), str), "Conversation ID should be string" + + # Test that initialization doesn't call super().__init__() improperly + # This should not raise any errors + print("āœ“ Initialization test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_logging_configuration() -> bool: + """Test logging configuration options.""" + print_test_header("Logging Configuration Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Logging configuration test skipped - Supabase not available") + return True + + # Test with logging enabled + conversation_with_logging = SupabaseConversation( + supabase_url=TEST_SUPABASE_URL, + supabase_key=TEST_SUPABASE_KEY, + table_name=TEST_TABLE_NAME, + enable_logging=True, + use_loguru=False, # Force standard logging + ) + + try: + assert conversation_with_logging.enable_logging == True, "Logging should be enabled" + assert conversation_with_logging.logger is not None, "Logger should be configured" + + # Test with logging disabled + conversation_no_logging = SupabaseConversation( + supabase_url=TEST_SUPABASE_URL, + supabase_key=TEST_SUPABASE_KEY, + table_name=TEST_TABLE_NAME + "_no_log", + enable_logging=False, + ) + + assert conversation_no_logging.enable_logging == False, "Logging should be disabled" + + print("āœ“ Logging configuration test passed") + return True + finally: + cleanup_test_conversation(conversation_with_logging) + try: + cleanup_test_conversation(conversation_no_logging) + except: + pass + + +def test_add_message() -> bool: + """Test adding a single message.""" + print_test_header("Add Message Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Add message test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + msg_id = conversation.add( + role="user", + content="Hello, Supabase!", + message_type=MessageType.USER, + metadata={"test": True} + ) + assert msg_id is not None, "Message ID should not be None" + assert isinstance(msg_id, int), "Message ID should be an integer" + + # Verify message was stored + messages = conversation.get_messages() + assert len(messages) >= 1, "Should have at least 1 message" + print("āœ“ Add message test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_add_complex_message() -> bool: + """Test adding a message with complex content.""" + print_test_header("Add Complex Message Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Add complex message test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + complex_content = { + "text": "Hello from Supabase", + "data": [1, 2, 3, {"nested": "value"}], + "metadata": {"source": "test", "priority": "high"} + } + + msg_id = conversation.add( + role="assistant", + content=complex_content, + message_type=MessageType.ASSISTANT, + metadata={ + "model": "test-model", + "temperature": 0.7, + "tokens": 42 + }, + token_count=42 + ) + + assert msg_id is not None, "Message ID should not be None" + + # Verify complex content was stored and retrieved correctly + message = conversation.query(str(msg_id)) + assert message is not None, "Message should be retrievable" + assert message["content"] == complex_content, "Complex content should match" + assert message["token_count"] == 42, "Token count should match" + + print("āœ“ Add complex message test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_batch_add() -> bool: + """Test batch adding messages.""" + print_test_header("Batch Add Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Batch add test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + messages = [ + Message( + role="user", + content="First batch message", + message_type=MessageType.USER, + metadata={"batch": 1} + ), + Message( + role="assistant", + content={"response": "First response", "confidence": 0.9}, + message_type=MessageType.ASSISTANT, + metadata={"batch": 1} + ), + Message( + role="user", + content="Second batch message", + message_type=MessageType.USER, + metadata={"batch": 2} + ) + ] + + msg_ids = conversation.batch_add(messages) + assert len(msg_ids) == 3, "Should have 3 message IDs" + assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers" + + # Verify messages were stored + all_messages = conversation.get_messages() + assert len([m for m in all_messages if m.get("metadata", {}).get("batch")]) == 3, "Should find 3 batch messages" + + print("āœ“ Batch add test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_get_str() -> bool: + """Test getting conversation as string.""" + print_test_header("Get String Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Get string test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + conversation.add("user", "Hello!") + conversation.add("assistant", "Hi there!") + + conv_str = conversation.get_str() + assert "user: Hello!" in conv_str, "User message not found in string" + assert "assistant: Hi there!" in conv_str, "Assistant message not found in string" + + print("āœ“ Get string test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_get_messages() -> bool: + """Test getting messages with pagination.""" + print_test_header("Get Messages Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Get messages test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add multiple messages + for i in range(5): + conversation.add("user", f"Message {i}") + + # Test getting all messages + all_messages = conversation.get_messages() + assert len(all_messages) >= 5, "Should have at least 5 messages" + + # Test pagination + limited_messages = conversation.get_messages(limit=2) + assert len(limited_messages) == 2, "Should have 2 limited messages" + + offset_messages = conversation.get_messages(offset=2, limit=2) + assert len(offset_messages) == 2, "Should have 2 offset messages" + + print("āœ“ Get messages test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_search_messages() -> bool: + """Test searching messages.""" + print_test_header("Search Messages Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Search messages test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + conversation.add("user", "Hello world from Supabase") + conversation.add("assistant", "Hello there, user!") + conversation.add("user", "Goodbye world") + conversation.add("system", "System message without keywords") + + # Test search functionality + world_results = conversation.search("world") + assert len(world_results) >= 2, "Should find at least 2 messages with 'world'" + + hello_results = conversation.search("Hello") + assert len(hello_results) >= 2, "Should find at least 2 messages with 'Hello'" + + supabase_results = conversation.search("Supabase") + assert len(supabase_results) >= 1, "Should find at least 1 message with 'Supabase'" + + print("āœ“ Search messages test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_update_and_delete() -> bool: + """Test updating and deleting messages.""" + print_test_header("Update and Delete Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Update and delete test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add a message to update/delete + msg_id = conversation.add("user", "Original message") + + # Test update method + conversation.update( + index=str(msg_id), + role="user", + content="Updated message" + ) + + updated_msg = conversation.query(str(msg_id)) + assert updated_msg["content"] == "Updated message", "Message should be updated" + + # Test delete + conversation.delete(str(msg_id)) + + deleted_msg = conversation.query(str(msg_id)) + assert deleted_msg is None, "Message should be deleted" + + print("āœ“ Update and delete test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_update_message_method() -> bool: + """Test the new update_message method.""" + print_test_header("Update Message Method Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Update message method test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add a message to update + msg_id = conversation.add( + role="user", + content="Original content", + metadata={"version": 1} + ) + + # Test update_message method + success = conversation.update_message( + message_id=msg_id, + content="Updated content via update_message", + metadata={"version": 2, "updated": True} + ) + + assert success == True, "update_message should return True on success" + + # Verify the update + updated_msg = conversation.query(str(msg_id)) + assert updated_msg is not None, "Message should still exist" + assert updated_msg["content"] == "Updated content via update_message", "Content should be updated" + assert updated_msg["metadata"]["version"] == 2, "Metadata should be updated" + assert updated_msg["metadata"]["updated"] == True, "New metadata field should be added" + + # Test update_message with non-existent ID + failure = conversation.update_message( + message_id=999999, + content="This should fail" + ) + assert failure == False, "update_message should return False for non-existent message" + + print("āœ“ Update message method test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_conversation_statistics() -> bool: + """Test getting conversation statistics.""" + print_test_header("Conversation Statistics Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Conversation statistics test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add messages with different roles and token counts + conversation.add("user", "Hello", token_count=2) + conversation.add("assistant", "Hi there!", token_count=3) + conversation.add("system", "System message", token_count=5) + conversation.add("user", "Another user message", token_count=4) + + stats = conversation.get_conversation_summary() + assert stats["total_messages"] >= 4, "Should have at least 4 messages" + assert stats["unique_roles"] >= 3, "Should have at least 3 unique roles" + assert stats["total_tokens"] >= 14, "Should have at least 14 total tokens" + + # Test role counting + role_counts = conversation.count_messages_by_role() + assert role_counts.get("user", 0) >= 2, "Should have at least 2 user messages" + assert role_counts.get("assistant", 0) >= 1, "Should have at least 1 assistant message" + assert role_counts.get("system", 0) >= 1, "Should have at least 1 system message" + + print("āœ“ Conversation statistics test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_json_operations() -> bool: + """Test JSON save and load operations.""" + print_test_header("JSON Operations Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ JSON operations test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + json_file = "test_conversation.json" + + try: + # Add test messages + conversation.add("user", "Test message for JSON") + conversation.add("assistant", {"response": "JSON test response", "data": [1, 2, 3]}) + + # Test JSON export + conversation.save_as_json(json_file) + assert os.path.exists(json_file), "JSON file should be created" + + # Verify JSON content + with open(json_file, 'r') as f: + json_data = json.load(f) + assert isinstance(json_data, list), "JSON data should be a list" + assert len(json_data) >= 2, "Should have at least 2 messages in JSON" + + # Test JSON import (creates new conversation) + original_conv_id = conversation.get_conversation_id() + conversation.load_from_json(json_file) + new_conv_id = conversation.get_conversation_id() + assert new_conv_id != original_conv_id, "Should create new conversation on import" + + imported_messages = conversation.get_messages() + assert len(imported_messages) >= 2, "Should have imported messages" + + print("āœ“ JSON operations test passed") + return True + finally: + # Cleanup + if os.path.exists(json_file): + os.remove(json_file) + cleanup_test_conversation(conversation) + + +def test_yaml_operations() -> bool: + """Test YAML save and load operations.""" + print_test_header("YAML Operations Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ YAML operations test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + yaml_file = "test_conversation.yaml" + + try: + # Add test messages + conversation.add("user", "Test message for YAML") + conversation.add("assistant", "YAML test response") + + # Test YAML export + conversation.save_as_yaml(yaml_file) + assert os.path.exists(yaml_file), "YAML file should be created" + + # Test YAML import (creates new conversation) + original_conv_id = conversation.get_conversation_id() + conversation.load_from_yaml(yaml_file) + new_conv_id = conversation.get_conversation_id() + assert new_conv_id != original_conv_id, "Should create new conversation on import" + + imported_messages = conversation.get_messages() + assert len(imported_messages) >= 2, "Should have imported messages" + + print("āœ“ YAML operations test passed") + return True + finally: + # Cleanup + if os.path.exists(yaml_file): + os.remove(yaml_file) + cleanup_test_conversation(conversation) + + +def test_message_types() -> bool: + """Test different message types.""" + print_test_header("Message Types Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Message types test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Test all message types + types_to_test = [ + (MessageType.USER, "user"), + (MessageType.ASSISTANT, "assistant"), + (MessageType.SYSTEM, "system"), + (MessageType.FUNCTION, "function"), + (MessageType.TOOL, "tool") + ] + + for msg_type, role in types_to_test: + msg_id = conversation.add( + role=role, + content=f"Test {msg_type.value} message", + message_type=msg_type + ) + assert msg_id is not None, f"Should create {msg_type.value} message" + + # Verify all message types were stored + messages = conversation.get_messages() + stored_types = {msg.get("message_type") for msg in messages if msg.get("message_type")} + expected_types = {msg_type.value for msg_type, _ in types_to_test} + assert stored_types.issuperset(expected_types), "Should store all message types" + + print("āœ“ Message types test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_conversation_management() -> bool: + """Test conversation management operations.""" + print_test_header("Conversation Management Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Conversation management test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Test getting conversation ID + conv_id = conversation.get_conversation_id() + assert conv_id, "Should have a conversation ID" + assert isinstance(conv_id, str), "Conversation ID should be a string" + + # Add some messages + conversation.add("user", "First conversation message") + conversation.add("assistant", "Response in first conversation") + + first_conv_messages = len(conversation.get_messages()) + assert first_conv_messages >= 2, "Should have messages in first conversation" + + # Start new conversation + new_conv_id = conversation.start_new_conversation() + assert new_conv_id != conv_id, "New conversation should have different ID" + assert conversation.get_conversation_id() == new_conv_id, "Should switch to new conversation" + assert isinstance(new_conv_id, str), "New conversation ID should be a string" + + # Verify new conversation is empty (except any system prompts) + new_messages = conversation.get_messages() + conversation.add("user", "New conversation message") + updated_messages = conversation.get_messages() + assert len(updated_messages) > len(new_messages), "Should be able to add to new conversation" + + # Test clear conversation + conversation.clear() + cleared_messages = conversation.get_messages() + assert len(cleared_messages) == 0, "Conversation should be cleared" + + print("āœ“ Conversation management test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_get_messages_by_role() -> bool: + """Test getting messages filtered by role.""" + print_test_header("Get Messages by Role Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Get messages by role test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add messages with different roles + conversation.add("user", "User message 1") + conversation.add("assistant", "Assistant message 1") + conversation.add("user", "User message 2") + conversation.add("system", "System message") + conversation.add("assistant", "Assistant message 2") + + # Test filtering by role + user_messages = conversation.get_messages_by_role("user") + assert len(user_messages) >= 2, "Should have at least 2 user messages" + assert all(msg["role"] == "user" for msg in user_messages), "All messages should be from user" + + assistant_messages = conversation.get_messages_by_role("assistant") + assert len(assistant_messages) >= 2, "Should have at least 2 assistant messages" + assert all(msg["role"] == "assistant" for msg in assistant_messages), "All messages should be from assistant" + + system_messages = conversation.get_messages_by_role("system") + assert len(system_messages) >= 1, "Should have at least 1 system message" + assert all(msg["role"] == "system" for msg in system_messages), "All messages should be from system" + + print("āœ“ Get messages by role test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_timeline_and_organization() -> bool: + """Test conversation timeline and organization features.""" + print_test_header("Timeline and Organization Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Timeline and organization test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + try: + # Add messages + conversation.add("user", "Timeline test message 1") + conversation.add("assistant", "Timeline test response 1") + conversation.add("user", "Timeline test message 2") + + # Test timeline organization + timeline = conversation.get_conversation_timeline_dict() + assert isinstance(timeline, dict), "Timeline should be a dictionary" + assert len(timeline) > 0, "Timeline should have entries" + + # Test organization by role + by_role = conversation.get_conversation_by_role_dict() + assert isinstance(by_role, dict), "Role organization should be a dictionary" + assert "user" in by_role, "Should have user messages" + assert "assistant" in by_role, "Should have assistant messages" + + # Test conversation as dict + conv_dict = conversation.get_conversation_as_dict() + assert isinstance(conv_dict, dict), "Conversation dict should be a dictionary" + assert "conversation_id" in conv_dict, "Should have conversation ID" + assert "messages" in conv_dict, "Should have messages" + + print("āœ“ Timeline and organization test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_concurrent_operations() -> bool: + """Test concurrent operations for thread safety.""" + print_test_header("Concurrent Operations Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Concurrent operations test skipped - Supabase not available") + return True + + conversation = setup_test_conversation() + results = [] + + def add_messages(thread_id): + """Add messages in a separate thread.""" + try: + for i in range(3): + msg_id = conversation.add( + role="user", + content=f"Thread {thread_id} message {i}", + metadata={"thread_id": thread_id, "message_num": i} + ) + results.append(("success", thread_id, msg_id)) + except Exception as e: + results.append(("error", thread_id, str(e))) + + try: + # Create and start multiple threads + threads = [] + for i in range(3): + thread = threading.Thread(target=add_messages, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Check results + successful_operations = [r for r in results if r[0] == "success"] + assert len(successful_operations) >= 6, "Should have successful concurrent operations" + + # Verify messages were actually stored + all_messages = conversation.get_messages() + thread_messages = [m for m in all_messages if m.get("metadata", {}).get("thread_id") is not None] + assert len(thread_messages) >= 6, "Should have stored concurrent messages" + + print("āœ“ Concurrent operations test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_enhanced_error_handling() -> bool: + """Test enhanced error handling for various edge cases.""" + print_test_header("Enhanced Error Handling Test") + + if not SUPABASE_AVAILABLE: + print("āœ“ Enhanced error handling test skipped - Supabase not available") + return True + + # Test invalid credentials + try: + invalid_conversation = SupabaseConversation( + supabase_url="https://invalid-url.supabase.co", + supabase_key="invalid_key", + enable_logging=False + ) + # This should raise an exception during initialization + assert False, "Should raise exception for invalid credentials" + except (SupabaseConnectionError, Exception): + pass # Expected behavior + + # Test with valid conversation + conversation = setup_test_conversation() + try: + # Test querying non-existent message + non_existent = conversation.query("999999") + assert non_existent is None, "Non-existent message should return None" + + # Test deleting non-existent message (should not raise exception) + conversation.delete("999999") # Should handle gracefully + + # Test updating non-existent message (should not raise exception) + conversation.update("999999", "user", "content") # Should handle gracefully + + # Test update_message with invalid ID + result = conversation.update_message(999999, "invalid content") + assert result == False, "update_message should return False for invalid ID" + + # Test search with empty query + empty_results = conversation.search("") + assert isinstance(empty_results, list), "Empty search should return list" + + # Test invalid message ID formats + try: + conversation.query("not_a_number") + assert False, "Should raise ValueError for non-numeric ID" + except ValueError: + pass # Expected + + try: + conversation.update("not_a_number", "user", "content") + assert False, "Should raise ValueError for non-numeric ID" + except ValueError: + pass # Expected + + try: + conversation.delete("not_a_number") + assert False, "Should raise ValueError for non-numeric ID" + except ValueError: + pass # Expected + + print("āœ“ Enhanced error handling test passed") + return True + finally: + cleanup_test_conversation(conversation) + + +def test_fallback_functionality() -> bool: + """Test fallback functionality when dependencies are missing.""" + print_test_header("Fallback Functionality Test") + + # This test always passes as it tests the import fallback mechanism + if not SUPABASE_AVAILABLE: + print("āœ“ Fallback functionality test passed - gracefully handled missing dependencies") + return True + else: + print("āœ“ Fallback functionality test passed - dependencies available, no fallback needed") + return True + + +def generate_test_report(test_results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Generate a comprehensive test report.""" + total_tests = len(test_results) + passed_tests = sum(1 for result in test_results if result["success"]) + failed_tests = total_tests - passed_tests + + total_time = sum(result["execution_time"] for result in test_results) + avg_time = total_time / total_tests if total_tests > 0 else 0 + + report = { + "summary": { + "total_tests": total_tests, + "passed_tests": passed_tests, + "failed_tests": failed_tests, + "success_rate": (passed_tests / total_tests * 100) if total_tests > 0 else 0, + "total_execution_time": total_time, + "average_execution_time": avg_time, + "timestamp": datetime.datetime.now().isoformat(), + "supabase_available": SUPABASE_AVAILABLE, + "environment_configured": bool(TEST_SUPABASE_URL and TEST_SUPABASE_KEY), + }, + "test_results": test_results, + "failed_tests": [ + result for result in test_results if not result["success"] + ], + } + + return report + + +def run_all_tests() -> None: + """Run all SupabaseConversation tests.""" + print("šŸš€ Starting Enhanced SupabaseConversation Test Suite") + print(f"Supabase Available: {'āœ…' if SUPABASE_AVAILABLE else 'āŒ'}") + + if TEST_SUPABASE_URL and TEST_SUPABASE_KEY: + print(f"Using Supabase URL: {TEST_SUPABASE_URL[:30]}...") + print(f"Using table: {TEST_TABLE_NAME}") + else: + print("āŒ Environment variables SUPABASE_URL and SUPABASE_KEY not set") + print("Some tests will be skipped") + + print("=" * 60) + + # Define tests to run + tests = [ + ("Import Availability", test_import_availability), + ("Fallback Functionality", test_fallback_functionality), + ("Initialization", test_initialization), + ("Logging Configuration", test_logging_configuration), + ("Add Message", test_add_message), + ("Add Complex Message", test_add_complex_message), + ("Batch Add", test_batch_add), + ("Get String", test_get_str), + ("Get Messages", test_get_messages), + ("Search Messages", test_search_messages), + ("Update and Delete", test_update_and_delete), + ("Update Message Method", test_update_message_method), + ("Conversation Statistics", test_conversation_statistics), + ("JSON Operations", test_json_operations), + ("YAML Operations", test_yaml_operations), + ("Message Types", test_message_types), + ("Conversation Management", test_conversation_management), + ("Get Messages by Role", test_get_messages_by_role), + ("Timeline and Organization", test_timeline_and_organization), + ("Concurrent Operations", test_concurrent_operations), + ("Enhanced Error Handling", test_enhanced_error_handling), + ] + + test_results = [] + + # Run each test + for test_name, test_func in tests: + print_test_header(test_name) + success, message, execution_time = run_test(test_func) + + test_result = { + "test_name": test_name, + "success": success, + "message": message, + "execution_time": execution_time, + } + test_results.append(test_result) + + print_test_result(test_name, success, message, execution_time) + + # Generate and display report + report = generate_test_report(test_results) + + print("\n" + "=" * 60) + print("šŸ“Š ENHANCED TEST SUMMARY") + print("=" * 60) + print(f"Total Tests: {report['summary']['total_tests']}") + print(f"Passed: {report['summary']['passed_tests']}") + print(f"Failed: {report['summary']['failed_tests']}") + print(f"Success Rate: {report['summary']['success_rate']:.1f}%") + print(f"Total Time: {report['summary']['total_execution_time']:.3f} seconds") + print(f"Average Time: {report['summary']['average_execution_time']:.3f} seconds") + print(f"Supabase Available: {'āœ…' if report['summary']['supabase_available'] else 'āŒ'}") + print(f"Environment Configured: {'āœ…' if report['summary']['environment_configured'] else 'āŒ'}") + + if report['failed_tests']: + print("\nāŒ FAILED TESTS:") + for failed_test in report['failed_tests']: + print(f" - {failed_test['test_name']}: {failed_test['message']}") + else: + print("\nāœ… All tests passed!") + + # Additional information + if not SUPABASE_AVAILABLE: + print("\nšŸ” NOTE: Supabase dependencies not available.") + print("Install with: pip install supabase") + + if not (TEST_SUPABASE_URL and TEST_SUPABASE_KEY): + print("\nšŸ” NOTE: Environment variables not set.") + print("Set SUPABASE_URL and SUPABASE_KEY to run full tests.") + + # Save detailed report + report_file = f"supabase_test_report_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + with open(report_file, 'w') as f: + json.dump(report, f, indent=2, default=str) + print(f"\nšŸ“„ Detailed report saved to: {report_file}") + + +if __name__ == "__main__": + run_all_tests() \ No newline at end of file