Refactor Pulsar and Redis error handling, improve SQLite connection management, and enhance conversation class functionality

pull/866/head
harshalmore31 4 weeks ago
parent 5a054b0781
commit 3fa6ba22cd

@ -109,7 +109,7 @@ class PulsarConversation(BaseCommunication):
logger.debug( logger.debug(
f"Connecting to Pulsar broker at {pulsar_host}" f"Connecting to Pulsar broker at {pulsar_host}"
) )
self.client = self.pulsar.Client(pulsar_host) self.client = pulsar.Client(pulsar_host)
logger.debug(f"Creating producer for topic: {self.topic}") logger.debug(f"Creating producer for topic: {self.topic}")
self.producer = self.client.create_producer(self.topic) self.producer = self.client.create_producer(self.topic)
@ -122,7 +122,7 @@ class PulsarConversation(BaseCommunication):
) )
logger.info("Successfully connected to Pulsar broker") logger.info("Successfully connected to Pulsar broker")
except self.pulsar.ConnectError as e: except pulsar.ConnectError as e:
error_msg = f"Failed to connect to Pulsar broker at {pulsar_host}: {str(e)}" error_msg = f"Failed to connect to Pulsar broker at {pulsar_host}: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
raise PulsarConnectionError(error_msg) raise PulsarConnectionError(error_msg)
@ -211,7 +211,7 @@ class PulsarConversation(BaseCommunication):
) )
return message["id"] return message["id"]
except self.pulsar.ConnectError as e: except pulsar.ConnectError as e:
error_msg = f"Failed to send message to Pulsar: Connection error: {str(e)}" error_msg = f"Failed to send message to Pulsar: Connection error: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
raise PulsarConnectionError(error_msg) raise PulsarConnectionError(error_msg)
@ -248,7 +248,7 @@ class PulsarConversation(BaseCommunication):
msg = self.consumer.receive(timeout_millis=1000) msg = self.consumer.receive(timeout_millis=1000)
messages.append(json.loads(msg.data())) messages.append(json.loads(msg.data()))
self.consumer.acknowledge(msg) self.consumer.acknowledge(msg)
except self.pulsar.Timeout: except pulsar.Timeout:
break # No more messages available break # No more messages available
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Failed to decode message: {e}") logger.error(f"Failed to decode message: {e}")
@ -263,7 +263,7 @@ class PulsarConversation(BaseCommunication):
return messages return messages
except self.pulsar.ConnectError as e: except pulsar.ConnectError as e:
error_msg = f"Failed to receive messages from Pulsar: Connection error: {str(e)}" error_msg = f"Failed to receive messages from Pulsar: Connection error: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
raise PulsarConnectionError(error_msg) raise PulsarConnectionError(error_msg)
@ -400,7 +400,7 @@ class PulsarConversation(BaseCommunication):
f"Successfully cleared conversation. New ID: {self.conversation_id}" f"Successfully cleared conversation. New ID: {self.conversation_id}"
) )
except self.pulsar.ConnectError as e: except pulsar.ConnectError as e:
error_msg = f"Failed to clear conversation: Connection error: {str(e)}" error_msg = f"Failed to clear conversation: Connection error: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
raise PulsarConnectionError(error_msg) raise PulsarConnectionError(error_msg)
@ -696,7 +696,7 @@ class PulsarConversation(BaseCommunication):
msg = self.consumer.receive(timeout_millis=1000) msg = self.consumer.receive(timeout_millis=1000)
self.consumer.acknowledge(msg) self.consumer.acknowledge(msg)
health["consumer_active"] = True health["consumer_active"] = True
except self.pulsar.Timeout: except pulsar.Timeout:
pass pass
logger.info(f"Health check results: {health}") logger.info(f"Health check results: {health}")

