From 3fa6ba22cdce7382155031b6ce38cf5148de9221 Mon Sep 17 00:00:00 2001 From: harshalmore31 Date: Fri, 6 Jun 2025 09:35:36 +0530 Subject: [PATCH] Refactor Pulsar and Redis error handling, improve SQLite connection management, and enhance conversation class functionality --- swarms/communication/pulsar_struct.py | 14 +- swarms/communication/redis_wrap.py | 12 +- swarms/communication/sqlite_wrap.py | 17 +-- swarms/structs/conversation.py | 196 +++++++++++++++++++------- tests/communication/test_redis.py | 71 +++------- 5 files changed, 181 insertions(+), 129 deletions(-) diff --git a/swarms/communication/pulsar_struct.py b/swarms/communication/pulsar_struct.py index 93ea00e4..cecac121 100644 --- a/swarms/communication/pulsar_struct.py +++ b/swarms/communication/pulsar_struct.py @@ -109,7 +109,7 @@ class PulsarConversation(BaseCommunication): logger.debug( f"Connecting to Pulsar broker at {pulsar_host}" ) - self.client = self.pulsar.Client(pulsar_host) + self.client = pulsar.Client(pulsar_host) logger.debug(f"Creating producer for topic: {self.topic}") self.producer = self.client.create_producer(self.topic) @@ -122,7 +122,7 @@ class PulsarConversation(BaseCommunication): ) logger.info("Successfully connected to Pulsar broker") - except self.pulsar.ConnectError as e: + except pulsar.ConnectError as e: error_msg = f"Failed to connect to Pulsar broker at {pulsar_host}: {str(e)}" logger.error(error_msg) raise PulsarConnectionError(error_msg) @@ -211,7 +211,7 @@ class PulsarConversation(BaseCommunication): ) return message["id"] - except self.pulsar.ConnectError as e: + except pulsar.ConnectError as e: error_msg = f"Failed to send message to Pulsar: Connection error: {str(e)}" logger.error(error_msg) raise PulsarConnectionError(error_msg) @@ -248,7 +248,7 @@ class PulsarConversation(BaseCommunication): msg = self.consumer.receive(timeout_millis=1000) messages.append(json.loads(msg.data())) self.consumer.acknowledge(msg) - except self.pulsar.Timeout: + except pulsar.Timeout: break # No more messages available except json.JSONDecodeError as e: logger.error(f"Failed to decode message: {e}") @@ -263,7 +263,7 @@ class PulsarConversation(BaseCommunication): return messages - except self.pulsar.ConnectError as e: + except pulsar.ConnectError as e: error_msg = f"Failed to receive messages from Pulsar: Connection error: {str(e)}" logger.error(error_msg) raise PulsarConnectionError(error_msg) @@ -400,7 +400,7 @@ class PulsarConversation(BaseCommunication): f"Successfully cleared conversation. New ID: {self.conversation_id}" ) - except self.pulsar.ConnectError as e: + except pulsar.ConnectError as e: error_msg = f"Failed to clear conversation: Connection error: {str(e)}" logger.error(error_msg) raise PulsarConnectionError(error_msg) @@ -696,7 +696,7 @@ class PulsarConversation(BaseCommunication): msg = self.consumer.receive(timeout_millis=1000) self.consumer.acknowledge(msg) health["consumer_active"] = True - except self.pulsar.Timeout: + except pulsar.Timeout: pass logger.info(f"Health check results: {health}") diff --git a/swarms/communication/redis_wrap.py b/swarms/communication/redis_wrap.py index 829a92d6..40fc1505 100644 --- a/swarms/communication/redis_wrap.py +++ b/swarms/communication/redis_wrap.py @@ -435,7 +435,7 @@ class RedisConversation(BaseStructure): if custom_rules_prompt is not None: self.add(user or "User", custom_rules_prompt) - except Exception as e: + except RedisError as e: logger.error( f"Failed to initialize conversation: {str(e)}" ) @@ -500,10 +500,10 @@ class RedisConversation(BaseStructure): ) return except ( - redis.ConnectionError, - redis.TimeoutError, - redis.AuthenticationError, - redis.BusyLoadingError, + ConnectionError, + TimeoutError, + AuthenticationError, + BusyLoadingError, ) as e: if attempt < retry_attempts - 1: logger.warning( @@ -560,7 +560,7 @@ class RedisConversation(BaseStructure): """ try: return operation_func(*args, **kwargs) - except redis.RedisError as e: + except RedisError as e: error_msg = ( f"Redis operation '{operation_name}' failed: {str(e)}" ) diff --git a/swarms/communication/sqlite_wrap.py b/swarms/communication/sqlite_wrap.py index 3b1d190d..443a456e 100644 --- a/swarms/communication/sqlite_wrap.py +++ b/swarms/communication/sqlite_wrap.py @@ -1,3 +1,4 @@ +import sqlite3 import json import datetime from typing import List, Optional, Union, Dict, Any @@ -64,16 +65,6 @@ class SQLiteConversation(BaseCommunication): connection_timeout: float = 5.0, **kwargs, ): - # Lazy load sqlite3 - try: - import sqlite3 - self.sqlite3 = sqlite3 - self.sqlite3_available = True - except ImportError as e: - raise ImportError( - f"SQLite3 is not available: {e}" - ) - super().__init__( system_prompt=system_prompt, time_enabled=time_enabled, @@ -171,13 +162,13 @@ class SQLiteConversation(BaseCommunication): conn = None for attempt in range(self.max_retries): try: - conn = self.sqlite3.connect( + conn = sqlite3.connect( str(self.db_path), timeout=self.connection_timeout ) - conn.row_factory = self.sqlite3.Row + conn.row_factory = sqlite3.Row yield conn break - except self.sqlite3.Error as e: + except sqlite3.Error as e: if attempt == self.max_retries - 1: raise if self.enable_logging: diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 1458bdc0..5cd8a9f7 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -55,7 +55,6 @@ def get_conversation_dir(): # Define available providers providers = Literal["mem0", "in-memory", "supabase", "redis", "sqlite", "duckdb", "pulsar"] - def _create_backend_conversation(backend: str, **kwargs): """ Create a backend conversation instance based on the specified backend type. @@ -153,12 +152,6 @@ class Conversation(BaseStructure): save_as_json_bool (bool): Flag to save conversation history as JSON. token_count (bool): Flag to enable token counting for messages. conversation_history (list): List to store the history of messages. - cache_enabled (bool): Flag to enable prompt caching. - cache_stats (dict): Statistics about cache usage. - cache_lock (threading.Lock): Lock for thread-safe cache operations. - conversations_dir (str): Directory to store cached conversations. - backend (str): The storage backend to use. - backend_instance: The actual backend instance (for non-memory backends). """ def __init__( @@ -179,6 +172,7 @@ class Conversation(BaseStructure): save_as_yaml: bool = False, save_as_json_bool: bool = False, token_count: bool = True, + message_id_on: bool = False, provider: providers = "in-memory", backend: Optional[str] = None, # Backend-specific parameters @@ -195,6 +189,8 @@ class Conversation(BaseStructure): persist_redis: bool = True, auto_persist: bool = True, redis_data_dir: Optional[str] = None, + conversations_dir: Optional[str] = None, + *args, **kwargs, ): @@ -243,15 +239,7 @@ class Conversation(BaseStructure): self.save_as_yaml = save_as_yaml self.save_as_json_bool = save_as_json_bool self.token_count = token_count - self.cache_enabled = cache_enabled self.provider = provider # Keep for backwards compatibility - self.cache_stats = { - "hits": 0, - "misses": 0, - "cached_tokens": 0, - "total_tokens": 0, - } - self.cache_lock = threading.Lock() self.conversations_dir = conversations_dir # Initialize backend if using persistent storage @@ -302,11 +290,9 @@ class Conversation(BaseStructure): "rules": self.rules, "custom_rules_prompt": self.custom_rules_prompt, "user": self.user, - "auto_save": self.auto_save, "save_as_yaml": self.save_as_yaml, "save_as_json_bool": self.save_as_json_bool, "token_count": self.token_count, - "cache_enabled": self.cache_enabled, } # Add backend-specific parameters @@ -665,7 +651,12 @@ class Conversation(BaseStructure): print("\n" + "=" * 50) def export_conversation(self, filename: str, *args, **kwargs): - """Export the conversation history to a file.""" + """Export the conversation history to a file. + + Args: + filename (str): Filename to export to. + """ + if self.backend_instance: try: return self.backend_instance.export_conversation(filename, *args, **kwargs) @@ -680,7 +671,11 @@ class Conversation(BaseStructure): self.save_as_json(filename) def import_conversation(self, filename: str): - """Import a conversation history from a file.""" + """Import a conversation history from a file. + + Args: + filename (str): Filename to import from. + """ if self.backend_instance: try: return self.backend_instance.import_conversation(filename) @@ -694,22 +689,34 @@ class Conversation(BaseStructure): """Count the number of messages by role. Returns: - dict: A dictionary mapping roles to message counts. + dict: A dictionary with counts of messages by role. """ + # Check backend instance first if self.backend_instance: try: return self.backend_instance.count_messages_by_role() except Exception as e: logger.error(f"Backend count_messages_by_role failed: {e}") - # Fallback to in-memory count + # Fallback to local implementation below pass - - role_counts = {} + # Initialize counts with expected roles + counts = { + "system": 0, + "user": 0, + "assistant": 0, + "function": 0, + } + + # Count messages by role for message in self.conversation_history: role = message["role"] - role_counts[role] = role_counts.get(role, 0) + 1 - return role_counts - + if role in counts: + counts[role] += 1 + else: + # Handle unexpected roles dynamically + counts[role] = counts.get(role, 0) + 1 + + return counts def return_history_as_string(self): """Return the conversation history as a string. @@ -753,17 +760,55 @@ class Conversation(BaseStructure): Args: filename (str): Filename to save the conversation history. """ + # Check backend instance first if self.backend_instance: try: return self.backend_instance.save_as_json(filename) except Exception as e: logger.error(f"Backend save_as_json failed: {e}") - # Fallback to in-memory save - pass - - if filename is not None: - with open(filename, "w") as f: - json.dump(self.conversation_history, f) + # Fallback to local save implementation below + + # Don't save if saving is disabled + if not self.save_enabled: + return + + save_path = filename or self.save_filepath + if save_path is not None: + try: + # Prepare metadata + metadata = { + "id": self.id, + "name": self.name, + "created_at": datetime.datetime.now().isoformat(), + "system_prompt": self.system_prompt, + "rules": self.rules, + "custom_rules_prompt": self.custom_rules_prompt, + } + + # Prepare save data + save_data = { + "metadata": metadata, + "history": self.conversation_history, + } + + # Create directory if it doesn't exist + os.makedirs( + os.path.dirname(save_path), + mode=0o755, + exist_ok=True, + ) + + # Write directly to file + with open(save_path, "w") as f: + json.dump(save_data, f, indent=2) + + # Only log explicit saves, not autosaves + if not self.autosave: + logger.info( + f"Successfully saved conversation to {save_path}" + ) + except Exception as e: + logger.error(f"Failed to save conversation: {str(e)}") def load_from_json(self, filename: str): """Load the conversation history from a JSON file. @@ -771,17 +816,32 @@ class Conversation(BaseStructure): Args: filename (str): Filename to load from. """ - if self.backend_instance: + if filename is not None and os.path.exists(filename): try: - return self.backend_instance.load_from_json(filename) + with open(filename) as f: + data = json.load(f) + + # Load metadata + metadata = data.get("metadata", {}) + self.id = metadata.get("id", self.id) + self.name = metadata.get("name", self.name) + self.system_prompt = metadata.get( + "system_prompt", self.system_prompt + ) + self.rules = metadata.get("rules", self.rules) + self.custom_rules_prompt = metadata.get( + "custom_rules_prompt", self.custom_rules_prompt + ) + + # Load conversation history + self.conversation_history = data.get("history", []) + + logger.info( + f"Successfully loaded conversation from {filename}" + ) except Exception as e: - logger.error(f"Backend load_from_json failed: {e}") - # Fallback to in-memory load - pass - - if filename is not None: - with open(filename) as f: - self.conversation_history = json.load(f) + logger.error(f"Failed to load conversation: {str(e)}") + raise def search_keyword_in_conversation(self, keyword: str): """Search for a keyword in the conversation history. @@ -865,7 +925,7 @@ class Conversation(BaseStructure): """Convert the conversation history to a dictionary. Returns: - dict: The conversation history as a dictionary. + list: The conversation history as a list of dictionaries. """ if self.backend_instance: try: @@ -1154,12 +1214,52 @@ class Conversation(BaseStructure): return [] conversations = [] - for file in os.listdir(conversations_dir): - if file.endswith(".json"): - conversations.append( - file[:-5] - ) # Remove .json extension - return conversations + seen_ids = ( + set() + ) # Track seen conversation IDs to avoid duplicates + + for filename in os.listdir(conv_dir): + if filename.endswith(".json"): + try: + filepath = os.path.join(conv_dir, filename) + with open(filepath) as f: + data = json.load(f) + metadata = data.get("metadata", {}) + conv_id = metadata.get("id") + name = metadata.get("name") + created_at = metadata.get("created_at") + + # Skip if we've already seen this ID or if required fields are missing + if ( + not all([conv_id, name, created_at]) + or conv_id in seen_ids + ): + continue + + seen_ids.add(conv_id) + conversations.append( + { + "id": conv_id, + "name": name, + "created_at": created_at, + "filepath": filepath, + } + ) + except json.JSONDecodeError: + logger.warning( + f"Skipping corrupted conversation file: {filename}" + ) + continue + except Exception as e: + logger.error( + f"Failed to read conversation {filename}: {str(e)}" + ) + continue + + # Sort by creation date, newest first + return sorted( + conversations, key=lambda x: x["created_at"], reverse=True + ) def clear_memory(self): """Clear the memory of the conversation.""" diff --git a/tests/communication/test_redis.py b/tests/communication/test_redis.py index e0b0b988..3e72c01b 100644 --- a/tests/communication/test_redis.py +++ b/tests/communication/test_redis.py @@ -85,7 +85,6 @@ class RedisConversationTester: def setup(self): """Initialize Redis server and conversation for testing.""" try: - # Try first with external Redis (if available) logger.info("Trying to connect to external Redis server...") self.conversation = RedisConversation( system_prompt="Test System Prompt", @@ -99,7 +98,6 @@ class RedisConversationTester: except Exception as external_error: logger.info(f"External Redis connection failed: {external_error}") logger.info("Trying to start embedded Redis server...") - try: # Fallback to embedded Redis self.conversation = RedisConversation( @@ -119,16 +117,8 @@ class RedisConversationTester: def cleanup(self): """Cleanup resources after tests.""" - if self.conversation: - try: - # Check if we have an embedded server to stop - if hasattr(self.conversation, 'embedded_server') and self.conversation.embedded_server is not None: - self.conversation.embedded_server.stop() - # Close Redis client if it exists - if hasattr(self.conversation, 'redis_client') and self.conversation.redis_client: - self.conversation.redis_client.close() - except Exception as e: - logger.warning(f"Error during cleanup: {str(e)}") + if self.redis_server: + self.redis_server.stop() def test_initialization(self): """Test basic initialization.""" @@ -151,8 +141,6 @@ class RedisConversationTester: json_content = {"key": "value", "nested": {"data": 123}} self.conversation.add("system", json_content) last_message = self.conversation.get_final_message_content() - - # Parse the JSON string back to dict for comparison if isinstance(last_message, str): try: parsed_content = json.loads(last_message) @@ -173,29 +161,20 @@ class RedisConversationTester: initial_count = len( self.conversation.return_messages_as_list() ) - if initial_count > 0: - self.conversation.delete(0) - new_count = len(self.conversation.return_messages_as_list()) - assert ( - new_count == initial_count - 1 - ), "Failed to delete message" + self.conversation.delete(0) + new_count = len(self.conversation.return_messages_as_list()) + assert ( + new_count == initial_count - 1 + ), "Failed to delete message" def test_update(self): """Test message update.""" # Add initial message self.conversation.add("user", "original message") - - # Get all messages to find the last message ID all_messages = self.conversation.return_messages_as_list() if len(all_messages) > 0: - # Update the last message (index 0 in this case means the first message) - # Note: This test may need adjustment based on how Redis stores messages - self.conversation.update(0, "user", "updated message") - - # Get the message directly using query + self.conversation.update(0, "user", "updated message") updated_message = self.conversation.query(0) - - # Since Redis might store content differently, just check that update didn't crash assert True, "Update method executed successfully" def test_clear(self): @@ -207,28 +186,14 @@ class RedisConversationTester: def test_export_import(self): """Test export and import functionality.""" - try: - self.conversation.add("user", "export test") - self.conversation.export_conversation("test_export.txt") - - # Clear conversation - self.conversation.clear() - - # Import back - self.conversation.import_conversation("test_export.txt") - messages = self.conversation.return_messages_as_list() - assert ( - len(messages) > 0 - ), "Failed to export/import conversation" - - # Cleanup test file - import os - if os.path.exists("test_export.txt"): - os.remove("test_export.txt") - except Exception as e: - logger.warning(f"Export/import test failed: {e}") - # Don't fail the test entirely, just log the warning - assert True, "Export/import test completed with warnings" + self.conversation.add("user", "export test") + self.conversation.export_conversation("test_export.txt") + self.conversation.clear() + self.conversation.import_conversation("test_export.txt") + messages = self.conversation.return_messages_as_list() + assert ( + len(messages) > 0 + ), "Failed to export/import conversation" def test_json_operations(self): """Test JSON operations.""" @@ -249,7 +214,6 @@ class RedisConversationTester: self.conversation.add("user", "token test message") time.sleep(1) # Wait for async token counting messages = self.conversation.to_dict() - # Token counting may not be implemented in Redis version, so just check it doesn't crash assert isinstance(messages, list), "Token counting test completed" def test_cache_operations(self): @@ -322,7 +286,6 @@ class RedisConversationTester: def main(): """Main function to run tests and save results.""" logger.info(f"Starting Redis tests. REDIS_AVAILABLE: {REDIS_AVAILABLE}") - tester = RedisConversationTester() markdown_results = tester.run_all_tests() @@ -333,8 +296,6 @@ def main(): logger.info("Test results have been saved to redis_test_results.md") except Exception as e: logger.error(f"Failed to save test results: {e}") - - # Also print results to console print(markdown_results)