pull/989/head
王祥宇 2 months ago
parent ac91191c4b
commit c53723bed4

@ -1,8 +1,8 @@
import traceback
import concurrent.futures
import datetime
import json
import os
import threading
import uuid
from typing import (
TYPE_CHECKING,
@ -15,10 +15,9 @@ from typing import (
)
import yaml
import inspect
from swarms.structs.base_structure import BaseStructure
from swarms.utils.any_to_str import any_to_str
from swarms.utils.formatter import formatter
from swarms.utils.litellm_tokenizer import count_tokens
if TYPE_CHECKING:
@ -143,7 +142,7 @@ def _create_backend_conversation(backend: str, **kwargs):
raise
class Conversation(BaseStructure):
class Conversation:
"""
A class to manage a conversation history, allowing for the addition, deletion,
and retrieval of messages, as well as saving and loading the conversation
@ -167,13 +166,12 @@ class Conversation(BaseStructure):
time_enabled (bool): Flag to enable time tracking for messages.
autosave (bool): Flag to enable automatic saving of conversation history.
save_filepath (str): File path for saving the conversation history.
tokenizer (Any): Tokenizer for counting tokens in messages.
context_length (int): Maximum number of tokens allowed in the conversation history.
rules (str): Rules for the conversation.
custom_rules_prompt (str): Custom prompt for rules.
user (str): The user identifier for messages.
auto_save (bool): Flag to enable auto-saving of conversation history.
save_as_yaml (bool): Flag to save conversation history as YAML.
save_as_yaml_on (bool): Flag to save conversation history as YAML.
save_as_json_bool (bool): Flag to save conversation history as JSON.
token_count (bool): Flag to enable token counting for messages.
conversation_history (list): List to store the history of messages.
@ -182,7 +180,7 @@ class Conversation(BaseStructure):
def __init__(
self,
id: str = generate_conversation_id(),
name: str = None,
name: str = "conversation-test",
system_prompt: Optional[str] = None,
time_enabled: bool = False,
autosave: bool = False, # Changed default to False
@ -193,7 +191,7 @@ class Conversation(BaseStructure):
rules: str = None,
custom_rules_prompt: str = None,
user: str = "User",
save_as_yaml: bool = False,
save_as_yaml_on: bool = False,
save_as_json_bool: bool = False,
token_count: bool = False,
message_id_on: bool = False,
@ -201,6 +199,7 @@ class Conversation(BaseStructure):
backend: Optional[str] = None,
supabase_url: Optional[str] = None,
supabase_key: Optional[str] = None,
tokenizer_model_name: str = "gpt-4.1",
redis_host: str = "localhost",
redis_port: int = 6379,
redis_db: int = 0,
@ -212,26 +211,28 @@ class Conversation(BaseStructure):
auto_persist: bool = True,
redis_data_dir: Optional[str] = None,
conversations_dir: Optional[str] = None,
export_method: str = "json",
*args,
**kwargs,
):
super().__init__()
# Initialize all attributes first
self.id = id
self.name = name or id
self.name = name
self.save_filepath = save_filepath
self.system_prompt = system_prompt
self.time_enabled = time_enabled
self.autosave = autosave
self.save_enabled = save_enabled
self.conversations_dir = conversations_dir
self.tokenizer_model_name = tokenizer_model_name
self.message_id_on = message_id_on
self.load_filepath = load_filepath
self.context_length = context_length
self.rules = rules
self.custom_rules_prompt = custom_rules_prompt
self.user = user
self.save_as_yaml = save_as_yaml
self.save_as_yaml_on = save_as_yaml_on
self.save_as_json_bool = save_as_json_bool
self.token_count = token_count
self.provider = provider # Keep for backwards compatibility
@ -249,23 +250,75 @@ class Conversation(BaseStructure):
self.persist_redis = persist_redis
self.auto_persist = auto_persist
self.redis_data_dir = redis_data_dir
self.export_method = export_method
if self.name is None:
self.name = id
self.conversation_history = []
# Handle save filepath
if save_enabled and save_filepath:
self.save_filepath = save_filepath
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join(
conversations_dir, f"{self.id}.json"
self.setup_file_path()
self.backend_setup(backend, provider)
def setup_file_path(self):
"""Set up the file path for saving the conversation and load existing data if available."""
# Validate export method
if self.export_method not in ["json", "yaml"]:
raise ValueError(
f"Invalid export_method: {self.export_method}. Must be 'json' or 'yaml'"
)
# Set default save filepath if not provided
if not self.save_filepath:
# Ensure extension matches export method
extension = (
".json" if self.export_method == "json" else ".yaml"
)
self.save_filepath = (
f"conversation_{self.name}{extension}"
)
logger.debug(
f"Setting default save filepath to: {self.save_filepath}"
)
else:
self.save_filepath = None
# Validate that provided filepath extension matches export method
file_ext = os.path.splitext(self.save_filepath)[1].lower()
expected_ext = (
".json" if self.export_method == "json" else ".yaml"
)
if file_ext != expected_ext:
logger.warning(
f"Save filepath extension ({file_ext}) does not match export_method ({self.export_method}). "
f"Updating filepath extension to match export method."
)
base_name = os.path.splitext(self.save_filepath)[0]
self.save_filepath = f"{base_name}{expected_ext}"
# Support both 'provider' and 'backend' parameters for backwards compatibility
# 'backend' takes precedence if both are provided
self.created_at = datetime.datetime.now().strftime(
"%Y-%m-%d_%H-%M-%S"
)
self.backend_setup(backend, provider)
# Check if file exists and load it
if os.path.exists(self.save_filepath):
logger.debug(
f"Found existing conversation file at: {self.save_filepath}"
)
try:
self.load(self.save_filepath)
logger.info(
f"Loaded existing conversation from {self.save_filepath}"
)
except Exception as e:
logger.error(
f"Failed to load existing conversation from {self.save_filepath}: {str(e)}"
)
# Keep the empty conversation_history initialized in __init__
else:
logger.debug(
f"No existing conversation file found at: {self.save_filepath}"
)
def backend_setup(
self, backend: str = None, provider: str = None
@ -341,7 +394,7 @@ class Conversation(BaseStructure):
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
"user": self.user,
"save_as_yaml": self.save_as_yaml,
"save_as_yaml_on": self.save_as_yaml_on,
"save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count,
}
@ -466,13 +519,7 @@ class Conversation(BaseStructure):
def _autosave(self):
"""Automatically save the conversation if autosave is enabled."""
if self.autosave and self.save_filepath:
try:
self.save_as_json(self.save_filepath)
except Exception as e:
logger.error(
f"Failed to autosave conversation: {str(e)}"
)
return self.export()
def mem0_provider(self):
try:
@ -503,6 +550,7 @@ class Conversation(BaseStructure):
Args:
role (str): The role of the speaker (e.g., 'User', 'System').
content (Union[str, dict, list]): The content of the message to be added.
category (Optional[str]): Optional category for the message.
"""
# Base message with role and timestamp
message = {
@ -522,20 +570,18 @@ class Conversation(BaseStructure):
# Add message to conversation history
self.conversation_history.append(message)
# Handle token counting in a separate thread if enabled
if self.token_count is True:
self._count_tokens(content, message)
tokens = count_tokens(
text=any_to_str(content),
model=self.tokenizer_model_name,
)
message["token_count"] = tokens
# Autosave after adding message, but only if saving is enabled
if self.autosave and self.save_enabled and self.save_filepath:
try:
self.save_as_json(self.save_filepath)
except Exception as e:
logger.error(
f"Failed to autosave conversation: {str(e)}"
)
return message
def export_and_count_categories(
self, tokenizer_model_name: Optional[str] = "gpt-4.1-mini"
self,
) -> Dict[str, int]:
"""Export all messages with category 'input' and 'output' and count their tokens.
@ -580,12 +626,16 @@ class Conversation(BaseStructure):
# Count tokens only if there is text
input_tokens = (
count_tokens(all_input_text, tokenizer_model_name)
count_tokens(
all_input_text, self.tokenizer_model_name
)
if all_input_text.strip()
else 0
)
output_tokens = (
count_tokens(all_output_text, tokenizer_model_name)
count_tokens(
all_output_text, self.tokenizer_model_name
)
if all_output_text.strip()
else 0
)
@ -637,56 +687,57 @@ class Conversation(BaseStructure):
metadata: Optional[dict] = None,
category: Optional[str] = None,
):
"""Add a message to the conversation history."""
"""Add a message to the conversation history.
Args:
role (str): The role of the speaker (e.g., 'User', 'System').
content (Union[str, dict, list]): The content of the message to be added.
metadata (Optional[dict]): Optional metadata for the message.
category (Optional[str]): Optional category for the message.
"""
result = None
# If using a persistent backend, delegate to it
if self.backend_instance:
try:
return self.backend_instance.add(
result = self.backend_instance.add(
role=role, content=content, metadata=metadata
)
except Exception as e:
logger.error(
f"Backend add failed: {e}. Falling back to in-memory."
)
return self.add_in_memory(role, content)
result = self.add_in_memory(
role=role, content=content, category=category
)
elif self.provider == "in-memory":
return self.add_in_memory(
result = self.add_in_memory(
role=role, content=content, category=category
)
elif self.provider == "mem0":
return self.add_mem0(
result = self.add_mem0(
role=role, content=content, metadata=metadata
)
else:
raise ValueError(f"Invalid provider: {self.provider}")
raise ValueError(
f"Error: Conversation: {self.name} Invalid provider: {self.provider} Traceback: {traceback.format_exc()}"
)
# Ensure autosave happens after the message is added
if self.autosave:
self._autosave()
return result
def add_multiple_messages(
self, roles: List[str], contents: List[Union[str, dict, list]]
):
return self.add_multiple(roles, contents)
def _count_tokens(self, content: str, message: dict):
# If token counting is enabled, do it in a separate thread
if self.token_count is True:
# Define a function to count tokens and update the message
def count_tokens_thread():
tokens = count_tokens(any_to_str(content))
# Update the message that's already in the conversation history
message["token_count"] = int(tokens)
added = self.add_multiple(roles, contents)
# If autosave is enabled, save after token count is updated
if self.autosave:
self.save_as_json(self.save_filepath)
if self.autosave:
self._autosave()
# Start a new thread for token counting
token_thread = threading.Thread(
target=count_tokens_thread
)
token_thread.daemon = (
True # Make thread terminate when main program exits
)
token_thread.start()
return added
def add_multiple(
self,
@ -785,45 +836,6 @@ class Conversation(BaseStructure):
if keyword in str(message["content"])
]
def display_conversation(self, detailed: bool = False):
"""Display the conversation history.
Args:
detailed (bool, optional): Flag to display detailed information. Defaults to False.
"""
if self.backend_instance:
try:
return self.backend_instance.display_conversation(
detailed
)
except Exception as e:
logger.error(f"Backend display failed: {e}")
# Fallback to in-memory display
pass
# In-memory display implementation with proper formatting
for message in self.conversation_history:
content = message.get("content", "")
role = message.get("role", "Unknown")
# Format the message content
if isinstance(content, (dict, list)):
content = json.dumps(content, indent=2)
# Create the display string
display_str = f"{role}: {content}"
# Add details if requested
if detailed:
display_str += f"\nTimestamp: {message.get('timestamp', 'Unknown')}"
display_str += f"\nMessage ID: {message.get('message_id', 'Unknown')}"
if "token_count" in message:
display_str += (
f"\nTokens: {message['token_count']}"
)
formatter.print_panel(display_str)
def export_conversation(self, filename: str, *args, **kwargs):
"""Export the conversation history to a file.
@ -844,7 +856,7 @@ class Conversation(BaseStructure):
# In-memory export implementation
# If the filename ends with .json, use save_as_json
if filename.endswith(".json"):
self.save_as_json(filename)
self.save_as_json(force=True)
else:
# Simple text export for non-JSON files
with open(filename, "w", encoding="utf-8") as f:
@ -946,99 +958,307 @@ class Conversation(BaseStructure):
pass
return self.return_history_as_string()
def save_as_json(self, filename: str = None):
"""Save the conversation history as a JSON file.
def to_dict(self) -> Dict[str, Any]:
"""
Converts all attributes of the class into a dictionary, including all __init__ parameters
and conversation history. Automatically extracts parameters from __init__ signature.
Returns:
Dict[str, Any]: A dictionary containing:
- metadata: All initialization parameters and their current values
- conversation_history: The list of conversation messages
"""
# Get all parameters from __init__ signature
init_signature = inspect.signature(self.__class__.__init__)
init_params = [
param
for param in init_signature.parameters
if param not in ["self", "args", "kwargs"]
]
# Build metadata dictionary from init parameters
metadata = {}
for param in init_params:
# Get the current value of the parameter from instance
value = getattr(self, param, None)
# Special handling for certain types
if value is not None:
if isinstance(
value, (str, int, float, bool, list, dict)
):
metadata[param] = value
elif hasattr(value, "to_dict"):
metadata[param] = value.to_dict()
else:
try:
# Try to convert to string if not directly serializable
metadata[param] = str(value)
except:
# Skip if we can't serialize
continue
# Add created_at if it exists
if hasattr(self, "created_at"):
metadata["created_at"] = self.created_at
return {
"metadata": metadata,
"conversation_history": self.conversation_history,
}
def save_as_json(self, force: bool = True):
"""Save the conversation history and metadata to a JSON file.
Args:
filename (str): Filename to save the conversation history.
force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
"""
# Check backend instance first
if self.backend_instance:
try:
return self.backend_instance.save_as_json(filename)
except Exception as e:
logger.error(f"Backend save_as_json failed: {e}")
# Fallback to local save implementation below
try:
# Check if saving is allowed
if not self.autosave and not force:
logger.warning(
"Autosave is disabled. To save anyway, call save_as_json(force=True) "
"or enable autosave by setting autosave=True when creating the Conversation."
)
return
# Don't save if saving is disabled
if not self.save_enabled:
logger.warning(
"An attempt to save the conversation failed: save_enabled is False."
"Please set save_enabled=True when creating a Conversation object to enable saving."
# Don't save if saving is disabled (你的PR代码)
if not self.save_enabled:
logger.warning(
"An attempt to save the conversation failed: save_enabled is False."
"Please set save_enabled=True when creating a Conversation object to enable saving."
)
return
# Get the full data including metadata and conversation history
data = self.get_init_params()
# Ensure we have a valid save path
if not self.save_filepath:
self.save_filepath = os.path.join(
self.conversations_dir or os.getcwd(),
f"conversation_{self.name}.json",
)
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Save with proper formatting
with open(self.save_filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4, default=str)
logger.info(f"Conversation saved to {self.save_filepath}")
except Exception as e:
logger.error(
f"Failed to save conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
)
return
raise # Re-raise to ensure the error is visible to the caller
def get_init_params(self):
data = {
"metadata": {
"id": self.id,
"name": self.name,
"system_prompt": self.system_prompt,
"time_enabled": self.time_enabled,
"autosave": self.autosave,
"save_filepath": self.save_filepath,
"load_filepath": self.load_filepath,
"context_length": self.context_length,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
"user": self.user,
"save_as_yaml_on": self.save_as_yaml_on,
"save_as_json_bool": self.save_as_json_bool,
"token_count": self.token_count,
"message_id_on": self.message_id_on,
"provider": self.provider,
"backend": self.backend,
"tokenizer_model_name": self.tokenizer_model_name,
"conversations_dir": self.conversations_dir,
"export_method": self.export_method,
"created_at": self.created_at,
},
"conversation_history": self.conversation_history,
}
save_path = filename or self.save_filepath
if save_path is not None:
try:
# Prepare metadata
metadata = {
"id": self.id,
"name": self.name,
"created_at": datetime.datetime.now().isoformat(),
"system_prompt": self.system_prompt,
"rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt,
}
return data
# Prepare save data
save_data = {
"metadata": metadata,
"history": self.conversation_history,
}
def save_as_yaml(self, force: bool = True):
"""Save the conversation history and metadata to a YAML file.
# Create directory if it doesn't exist
os.makedirs(
os.path.dirname(save_path),
mode=0o755,
exist_ok=True,
Args:
force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
"""
try:
# Check if saving is allowed
if not self.autosave and not force:
logger.warning(
"Autosave is disabled. To save anyway, call save_as_yaml(force=True) "
"or enable autosave by setting autosave=True when creating the Conversation."
)
return
# Get the full data including metadata and conversation history
data = self.get_init_params()
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Save with proper formatting
with open(self.save_filepath, "w", encoding="utf-8") as f:
yaml.dump(
data,
f,
indent=4,
default_flow_style=False,
sort_keys=False,
)
logger.info(
f"Conversation saved to {self.save_filepath}"
)
# Write directly to file
with open(save_path, "w") as f:
json.dump(save_data, f, indent=2)
except Exception as e:
logger.error(
f"Failed to save conversation to {self.save_filepath}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise # Re-raise the exception to handle it in the calling method
# Only log explicit saves, not autosaves
if not self.autosave:
logger.info(
f"Successfully saved conversation to {save_path}"
)
except Exception as e:
logger.error(f"Failed to save conversation: {str(e)}")
def export(self, force: bool = True):
"""Export the conversation to a file based on the export method.
Args:
force (bool, optional): If True, saves regardless of autosave setting. Defaults to True.
"""
try:
# Validate export method
if self.export_method not in ["json", "yaml"]:
raise ValueError(
f"Invalid export_method: {self.export_method}. Must be 'json' or 'yaml'"
)
# Create directory if it doesn't exist
save_dir = os.path.dirname(self.save_filepath)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
# Ensure filepath extension matches export method
file_ext = os.path.splitext(self.save_filepath)[1].lower()
expected_ext = (
".json" if self.export_method == "json" else ".yaml"
)
if file_ext != expected_ext:
base_name = os.path.splitext(self.save_filepath)[0]
self.save_filepath = f"{base_name}{expected_ext}"
logger.warning(
f"Updated save filepath to match export method: {self.save_filepath}"
)
if self.export_method == "json":
self.save_as_json(force=force)
elif self.export_method == "yaml":
self.save_as_yaml(force=force)
except Exception as e:
logger.error(
f"Failed to export conversation to {self.save_filepath}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise # Re-raise to ensure the error is visible
def load_from_json(self, filename: str):
"""Load the conversation history from a JSON file.
"""Load the conversation history and metadata from a JSON file.
Args:
filename (str): Filename to load from.
"""
if filename is not None and os.path.exists(filename):
try:
with open(filename) as f:
with open(filename, "r", encoding="utf-8") as f:
data = json.load(f)
# Load metadata
metadata = data.get("metadata", {})
self.id = metadata.get("id", self.id)
self.name = metadata.get("name", self.name)
self.system_prompt = metadata.get(
"system_prompt", self.system_prompt
# Update all metadata attributes
for key, value in metadata.items():
if hasattr(self, key):
setattr(self, key, value)
# Load conversation history
self.conversation_history = data.get(
"conversation_history", []
)
self.rules = metadata.get("rules", self.rules)
self.custom_rules_prompt = metadata.get(
"custom_rules_prompt", self.custom_rules_prompt
logger.info(
f"Successfully loaded conversation from {filename}"
)
except Exception as e:
logger.error(
f"Failed to load conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise
def load_from_yaml(self, filename: str):
"""Load the conversation history and metadata from a YAML file.
Args:
filename (str): Filename to load from.
"""
if filename is not None and os.path.exists(filename):
try:
with open(filename, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
# Load metadata
metadata = data.get("metadata", {})
# Update all metadata attributes
for key, value in metadata.items():
if hasattr(self, key):
setattr(self, key, value)
# Load conversation history
self.conversation_history = data.get("history", [])
self.conversation_history = data.get(
"conversation_history", []
)
logger.info(
f"Successfully loaded conversation from {filename}"
)
except Exception as e:
logger.error(f"Failed to load conversation: {str(e)}")
logger.error(
f"Failed to load conversation: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise
def load(self, filename: str):
"""Load the conversation history and metadata from a file.
Automatically detects the file format based on extension.
Args:
filename (str): Filename to load from.
"""
if filename is None or not os.path.exists(filename):
logger.warning(f"File not found: {filename}")
return
file_ext = os.path.splitext(filename)[1].lower()
try:
if file_ext == ".json":
self.load_from_json(filename)
elif file_ext == ".yaml" or file_ext == ".yml":
self.load_from_yaml(filename)
else:
raise ValueError(
f"Unsupported file format: {file_ext}. Must be .json, .yaml, or .yml"
)
except Exception as e:
logger.error(
f"Failed to load conversation from {filename}: {str(e)}\nTraceback: {traceback.format_exc()}"
)
raise
def search_keyword_in_conversation(self, keyword: str):
"""Search for a keyword in the conversation history.
@ -1067,7 +1287,7 @@ class Conversation(BaseStructure):
for message in self.conversation_history:
role = message.get("role")
content = message.get("content")
tokens = count_tokens(content)
tokens = count_tokens(content, self.tokenizer_model_name)
count = tokens # Assign the token count
total_tokens += count
@ -1130,21 +1350,6 @@ class Conversation(BaseStructure):
pass
return self.conversation_history
def to_yaml(self):
"""Convert the conversation history to a YAML string.
Returns:
str: The conversation history as a YAML string.
"""
if self.backend_instance:
try:
return self.backend_instance.to_yaml()
except Exception as e:
logger.error(f"Backend to_yaml failed: {e}")
# Fallback to in-memory implementation
pass
return yaml.dump(self.conversation_history)
def get_visible_messages(self, agent: "Agent", turn: int):
"""
Get the visible messages for a given agent and turn.
@ -1359,10 +1564,6 @@ class Conversation(BaseStructure):
pass
self.conversation_history.extend(messages)
def clear_memory(self):
"""Clear the memory of the conversation."""
self.conversation_history = []
@classmethod
def load_conversation(
cls,
@ -1381,35 +1582,33 @@ class Conversation(BaseStructure):
Conversation: The loaded conversation object
"""
if load_filepath:
return cls(
name=name,
load_filepath=load_filepath,
save_enabled=False, # Don't enable saving when loading specific file
)
conversation = cls(name=name)
conversation.load(load_filepath)
return conversation
conv_dir = conversations_dir or get_conversation_dir()
# Try loading by name first
filepath = os.path.join(conv_dir, f"{name}.json")
# If not found by name, try loading by ID
if not os.path.exists(filepath):
filepath = os.path.join(conv_dir, f"{name}")
if not os.path.exists(filepath):
logger.warning(
f"No conversation found with name or ID: {name}"
# Try loading by name with different extensions
for ext in [".json", ".yaml", ".yml"]:
filepath = os.path.join(conv_dir, f"{name}{ext}")
if os.path.exists(filepath):
conversation = cls(
name=name, conversations_dir=conv_dir
)
return cls(
name=name,
conversations_dir=conv_dir,
save_enabled=True,
)
return cls(
name=name,
conversations_dir=conv_dir,
load_filepath=filepath,
save_enabled=True,
conversation.load(filepath)
return conversation
# If not found by name with extensions, try loading by ID
filepath = os.path.join(conv_dir, name)
if os.path.exists(filepath):
conversation = cls(name=name, conversations_dir=conv_dir)
conversation.load(filepath)
return conversation
logger.warning(
f"No conversation found with name or ID: {name}"
)
return cls(name=name, conversations_dir=conv_dir)
def return_dict_final(self):
"""Return the final message as a dictionary."""

Loading…
Cancel
Save