@ -435,7 +435,7 @@ class RedisConversation(BaseStructure):
if custom_rules_prompt is not None: if custom_rules_prompt is not None:
self.add(user or "User", custom_rules_prompt) self.add(user or "User", custom_rules_prompt)
except Exception as e: except RedisError as e:
logger.error( logger.error(
f"Failed to initialize conversation: {str(e)}" f"Failed to initialize conversation: {str(e)}"
) )
@ -500,10 +500,10 @@ class RedisConversation(BaseStructure):
) )
return return
except ( except (
redis.ConnectionError, ConnectionError,
redis.TimeoutError, TimeoutError,
redis.AuthenticationError, AuthenticationError,
redis.BusyLoadingError, BusyLoadingError,
) as e: ) as e:
if attempt < retry_attempts - 1: if attempt < retry_attempts - 1:
logger.warning( logger.warning(
@ -560,7 +560,7 @@ class RedisConversation(BaseStructure):
""" """
try: try:
return operation_func(*args, **kwargs) return operation_func(*args, **kwargs)
except redis.RedisError as e: except RedisError as e:
error_msg = ( error_msg = (
f"Redis operation '{operation_name}' failed: {str(e)}" f"Redis operation '{operation_name}' failed: {str(e)}"
) )

@ -1,3 +1,4 @@
import sqlite3
import json import json
import datetime import datetime
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any
@ -64,16 +65,6 @@ class SQLiteConversation(BaseCommunication):
connection_timeout: float = 5.0, connection_timeout: float = 5.0,
**kwargs, **kwargs,
): ):
# Lazy load sqlite3
try:
import sqlite3
self.sqlite3 = sqlite3
self.sqlite3_available = True
except ImportError as e:
raise ImportError(
f"SQLite3 is not available: {e}"
)
super().__init__( super().__init__(
system_prompt=system_prompt, system_prompt=system_prompt,
time_enabled=time_enabled, time_enabled=time_enabled,
@ -171,13 +162,13 @@ class SQLiteConversation(BaseCommunication):
conn = None conn = None
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
conn = self.sqlite3.connect( conn = sqlite3.connect(
str(self.db_path), timeout=self.connection_timeout str(self.db_path), timeout=self.connection_timeout
) )
conn.row_factory = self.sqlite3.Row conn.row_factory = sqlite3.Row
yield conn yield conn
break break
except self.sqlite3.Error as e: except sqlite3.Error as e:
if attempt == self.max_retries - 1: if attempt == self.max_retries - 1:
raise raise
if self.enable_logging: if self.enable_logging:

@ -55,7 +55,6 @@ def get_conversation_dir():
# Define available providers # Define available providers
providers = Literal["mem0", "in-memory", "supabase", "redis", "sqlite", "duckdb", "pulsar"] providers = Literal["mem0", "in-memory", "supabase", "redis", "sqlite", "duckdb", "pulsar"]
def _create_backend_conversation(backend: str, **kwargs): def _create_backend_conversation(backend: str, **kwargs):
""" """
Create a backend conversation instance based on the specified backend type. Create a backend conversation instance based on the specified backend type.
@ -153,12 +152,6 @@ class Conversation(BaseStructure):
save_as_json_bool (bool): Flag to save conversation history as JSON. save_as_json_bool (bool): Flag to save conversation history as JSON.
token_count (bool): Flag to enable token counting for messages. token_count (bool): Flag to enable token counting for messages.
conversation_history (list): List to store the history of messages. conversation_history (list): List to store the history of messages.
cache_enabled (bool): Flag to enable prompt caching.
cache_stats (dict): Statistics about cache usage.
cache_lock (threading.Lock): Lock for thread-safe cache operations.
conversations_dir (str): Directory to store cached conversations.
backend (str): The storage backend to use.
backend_instance: The actual backend instance (for non-memory backends).
""" """
def __init__( def __init__(
@ -179,6 +172,7 @@ class Conversation(BaseStructure):
save_as_yaml: bool = False, save_as_yaml: bool = False,
save_as_json_bool: bool = False, save_as_json_bool: bool = False,
token_count: bool = True, token_count: bool = True,
message_id_on: bool = False,
provider: providers = "in-memory", provider: providers = "in-memory",
backend: Optional[str] = None, backend: Optional[str] = None,
# Backend-specific parameters # Backend-specific parameters
@ -195,6 +189,8 @@ class Conversation(BaseStructure):
persist_redis: bool = True, persist_redis: bool = True,
auto_persist: bool = True, auto_persist: bool = True,
redis_data_dir: Optional[str] = None, redis_data_dir: Optional[str] = None,
conversations_dir: Optional[str] = None,
*args, *args,
**kwargs, **kwargs,
): ):
@ -243,15 +239,7 @@ class Conversation(BaseStructure):
self.save_as_yaml = save_as_yaml self.save_as_yaml = save_as_yaml
self.save_as_json_bool = save_as_json_bool self.save_as_json_bool = save_as_json_bool
self.token_count = token_count self.token_count = token_count
self.cache_enabled = cache_enabled
self.provider = provider # Keep for backwards compatibility self.provider = provider # Keep for backwards compatibility
self.cache_stats = {
"hits": 0,
"misses": 0,
"cached_tokens": 0,
"total_tokens": 0,
}
self.cache_lock = threading.Lock()
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
# Initialize backend if using persistent storage # Initialize backend if using persistent storage
@ -302,11 +290,9 @@ class Conversation(BaseStructure):
"rules": self.rules, "rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt, "custom_rules_prompt": self.custom_rules_prompt,
"user": self.user, "user": self.user,
"auto_save": self.auto_save,
"save_as_yaml": self.save_as_yaml, "save_as_yaml": self.save_as_yaml,
"save_as_json_bool": self.save_as_json_bool, "save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count, "token_count": self.token_count,
"cache_enabled": self.cache_enabled,
} }
# Add backend-specific parameters # Add backend-specific parameters
@ -665,7 +651,12 @@ class Conversation(BaseStructure):
print("\n" + "=" * 50) print("\n" + "=" * 50)
def export_conversation(self, filename: str, *args, **kwargs): def export_conversation(self, filename: str, *args, **kwargs):
"""Export the conversation history to a file.""" """Export the conversation history to a file.
Args:
filename (str): Filename to export to.
"""
if self.backend_instance: if self.backend_instance:
try: try:
return self.backend_instance.export_conversation(filename, *args, **kwargs) return self.backend_instance.export_conversation(filename, *args, **kwargs)
@ -680,7 +671,11 @@ class Conversation(BaseStructure):
self.save_as_json(filename) self.save_as_json(filename)
def import_conversation(self, filename: str): def import_conversation(self, filename: str):
"""Import a conversation history from a file.""" """Import a conversation history from a file.
Args:
filename (str): Filename to import from.
"""
if self.backend_instance: if self.backend_instance:
try: try:
return self.backend_instance.import_conversation(filename) return self.backend_instance.import_conversation(filename)
@ -694,22 +689,34 @@ class Conversation(BaseStructure):
"""Count the number of messages by role. """Count the number of messages by role.
Returns: Returns:
dict: A dictionary mapping roles to message counts. dict: A dictionary with counts of messages by role.
""" """
# Check backend instance first
if self.backend_instance: if self.backend_instance:
try: try:
return self.backend_instance.count_messages_by_role() return self.backend_instance.count_messages_by_role()
except Exception as e: except Exception as e:
logger.error(f"Backend count_messages_by_role failed: {e}") logger.error(f"Backend count_messages_by_role failed: {e}")
# Fallback to in-memory count # Fallback to local implementation below
pass pass
# Initialize counts with expected roles
counts = {
"system": 0,
"user": 0,
"assistant": 0,
"function": 0,
}
role_counts = {} # Count messages by role
for message in self.conversation_history: for message in self.conversation_history:
role = message["role"] role = message["role"]
role_counts[role] = role_counts.get(role, 0) + 1 if role in counts:
return role_counts counts[role] += 1
else:
# Handle unexpected roles dynamically
counts[role] = counts.get(role, 0) + 1
return counts
def return_history_as_string(self): def return_history_as_string(self):
"""Return the conversation history as a string. """Return the conversation history as a string.
@ -753,17 +760,55 @@ class Conversation(BaseStructure):
Args: Args:
filename (str): Filename to save the conversation history. filename (str): Filename to save the conversation history.
""" """
# Check backend instance first
if self.backend_instance: if self.backend_instance:
try: try:
return self.backend_instance.save_as_json(filename) return self.backend_instance.save_as_json(filename)
except Exception as e: except Exception as e:
logger.error(f"Backend save_as_json failed: {e}") logger.error(f"Backend save_as_json failed: {e}")
# Fallback to in-memory save # Fallback to local save implementation below
pass
# Don't save if saving is disabled
if not self.save_enabled:
return
save_path = filename or self.save_filepath
if save_path is not None:
try:
# Prepare metadata
metadata = {
"id": self.id,
"name": self.name,
"created_at": datetime.datetime.now().isoformat(),
"system_prompt": self.system_prompt,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
}
# Prepare save data
save_data = {
"metadata": metadata,
"history": self.conversation_history,
}
# Create directory if it doesn't exist
os.makedirs(
os.path.dirname(save_path),
mode=0o755,
exist_ok=True,
)
if filename is not None: # Write directly to file
with open(filename, "w") as f: with open(save_path, "w") as f:
json.dump(self.conversation_history, f) json.dump(save_data, f, indent=2)
# Only log explicit saves, not autosaves
if not self.autosave:
logger.info(
f"Successfully saved conversation to {save_path}"
)
except Exception as e:
logger.error(f"Failed to save conversation: {str(e)}")
def load_from_json(self, filename: str): def load_from_json(self, filename: str):
"""Load the conversation history from a JSON file. """Load the conversation history from a JSON file.
@ -771,17 +816,32 @@ class Conversation(BaseStructure):
Args: Args:
filename (str): Filename to load from. filename (str): Filename to load from.
""" """
if self.backend_instance: if filename is not None and os.path.exists(filename):
try: try:
return self.backend_instance.load_from_json(filename) with open(filename) as f:
except Exception as e: data = json.load(f)
logger.error(f"Backend load_from_json failed: {e}")
# Fallback to in-memory load # Load metadata
pass metadata = data.get("metadata", {})
self.id = metadata.get("id", self.id)
self.name = metadata.get("name", self.name)
self.system_prompt = metadata.get(
"system_prompt", self.system_prompt
)
self.rules = metadata.get("rules", self.rules)
self.custom_rules_prompt = metadata.get(
"custom_rules_prompt", self.custom_rules_prompt
)
# Load conversation history
self.conversation_history = data.get("history", [])
if filename is not None: logger.info(
with open(filename) as f: f"Successfully loaded conversation from {filename}"
self.conversation_history = json.load(f) )
except Exception as e:
logger.error(f"Failed to load conversation: {str(e)}")
raise
def search_keyword_in_conversation(self, keyword: str): def search_keyword_in_conversation(self, keyword: str):
"""Search for a keyword in the conversation history. """Search for a keyword in the conversation history.
@ -865,7 +925,7 @@ class Conversation(BaseStructure):
"""Convert the conversation history to a dictionary. """Convert the conversation history to a dictionary.
Returns: Returns:
dict: The conversation history as a dictionary. list: The conversation history as a list of dictionaries.
""" """
if self.backend_instance: if self.backend_instance:
try: try:
@ -1154,12 +1214,52 @@ class Conversation(BaseStructure):
return [] return []
conversations = [] conversations = []
for file in os.listdir(conversations_dir): seen_ids = (
if file.endswith(".json"): set()
conversations.append( ) # Track seen conversation IDs to avoid duplicates
file[:-5]
) # Remove .json extension for filename in os.listdir(conv_dir):
return conversations if filename.endswith(".json"):
try:
filepath = os.path.join(conv_dir, filename)
with open(filepath) as f:
data = json.load(f)
metadata = data.get("metadata", {})
conv_id = metadata.get("id")
name = metadata.get("name")
created_at = metadata.get("created_at")
# Skip if we've already seen this ID or if required fields are missing
if (
not all([conv_id, name, created_at])
or conv_id in seen_ids
):
continue
seen_ids.add(conv_id)
conversations.append(
{
"id": conv_id,
"name": name,
"created_at": created_at,
"filepath": filepath,
}
)
except json.JSONDecodeError:
logger.warning(
f"Skipping corrupted conversation file: {filename}"
)
continue
except Exception as e:
logger.error(
f"Failed to read conversation {filename}: {str(e)}"
)
continue
# Sort by creation date, newest first
return sorted(
conversations, key=lambda x: x["created_at"], reverse=True
)
def clear_memory(self): def clear_memory(self):
"""Clear the memory of the conversation.""" """Clear the memory of the conversation."""

