Refactor Pulsar and Redis error handling, improve SQLite connection management, and enhance conversation class functionality

pull/866/head
harshalmore31 4 weeks ago
parent 5a054b0781
commit 3fa6ba22cd

@ -109,7 +109,7 @@ class PulsarConversation(BaseCommunication):
logger.debug(
f"Connecting to Pulsar broker at {pulsar_host}"
)
self.client = self.pulsar.Client(pulsar_host)
self.client = pulsar.Client(pulsar_host)
logger.debug(f"Creating producer for topic: {self.topic}")
self.producer = self.client.create_producer(self.topic)
@ -122,7 +122,7 @@ class PulsarConversation(BaseCommunication):
)
logger.info("Successfully connected to Pulsar broker")
except self.pulsar.ConnectError as e:
except pulsar.ConnectError as e:
error_msg = f"Failed to connect to Pulsar broker at {pulsar_host}: {str(e)}"
logger.error(error_msg)
raise PulsarConnectionError(error_msg)
@ -211,7 +211,7 @@ class PulsarConversation(BaseCommunication):
)
return message["id"]
except self.pulsar.ConnectError as e:
except pulsar.ConnectError as e:
error_msg = f"Failed to send message to Pulsar: Connection error: {str(e)}"
logger.error(error_msg)
raise PulsarConnectionError(error_msg)
@ -248,7 +248,7 @@ class PulsarConversation(BaseCommunication):
msg = self.consumer.receive(timeout_millis=1000)
messages.append(json.loads(msg.data()))
self.consumer.acknowledge(msg)
except self.pulsar.Timeout:
except pulsar.Timeout:
break # No more messages available
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message: {e}")
@ -263,7 +263,7 @@ class PulsarConversation(BaseCommunication):
return messages
except self.pulsar.ConnectError as e:
except pulsar.ConnectError as e:
error_msg = f"Failed to receive messages from Pulsar: Connection error: {str(e)}"
logger.error(error_msg)
raise PulsarConnectionError(error_msg)
@ -400,7 +400,7 @@ class PulsarConversation(BaseCommunication):
f"Successfully cleared conversation. New ID: {self.conversation_id}"
)
except self.pulsar.ConnectError as e:
except pulsar.ConnectError as e:
error_msg = f"Failed to clear conversation: Connection error: {str(e)}"
logger.error(error_msg)
raise PulsarConnectionError(error_msg)
@ -696,7 +696,7 @@ class PulsarConversation(BaseCommunication):
msg = self.consumer.receive(timeout_millis=1000)
self.consumer.acknowledge(msg)
health["consumer_active"] = True
except self.pulsar.Timeout:
except pulsar.Timeout:
pass
logger.info(f"Health check results: {health}")

