You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1660 lines
62 KiB
1660 lines
62 KiB
import datetime
|
|
import json
|
|
import logging
|
|
import threading
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
import yaml
|
|
|
|
from swarms.communication.base_communication import (
|
|
BaseCommunication,
|
|
Message,
|
|
MessageType,
|
|
)
|
|
|
|
# Try to import loguru logger, fallback to standard logging
|
|
try:
|
|
from loguru import logger
|
|
|
|
LOGURU_AVAILABLE = True
|
|
except ImportError:
|
|
LOGURU_AVAILABLE = False
|
|
logger = None
|
|
|
|
|
|
# Custom Exceptions for Supabase Communication
|
|
class SupabaseConnectionError(Exception):
|
|
"""Custom exception for Supabase connection errors."""
|
|
|
|
pass
|
|
|
|
|
|
class SupabaseOperationError(Exception):
|
|
"""Custom exception for Supabase operation errors."""
|
|
|
|
pass
|
|
|
|
|
|
class DateTimeEncoder(json.JSONEncoder):
|
|
"""Custom JSON encoder for handling datetime objects."""
|
|
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.isoformat()
|
|
return super().default(obj)
|
|
|
|
|
|
class SupabaseConversation(BaseCommunication):
|
|
"""
|
|
A Supabase-backed implementation of the BaseCommunication class for managing
|
|
conversation history using a Supabase (PostgreSQL) database.
|
|
|
|
Prerequisites:
|
|
- supabase-py library: pip install supabase
|
|
- Valid Supabase project URL and API key
|
|
- Network access to your Supabase instance
|
|
|
|
Attributes:
|
|
supabase_url (str): URL of the Supabase project.
|
|
supabase_key (str): Anon or service key for the Supabase project.
|
|
client (supabase.Client): The Supabase client instance.
|
|
table_name (str): Name of the table in Supabase to store conversations.
|
|
current_conversation_id (Optional[str]): ID of the currently active conversation.
|
|
tokenizer (Any): Tokenizer for counting tokens in messages.
|
|
context_length (int): Maximum number of tokens for context window.
|
|
time_enabled (bool): Flag to prepend timestamps to messages.
|
|
enable_logging (bool): Flag to enable logging.
|
|
logger (logging.Logger | loguru.Logger): Logger instance.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
supabase_url: str,
|
|
supabase_key: str,
|
|
system_prompt: Optional[str] = None,
|
|
time_enabled: bool = False,
|
|
autosave: bool = False, # Standardized parameter name - less relevant for DB-backed, but kept for interface
|
|
save_filepath: str = None, # Used for export/import
|
|
tokenizer: Any = None,
|
|
context_length: int = 8192,
|
|
rules: str = None,
|
|
custom_rules_prompt: str = None,
|
|
user: str = "User:",
|
|
save_as_yaml: bool = True, # Default export format
|
|
save_as_json_bool: bool = False, # Alternative export format
|
|
token_count: bool = True,
|
|
cache_enabled: bool = True, # Currently for token counting
|
|
table_name: str = "conversations",
|
|
enable_timestamps: bool = True, # DB schema handles this with DEFAULT NOW()
|
|
enable_logging: bool = True,
|
|
use_loguru: bool = True,
|
|
max_retries: int = 3, # For Supabase API calls (not implemented yet, supabase-py might handle)
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
# Lazy load Supabase with auto-installation
|
|
try:
|
|
from supabase import Client, create_client
|
|
|
|
self.supabase_client = Client
|
|
self.create_client = create_client
|
|
self.supabase_available = True
|
|
except ImportError:
|
|
# Auto-install supabase if not available
|
|
print(
|
|
"📦 Supabase not found. Installing automatically..."
|
|
)
|
|
try:
|
|
import subprocess
|
|
import sys
|
|
|
|
# Install supabase
|
|
subprocess.check_call(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"pip",
|
|
"install",
|
|
"supabase",
|
|
]
|
|
)
|
|
print("✅ Supabase installed successfully!")
|
|
|
|
# Try importing again
|
|
from supabase import Client, create_client
|
|
|
|
self.supabase_client = Client
|
|
self.create_client = create_client
|
|
self.supabase_available = True
|
|
print("✅ Supabase loaded successfully!")
|
|
|
|
except Exception as e:
|
|
self.supabase_available = False
|
|
if logger:
|
|
logger.error(
|
|
f"Failed to auto-install Supabase. Please install manually with 'pip install supabase': {e}"
|
|
)
|
|
raise ImportError(
|
|
f"Failed to auto-install Supabase. Please install manually with 'pip install supabase': {e}"
|
|
)
|
|
|
|
# Store initialization parameters - BaseCommunication.__init__ is just pass
|
|
self.system_prompt = system_prompt
|
|
self.time_enabled = time_enabled
|
|
self.autosave = autosave
|
|
self.save_filepath = save_filepath
|
|
self.tokenizer = tokenizer
|
|
self.context_length = context_length
|
|
self.rules = rules
|
|
self.custom_rules_prompt = custom_rules_prompt
|
|
self.user = user
|
|
self.save_as_yaml_on_export = save_as_yaml
|
|
self.save_as_json_on_export = save_as_json_bool
|
|
self.calculate_token_count = token_count
|
|
self.cache_enabled = cache_enabled
|
|
|
|
self.supabase_url = supabase_url
|
|
self.supabase_key = supabase_key
|
|
self.table_name = table_name
|
|
self.enable_timestamps = (
|
|
enable_timestamps # DB handles actual timestamping
|
|
)
|
|
self.enable_logging = enable_logging
|
|
self.use_loguru = use_loguru and LOGURU_AVAILABLE
|
|
self.max_retries = max_retries
|
|
|
|
# Setup logging
|
|
if self.enable_logging:
|
|
if self.use_loguru and logger:
|
|
self.logger = logger
|
|
else:
|
|
self.logger = logging.getLogger(__name__)
|
|
if not self.logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
handler.setFormatter(formatter)
|
|
self.logger.addHandler(handler)
|
|
self.logger.setLevel(logging.INFO)
|
|
else:
|
|
# Create a null logger that does nothing
|
|
self.logger = logging.getLogger(__name__)
|
|
self.logger.addHandler(logging.NullHandler())
|
|
|
|
self.current_conversation_id: Optional[str] = None
|
|
self._lock = (
|
|
threading.Lock()
|
|
) # For thread-safe operations if any (e.g. token calculation)
|
|
|
|
try:
|
|
self.client = self.create_client(
|
|
supabase_url, supabase_key
|
|
)
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Successfully initialized Supabase client for URL: {supabase_url}"
|
|
)
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Failed to initialize Supabase client: {e}"
|
|
)
|
|
raise SupabaseConnectionError(
|
|
f"Failed to connect to Supabase: {e}"
|
|
)
|
|
|
|
self._init_db() # Verifies table existence
|
|
self.start_new_conversation() # Initializes a conversation ID
|
|
|
|
# Add initial prompts if provided
|
|
if self.system_prompt:
|
|
self.add(
|
|
role="system",
|
|
content=self.system_prompt,
|
|
message_type=MessageType.SYSTEM,
|
|
)
|
|
if self.rules:
|
|
# Assuming rules are spoken by the system or user based on context
|
|
self.add(
|
|
role="system",
|
|
content=self.rules,
|
|
message_type=MessageType.SYSTEM,
|
|
)
|
|
if self.custom_rules_prompt:
|
|
self.add(
|
|
role=self.user,
|
|
content=self.custom_rules_prompt,
|
|
message_type=MessageType.USER,
|
|
)
|
|
|
|
def _init_db(self):
|
|
"""
|
|
Initialize the database and create necessary tables.
|
|
Creates the table if it doesn't exist, similar to SQLite implementation.
|
|
"""
|
|
# First, try to create the table if it doesn't exist
|
|
try:
|
|
# Use Supabase RPC to execute raw SQL for table creation
|
|
create_table_sql = f"""
|
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
conversation_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp TIMESTAMPTZ DEFAULT NOW(),
|
|
message_type TEXT,
|
|
metadata JSONB,
|
|
token_count INTEGER,
|
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
"""
|
|
|
|
# Try to create index as well
|
|
|
|
# Attempt to create table using RPC function
|
|
# Note: This requires a stored procedure to be created in Supabase
|
|
# If RPC is not available, we'll fall back to checking if table exists
|
|
try:
|
|
# Try using a custom RPC function if available
|
|
self.client.rpc(
|
|
"exec_sql", {"sql": create_table_sql}
|
|
).execute()
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Successfully created or verified table '{self.table_name}' using RPC."
|
|
)
|
|
except Exception as rpc_error:
|
|
if self.enable_logging:
|
|
self.logger.debug(
|
|
f"RPC table creation failed (expected if no custom function): {rpc_error}"
|
|
)
|
|
|
|
# Fallback: Try to verify table exists, if not provide helpful error
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("id")
|
|
.limit(1)
|
|
.execute()
|
|
)
|
|
if (
|
|
response.error
|
|
and "does not exist"
|
|
in str(response.error).lower()
|
|
):
|
|
# Table doesn't exist, try alternative creation method
|
|
self._create_table_fallback()
|
|
elif response.error:
|
|
raise SupabaseOperationError(
|
|
f"Error accessing table: {response.error.message}"
|
|
)
|
|
else:
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Successfully verified existing table '{self.table_name}'."
|
|
)
|
|
except Exception as table_check_error:
|
|
if (
|
|
"does not exist"
|
|
in str(table_check_error).lower()
|
|
or "relation"
|
|
in str(table_check_error).lower()
|
|
):
|
|
# Table definitely doesn't exist, provide creation instructions
|
|
self._handle_missing_table()
|
|
else:
|
|
raise SupabaseOperationError(
|
|
f"Failed to access or create table: {table_check_error}"
|
|
)
|
|
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Database initialization failed: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Failed to initialize database: {e}"
|
|
)
|
|
|
|
def _create_table_fallback(self):
|
|
"""
|
|
Fallback method to create table when RPC is not available.
|
|
Attempts to use Supabase's admin API or provides clear instructions.
|
|
"""
|
|
try:
|
|
# Try using the admin API if available (requires service role key)
|
|
# This might work if the user is using a service role key
|
|
admin_sql = f"""
|
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
conversation_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp TIMESTAMPTZ DEFAULT NOW(),
|
|
message_type TEXT,
|
|
metadata JSONB,
|
|
token_count INTEGER,
|
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id
|
|
ON {self.table_name} (conversation_id);
|
|
"""
|
|
|
|
# Note: This might not work with all Supabase configurations
|
|
# but we attempt it anyway
|
|
if hasattr(self.client, "postgrest") and hasattr(
|
|
self.client.postgrest, "rpc"
|
|
):
|
|
self.client.postgrest.rpc(
|
|
"exec_sql", {"query": admin_sql}
|
|
).execute()
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Successfully created table '{self.table_name}' using admin API."
|
|
)
|
|
return
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.debug(
|
|
f"Admin API table creation failed: {e}"
|
|
)
|
|
|
|
# If all else fails, call the missing table handler
|
|
self._handle_missing_table()
|
|
|
|
def _handle_missing_table(self):
|
|
"""
|
|
Handle the case where the table doesn't exist and can't be created automatically.
|
|
Provides clear instructions for manual table creation.
|
|
"""
|
|
table_creation_sql = f"""
|
|
-- Run this SQL in your Supabase SQL Editor to create the required table:
|
|
|
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
conversation_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp TIMESTAMPTZ DEFAULT NOW(),
|
|
message_type TEXT,
|
|
metadata JSONB,
|
|
token_count INTEGER,
|
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
|
|
-- Create index for better query performance:
|
|
CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id
|
|
ON {self.table_name} (conversation_id);
|
|
|
|
-- Optional: Enable Row Level Security (RLS) for production:
|
|
ALTER TABLE {self.table_name} ENABLE ROW LEVEL SECURITY;
|
|
|
|
-- Optional: Create RLS policy (customize according to your needs):
|
|
CREATE POLICY "Users can manage their own conversations" ON {self.table_name}
|
|
FOR ALL USING (true); -- Adjust this policy based on your security requirements
|
|
"""
|
|
|
|
error_msg = (
|
|
f"Table '{self.table_name}' does not exist in your Supabase database and cannot be created automatically. "
|
|
f"Please create it manually by running the following SQL in your Supabase SQL Editor:\n\n{table_creation_sql}\n\n"
|
|
f"Alternatively, you can create a custom RPC function in Supabase to enable automatic table creation. "
|
|
f"Visit your Supabase dashboard > SQL Editor and create this function:\n\n"
|
|
f"CREATE OR REPLACE FUNCTION exec_sql(sql TEXT)\n"
|
|
f"RETURNS TEXT AS $$\n"
|
|
f"BEGIN\n"
|
|
f" EXECUTE sql;\n"
|
|
f" RETURN 'SUCCESS';\n"
|
|
f"END;\n"
|
|
f"$$ LANGUAGE plpgsql SECURITY DEFINER;\n\n"
|
|
f"After creating either the table or the RPC function, retry initializing the SupabaseConversation."
|
|
)
|
|
|
|
if self.enable_logging:
|
|
self.logger.error(error_msg)
|
|
raise SupabaseOperationError(error_msg)
|
|
|
|
def _handle_api_response(
|
|
self, response, operation_name: str = "Supabase operation"
|
|
):
|
|
"""Handles Supabase API response, checking for errors and returning data."""
|
|
# The new supabase-py client structure: response has .data and .count attributes
|
|
# Errors are raised as exceptions rather than being in response.error
|
|
try:
|
|
if hasattr(response, "data"):
|
|
# Return the data, which could be None, a list, or a dict
|
|
return response.data
|
|
else:
|
|
# Fallback for older response structures or direct data
|
|
return response
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(f"{operation_name} failed: {e}")
|
|
raise SupabaseOperationError(
|
|
f"{operation_name} failed: {e}"
|
|
)
|
|
|
|
def _serialize_content(
|
|
self, content: Union[str, dict, list]
|
|
) -> str:
|
|
"""Serializes content to JSON string if it's a dict or list."""
|
|
if isinstance(content, (dict, list)):
|
|
return json.dumps(content, cls=DateTimeEncoder)
|
|
return str(content)
|
|
|
|
def _deserialize_content(
|
|
self, content_str: str
|
|
) -> Union[str, dict, list]:
|
|
"""Deserializes content from JSON string if it looks like JSON. More robust approach."""
|
|
if not content_str:
|
|
return content_str
|
|
|
|
# Always try to parse as JSON first, fall back to string
|
|
try:
|
|
return json.loads(content_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
# Not valid JSON, return as string
|
|
return content_str
|
|
|
|
def _serialize_metadata(
|
|
self, metadata: Optional[Dict]
|
|
) -> Optional[str]:
|
|
"""Serializes metadata dict to JSON string using simplified encoder."""
|
|
if metadata is None:
|
|
return None
|
|
try:
|
|
return json.dumps(
|
|
metadata, default=str, ensure_ascii=False
|
|
)
|
|
except (TypeError, ValueError) as e:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"Failed to serialize metadata: {e}"
|
|
)
|
|
return None
|
|
|
|
def _deserialize_metadata(
|
|
self, metadata_str: Optional[str]
|
|
) -> Optional[Dict]:
|
|
"""Deserializes metadata from JSON string with better error handling."""
|
|
if metadata_str is None:
|
|
return None
|
|
try:
|
|
return json.loads(metadata_str)
|
|
except (json.JSONDecodeError, TypeError) as e:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"Failed to deserialize metadata: {metadata_str[:50]}... Error: {e}"
|
|
)
|
|
return None
|
|
|
|
def _generate_conversation_id(self) -> str:
|
|
"""Generate a unique conversation ID using UUID and timestamp."""
|
|
timestamp = datetime.datetime.now(
|
|
datetime.timezone.utc
|
|
).strftime("%Y%m%d_%H%M%S_%f")
|
|
unique_id = str(uuid.uuid4())[:8]
|
|
return f"conv_{timestamp}_{unique_id}"
|
|
|
|
def start_new_conversation(self) -> str:
|
|
"""Starts a new conversation and returns its ID."""
|
|
self.current_conversation_id = (
|
|
self._generate_conversation_id()
|
|
)
|
|
self.logger.info(
|
|
f"Started new conversation with ID: {self.current_conversation_id}"
|
|
)
|
|
return self.current_conversation_id
|
|
|
|
def add(
|
|
self,
|
|
role: str,
|
|
content: Union[str, dict, list],
|
|
message_type: Optional[MessageType] = None,
|
|
metadata: Optional[Dict] = None,
|
|
token_count: Optional[int] = None,
|
|
) -> int:
|
|
"""Add a message to the current conversation history in Supabase."""
|
|
if self.current_conversation_id is None:
|
|
self.start_new_conversation()
|
|
|
|
serialized_content = self._serialize_content(content)
|
|
current_timestamp_iso = datetime.datetime.now(
|
|
datetime.timezone.utc
|
|
).isoformat()
|
|
|
|
message_data = {
|
|
"conversation_id": self.current_conversation_id,
|
|
"role": role,
|
|
"content": serialized_content,
|
|
"timestamp": current_timestamp_iso, # Supabase will use its default if not provided / column allows NULL
|
|
"message_type": (
|
|
message_type.value if message_type else None
|
|
),
|
|
"metadata": self._serialize_metadata(metadata),
|
|
# token_count handled below
|
|
}
|
|
|
|
# Calculate token_count if enabled and not provided
|
|
if (
|
|
self.calculate_token_count
|
|
and token_count is None
|
|
and self.tokenizer
|
|
):
|
|
try:
|
|
# For now, do this synchronously. For long content, consider async/threading.
|
|
message_data["token_count"] = (
|
|
self.tokenizer.count_tokens(str(content))
|
|
)
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"Failed to count tokens for content: {e}"
|
|
)
|
|
elif token_count is not None:
|
|
message_data["token_count"] = token_count
|
|
|
|
# Filter out None values to let Supabase handle defaults or NULLs appropriately
|
|
message_to_insert = {
|
|
k: v for k, v in message_data.items() if v is not None
|
|
}
|
|
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.insert(message_to_insert)
|
|
.execute()
|
|
)
|
|
data = self._handle_api_response(response, "add_message")
|
|
if data and len(data) > 0 and "id" in data[0]:
|
|
inserted_id = data[0]["id"]
|
|
if self.enable_logging:
|
|
self.logger.debug(
|
|
f"Added message with ID {inserted_id} to conversation {self.current_conversation_id}"
|
|
)
|
|
return inserted_id
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Failed to retrieve ID for inserted message in conversation {self.current_conversation_id}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
"Failed to retrieve ID for inserted message."
|
|
)
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Error adding message to Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(f"Error adding message: {e}")
|
|
|
|
def batch_add(self, messages: List[Message]) -> List[int]:
|
|
"""Add multiple messages to the current conversation history in Supabase."""
|
|
if self.current_conversation_id is None:
|
|
self.start_new_conversation()
|
|
|
|
messages_to_insert = []
|
|
for msg_obj in messages:
|
|
serialized_content = self._serialize_content(
|
|
msg_obj.content
|
|
)
|
|
current_timestamp_iso = (
|
|
msg_obj.timestamp
|
|
or datetime.datetime.now(
|
|
datetime.timezone.utc
|
|
).isoformat()
|
|
)
|
|
|
|
msg_data = {
|
|
"conversation_id": self.current_conversation_id,
|
|
"role": msg_obj.role,
|
|
"content": serialized_content,
|
|
"timestamp": current_timestamp_iso,
|
|
"message_type": (
|
|
msg_obj.message_type.value
|
|
if msg_obj.message_type
|
|
else None
|
|
),
|
|
"metadata": self._serialize_metadata(
|
|
msg_obj.metadata
|
|
),
|
|
}
|
|
|
|
# Token count
|
|
current_token_count = msg_obj.token_count
|
|
if (
|
|
self.calculate_token_count
|
|
and current_token_count is None
|
|
and self.tokenizer
|
|
):
|
|
try:
|
|
current_token_count = self.tokenizer.count_tokens(
|
|
str(msg_obj.content)
|
|
)
|
|
except Exception as e:
|
|
self.logger.warning(
|
|
f"Failed to count tokens for batch message: {e}"
|
|
)
|
|
if current_token_count is not None:
|
|
msg_data["token_count"] = current_token_count
|
|
|
|
messages_to_insert.append(
|
|
{k: v for k, v in msg_data.items() if v is not None}
|
|
)
|
|
|
|
if not messages_to_insert:
|
|
return []
|
|
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.insert(messages_to_insert)
|
|
.execute()
|
|
)
|
|
data = self._handle_api_response(
|
|
response, "batch_add_messages"
|
|
)
|
|
inserted_ids = [
|
|
item["id"] for item in data if "id" in item
|
|
]
|
|
if len(inserted_ids) != len(messages_to_insert):
|
|
self.logger.warning(
|
|
"Mismatch in expected and inserted message counts during batch_add."
|
|
)
|
|
self.logger.debug(
|
|
f"Batch added {len(inserted_ids)} messages to conversation {self.current_conversation_id}"
|
|
)
|
|
return inserted_ids
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error batch adding messages to Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error batch adding messages: {e}"
|
|
)
|
|
|
|
def _format_row_to_dict(self, row: Dict) -> Dict:
|
|
"""Helper to format a raw row from Supabase to our standard message dict."""
|
|
formatted_message = {
|
|
"id": row.get("id"),
|
|
"role": row.get("role"),
|
|
"content": self._deserialize_content(
|
|
row.get("content", "")
|
|
),
|
|
"timestamp": row.get("timestamp"),
|
|
"message_type": row.get("message_type"),
|
|
"metadata": self._deserialize_metadata(
|
|
row.get("metadata")
|
|
),
|
|
"token_count": row.get("token_count"),
|
|
"conversation_id": row.get("conversation_id"),
|
|
"created_at": row.get("created_at"),
|
|
}
|
|
# Clean None values from the root, but keep them within deserialized content/metadata
|
|
return {
|
|
k: v
|
|
for k, v in formatted_message.items()
|
|
if v is not None
|
|
or k in ["metadata", "token_count", "message_type"]
|
|
}
|
|
|
|
def get_messages(
|
|
self,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
) -> List[Dict]:
|
|
"""Get messages from the current conversation with optional pagination."""
|
|
if self.current_conversation_id is None:
|
|
return []
|
|
try:
|
|
query = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.order("timestamp", desc=False)
|
|
) # Assuming 'timestamp' or 'id' for ordering
|
|
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
if offset is not None:
|
|
query = query.offset(offset)
|
|
|
|
response = query.execute()
|
|
data = self._handle_api_response(response, "get_messages")
|
|
return [self._format_row_to_dict(row) for row in data]
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error getting messages from Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error getting messages: {e}"
|
|
)
|
|
|
|
def get_str(self) -> str:
|
|
"""Get the current conversation history as a formatted string."""
|
|
messages_dict = self.get_messages()
|
|
conv_str = []
|
|
for msg in messages_dict:
|
|
ts_prefix = (
|
|
f"[{msg['timestamp']}] "
|
|
if msg.get("timestamp") and self.time_enabled
|
|
else ""
|
|
)
|
|
# Content might be dict/list if deserialized
|
|
content_display = msg["content"]
|
|
if isinstance(content_display, (dict, list)):
|
|
content_display = json.dumps(
|
|
content_display, indent=2, cls=DateTimeEncoder
|
|
)
|
|
conv_str.append(
|
|
f"{ts_prefix}{msg['role']}: {content_display}"
|
|
)
|
|
return "\n".join(conv_str)
|
|
|
|
def display_conversation(self, detailed: bool = False):
|
|
"""Display the conversation history."""
|
|
# `detailed` flag might be used for more verbose printing if needed
|
|
print(self.get_str())
|
|
|
|
def delete(self, index: str):
|
|
"""Delete a message from the conversation history by its primary key 'id'."""
|
|
if self.current_conversation_id is None:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
"Cannot delete message: No current conversation."
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Handle both string and int message IDs
|
|
try:
|
|
message_id = int(index)
|
|
except ValueError:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Invalid message ID for delete: {index}. Must be an integer."
|
|
)
|
|
raise ValueError(
|
|
f"Invalid message ID for delete: {index}. Must be an integer."
|
|
)
|
|
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.delete()
|
|
.eq("id", message_id)
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.execute()
|
|
)
|
|
self._handle_api_response(
|
|
response, f"delete_message (id: {message_id})"
|
|
)
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Deleted message with ID {message_id} from conversation {self.current_conversation_id}"
|
|
)
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Error deleting message ID {index} from Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error deleting message ID {index}: {e}"
|
|
)
|
|
|
|
def update(
|
|
self, index: str, role: str, content: Union[str, dict]
|
|
):
|
|
"""Update a message in the conversation history. Matches BaseCommunication signature exactly."""
|
|
# Use the flexible internal method
|
|
return self._update_flexible(
|
|
index=index, role=role, content=content
|
|
)
|
|
|
|
def _update_flexible(
|
|
self,
|
|
index: Union[str, int],
|
|
role: Optional[str] = None,
|
|
content: Optional[Union[str, dict]] = None,
|
|
metadata: Optional[Dict] = None,
|
|
) -> bool:
|
|
"""Internal flexible update method. Returns True if successful, False otherwise."""
|
|
if self.current_conversation_id is None:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
"Cannot update message: No current conversation."
|
|
)
|
|
return False
|
|
|
|
# Handle both string and int message IDs
|
|
try:
|
|
if isinstance(index, str):
|
|
message_id = int(index)
|
|
else:
|
|
message_id = index
|
|
except ValueError:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Invalid message ID for update: {index}. Must be an integer."
|
|
)
|
|
return False
|
|
|
|
update_data = {}
|
|
if role is not None:
|
|
update_data["role"] = role
|
|
if content is not None:
|
|
update_data["content"] = self._serialize_content(content)
|
|
if self.calculate_token_count and self.tokenizer:
|
|
try:
|
|
update_data["token_count"] = (
|
|
self.tokenizer.count_tokens(str(content))
|
|
)
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"Failed to count tokens for updated content: {e}"
|
|
)
|
|
if (
|
|
metadata is not None
|
|
): # Allows setting metadata to null by passing {} then serializing
|
|
update_data["metadata"] = self._serialize_metadata(
|
|
metadata
|
|
)
|
|
|
|
if not update_data:
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
"No fields provided to update for message."
|
|
)
|
|
return False
|
|
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.update(update_data)
|
|
.eq("id", message_id)
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.execute()
|
|
)
|
|
|
|
data = self._handle_api_response(
|
|
response, f"update_message (id: {message_id})"
|
|
)
|
|
|
|
# Check if any rows were actually updated
|
|
if data and len(data) > 0:
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Updated message with ID {message_id} in conversation {self.current_conversation_id}"
|
|
)
|
|
return True
|
|
else:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"No message found with ID {message_id} in conversation {self.current_conversation_id}"
|
|
)
|
|
return False
|
|
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Error updating message ID {message_id} in Supabase: {e}"
|
|
)
|
|
return False
|
|
|
|
def query(self, index: str) -> Dict:
|
|
"""Query a message in the conversation history by its primary key 'id'. Returns empty dict if not found to match BaseCommunication signature."""
|
|
if self.current_conversation_id is None:
|
|
return {}
|
|
try:
|
|
# Handle both string and int message IDs
|
|
try:
|
|
message_id = int(index)
|
|
except ValueError:
|
|
if self.enable_logging:
|
|
self.logger.warning(
|
|
f"Invalid message ID for query: {index}. Must be an integer."
|
|
)
|
|
return {}
|
|
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("id", message_id)
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.maybe_single()
|
|
.execute()
|
|
) # maybe_single returns one record or None
|
|
|
|
data = self._handle_api_response(
|
|
response, f"query_message (id: {message_id})"
|
|
)
|
|
if data:
|
|
return self._format_row_to_dict(data)
|
|
return {}
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Error querying message ID {index} from Supabase: {e}"
|
|
)
|
|
return {}
|
|
|
|
def query_optional(self, index: str) -> Optional[Dict]:
|
|
"""Query a message and return None if not found. More precise return type."""
|
|
result = self.query(index)
|
|
return result if result else None
|
|
|
|
def search(self, keyword: str) -> List[Dict]:
|
|
"""Search for messages containing a keyword in their content."""
|
|
if self.current_conversation_id is None:
|
|
return []
|
|
try:
|
|
# PostgREST ilike is case-insensitive
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.ilike("content", f"%{keyword}%")
|
|
.order("timestamp", desc=False)
|
|
.execute()
|
|
)
|
|
data = self._handle_api_response(
|
|
response, f"search_messages (keyword: {keyword})"
|
|
)
|
|
return [self._format_row_to_dict(row) for row in data]
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error searching messages in Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error searching messages: {e}"
|
|
)
|
|
|
|
def _export_to_file(self, filename: str, format_type: str):
|
|
"""Helper to export conversation to JSON or YAML file."""
|
|
if self.current_conversation_id is None:
|
|
self.logger.warning("No current conversation to export.")
|
|
return
|
|
|
|
data_to_export = (
|
|
self.to_dict()
|
|
) # Gets messages for current_conversation_id
|
|
try:
|
|
with open(filename, "w") as f:
|
|
if format_type == "json":
|
|
json.dump(
|
|
data_to_export,
|
|
f,
|
|
indent=2,
|
|
cls=DateTimeEncoder,
|
|
)
|
|
elif format_type == "yaml":
|
|
yaml.dump(data_to_export, f, sort_keys=False)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported export format: {format_type}"
|
|
)
|
|
self.logger.info(
|
|
f"Conversation {self.current_conversation_id} exported to {filename} as {format_type}."
|
|
)
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Failed to export conversation to {format_type}: {e}"
|
|
)
|
|
raise
|
|
|
|
def export_conversation(self, filename: str):
|
|
"""Export the current conversation history to a file (JSON or YAML based on init flags)."""
|
|
if self.save_as_json_on_export:
|
|
self._export_to_file(filename, "json")
|
|
elif self.save_as_yaml_on_export: # Default if json is false
|
|
self._export_to_file(filename, "yaml")
|
|
else: # Fallback if somehow both are false
|
|
self._export_to_file(filename, "yaml")
|
|
|
|
def _import_from_file(self, filename: str, format_type: str):
|
|
"""Helper to import conversation from JSON or YAML file."""
|
|
try:
|
|
with open(filename, "r") as f:
|
|
if format_type == "json":
|
|
imported_data = json.load(f)
|
|
elif format_type == "yaml":
|
|
imported_data = yaml.safe_load(f)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported import format: {format_type}"
|
|
)
|
|
|
|
if not isinstance(imported_data, list):
|
|
raise ValueError(
|
|
"Imported data must be a list of messages."
|
|
)
|
|
|
|
# Start a new conversation for the imported data
|
|
self.start_new_conversation()
|
|
|
|
messages_to_batch = []
|
|
for msg_data in imported_data:
|
|
# Adapt to Message dataclass structure if possible
|
|
role = msg_data.get("role")
|
|
content = msg_data.get("content")
|
|
if role is None or content is None:
|
|
self.logger.warning(
|
|
f"Skipping message due to missing role/content: {msg_data}"
|
|
)
|
|
continue
|
|
|
|
messages_to_batch.append(
|
|
Message(
|
|
role=role,
|
|
content=content,
|
|
timestamp=msg_data.get(
|
|
"timestamp"
|
|
), # Will be handled by batch_add
|
|
message_type=(
|
|
MessageType(msg_data["message_type"])
|
|
if msg_data.get("message_type")
|
|
else None
|
|
),
|
|
metadata=msg_data.get("metadata"),
|
|
token_count=msg_data.get("token_count"),
|
|
)
|
|
)
|
|
|
|
if messages_to_batch:
|
|
self.batch_add(messages_to_batch)
|
|
self.logger.info(
|
|
f"Conversation imported from {filename} ({format_type}) into new ID {self.current_conversation_id}."
|
|
)
|
|
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Failed to import conversation from {format_type}: {e}"
|
|
)
|
|
raise
|
|
|
|
def import_conversation(self, filename: str):
|
|
"""Import a conversation history from a file (tries JSON then YAML)."""
|
|
try:
|
|
if filename.lower().endswith(".json"):
|
|
self._import_from_file(filename, "json")
|
|
elif filename.lower().endswith((".yaml", ".yml")):
|
|
self._import_from_file(filename, "yaml")
|
|
else:
|
|
# Try JSON first, then YAML as a fallback
|
|
try:
|
|
self._import_from_file(filename, "json")
|
|
except (
|
|
json.JSONDecodeError,
|
|
ValueError,
|
|
): # ValueError if not list
|
|
self.logger.info(
|
|
f"Failed to import {filename} as JSON, trying YAML."
|
|
)
|
|
self._import_from_file(filename, "yaml")
|
|
except Exception as e: # Catch errors from _import_from_file
|
|
raise SupabaseOperationError(
|
|
f"Could not import {filename}: {e}"
|
|
)
|
|
|
|
def count_messages_by_role(self) -> Dict[str, int]:
|
|
"""Count messages by role for the current conversation."""
|
|
if self.current_conversation_id is None:
|
|
return {}
|
|
try:
|
|
# Supabase rpc might be better for direct count, but select + python count is also fine
|
|
# For direct DB count: self.client.rpc('count_roles', {'conv_id': self.current_conversation_id}).execute()
|
|
messages = (
|
|
self.get_messages()
|
|
) # Fetches for current_conversation_id
|
|
counts = {}
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
counts[role] = counts.get(role, 0) + 1
|
|
return counts
|
|
except Exception as e:
|
|
self.logger.error(f"Error counting messages by role: {e}")
|
|
raise SupabaseOperationError(
|
|
f"Error counting messages by role: {e}"
|
|
)
|
|
|
|
def return_history_as_string(self) -> str:
|
|
"""Return the conversation history as a string."""
|
|
return self.get_str()
|
|
|
|
def clear(self):
|
|
"""Clear the current conversation history from Supabase."""
|
|
if self.current_conversation_id is None:
|
|
self.logger.info("No current conversation to clear.")
|
|
return
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.delete()
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.execute()
|
|
)
|
|
# response.data will be a list of deleted items.
|
|
# response.count might be available for delete operations in some supabase-py versions or configurations.
|
|
# For now, we assume success if no error.
|
|
self._handle_api_response(
|
|
response,
|
|
f"clear_conversation (id: {self.current_conversation_id})",
|
|
)
|
|
self.logger.info(
|
|
f"Cleared conversation with ID: {self.current_conversation_id}"
|
|
)
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error clearing conversation {self.current_conversation_id} from Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error clearing conversation: {e}"
|
|
)
|
|
|
|
def to_dict(self) -> List[Dict]:
|
|
"""Convert the current conversation history to a list of dictionaries."""
|
|
return (
|
|
self.get_messages()
|
|
) # Already fetches for current_conversation_id
|
|
|
|
def to_json(self) -> str:
|
|
"""Convert the current conversation history to a JSON string."""
|
|
return json.dumps(
|
|
self.to_dict(), indent=2, cls=DateTimeEncoder
|
|
)
|
|
|
|
def to_yaml(self) -> str:
|
|
"""Convert the current conversation history to a YAML string."""
|
|
return yaml.dump(self.to_dict(), sort_keys=False)
|
|
|
|
def save_as_json(self, filename: str):
|
|
"""Save the current conversation history as a JSON file."""
|
|
self._export_to_file(filename, "json")
|
|
|
|
def load_from_json(self, filename: str):
|
|
"""Load a conversation history from a JSON file into a new conversation."""
|
|
self._import_from_file(filename, "json")
|
|
|
|
def save_as_yaml(self, filename: str):
|
|
"""Save the current conversation history as a YAML file."""
|
|
self._export_to_file(filename, "yaml")
|
|
|
|
def load_from_yaml(self, filename: str):
|
|
"""Load a conversation history from a YAML file into a new conversation."""
|
|
self._import_from_file(filename, "yaml")
|
|
|
|
def get_last_message(self) -> Optional[Dict]:
|
|
"""Get the last message from the current conversation history."""
|
|
if self.current_conversation_id is None:
|
|
return None
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.order("timestamp", desc=True)
|
|
.limit(1)
|
|
.maybe_single()
|
|
.execute()
|
|
)
|
|
data = self._handle_api_response(
|
|
response, "get_last_message"
|
|
)
|
|
return self._format_row_to_dict(data) if data else None
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error getting last message from Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error getting last message: {e}"
|
|
)
|
|
|
|
def get_last_message_as_string(self) -> str:
|
|
"""Get the last message as a formatted string."""
|
|
last_msg = self.get_last_message()
|
|
if not last_msg:
|
|
return ""
|
|
ts_prefix = (
|
|
f"[{last_msg['timestamp']}] "
|
|
if last_msg.get("timestamp") and self.time_enabled
|
|
else ""
|
|
)
|
|
content_display = last_msg["content"]
|
|
if isinstance(content_display, (dict, list)):
|
|
content_display = json.dumps(
|
|
content_display, cls=DateTimeEncoder
|
|
)
|
|
return f"{ts_prefix}{last_msg['role']}: {content_display}"
|
|
|
|
def get_messages_by_role(self, role: str) -> List[Dict]:
|
|
"""Get all messages from a specific role in the current conversation."""
|
|
if self.current_conversation_id is None:
|
|
return []
|
|
try:
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.eq("role", role)
|
|
.order("timestamp", desc=False)
|
|
.execute()
|
|
)
|
|
data = self._handle_api_response(
|
|
response, f"get_messages_by_role (role: {role})"
|
|
)
|
|
return [self._format_row_to_dict(row) for row in data]
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error getting messages by role '{role}' from Supabase: {e}"
|
|
)
|
|
raise SupabaseOperationError(
|
|
f"Error getting messages by role '{role}': {e}"
|
|
)
|
|
|
|
def get_conversation_summary(self) -> Dict:
|
|
"""Get a summary of the current conversation."""
|
|
if self.current_conversation_id is None:
|
|
return {"error": "No current conversation."}
|
|
|
|
# This could be optimized with an RPC call in Supabase for better performance
|
|
# Example RPC: CREATE OR REPLACE FUNCTION get_conversation_summary(conv_id TEXT) ...
|
|
messages = self.get_messages()
|
|
if not messages:
|
|
return {
|
|
"conversation_id": self.current_conversation_id,
|
|
"total_messages": 0,
|
|
"unique_roles": 0,
|
|
"first_message_time": None,
|
|
"last_message_time": None,
|
|
"total_tokens": 0,
|
|
"roles": {},
|
|
}
|
|
|
|
roles_counts = {}
|
|
total_tokens_sum = 0
|
|
for msg in messages:
|
|
roles_counts[msg["role"]] = (
|
|
roles_counts.get(msg["role"], 0) + 1
|
|
)
|
|
if msg.get("token_count") is not None:
|
|
total_tokens_sum += int(msg["token_count"])
|
|
|
|
return {
|
|
"conversation_id": self.current_conversation_id,
|
|
"total_messages": len(messages),
|
|
"unique_roles": len(roles_counts),
|
|
"first_message_time": messages[0].get("timestamp"),
|
|
"last_message_time": messages[-1].get("timestamp"),
|
|
"total_tokens": total_tokens_sum,
|
|
"roles": roles_counts,
|
|
}
|
|
|
|
def get_statistics(self) -> Dict:
|
|
"""Get statistics about the current conversation (alias for get_conversation_summary)."""
|
|
return self.get_conversation_summary()
|
|
|
|
def get_conversation_id(self) -> str:
|
|
"""Get the current conversation ID."""
|
|
return self.current_conversation_id or ""
|
|
|
|
def delete_current_conversation(self) -> bool:
|
|
"""Delete the current conversation. Returns True if successful."""
|
|
if self.current_conversation_id:
|
|
self.clear() # clear messages for current_conversation_id
|
|
self.logger.info(
|
|
f"Deleted current conversation: {self.current_conversation_id}"
|
|
)
|
|
self.current_conversation_id = (
|
|
None # No active conversation after deletion
|
|
)
|
|
return True
|
|
self.logger.info("No current conversation to delete.")
|
|
return False
|
|
|
|
def search_messages(self, query: str) -> List[Dict]:
|
|
"""Search for messages containing specific text (alias for search)."""
|
|
return self.search(keyword=query)
|
|
|
|
def get_conversation_metadata_dict(self) -> Dict:
|
|
"""Get detailed metadata about the conversation."""
|
|
# Similar to get_conversation_summary, could be expanded with more DB-side aggregations if needed via RPC.
|
|
# For now, returning the summary.
|
|
if self.current_conversation_id is None:
|
|
return {"error": "No current conversation."}
|
|
summary = self.get_conversation_summary()
|
|
|
|
# Example of additional metadata one might compute client-side or via RPC
|
|
# message_type_distribution, average_tokens_per_message, hourly_message_frequency
|
|
return {
|
|
"conversation_id": self.current_conversation_id,
|
|
"basic_stats": summary,
|
|
# Placeholder for more detailed stats if implemented
|
|
}
|
|
|
|
def get_conversation_timeline_dict(self) -> Dict[str, List[Dict]]:
|
|
"""Get the conversation organized by timestamps (dates as keys)."""
|
|
if self.current_conversation_id is None:
|
|
return {}
|
|
|
|
messages = (
|
|
self.get_messages()
|
|
) # Assumes messages are ordered by timestamp
|
|
timeline_dict = {}
|
|
for msg in messages:
|
|
try:
|
|
# Ensure timestamp is a string and valid ISO format
|
|
ts_str = msg.get("timestamp")
|
|
if isinstance(ts_str, str):
|
|
date_key = datetime.datetime.fromisoformat(
|
|
ts_str.replace("Z", "+00:00")
|
|
).strftime("%Y-%m-%d")
|
|
if date_key not in timeline_dict:
|
|
timeline_dict[date_key] = []
|
|
timeline_dict[date_key].append(msg)
|
|
else:
|
|
self.logger.warning(
|
|
f"Message ID {msg.get('id')} has invalid timestamp format: {ts_str}"
|
|
)
|
|
except ValueError as e:
|
|
self.logger.warning(
|
|
f"Could not parse timestamp for message ID {msg.get('id')}: {ts_str}, Error: {e}"
|
|
)
|
|
|
|
return timeline_dict
|
|
|
|
def get_conversation_by_role_dict(self) -> Dict[str, List[Dict]]:
|
|
"""Get the conversation organized by roles."""
|
|
if self.current_conversation_id is None:
|
|
return {}
|
|
|
|
messages = self.get_messages()
|
|
role_dict = {}
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
if role not in role_dict:
|
|
role_dict[role] = []
|
|
role_dict[role].append(msg)
|
|
return role_dict
|
|
|
|
def get_conversation_as_dict(self) -> Dict:
|
|
"""Get the entire current conversation as a dictionary with messages and metadata."""
|
|
if self.current_conversation_id is None:
|
|
return {"error": "No current conversation."}
|
|
|
|
return {
|
|
"conversation_id": self.current_conversation_id,
|
|
"messages": self.get_messages(),
|
|
"metadata": self.get_conversation_summary(), # Using summary as metadata
|
|
}
|
|
|
|
def truncate_memory_with_tokenizer(self):
|
|
"""Truncate the conversation history based on token count if a tokenizer is provided. Optimized for better performance."""
|
|
if not self.tokenizer or self.current_conversation_id is None:
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
"Tokenizer not available or no current conversation, skipping truncation."
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Fetch messages with only necessary fields for efficiency
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.select("id, content, token_count")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.order("timestamp", desc=False)
|
|
.execute()
|
|
)
|
|
|
|
messages = self._handle_api_response(
|
|
response, "fetch_messages_for_truncation"
|
|
)
|
|
if not messages:
|
|
return
|
|
|
|
# Calculate tokens and determine which messages to delete
|
|
total_tokens = 0
|
|
message_tokens = []
|
|
|
|
for msg in messages:
|
|
token_count = msg.get("token_count")
|
|
if token_count is None and self.calculate_token_count:
|
|
# Recalculate if missing
|
|
content = self._deserialize_content(
|
|
msg.get("content", "")
|
|
)
|
|
token_count = self.tokenizer.count_tokens(
|
|
str(content)
|
|
)
|
|
|
|
message_tokens.append(
|
|
{"id": msg["id"], "tokens": token_count or 0}
|
|
)
|
|
total_tokens += token_count or 0
|
|
|
|
tokens_to_remove = total_tokens - self.context_length
|
|
if tokens_to_remove <= 0:
|
|
return # No truncation needed
|
|
|
|
# Collect IDs to delete (oldest first)
|
|
ids_to_delete = []
|
|
for msg_info in message_tokens:
|
|
if tokens_to_remove <= 0:
|
|
break
|
|
ids_to_delete.append(msg_info["id"])
|
|
tokens_to_remove -= msg_info["tokens"]
|
|
|
|
if not ids_to_delete:
|
|
return
|
|
|
|
# Batch delete for better performance
|
|
if len(ids_to_delete) == 1:
|
|
# Single delete
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.delete()
|
|
.eq("id", ids_to_delete[0])
|
|
.eq(
|
|
"conversation_id",
|
|
self.current_conversation_id,
|
|
)
|
|
.execute()
|
|
)
|
|
else:
|
|
# Batch delete using 'in' operator
|
|
response = (
|
|
self.client.table(self.table_name)
|
|
.delete()
|
|
.in_("id", ids_to_delete)
|
|
.eq(
|
|
"conversation_id",
|
|
self.current_conversation_id,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
self._handle_api_response(
|
|
response, "truncate_conversation_batch_delete"
|
|
)
|
|
|
|
if self.enable_logging:
|
|
self.logger.info(
|
|
f"Truncated conversation {self.current_conversation_id}, removed {len(ids_to_delete)} oldest messages."
|
|
)
|
|
|
|
except Exception as e:
|
|
if self.enable_logging:
|
|
self.logger.error(
|
|
f"Error during memory truncation for conversation {self.current_conversation_id}: {e}"
|
|
)
|
|
# Don't re-raise, truncation is best-effort
|
|
|
|
# Methods from duckdb_wrap.py that seem generally useful and can be adapted
|
|
def get_visible_messages(
|
|
self,
|
|
agent: Optional[Callable] = None,
|
|
turn: Optional[int] = None,
|
|
) -> List[Dict]:
|
|
"""
|
|
Get visible messages, optionally filtered by agent visibility and turn.
|
|
Assumes 'metadata' field can contain 'visible_to' (list of agent names or 'all')
|
|
and 'turn' (integer).
|
|
"""
|
|
if self.current_conversation_id is None:
|
|
return []
|
|
|
|
# Base query
|
|
query = (
|
|
self.client.table(self.table_name)
|
|
.select("*")
|
|
.eq("conversation_id", self.current_conversation_id)
|
|
.order("timestamp", desc=False)
|
|
)
|
|
|
|
# Execute and then filter in Python, as JSONB querying for array containment or
|
|
# numeric comparison within JSON can be complex with supabase-py's fluent API.
|
|
# For complex filtering, an RPC function in Supabase would be more efficient.
|
|
|
|
try:
|
|
response = query.execute()
|
|
all_messages = self._handle_api_response(
|
|
response, "get_visible_messages_fetch_all"
|
|
)
|
|
except Exception as e:
|
|
self.logger.error(
|
|
f"Error fetching messages for visibility check: {e}"
|
|
)
|
|
return []
|
|
|
|
visible_messages = []
|
|
for row_data in all_messages:
|
|
msg = self._format_row_to_dict(row_data)
|
|
metadata = (
|
|
msg.get("metadata")
|
|
if isinstance(msg.get("metadata"), dict)
|
|
else {}
|
|
)
|
|
|
|
# Turn filtering
|
|
if turn is not None:
|
|
msg_turn = metadata.get("turn")
|
|
if not (
|
|
isinstance(msg_turn, int) and msg_turn < turn
|
|
):
|
|
continue # Skip if turn condition not met
|
|
|
|
# Agent visibility filtering
|
|
if agent is not None:
|
|
visible_to = metadata.get("visible_to")
|
|
agent_name_attr = getattr(
|
|
agent, "agent_name", None
|
|
) # Safely get agent_name
|
|
if (
|
|
agent_name_attr is None
|
|
): # If agent has no name, assume it can't see restricted msgs
|
|
if visible_to is not None and visible_to != "all":
|
|
continue
|
|
elif (
|
|
isinstance(visible_to, list)
|
|
and agent_name_attr not in visible_to
|
|
):
|
|
continue # Skip if agent not in visible_to list
|
|
elif (
|
|
isinstance(visible_to, str)
|
|
and visible_to != "all"
|
|
):
|
|
# If visible_to is a string but not "all", and doesn't match agent_name
|
|
if visible_to != agent_name_attr:
|
|
continue
|
|
|
|
visible_messages.append(msg)
|
|
return visible_messages
|
|
|
|
def return_messages_as_list(self) -> List[str]:
|
|
"""Return the conversation messages as a list of formatted strings 'role: content'."""
|
|
messages_dict = self.get_messages()
|
|
return [
|
|
f"{msg.get('role', 'unknown')}: {self._serialize_content(msg.get('content', ''))}"
|
|
for msg in messages_dict
|
|
]
|
|
|
|
def return_messages_as_dictionary(self) -> List[Dict]:
|
|
"""Return the conversation messages as a list of dictionaries [{role: R, content: C}]."""
|
|
messages_dict = self.get_messages()
|
|
return [
|
|
{
|
|
"role": msg.get("role"),
|
|
"content": msg.get("content"),
|
|
} # Content already deserialized by _format_row_to_dict
|
|
for msg in messages_dict
|
|
]
|
|
|
|
def add_tool_output_to_agent(
|
|
self, role: str, tool_output: dict
|
|
): # role is usually "tool"
|
|
"""Add a tool output to the conversation history."""
|
|
# Assuming tool_output is a dict that should be stored as content
|
|
self.add(
|
|
role=role,
|
|
content=tool_output,
|
|
message_type=MessageType.TOOL,
|
|
)
|
|
|
|
def get_final_message(self) -> Optional[str]:
|
|
"""Return the final message from the conversation history as 'role: content' string."""
|
|
last_msg = self.get_last_message()
|
|
if not last_msg:
|
|
return None
|
|
content_display = last_msg["content"]
|
|
if isinstance(content_display, (dict, list)):
|
|
content_display = json.dumps(
|
|
content_display, cls=DateTimeEncoder
|
|
)
|
|
return f"{last_msg.get('role', 'unknown')}: {content_display}"
|
|
|
|
def get_final_message_content(
|
|
self,
|
|
) -> Union[str, dict, list, None]:
|
|
"""Return the content of the final message from the conversation history."""
|
|
last_msg = self.get_last_message()
|
|
return last_msg.get("content") if last_msg else None
|
|
|
|
def return_all_except_first(self) -> List[Dict]:
|
|
"""Return all messages except the first one."""
|
|
# The limit=-1, offset=2 from duckdb_wrap is specific to its ID generation.
|
|
# For Supabase, we fetch all and skip the first one in Python.
|
|
all_messages = self.get_messages()
|
|
return all_messages[1:] if len(all_messages) > 1 else []
|
|
|
|
def return_all_except_first_string(self) -> str:
|
|
"""Return all messages except the first one as a concatenated string."""
|
|
messages_to_format = self.return_all_except_first()
|
|
conv_str = []
|
|
for msg in messages_to_format:
|
|
ts_prefix = (
|
|
f"[{msg['timestamp']}] "
|
|
if msg.get("timestamp") and self.time_enabled
|
|
else ""
|
|
)
|
|
content_display = msg["content"]
|
|
if isinstance(content_display, (dict, list)):
|
|
content_display = json.dumps(
|
|
content_display, indent=2, cls=DateTimeEncoder
|
|
)
|
|
conv_str.append(
|
|
f"{ts_prefix}{msg['role']}: {content_display}"
|
|
)
|
|
return "\n".join(conv_str)
|
|
|
|
def update_message(
|
|
self,
|
|
message_id: int,
|
|
content: Union[str, dict, list],
|
|
metadata: Optional[Dict] = None,
|
|
) -> bool:
|
|
"""Update an existing message. Matches BaseCommunication.update_message signature exactly."""
|
|
# Use the flexible internal method
|
|
return self._update_flexible(
|
|
index=message_id, content=content, metadata=metadata
|
|
)
|