@ -85,7 +85,6 @@ class RedisConversationTester:
def setup(self): def setup(self):
"""Initialize Redis server and conversation for testing.""" """Initialize Redis server and conversation for testing."""
try: try:
# Try first with external Redis (if available)
logger.info("Trying to connect to external Redis server...") logger.info("Trying to connect to external Redis server...")
self.conversation = RedisConversation( self.conversation = RedisConversation(
system_prompt="Test System Prompt", system_prompt="Test System Prompt",
@ -99,7 +98,6 @@ class RedisConversationTester:
except Exception as external_error: except Exception as external_error:
logger.info(f"External Redis connection failed: {external_error}") logger.info(f"External Redis connection failed: {external_error}")
logger.info("Trying to start embedded Redis server...") logger.info("Trying to start embedded Redis server...")
try: try:
# Fallback to embedded Redis # Fallback to embedded Redis
self.conversation = RedisConversation( self.conversation = RedisConversation(
@ -119,16 +117,8 @@ class RedisConversationTester:
def cleanup(self): def cleanup(self):
"""Cleanup resources after tests.""" """Cleanup resources after tests."""
if self.conversation: if self.redis_server:
try: self.redis_server.stop()
# Check if we have an embedded server to stop
if hasattr(self.conversation, 'embedded_server') and self.conversation.embedded_server is not None:
self.conversation.embedded_server.stop()
# Close Redis client if it exists
if hasattr(self.conversation, 'redis_client') and self.conversation.redis_client:
self.conversation.redis_client.close()
except Exception as e:
logger.warning(f"Error during cleanup: {str(e)}")
def test_initialization(self): def test_initialization(self):
"""Test basic initialization.""" """Test basic initialization."""
@ -151,8 +141,6 @@ class RedisConversationTester:
json_content = {"key": "value", "nested": {"data": 123}} json_content = {"key": "value", "nested": {"data": 123}}
self.conversation.add("system", json_content) self.conversation.add("system", json_content)
last_message = self.conversation.get_final_message_content() last_message = self.conversation.get_final_message_content()
# Parse the JSON string back to dict for comparison
if isinstance(last_message, str): if isinstance(last_message, str):
try: try:
parsed_content = json.loads(last_message) parsed_content = json.loads(last_message)
@ -173,29 +161,20 @@ class RedisConversationTester:
initial_count = len( initial_count = len(
self.conversation.return_messages_as_list() self.conversation.return_messages_as_list()
) )
if initial_count > 0: self.conversation.delete(0)
self.conversation.delete(0) new_count = len(self.conversation.return_messages_as_list())
new_count = len(self.conversation.return_messages_as_list()) assert (
assert ( new_count == initial_count - 1
new_count == initial_count - 1 ), "Failed to delete message"
), "Failed to delete message"
def test_update(self): def test_update(self):
"""Test message update.""" """Test message update."""
# Add initial message # Add initial message
self.conversation.add("user", "original message") self.conversation.add("user", "original message")
# Get all messages to find the last message ID
all_messages = self.conversation.return_messages_as_list() all_messages = self.conversation.return_messages_as_list()
if len(all_messages) > 0: if len(all_messages) > 0:
# Update the last message (index 0 in this case means the first message)
# Note: This test may need adjustment based on how Redis stores messages
self.conversation.update(0, "user", "updated message") self.conversation.update(0, "user", "updated message")
# Get the message directly using query
updated_message = self.conversation.query(0) updated_message = self.conversation.query(0)
# Since Redis might store content differently, just check that update didn't crash
assert True, "Update method executed successfully" assert True, "Update method executed successfully"
def test_clear(self): def test_clear(self):
@ -207,28 +186,14 @@ class RedisConversationTester:
def test_export_import(self): def test_export_import(self):
"""Test export and import functionality.""" """Test export and import functionality."""
try: self.conversation.add("user", "export test")
self.conversation.add("user", "export test") self.conversation.export_conversation("test_export.txt")
self.conversation.export_conversation("test_export.txt") self.conversation.clear()
self.conversation.import_conversation("test_export.txt")
# Clear conversation messages = self.conversation.return_messages_as_list()
self.conversation.clear() assert (
len(messages) > 0
# Import back ), "Failed to export/import conversation"
self.conversation.import_conversation("test_export.txt")
messages = self.conversation.return_messages_as_list()
assert (
len(messages) > 0
), "Failed to export/import conversation"
# Cleanup test file
import os
if os.path.exists("test_export.txt"):
os.remove("test_export.txt")
except Exception as e:
logger.warning(f"Export/import test failed: {e}")
# Don't fail the test entirely, just log the warning
assert True, "Export/import test completed with warnings"
def test_json_operations(self): def test_json_operations(self):
"""Test JSON operations.""" """Test JSON operations."""
@ -249,7 +214,6 @@ class RedisConversationTester:
self.conversation.add("user", "token test message") self.conversation.add("user", "token test message")
time.sleep(1) # Wait for async token counting time.sleep(1) # Wait for async token counting
messages = self.conversation.to_dict() messages = self.conversation.to_dict()
# Token counting may not be implemented in Redis version, so just check it doesn't crash
assert isinstance(messages, list), "Token counting test completed" assert isinstance(messages, list), "Token counting test completed"
def test_cache_operations(self): def test_cache_operations(self):
@ -322,7 +286,6 @@ class RedisConversationTester:
def main(): def main():
"""Main function to run tests and save results.""" """Main function to run tests and save results."""
logger.info(f"Starting Redis tests. REDIS_AVAILABLE: {REDIS_AVAILABLE}") logger.info(f"Starting Redis tests. REDIS_AVAILABLE: {REDIS_AVAILABLE}")
tester = RedisConversationTester() tester = RedisConversationTester()
markdown_results = tester.run_all_tests() markdown_results = tester.run_all_tests()
@ -333,8 +296,6 @@ def main():
logger.info("Test results have been saved to redis_test_results.md") logger.info("Test results have been saved to redis_test_results.md")
except Exception as e: except Exception as e:
logger.error(f"Failed to save test results: {e}") logger.error(f"Failed to save test results: {e}")
# Also print results to console
print(markdown_results) print(markdown_results)

Loading…
Cancel
Save