@ -435,7 +435,7 @@ class RedisConversation(BaseStructure):
if custom_rules_prompt is not None:
self.add(user or "User", custom_rules_prompt)
except Exception as e:
except RedisError as e:
logger.error(
f"Failed to initialize conversation: {str(e)}"
)
@ -500,10 +500,10 @@ class RedisConversation(BaseStructure):
)
return
except (
redis.ConnectionError,
redis.TimeoutError,
redis.AuthenticationError,
redis.BusyLoadingError,
ConnectionError,
TimeoutError,
AuthenticationError,
BusyLoadingError,
) as e:
if attempt < retry_attempts - 1:
logger.warning(
@ -560,7 +560,7 @@ class RedisConversation(BaseStructure):
"""
try:
return operation_func(*args, **kwargs)
except redis.RedisError as e:
except RedisError as e:
error_msg = (
f"Redis operation '{operation_name}' failed: {str(e)}"
)

@ -1,3 +1,4 @@
import sqlite3
import json
import datetime
from typing import List, Optional, Union, Dict, Any
@ -64,16 +65,6 @@ class SQLiteConversation(BaseCommunication):
connection_timeout: float = 5.0,
**kwargs,
):
# Lazy load sqlite3
try:
import sqlite3
self.sqlite3 = sqlite3
self.sqlite3_available = True
except ImportError as e:
raise ImportError(
f"SQLite3 is not available: {e}"
)
super().__init__(
system_prompt=system_prompt,
time_enabled=time_enabled,
@ -171,13 +162,13 @@ class SQLiteConversation(BaseCommunication):
conn = None
for attempt in range(self.max_retries):
try:
conn = self.sqlite3.connect(
conn = sqlite3.connect(
str(self.db_path), timeout=self.connection_timeout
)
conn.row_factory = self.sqlite3.Row
conn.row_factory = sqlite3.Row
yield conn
break
except self.sqlite3.Error as e:
except sqlite3.Error as e:
if attempt == self.max_retries - 1:
raise
if self.enable_logging:

@ -55,7 +55,6 @@ def get_conversation_dir():
# Define available providers
providers = Literal["mem0", "in-memory", "supabase", "redis", "sqlite", "duckdb", "pulsar"]
def _create_backend_conversation(backend: str, **kwargs):
"""
Create a backend conversation instance based on the specified backend type.
@ -153,12 +152,6 @@ class Conversation(BaseStructure):
save_as_json_bool (bool): Flag to save conversation history as JSON.
token_count (bool): Flag to enable token counting for messages.
conversation_history (list): List to store the history of messages.
cache_enabled (bool): Flag to enable prompt caching.
cache_stats (dict): Statistics about cache usage.
cache_lock (threading.Lock): Lock for thread-safe cache operations.
conversations_dir (str): Directory to store cached conversations.
backend (str): The storage backend to use.
backend_instance: The actual backend instance (for non-memory backends).
"""
def __init__(
@ -179,6 +172,7 @@ class Conversation(BaseStructure):
save_as_yaml: bool = False,
save_as_json_bool: bool = False,
token_count: bool = True,
message_id_on: bool = False,
provider: providers = "in-memory",
backend: Optional[str] = None,
# Backend-specific parameters
@ -195,6 +189,8 @@ class Conversation(BaseStructure):
persist_redis: bool = True,
auto_persist: bool = True,
redis_data_dir: Optional[str] = None,
conversations_dir: Optional[str] = None,
*args,
**kwargs,
):
@ -243,15 +239,7 @@ class Conversation(BaseStructure):
self.save_as_yaml = save_as_yaml
self.save_as_json_bool = save_as_json_bool
self.token_count = token_count
self.cache_enabled = cache_enabled
self.provider = provider # Keep for backwards compatibility
self.cache_stats = {
"hits": 0,
"misses": 0,
"cached_tokens": 0,
"total_tokens": 0,
}
self.cache_lock = threading.Lock()
self.conversations_dir = conversations_dir
# Initialize backend if using persistent storage
@ -302,11 +290,9 @@ class Conversation(BaseStructure):
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
"user": self.user,
"auto_save": self.auto_save,
"save_as_yaml": self.save_as_yaml,
"save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count,
"cache_enabled": self.cache_enabled,
}
# Add backend-specific parameters
@ -665,7 +651,12 @@ class Conversation(BaseStructure):
print("\n" + "=" * 50)
def export_conversation(self, filename: str, *args, **kwargs):
"""Export the conversation history to a file."""
"""Export the conversation history to a file.
Args:
filename (str): Filename to export to.
"""
if self.backend_instance:
try:
return self.backend_instance.export_conversation(filename, *args, **kwargs)
@ -680,7 +671,11 @@ class Conversation(BaseStructure):
self.save_as_json(filename)
def import_conversation(self, filename: str):
"""Import a conversation history from a file."""
"""Import a conversation history from a file.
Args:
filename (str): Filename to import from.
"""
if self.backend_instance:
try:
return self.backend_instance.import_conversation(filename)
@ -694,22 +689,34 @@ class Conversation(BaseStructure):
"""Count the number of messages by role.
Returns:
dict: A dictionary mapping roles to message counts.
dict: A dictionary with counts of messages by role.
"""
# Check backend instance first
if self.backend_instance:
try:
return self.backend_instance.count_messages_by_role()
except Exception as e:
logger.error(f"Backend count_messages_by_role failed: {e}")
# Fallback to in-memory count
# Fallback to local implementation below
pass
role_counts = {}
# Initialize counts with expected roles
counts = {
"system": 0,
"user": 0,
"assistant": 0,
"function": 0,
}
# Count messages by role
for message in self.conversation_history:
role = message["role"]
role_counts[role] = role_counts.get(role, 0) + 1
return role_counts
if role in counts:
counts[role] += 1
else:
# Handle unexpected roles dynamically
counts[role] = counts.get(role, 0) + 1
return counts
def return_history_as_string(self):
"""Return the conversation history as a string.
@ -753,17 +760,55 @@ class Conversation(BaseStructure):
Args:
filename (str): Filename to save the conversation history.
"""
# Check backend instance first
if self.backend_instance:
try:
return self.backend_instance.save_as_json(filename)
except Exception as e:
logger.error(f"Backend save_as_json failed: {e}")
# Fallback to in-memory save
pass
if filename is not None:
with open(filename, "w") as f:
json.dump(self.conversation_history, f)
# Fallback to local save implementation below
# Don't save if saving is disabled
if not self.save_enabled:
return
save_path = filename or self.save_filepath
if save_path is not None:
try:
# Prepare metadata
metadata = {
"id": self.id,
"name": self.name,
"created_at": datetime.datetime.now().isoformat(),
"system_prompt": self.system_prompt,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
}
# Prepare save data
save_data = {
"metadata": metadata,
"history": self.conversation_history,
}
# Create directory if it doesn't exist
os.makedirs(
os.path.dirname(save_path),
mode=0o755,
exist_ok=True,
)
# Write directly to file
with open(save_path, "w") as f:
json.dump(save_data, f, indent=2)
# Only log explicit saves, not autosaves
if not self.autosave:
logger.info(
f"Successfully saved conversation to {save_path}"
)
except Exception as e:
logger.error(f"Failed to save conversation: {str(e)}")
def load_from_json(self, filename: str):
"""Load the conversation history from a JSON file.
@ -771,17 +816,32 @@ class Conversation(BaseStructure):
Args:
filename (str): Filename to load from.
"""
if self.backend_instance:
if filename is not None and os.path.exists(filename):
try:
return self.backend_instance.load_from_json(filename)
with open(filename) as f:
data = json.load(f)
# Load metadata
metadata = data.get("metadata", {})
self.id = metadata.get("id", self.id)
self.name = metadata.get("name", self.name)
self.system_prompt = metadata.get(
"system_prompt", self.system_prompt
)
self.rules = metadata.get("rules", self.rules)
self.custom_rules_prompt = metadata.get(
"custom_rules_prompt", self.custom_rules_prompt
)
# Load conversation history
self.conversation_history = data.get("history", [])
logger.info(
f"Successfully loaded conversation from {filename}"
)
except Exception as e:
logger.error(f"Backend load_from_json failed: {e}")
# Fallback to in-memory load
pass
if filename is not None:
with open(filename) as f:
self.conversation_history = json.load(f)
logger.error(f"Failed to load conversation: {str(e)}")
raise
def search_keyword_in_conversation(self, keyword: str):
"""Search for a keyword in the conversation history.
@ -865,7 +925,7 @@ class Conversation(BaseStructure):
"""Convert the conversation history to a dictionary.
Returns:
dict: The conversation history as a dictionary.
list: The conversation history as a list of dictionaries.
"""
if self.backend_instance:
try:
@ -1154,12 +1214,52 @@ class Conversation(BaseStructure):
return []
conversations = []
for file in os.listdir(conversations_dir):
if file.endswith(".json"):
conversations.append(
file[:-5]
) # Remove .json extension
return conversations
seen_ids = (
set()
) # Track seen conversation IDs to avoid duplicates
for filename in os.listdir(conv_dir):
if filename.endswith(".json"):
try:
filepath = os.path.join(conv_dir, filename)
with open(filepath) as f:
data = json.load(f)
metadata = data.get("metadata", {})
conv_id = metadata.get("id")
name = metadata.get("name")
created_at = metadata.get("created_at")
# Skip if we've already seen this ID or if required fields are missing
if (
not all([conv_id, name, created_at])
or conv_id in seen_ids
):
continue
seen_ids.add(conv_id)
conversations.append(
{
"id": conv_id,
"name": name,
"created_at": created_at,
"filepath": filepath,
}
)
except json.JSONDecodeError:
logger.warning(
f"Skipping corrupted conversation file: {filename}"
)
continue
except Exception as e:
logger.error(
f"Failed to read conversation {filename}: {str(e)}"
)
continue
# Sort by creation date, newest first
return sorted(
conversations, key=lambda x: x["created_at"], reverse=True
)
def clear_memory(self):
"""Clear the memory of the conversation."""

@ -85,7 +85,6 @@ class RedisConversationTester:
def setup(self):
"""Initialize Redis server and conversation for testing."""
try:
# Try first with external Redis (if available)
logger.info("Trying to connect to external Redis server...")
self.conversation = RedisConversation(
system_prompt="Test System Prompt",
@ -99,7 +98,6 @@ class RedisConversationTester:
except Exception as external_error:
logger.info(f"External Redis connection failed: {external_error}")
logger.info("Trying to start embedded Redis server...")
try:
# Fallback to embedded Redis
self.conversation = RedisConversation(
@ -119,16 +117,8 @@ class RedisConversationTester:
def cleanup(self):
"""Cleanup resources after tests."""
if self.conversation:
try:
# Check if we have an embedded server to stop
if hasattr(self.conversation, 'embedded_server') and self.conversation.embedded_server is not None:
self.conversation.embedded_server.stop()
# Close Redis client if it exists
if hasattr(self.conversation, 'redis_client') and self.conversation.redis_client:
self.conversation.redis_client.close()
except Exception as e:
logger.warning(f"Error during cleanup: {str(e)}")
if self.redis_server:
self.redis_server.stop()
def test_initialization(self):
"""Test basic initialization."""
@ -151,8 +141,6 @@ class RedisConversationTester:
json_content = {"key": "value", "nested": {"data": 123}}
self.conversation.add("system", json_content)
last_message = self.conversation.get_final_message_content()
# Parse the JSON string back to dict for comparison
if isinstance(last_message, str):
try:
parsed_content = json.loads(last_message)
@ -173,29 +161,20 @@ class RedisConversationTester:
initial_count = len(
self.conversation.return_messages_as_list()
)
if initial_count > 0:
self.conversation.delete(0)
new_count = len(self.conversation.return_messages_as_list())
assert (
new_count == initial_count - 1
), "Failed to delete message"
self.conversation.delete(0)
new_count = len(self.conversation.return_messages_as_list())
assert (
new_count == initial_count - 1
), "Failed to delete message"
def test_update(self):
"""Test message update."""
# Add initial message
self.conversation.add("user", "original message")
# Get all messages to find the last message ID
all_messages = self.conversation.return_messages_as_list()
if len(all_messages) > 0:
# Update the last message (index 0 in this case means the first message)
# Note: This test may need adjustment based on how Redis stores messages
self.conversation.update(0, "user", "updated message")
# Get the message directly using query
self.conversation.update(0, "user", "updated message")
updated_message = self.conversation.query(0)
# Since Redis might store content differently, just check that update didn't crash
assert True, "Update method executed successfully"
def test_clear(self):
@ -207,28 +186,14 @@ class RedisConversationTester:
def test_export_import(self):
"""Test export and import functionality."""
try:
self.conversation.add("user", "export test")
self.conversation.export_conversation("test_export.txt")
# Clear conversation
self.conversation.clear()
# Import back
self.conversation.import_conversation("test_export.txt")
messages = self.conversation.return_messages_as_list()
assert (
len(messages) > 0
), "Failed to export/import conversation"
# Cleanup test file
import os
if os.path.exists("test_export.txt"):
os.remove("test_export.txt")
except Exception as e:
logger.warning(f"Export/import test failed: {e}")
# Don't fail the test entirely, just log the warning
assert True, "Export/import test completed with warnings"
self.conversation.add("user", "export test")
self.conversation.export_conversation("test_export.txt")
self.conversation.clear()
self.conversation.import_conversation("test_export.txt")
messages = self.conversation.return_messages_as_list()
assert (
len(messages) > 0
), "Failed to export/import conversation"
def test_json_operations(self):
"""Test JSON operations."""
@ -249,7 +214,6 @@ class RedisConversationTester:
self.conversation.add("user", "token test message")
time.sleep(1) # Wait for async token counting
messages = self.conversation.to_dict()
# Token counting may not be implemented in Redis version, so just check it doesn't crash
assert isinstance(messages, list), "Token counting test completed"
def test_cache_operations(self):
@ -322,7 +286,6 @@ class RedisConversationTester:
def main():
"""Main function to run tests and save results."""
logger.info(f"Starting Redis tests. REDIS_AVAILABLE: {REDIS_AVAILABLE}")
tester = RedisConversationTester()
markdown_results = tester.run_all_tests()
@ -333,8 +296,6 @@ def main():
logger.info("Test results have been saved to redis_test_results.md")
except Exception as e:
logger.error(f"Failed to save test results: {e}")
# Also print results to console
print(markdown_results)

Loading…
Cancel
Save