pull/989/head
王祥宇 2 months ago
parent ac91191c4b
commit c53723bed4

@ -1,8 +1,8 @@
import traceback
import concurrent.futures import concurrent.futures
import datetime import datetime
import json import json
import os import os
import threading
import uuid import uuid
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -15,10 +15,9 @@ from typing import (
) )
import yaml import yaml
import inspect
from swarms.structs.base_structure import BaseStructure
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
from swarms.utils.formatter import formatter
from swarms.utils.litellm_tokenizer import count_tokens from swarms.utils.litellm_tokenizer import count_tokens
if TYPE_CHECKING: if TYPE_CHECKING:
@ -143,7 +142,7 @@ def _create_backend_conversation(backend: str, **kwargs):
raise raise
class Conversation(BaseStructure): class Conversation:
""" """
A class to manage a conversation history, allowing for the addition, deletion, A class to manage a conversation history, allowing for the addition, deletion,
and retrieval of messages, as well as saving and loading the conversation and retrieval of messages, as well as saving and loading the conversation
@ -167,13 +166,12 @@ class Conversation(BaseStructure):
time_enabled (bool): Flag to enable time tracking for messages. time_enabled (bool): Flag to enable time tracking for messages.
autosave (bool): Flag to enable automatic saving of conversation history. autosave (bool): Flag to enable automatic saving of conversation history.
save_filepath (str): File path for saving the conversation history. save_filepath (str): File path for saving the conversation history.
tokenizer (Any): Tokenizer for counting tokens in messages.
context_length (int): Maximum number of tokens allowed in the conversation history. context_length (int): Maximum number of tokens allowed in the conversation history.
rules (str): Rules for the conversation. rules (str): Rules for the conversation.
custom_rules_prompt (str): Custom prompt for rules. custom_rules_prompt (str): Custom prompt for rules.
user (str): The user identifier for messages. user (str): The user identifier for messages.
auto_save (bool): Flag to enable auto-saving of conversation history. auto_save (bool): Flag to enable auto-saving of conversation history.
save_as_yaml (bool): Flag to save conversation history as YAML. save_as_yaml_on (bool): Flag to save conversation history as YAML.
save_as_json_bool (bool): Flag to save conversation history as JSON. save_as_json_bool (bool): Flag to save conversation history as JSON.
token_count (bool): Flag to enable token counting for messages. token_count (bool): Flag to enable token counting for messages.
conversation_history (list): List to store the history of messages. conversation_history (list): List to store the history of messages.
@ -182,7 +180,7 @@ class Conversation(BaseStructure):
def __init__( def __init__(
self, self,
id: str = generate_conversation_id(), id: str = generate_conversation_id(),
name: str = None, name: str = "conversation-test",
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
time_enabled: bool = False, time_enabled: bool = False,
autosave: bool = False, # Changed default to False autosave: bool = False, # Changed default to False
@ -193,7 +191,7 @@ class Conversation(BaseStructure):
rules: str = None, rules: str = None,
custom_rules_prompt: str = None, custom_rules_prompt: str = None,
user: str = "User", user: str = "User",
save_as_yaml: bool = False, save_as_yaml_on: bool = False,
save_as_json_bool: bool = False, save_as_json_bool: bool = False,
token_count: bool = False, token_count: bool = False,
message_id_on: bool = False, message_id_on: bool = False,
@ -201,6 +199,7 @@ class Conversation(BaseStructure):
backend: Optional[str] = None, backend: Optional[str] = None,
supabase_url: Optional[str] = None, supabase_url: Optional[str] = None,
supabase_key: Optional[str] = None, supabase_key: Optional[str] = None,
tokenizer_model_name: str = "gpt-4.1",
redis_host: str = "localhost", redis_host: str = "localhost",
redis_port: int = 6379, redis_port: int = 6379,
redis_db: int = 0, redis_db: int = 0,
@ -212,26 +211,28 @@ class Conversation(BaseStructure):
auto_persist: bool = True, auto_persist: bool = True,
redis_data_dir: Optional[str] = None, redis_data_dir: Optional[str] = None,
conversations_dir: Optional[str] = None, conversations_dir: Optional[str] = None,
export_method: str = "json",
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__()
# Initialize all attributes first # Initialize all attributes first
self.id = id self.id = id
self.name = name or id self.name = name
self.save_filepath = save_filepath
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.time_enabled = time_enabled self.time_enabled = time_enabled
self.autosave = autosave self.autosave = autosave
self.save_enabled = save_enabled self.save_enabled = save_enabled
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
self.tokenizer_model_name = tokenizer_model_name
self.message_id_on = message_id_on self.message_id_on = message_id_on
self.load_filepath = load_filepath self.load_filepath = load_filepath
self.context_length = context_length self.context_length = context_length
self.rules = rules self.rules = rules
self.custom_rules_prompt = custom_rules_prompt self.custom_rules_prompt = custom_rules_prompt
self.user = user self.user = user
self.save_as_yaml = save_as_yaml self.save_as_yaml_on = save_as_yaml_on
self.save_as_json_bool = save_as_json_bool self.save_as_json_bool = save_as_json_bool
self.token_count = token_count self.token_count = token_count
self.provider = provider # Keep for backwards compatibility self.provider = provider # Keep for backwards compatibility
@ -249,23 +250,75 @@ class Conversation(BaseStructure):
self.persist_redis = persist_redis self.persist_redis = persist_redis
self.auto_persist = auto_persist self.auto_persist = auto_persist
self.redis_data_dir = redis_data_dir self.redis_data_dir = redis_data_dir
self.export_method = export_method
if self.name is None:
self.name = id
self.conversation_history = [] self.conversation_history = []
# Handle save filepath self.setup_file_path()
if save_enabled and save_filepath:
self.save_filepath = save_filepath self.backend_setup(backend, provider)
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join( def setup_file_path(self):
conversations_dir, f"{self.id}.json" """Set up the file path for saving the conversation and load existing data if available."""
# Validate export method
if self.export_method not in ["json", "yaml"]:
raise ValueError(
f"Invalid export_method: {self.export_method}. Must be 'json' or 'yaml'"
)
# Set default save filepath if not provided
if not self.save_filepath:
# Ensure extension matches export method
extension = (
".json" if self.export_method == "json" else ".yaml"
)
self.save_filepath = (
f"conversation_{self.name}{extension}"
)
logger.debug(
f"Setting default save filepath to: {self.save_filepath}"
) )
else: else:
self.save_filepath = None # Validate that provided filepath extension matches export method
file_ext = os.path.splitext(self.save_filepath)[1].lower()
expected_ext = (
".json" if self.export_method == "json" else ".yaml"
)
if file_ext != expected_ext:
logger.warning(
f"Save filepath extension ({file_ext}) does not match export_method ({self.export_method}). "
f"Updating filepath extension to match export method."
)
base_name = os.path.splitext(self.save_filepath)[0]
self.save_filepath = f"{base_name}{expected_ext}"
# Support both 'provider' and 'backend' parameters for backwards compatibility self.created_at = datetime.datetime.now().strftime(
# 'backend' takes precedence if both are provided "%Y-%m-%d_%H-%M-%S"
)
self.backend_setup(backend, provider) # Check if file exists and load it
if os.path.exists(self.save_filepath):
logger.debug(
f"Found existing conversation file at: {self.save_filepath}"
)
try:
self.load(self.save_filepath)
logger.info(
f"Loaded existing conversation from {self.save_filepath}"
)
except Exception as e:
logger.error(
f"Failed to load existing conversation from {self.save_filepath}: {str(e)}"
)
# Keep the empty conversation_history initialized in __init__
else:
logger.debug(
f"No existing conversation file found at: {self.save_filepath}"
)
def backend_setup( def backend_setup(
self, backend: str = None, provider: str = None self, backend: str = None, provider: str = None
@ -341,7 +394,7 @@ class Conversation(BaseStructure):
"rules": self.rules, "rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt, "custom_rules_prompt": self.custom_rules_prompt,
"user": self.user, "user": self.user,
"save_as_yaml": self.save_as_yaml, "save_as_yaml_on": self.save_as_yaml_on,
"save_as_json_bool": self.save_as_json_bool, "save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count, "token_count": self.token_count,
} }
@ -466,13 +519,7 @@ class Conversation(BaseStructure):
def _autosave(self): def _autosave(self):
"""Automatically save the conversation if autosave is enabled.""" """Automatically save the conversation if autosave is enabled."""
if self.autosave and self.save_filepath: return self.export()
try:
self.save_as_json(self.save_filepath)
except Exception as e:
logger.error(
f"Failed to autosave conversation: {str(e)}"
)
def mem0_provider(self): def mem0_provider(self):
try: try:
@ -503,6 +550,7 @@ class Conversation(BaseStructure):
Args: Args:
role (str): The role of the speaker (e.g., 'User', 'System'). role (str): The role of the speaker (e.g., 'User', 'System').
content (Union[str, dict, list]): The content of the message to be added. content (Union[str, dict, list]): The content of the message to be added.
category (Optional[str]): Optional category for the message.
""" """
# Base message with role and timestamp # Base message with role and timestamp
message = { message = {
@ -522,20 +570,18 @@ class Conversation(BaseStructure):
# Add message to conversation history # Add message to conversation history
self.conversation_history.append(message) self.conversation_history.append(message)
# Handle token counting in a separate thread if enabled
if self.token_count is True: if self.token_count is True:
self._count_tokens(content, message) tokens = count_tokens(
text=any_to_str(content),
model=self.tokenizer_model_name,
)
message["token_count"] = tokens
# Autosave after adding message, but only if saving is enabled return message
if self.autosave and self.save_enabled and self.save_filepath:
try:
self.save_as_json(self.save_filepath)
except Exception as e:
logger.error(
f"Failed to autosave conversation: {str(e)}"
)
def export_and_count_categories( def export_and_count_categories(
self, tokenizer_model_name: Optional[str] = "gpt-4.1-mini" self,
) -> Dict[str, int]: ) -> Dict[str, int]:
"""Export all messages with category 'input' and 'output' and count their tokens. """Export all messages with category 'input' and 'output' and count their tokens.
@ -580,12 +626,16 @@ class Conversation(BaseStructure):
# Count tokens only if there is text # Count tokens only if there is text
input_tokens = ( input_tokens = (
count_tokens(all_input_text, tokenizer_model_name) count_tokens(
all_input_text, self.tokenizer_model_name
)
if all_input_text.strip() if all_input_text.strip()
else 0 else 0
) )
output_tokens = ( output_tokens = (
count_tokens(all_output_text, tokenizer_model_name) count_tokens(
all_output_text, self.tokenizer_model_name
)
if all_output_text.strip() if all_output_text.strip()
else 0 else 0
) )
@ -637,56 +687,57 @@ class Conversation(BaseStructure):
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
category: Optional[str] = None, category: Optional[str] = None,
): ):
"""Add a message to the conversation history.""" """Add a message to the conversation history.
Args:
role (str): The role of the speaker (e.g., 'User', 'System').
content (Union[str, dict, list]): The content of the message to be added.
metadata (Optional[dict]): Optional metadata for the message.
category (Optional[str]): Optional category for the message.
"""
result = None
# If using a persistent backend, delegate to it # If using a persistent backend, delegate to it
if self.backend_instance: if self.backend_instance:
try: try:
return self.backend_instance.add( result = self.backend_instance.add(
role=role, content=content, metadata=metadata role=role, content=content, metadata=metadata
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Backend add failed: {e}. Falling back to in-memory." f"Backend add failed: {e}. Falling back to in-memory."
) )
return self.add_in_memory(role, content) result = self.add_in_memory(
role=role, content=content, category=category
)
elif self.provider == "in-memory": elif self.provider == "in-memory":
return self.add_in_memory( result = self.add_in_memory(
role=role, content=content, category=category role=role, content=content, category=category
) )
elif self.provider == "mem0": elif self.provider == "mem0":
return self.add_mem0( result = self.add_mem0(
role=role, content=content, metadata=metadata role=role, content=content, metadata=metadata
) )
else: else:
raise ValueError(f"Invalid provider: {self.provider}") raise ValueError(
f"Error: Conversation: {self.name} Invalid provider: {self.provider} Traceback: {traceback.format_exc()}"
)
# Ensure autosave happens after the message is added
if self.autosave:
self._autosave()
return result
def add_multiple_messages( def add_multiple_messages(
self, roles: List[str], contents: List[Union[str, dict, list]] self, roles: List[str], contents: List[Union[str, dict, list]]
): ):
return self.add_multiple(roles, contents) added = self.add_multiple(roles, contents)
def _count_tokens(self, content: str, message: dict):
# If token counting is enabled, do it in a separate thread
if self.token_count is True:
# Define a function to count tokens and update the message
def count_tokens_thread():
tokens = count_tokens(any_to_str(content))
# Update the message that's already in the conversation history
message["token_count"] = int(tokens)
# If autosave is enabled, save after token count is updated if self.autosave:
if self.autosave: self._autosave()
self.save_as_json(self.save_filepath)
# Start a new thread for token counting return added
token_thread = threading.Thread(
target=count_tokens_thread
)
token_thread.daemon = (
True # Make thread terminate when main program exits
)
token_thread.start()
def add_multiple( def add_multiple(
self, self,
@ -785,45 +836,6 @@ class Conversation(BaseStructure):
if keyword in str(message["content"]) if keyword in str(message["content"])
] ]
def display_conversation(self, detailed: bool = False):
"""Display the conversation history.
Args:
detailed (bool, optional): Flag to display detailed information. Defaults to False.
"""
if self.backend_instance:
try:
return self.backend_instance.display_conversation(
detailed
)
except Exception as e:
logger.error(f"Backend display failed: {e}")
# Fallback to in-memory display
pass
# In-memory display implementation with proper formatting
for message in self.conversation_history:
content = message.get("content", "")
role = message.get("role", "Unknown")
# Format the message content
if isinstance(content, (dict, list)):
content = json.dumps(content, indent=2)
# Create the display string
display_str = f"{role}: {content}"
# Add details if requested
if detailed:
display_str += f"\nTimestamp: {message.get('timestamp', 'Unknown')}"
display_str += f"\nMessage ID: {message.get('message_id', 'Unknown')}"
if "token_count" in message:
display_str += (
f"\nTokens: {message['token_count']}"
)
formatter.print_panel(display_str)
def export_conversation(self, filename: str, *args, **kwargs): def export_conversation(self, filename: str, *args, **kwargs):
"""Export the conversation history to a file. """Export the conversation history to a file.
@ -844,7 +856,7 @@ class Conversation(BaseStructure):
# In-memory export implementation # In-memory export implementation
# If the filename ends with .json, use save_as_json # If the filename ends with .json, use save_as_json
if filename.endswith(".json"): if filename.endswith(".json"):
self.save_as_json(filename) self.save_as_json(force=True)
else: else:
# Simple text export for non-JSON files # Simple text export for non-JSON files
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
@ -946,99 +958,307 @@ class Conversation(BaseStructure):
pass pass
return self.return_history_as_string() return self.return_history_as_string()
def save_as_json(self, filename: str = None): def to_dict(self) -> Dict[str, Any]:
"""Save the conversation history as a JSON file. """
Converts all attributes of the class into a dictionary, including all __init__ parameters
and conversation history. Automatically extracts parameters from __init__ signature.
Returns:
Dict[str, Any]: A dictionary containing:
- metadata: All initialization parameters and their current values
- conversation_history: The list of conversation messages
"""
# Get all parameters from __init__ signature
init_signature = inspect.signature(self.__class__.__init__)
init_params = [
param
for param in init_signature.parameters
if param not in ["self", "args", "kwargs"]
]
# Build metadata dictionary from init parameters
metadata = {}
for param in init_params:
# Get the current value of the parameter from instance
value = getattr(self, param, None)
# Special handling for certain types
if value is not None:
if isinstance(
value, (str, int, float, bool, list, dict)
):
metadata[param] = value
elif hasattr(value, "to_dict"):
metadata[param] = value.to_dict()
else:
try:
# Try to convert to string if not directly serializable
metadata[param] = str(value)
except:
# Skip if we can't serialize
continue
# Add created_at if it exists
if hasattr(self, "created_at"):
metadata["created_at"] = self.created_at
return {
"metadata": metadata,
"conversation_history": self.conversation_history,
}
def save_as_json(self, force: bool = True):
"""Save the conversation history and metadata to a JSON file.
Args: Args:
filename (str): Filename to save the conversation history. force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
""" """
# Check backend instance first try:
if self.backend_instance: # Check if saving is allowed
try: if not self.autosave and not force:
return self.backend_instance.save_as_json(filename) logger.warning(
except Exception as e: "Autosave is disabled. To save anyway, call save_as_json(force=True) "
logger.error(f"Backend save_as_json failed: {e}") "or enable autosave by setting autosave=True when creating the Conversation."
# Fallback to local save implementation below )
return
# Don't save if saving is disabled # Don't save if saving is disabled (你的PR代码)
if not self.save_enabled: if not self.save_enabled:
logger.warning( logger.warning(
"An attempt to save the conversation failed: save_enabled is False." "An attempt to save the conversation failed: save_enabled is False."
"Please set save_enabled=True when creating a Conversation object to enable saving." "Please set save_enabled=True when creating a Conversation object to enable saving."
)
return
# Get the full data including metadata and conversation history
data = self.get_init_params()
# Ensure we have a valid save path
if not self.save_filepath:
self.save_filepath = os.path.join(
self.conversations_dir or os.getcwd(),
f"conversation_{self.name}.json",
)
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Save with proper formatting
with open(self.save_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, default=str)
logger.info(f"Conversation saved to {self.save_filepath}")
except Exception as e:
logger.error(
f"Failed to save conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
) )
return raise # Re-raise to ensure the error is visible to the caller
def get_init_params(self):
data = {
"metadata": {
"id": self.id,
"name": self.name,
"system_prompt": self.system_prompt,
"time_enabled": self.time_enabled,
"autosave": self.autosave,
"save_filepath": self.save_filepath,
"load_filepath": self.load_filepath,
"context_length": self.context_length,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
"user": self.user,
"save_as_yaml_on": self.save_as_yaml_on,
"save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count,
"message_id_on": self.message_id_on,
"provider": self.provider,
"backend": self.backend,
"tokenizer_model_name": self.tokenizer_model_name,
"conversations_dir": self.conversations_dir,
"export_method": self.export_method,
"created_at": self.created_at,
},
"conversation_history": self.conversation_history,
}
save_path = filename or self.save_filepath return data
if save_path is not None:
try:
# Prepare metadata
metadata = {
"id": self.id,
"name": self.name,
"created_at": datetime.datetime.now().isoformat(),
"system_prompt": self.system_prompt,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
}
# Prepare save data def save_as_yaml(self, force: bool = True):
save_data = { """Save the conversation history and metadata to a YAML file.
"metadata": metadata,
"history": self.conversation_history,
}
# Create directory if it doesn't exist Args:
os.makedirs( force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
os.path.dirname(save_path), """
mode=0o755, try:
exist_ok=True, # Check if saving is allowed
if not self.autosave and not force:
logger.warning(
"Autosave is disabled. To save anyway, call save_as_yaml(force=True) "
"or enable autosave by setting autosave=True when creating the Conversation."
)
return
# Get the full data including metadata and conversation history
data = self.get_init_params()
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Save with proper formatting
with open(self.save_filepath, "w", encoding="utf-8") as f:
yaml.dump(
data,
f,
indent=4,
default_flow_style=False,
sort_keys=False,
)
logger.info(
f"Conversation saved to {self.save_filepath}"
) )
# Write directly to file except Exception as e:
with open(save_path, "w") as f: logger.error(
json.dump(save_data, f, indent=2) f"Failed to save conversation to {self.save_filepath}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise # Re-raise the exception to handle it in the calling method
# Only log explicit saves, not autosaves def export(self, force: bool = True):
if not self.autosave: """Export the conversation to a file based on the export method.
logger.info(
f"Successfully saved conversation to {save_path}" Args:
) force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
except Exception as e: """
logger.error(f"Failed to save conversation: {str(e)}") try:
# Validate export method
if self.export_method not in ["json", "yaml"]:
raise ValueError(
f"Invalid export_method: {self.export_method}. Must be 'json' or 'yaml'"
)
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Ensure filepath extension matches export method
file_ext = os.path.splitext(self.save_filepath)[1].lower()
expected_ext = (
".json" if self.export_method == "json" else ".yaml"
)
if file_ext != expected_ext:
base_name = os.path.splitext(self.save_filepath)[0]
self.save_filepath = f"{base_name}{expected_ext}"
logger.warning(
f"Updated save filepath to match export method: {self.save_filepath}"
)
if self.export_method == "json":
self.save_as_json(force=force)
elif self.export_method == "yaml":
self.save_as_yaml(force=force)
except Exception as e:
logger.error(
f"Failed to export conversation to {self.save_filepath}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise # Re-raise to ensure the error is visible
def load_from_json(self, filename: str): def load_from_json(self, filename: str):
"""Load the conversation history from a JSON file. """Load the conversation history and metadata from a JSON file.
Args: Args:
filename (str): Filename to load from. filename (str): Filename to load from.
""" """
if filename is not None and os.path.exists(filename): if filename is not None and os.path.exists(filename):
try: try:
with open(filename) as f: with open(filename, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Load metadata # Load metadata
metadata = data.get("metadata", {}) metadata = data.get("metadata", {})
self.id = metadata.get("id", self.id) # Update all metadata attributes
self.name = metadata.get("name", self.name) for key, value in metadata.items():
self.system_prompt = metadata.get( if hasattr(self, key):
"system_prompt", self.system_prompt setattr(self, key, value)
# Load conversation history
self.conversation_history = data.get(
"conversation_history", []
) )
self.rules = metadata.get("rules", self.rules)
self.custom_rules_prompt = metadata.get( logger.info(
"custom_rules_prompt", self.custom_rules_prompt f"Successfully loaded conversation from {filename}"
) )
except Exception as e:
logger.error(
f"Failed to load conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise
def load_from_yaml(self, filename: str):
"""Load the conversation history and metadata from a YAML file.
Args:
filename (str): Filename to load from.
"""
if filename is not None and os.path.exists(filename):
try:
with open(filename, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
# Load metadata
metadata = data.get("metadata", {})
# Update all metadata attributes
for key, value in metadata.items():
if hasattr(self, key):
setattr(self, key, value)
# Load conversation history # Load conversation history
self.conversation_history = data.get("history", []) self.conversation_history = data.get(
"conversation_history", []
)
logger.info( logger.info(
f"Successfully loaded conversation from {filename}" f"Successfully loaded conversation from {filename}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load conversation: {str(e)}") logger.error(
f"Failed to load conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise raise
def load(self, filename: str):
"""Load the conversation history and metadata from a file.
Automatically detects the file format based on extension.
Args:
filename (str): Filename to load from.
"""
if filename is None or not os.path.exists(filename):
logger.warning(f"File not found: {filename}")
return
file_ext = os.path.splitext(filename)[1].lower()
try:
if file_ext == ".json":
self.load_from_json(filename)
elif file_ext == ".yaml" or file_ext == ".yml":
self.load_from_yaml(filename)
else:
raise ValueError(
f"Unsupported file format: {file_ext}. Must be .json, .yaml, or .yml"
)
except Exception as e:
logger.error(
f"Failed to load conversation from {filename}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise
def search_keyword_in_conversation(self, keyword: str): def search_keyword_in_conversation(self, keyword: str):
"""Search for a keyword in the conversation history. """Search for a keyword in the conversation history.
@ -1067,7 +1287,7 @@ class Conversation(BaseStructure):
for message in self.conversation_history: for message in self.conversation_history:
role = message.get("role") role = message.get("role")
content = message.get("content") content = message.get("content")
tokens = count_tokens(content) tokens = count_tokens(content, self.tokenizer_model_name)
count = tokens # Assign the token count count = tokens # Assign the token count
total_tokens += count total_tokens += count
@ -1130,21 +1350,6 @@ class Conversation(BaseStructure):
pass pass
return self.conversation_history return self.conversation_history
def to_yaml(self):
"""Convert the conversation history to a YAML string.
Returns:
str: The conversation history as a YAML string.
"""
if self.backend_instance:
try:
return self.backend_instance.to_yaml()
except Exception as e:
logger.error(f"Backend to_yaml failed: {e}")
# Fallback to in-memory implementation
pass
return yaml.dump(self.conversation_history)
def get_visible_messages(self, agent: "Agent", turn: int): def get_visible_messages(self, agent: "Agent", turn: int):
""" """
Get the visible messages for a given agent and turn. Get the visible messages for a given agent and turn.
@ -1359,10 +1564,6 @@ class Conversation(BaseStructure):
pass pass
self.conversation_history.extend(messages) self.conversation_history.extend(messages)
def clear_memory(self):
"""Clear the memory of the conversation."""
self.conversation_history = []
@classmethod @classmethod
def load_conversation( def load_conversation(
cls, cls,
@ -1381,35 +1582,33 @@ class Conversation(BaseStructure):
Conversation: The loaded conversation object Conversation: The loaded conversation object
""" """
if load_filepath: if load_filepath:
return cls( conversation = cls(name=name)
name=name, conversation.load(load_filepath)
load_filepath=load_filepath, return conversation
save_enabled=False, # Don't enable saving when loading specific file
)
conv_dir = conversations_dir or get_conversation_dir() conv_dir = conversations_dir or get_conversation_dir()
# Try loading by name first
filepath = os.path.join(conv_dir, f"{name}.json")
# If not found by name, try loading by ID # Try loading by name with different extensions
if not os.path.exists(filepath): for ext in [".json", ".yaml", ".yml"]:
filepath = os.path.join(conv_dir, f"{name}") filepath = os.path.join(conv_dir, f"{name}{ext}")
if not os.path.exists(filepath): if os.path.exists(filepath):
logger.warning( conversation = cls(
f"No conversation found with name or ID: {name}" name=name, conversations_dir=conv_dir
) )
return cls( conversation.load(filepath)
name=name, return conversation
conversations_dir=conv_dir,
save_enabled=True, # If not found by name with extensions, try loading by ID
) filepath = os.path.join(conv_dir, name)
if os.path.exists(filepath):
return cls( conversation = cls(name=name, conversations_dir=conv_dir)
name=name, conversation.load(filepath)
conversations_dir=conv_dir, return conversation
load_filepath=filepath,
save_enabled=True, logger.warning(
f"No conversation found with name or ID: {name}"
) )
return cls(name=name, conversations_dir=conv_dir)
def return_dict_final(self): def return_dict_final(self):
"""Return the final message as a dictionary.""" """Return the final message as a dictionary."""

Loading…
Cancel
Save